promocrypt_core/
damm.rs

1//! Damm check digit algorithm.
2//!
3//! Implements the Damm algorithm for generating and validating check digits.
4//! This algorithm detects all single-digit errors and all adjacent transposition errors.
5
6use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
7use serde::{Deserialize, Serialize};
8
9use crate::alphabet::Alphabet;
10use crate::error::{PromocryptError, Result};
11
12/// Damm quasi-group table for check digit calculation.
13///
14/// The table has the following properties:
15/// - Each row is a permutation of 0..size
16/// - Each column is a permutation of 0..size
17/// - Diagonal is all zeros: `table[i][i] = 0`
18/// - Anti-symmetric: `table[a][b] != table[b][a]` when `a != b`
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct DammTable {
21    table: Vec<Vec<u8>>,
22    size: usize,
23}
24
25impl DammTable {
26    /// Generate a quasi-group table for the given alphabet size.
27    ///
28    /// The generated table guarantees:
29    /// - 100% detection of single-character errors
30    /// - 100% detection of adjacent transposition errors
31    ///
32    /// # Arguments
33    /// * `size` - The alphabet size (10-62)
34    ///
35    /// # Examples
36    ///
37    /// ```
38    /// use promocrypt_core::DammTable;
39    ///
40    /// let table = DammTable::new(26);
41    /// assert_eq!(table.size(), 26);
42    /// ```
43    pub fn new(size: usize) -> Self {
44        assert!((10..=62).contains(&size), "Size must be between 10 and 62");
45
46        let table = Self::generate_quasi_group(size);
47
48        Self { table, size }
49    }
50
51    /// Generate a quasi-group with the required properties.
52    ///
53    /// Uses a construction that guarantees:
54    /// - Each row is a permutation of 0..n
55    /// - Each column is a permutation of 0..n
56    /// - Diagonal is all zeros: table\[i\]\[i\] = 0
57    #[allow(clippy::needless_range_loop)]
58    fn generate_quasi_group(n: usize) -> Vec<Vec<u8>> {
59        let mut table = vec![vec![0u8; n]; n];
60
61        // Use a proven construction for Latin squares with zero diagonal.
62        // Method: Start with standard addition table, then apply a permutation
63        // to make the diagonal all zeros.
64        //
65        // Standard Latin square: table[i][j] = (i + j) mod n
66        // Diagonal values: table[i][i] = 2i mod n
67        //
68        // To get zeros on diagonal, we apply a value permutation.
69        // If we define perm(x) = (x - 2*row_index) mod n, then:
70        // perm(table[i][j]) = (i + j - 2i) mod n = (j - i) mod n
71        // For diagonal: perm(2i mod n) = (i - i) mod n = 0
72
73        // However, applying different permutations per row breaks the Latin property.
74        // Instead, use a global transformation.
75
76        // Correct approach: Use table[i][j] = (i + j) mod n, then swap values
77        // within each row to place 0 on diagonal.
78
79        // Step 1: Create base Latin square
80        for i in 0..n {
81            for j in 0..n {
82                table[i][j] = ((i + j) % n) as u8;
83            }
84        }
85
86        // Step 2: For each row i, swap position of value 0 with diagonal position
87        // In row i, value 0 is at column (n - i) % n
88        // We want value 0 at column i (diagonal)
89        // So swap columns (n-i)%n and i in each row
90
91        // Actually, we need to be more careful. Swapping within rows preserves
92        // row permutation, but breaks column permutation.
93
94        // Better approach: Create a Latin square using a derangement
95        // table[i][j] = (i + j + 1) mod n  -- this gives diagonal = (2i+1) mod n, never 0 for n>1
96        // Then we swap the diagonal with where 0 appears
97
98        // Even better: Use the fact that for a Latin square L, if we apply
99        // a permutation π to rows and σ to columns, we get another Latin square.
100        // We want L'[i][i] = 0 for all i.
101
102        // Simplest correct approach that works:
103        // For row i, create permutation: start at some value, go sequentially,
104        // but ensure 0 is at position i.
105
106        // Let's use: row i has values arranged as:
107        // position j gets value ((j - i + n) % n) if j != i
108        // position i (diagonal) gets 0
109
110        for i in 0..n {
111            for j in 0..n {
112                if i == j {
113                    table[i][j] = 0;
114                } else {
115                    // Value at position j in row i
116                    // We want each row and column to be a permutation
117                    // Use: value = 1 + ((j - i - 1 + n) % (n-1)) for j > i
118                    //      value = 1 + ((j - i + n) % (n-1)) for j < i
119                    // This ensures 0 only on diagonal, and rest are 1..n-1
120
121                    let diff = ((j as isize - i as isize).rem_euclid(n as isize)) as usize;
122                    // diff is in range 1..n-1 for j != i (since we skip j==i case)
123                    table[i][j] = diff as u8;
124                }
125            }
126        }
127
128        // Verify: In row i, for j from 0 to n-1:
129        // - j=i: value is 0
130        // - j≠i: value is (j-i) mod n, which is 1,2,...,n-1 (excluding 0)
131        // So row i has exactly one 0 (at diagonal) and values 1..n-1 at other positions
132        // Wait, this means each row only has values 0..n-1 once? Let's check:
133        // For row 0: j=0->0, j=1->1, j=2->2, ..., j=n-1->n-1 ✓
134        // For row 1: j=0->n-1, j=1->0, j=2->1, j=3->2, ..., j=n-1->n-2 ✓
135
136        // For column j, value at row i is:
137        // - i=j: value is 0
138        // - i≠j: value is (j-i) mod n
139        // For column 0: i=0->0, i=1->n-1, i=2->n-2, ..., i=n-1->1
140        // This is 0, n-1, n-2, ..., 1 which is a permutation of 0..n-1 ✓
141
142        table
143    }
144
145    /// Get the size of the table (alphabet size).
146    #[inline]
147    pub fn size(&self) -> usize {
148        self.size
149    }
150
151    /// Get table value at position.
152    #[inline]
153    pub fn get(&self, row: usize, col: usize) -> u8 {
154        self.table[row][col]
155    }
156
157    /// Calculate check digit for a sequence of characters (check digit at end).
158    ///
159    /// # Arguments
160    /// * `chars` - The characters to calculate check digit for (without existing check digit)
161    /// * `alphabet` - The alphabet used for character→index mapping
162    ///
163    /// # Returns
164    /// The check digit character from the alphabet.
165    pub fn calculate(&self, chars: &[char], alphabet: &Alphabet) -> char {
166        let mut interim = 0usize;
167
168        for &c in chars {
169            if let Some(index) = alphabet.index_of(c) {
170                interim = self.table[interim][index] as usize;
171            }
172        }
173
174        alphabet.char_at(interim)
175    }
176
177    /// Calculate check digit for start position.
178    ///
179    /// Finds a check digit X such that processing [X, chars...] results in 0.
180    ///
181    /// # Arguments
182    /// * `chars` - The characters (check digit will be prepended)
183    /// * `alphabet` - The alphabet used for character→index mapping
184    ///
185    /// # Returns
186    /// The check digit character from the alphabet.
187    pub fn calculate_for_start(&self, chars: &[char], alphabet: &Alphabet) -> char {
188        // We need to find X such that process([X] + chars) = 0
189        // Try each possible check digit
190        for check_idx in 0..self.size {
191            let mut interim = self.table[0][check_idx] as usize;
192
193            for &c in chars {
194                if let Some(index) = alphabet.index_of(c) {
195                    interim = self.table[interim][index] as usize;
196                }
197            }
198
199            if interim == 0 {
200                return alphabet.char_at(check_idx);
201            }
202        }
203
204        // Fallback (should never happen with valid quasi-group)
205        alphabet.char_at(0)
206    }
207
208    /// Calculate check digit for arbitrary position in the code.
209    ///
210    /// Finds a check digit X such that inserting X at `position` in `chars`
211    /// results in a valid code (Damm validation returns 0).
212    ///
213    /// # Arguments
214    /// * `chars` - The characters (check digit will be inserted at position)
215    /// * `alphabet` - The alphabet used for character→index mapping
216    /// * `position` - Position where check digit will be inserted (0-based)
217    ///
218    /// # Returns
219    /// The check digit character from the alphabet.
220    pub fn calculate_for_position(
221        &self,
222        chars: &[char],
223        alphabet: &Alphabet,
224        position: usize,
225    ) -> char {
226        // We need to find X such that process(chars[0..position] + [X] + chars[position..]) = 0
227
228        // First, process chars up to the check digit position
229        let mut interim_before = 0usize;
230        for &c in chars.iter().take(position) {
231            if let Some(index) = alphabet.index_of(c) {
232                interim_before = self.table[interim_before][index] as usize;
233            }
234        }
235
236        // For each possible check digit, calculate what state we'd be in after it
237        // Then process the remaining chars and see if we end up at 0
238        for check_idx in 0..self.size {
239            let mut interim = self.table[interim_before][check_idx] as usize;
240
241            // Process remaining characters
242            for &c in chars.iter().skip(position) {
243                if let Some(index) = alphabet.index_of(c) {
244                    interim = self.table[interim][index] as usize;
245                }
246            }
247
248            if interim == 0 {
249                return alphabet.char_at(check_idx);
250            }
251        }
252
253        // Fallback (should never happen with valid quasi-group)
254        alphabet.char_at(0)
255    }
256
257    /// Calculate check digit for a string.
258    pub fn calculate_str(&self, s: &str, alphabet: &Alphabet) -> char {
259        let chars: Vec<char> = s.chars().collect();
260        self.calculate(&chars, alphabet)
261    }
262
263    /// Validate a code including its check digit.
264    ///
265    /// The check digit is valid if processing the entire code (including check digit)
266    /// results in 0.
267    ///
268    /// # Arguments
269    /// * `code` - The full code including check digit
270    /// * `alphabet` - The alphabet used
271    ///
272    /// # Returns
273    /// `true` if the check digit is valid, `false` otherwise.
274    pub fn validate(&self, code: &str, alphabet: &Alphabet) -> bool {
275        let mut interim = 0usize;
276
277        for c in code.chars() {
278            match alphabet.index_of(c) {
279                Some(index) => {
280                    interim = self.table[interim][index] as usize;
281                }
282                None => return false, // Invalid character
283            }
284        }
285
286        interim == 0
287    }
288
289    /// Serialize the table to bytes for storage.
290    pub fn to_bytes(&self) -> Vec<u8> {
291        let mut bytes = Vec::with_capacity(4 + self.size * self.size);
292
293        // Store size as u32
294        bytes.extend_from_slice(&(self.size as u32).to_le_bytes());
295
296        // Store table data row by row
297        for row in &self.table {
298            bytes.extend_from_slice(row);
299        }
300
301        bytes
302    }
303
304    /// Deserialize table from bytes.
305    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
306        if bytes.len() < 4 {
307            return Err(PromocryptError::InvalidFileFormatDetails(
308                "Damm table too short".to_string(),
309            ));
310        }
311
312        let size = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
313
314        if !(10..=62).contains(&size) {
315            return Err(PromocryptError::InvalidFileFormatDetails(format!(
316                "Invalid Damm table size: {}",
317                size
318            )));
319        }
320
321        let expected_len = 4 + size * size;
322        if bytes.len() < expected_len {
323            return Err(PromocryptError::InvalidFileFormatDetails(format!(
324                "Damm table data too short: expected {}, got {}",
325                expected_len,
326                bytes.len()
327            )));
328        }
329
330        let mut table = Vec::with_capacity(size);
331        let data = &bytes[4..];
332
333        for i in 0..size {
334            let row_start = i * size;
335            let row_end = row_start + size;
336            table.push(data[row_start..row_end].to_vec());
337        }
338
339        Ok(Self { table, size })
340    }
341
342    /// Serialize to base64 string for JSON storage.
343    pub fn to_base64(&self) -> String {
344        BASE64.encode(self.to_bytes())
345    }
346
347    /// Deserialize from base64 string.
348    pub fn from_base64(s: &str) -> Result<Self> {
349        let bytes = BASE64.decode(s).map_err(|e| {
350            PromocryptError::InvalidFileFormatDetails(format!("Invalid base64: {}", e))
351        })?;
352        Self::from_bytes(&bytes)
353    }
354}
355
356/// Custom serialization for DammTable using base64
357impl Serialize for DammTable {
358    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
359    where
360        S: serde::Serializer,
361    {
362        serializer.serialize_str(&self.to_base64())
363    }
364}
365
366impl<'de> Deserialize<'de> for DammTable {
367    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
368    where
369        D: serde::Deserializer<'de>,
370    {
371        let s = String::deserialize(deserializer)?;
372        DammTable::from_base64(&s).map_err(serde::de::Error::custom)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_new_table() {
382        let table = DammTable::new(26);
383        assert_eq!(table.size(), 26);
384    }
385
386    #[test]
387    fn test_diagonal_zeros() {
388        let table = DammTable::new(26);
389        for i in 0..26 {
390            assert_eq!(table.get(i, i), 0, "Diagonal at {} should be 0", i);
391        }
392    }
393
394    #[test]
395    fn test_calculate_and_validate() {
396        let alphabet = Alphabet::default_alphabet();
397        let table = DammTable::new(alphabet.len());
398
399        // Calculate check digit for "ABCD"
400        let chars: Vec<char> = "ABCD".chars().collect();
401        let check = table.calculate(&chars, &alphabet);
402
403        // Validate full code
404        let mut code = String::from("ABCD");
405        code.push(check);
406
407        assert!(table.validate(&code, &alphabet));
408    }
409
410    #[test]
411    fn test_single_error_detection() {
412        let alphabet = Alphabet::default_alphabet();
413        let table = DammTable::new(alphabet.len());
414
415        // Create valid code
416        let chars: Vec<char> = "ABCDEFGH".chars().collect();
417        let check = table.calculate(&chars, &alphabet);
418        let mut valid_code = String::from("ABCDEFGH");
419        valid_code.push(check);
420
421        assert!(table.validate(&valid_code, &alphabet));
422
423        // Modify one character - should be detected
424        let mut modified = valid_code.chars().collect::<Vec<_>>();
425        if modified[4] == 'E' {
426            modified[4] = 'F';
427        } else {
428            modified[4] = 'E';
429        }
430        let modified_code: String = modified.iter().collect();
431
432        assert!(!table.validate(&modified_code, &alphabet));
433    }
434
435    #[test]
436    fn test_serialization() {
437        let table = DammTable::new(26);
438        let bytes = table.to_bytes();
439        let restored = DammTable::from_bytes(&bytes).unwrap();
440        assert_eq!(table, restored);
441    }
442
443    #[test]
444    fn test_base64_serialization() {
445        let table = DammTable::new(26);
446        let b64 = table.to_base64();
447        let restored = DammTable::from_base64(&b64).unwrap();
448        assert_eq!(table, restored);
449    }
450
451    #[test]
452    fn test_serde() {
453        let table = DammTable::new(26);
454        let json = serde_json::to_string(&table).unwrap();
455        let restored: DammTable = serde_json::from_str(&json).unwrap();
456        assert_eq!(table, restored);
457    }
458
459    #[test]
460    fn test_various_sizes() {
461        for size in [10, 16, 26, 36, 62] {
462            let table = DammTable::new(size);
463            assert_eq!(table.size(), size);
464
465            // Verify diagonal
466            for i in 0..size {
467                assert_eq!(table.get(i, i), 0);
468            }
469        }
470    }
471
472    #[test]
473    fn test_row_permutation() {
474        let table = DammTable::new(10);
475        for i in 0..10 {
476            let mut values: Vec<u8> = (0..10).map(|j| table.get(i, j)).collect();
477            values.sort();
478            let expected: Vec<u8> = (0..10).map(|x| x as u8).collect();
479            assert_eq!(values, expected, "Row {} is not a permutation of 0..10", i);
480        }
481    }
482
483    #[test]
484    fn test_column_permutation() {
485        let table = DammTable::new(10);
486        for j in 0..10 {
487            let mut values: Vec<u8> = (0..10).map(|i| table.get(i, j)).collect();
488            values.sort();
489            let expected: Vec<u8> = (0..10).map(|x| x as u8).collect();
490            assert_eq!(
491                values, expected,
492                "Column {} is not a permutation of 0..10",
493                j
494            );
495        }
496    }
497}