promocrypt_core/damm.rs
1//! Damm check digit algorithm.
2//!
3//! Implements the Damm algorithm for generating and validating check digits.
4//! This algorithm detects all single-digit errors and all adjacent transposition errors.
5
6use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
7use serde::{Deserialize, Serialize};
8
9use crate::alphabet::Alphabet;
10use crate::error::{PromocryptError, Result};
11
12/// Damm quasi-group table for check digit calculation.
13///
14/// The table has the following properties:
15/// - Each row is a permutation of 0..size
16/// - Each column is a permutation of 0..size
17/// - Diagonal is all zeros: `table[i][i] = 0`
18/// - Anti-symmetric: `table[a][b] != table[b][a]` when `a != b`
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct DammTable {
21 table: Vec<Vec<u8>>,
22 size: usize,
23}
24
25impl DammTable {
26 /// Generate a quasi-group table for the given alphabet size.
27 ///
28 /// The generated table guarantees:
29 /// - 100% detection of single-character errors
30 /// - 100% detection of adjacent transposition errors
31 ///
32 /// # Arguments
33 /// * `size` - The alphabet size (10-62)
34 ///
35 /// # Examples
36 ///
37 /// ```
38 /// use promocrypt_core::DammTable;
39 ///
40 /// let table = DammTable::new(26);
41 /// assert_eq!(table.size(), 26);
42 /// ```
43 pub fn new(size: usize) -> Self {
44 assert!((10..=62).contains(&size), "Size must be between 10 and 62");
45
46 let table = Self::generate_quasi_group(size);
47
48 Self { table, size }
49 }
50
51 /// Generate a quasi-group with the required properties.
52 ///
53 /// Uses a construction that guarantees:
54 /// - Each row is a permutation of 0..n
55 /// - Each column is a permutation of 0..n
56 /// - Diagonal is all zeros: table\[i\]\[i\] = 0
57 #[allow(clippy::needless_range_loop)]
58 fn generate_quasi_group(n: usize) -> Vec<Vec<u8>> {
59 let mut table = vec![vec![0u8; n]; n];
60
61 // Use a proven construction for Latin squares with zero diagonal.
62 // Method: Start with standard addition table, then apply a permutation
63 // to make the diagonal all zeros.
64 //
65 // Standard Latin square: table[i][j] = (i + j) mod n
66 // Diagonal values: table[i][i] = 2i mod n
67 //
68 // To get zeros on diagonal, we apply a value permutation.
69 // If we define perm(x) = (x - 2*row_index) mod n, then:
70 // perm(table[i][j]) = (i + j - 2i) mod n = (j - i) mod n
71 // For diagonal: perm(2i mod n) = (i - i) mod n = 0
72
73 // However, applying different permutations per row breaks the Latin property.
74 // Instead, use a global transformation.
75
76 // Correct approach: Use table[i][j] = (i + j) mod n, then swap values
77 // within each row to place 0 on diagonal.
78
79 // Step 1: Create base Latin square
80 for i in 0..n {
81 for j in 0..n {
82 table[i][j] = ((i + j) % n) as u8;
83 }
84 }
85
86 // Step 2: For each row i, swap position of value 0 with diagonal position
87 // In row i, value 0 is at column (n - i) % n
88 // We want value 0 at column i (diagonal)
89 // So swap columns (n-i)%n and i in each row
90
91 // Actually, we need to be more careful. Swapping within rows preserves
92 // row permutation, but breaks column permutation.
93
94 // Better approach: Create a Latin square using a derangement
95 // table[i][j] = (i + j + 1) mod n -- this gives diagonal = (2i+1) mod n, never 0 for n>1
96 // Then we swap the diagonal with where 0 appears
97
98 // Even better: Use the fact that for a Latin square L, if we apply
99 // a permutation π to rows and σ to columns, we get another Latin square.
100 // We want L'[i][i] = 0 for all i.
101
102 // Simplest correct approach that works:
103 // For row i, create permutation: start at some value, go sequentially,
104 // but ensure 0 is at position i.
105
106 // Let's use: row i has values arranged as:
107 // position j gets value ((j - i + n) % n) if j != i
108 // position i (diagonal) gets 0
109
110 for i in 0..n {
111 for j in 0..n {
112 if i == j {
113 table[i][j] = 0;
114 } else {
115 // Value at position j in row i
116 // We want each row and column to be a permutation
117 // Use: value = 1 + ((j - i - 1 + n) % (n-1)) for j > i
118 // value = 1 + ((j - i + n) % (n-1)) for j < i
119 // This ensures 0 only on diagonal, and rest are 1..n-1
120
121 let diff = ((j as isize - i as isize).rem_euclid(n as isize)) as usize;
122 // diff is in range 1..n-1 for j != i (since we skip j==i case)
123 table[i][j] = diff as u8;
124 }
125 }
126 }
127
128 // Verify: In row i, for j from 0 to n-1:
129 // - j=i: value is 0
130 // - j≠i: value is (j-i) mod n, which is 1,2,...,n-1 (excluding 0)
131 // So row i has exactly one 0 (at diagonal) and values 1..n-1 at other positions
132 // Wait, this means each row only has values 0..n-1 once? Let's check:
133 // For row 0: j=0->0, j=1->1, j=2->2, ..., j=n-1->n-1 ✓
134 // For row 1: j=0->n-1, j=1->0, j=2->1, j=3->2, ..., j=n-1->n-2 ✓
135
136 // For column j, value at row i is:
137 // - i=j: value is 0
138 // - i≠j: value is (j-i) mod n
139 // For column 0: i=0->0, i=1->n-1, i=2->n-2, ..., i=n-1->1
140 // This is 0, n-1, n-2, ..., 1 which is a permutation of 0..n-1 ✓
141
142 table
143 }
144
145 /// Get the size of the table (alphabet size).
146 #[inline]
147 pub fn size(&self) -> usize {
148 self.size
149 }
150
151 /// Get table value at position.
152 #[inline]
153 pub fn get(&self, row: usize, col: usize) -> u8 {
154 self.table[row][col]
155 }
156
157 /// Calculate check digit for a sequence of characters (check digit at end).
158 ///
159 /// # Arguments
160 /// * `chars` - The characters to calculate check digit for (without existing check digit)
161 /// * `alphabet` - The alphabet used for character→index mapping
162 ///
163 /// # Returns
164 /// The check digit character from the alphabet.
165 pub fn calculate(&self, chars: &[char], alphabet: &Alphabet) -> char {
166 let mut interim = 0usize;
167
168 for &c in chars {
169 if let Some(index) = alphabet.index_of(c) {
170 interim = self.table[interim][index] as usize;
171 }
172 }
173
174 alphabet.char_at(interim)
175 }
176
177 /// Calculate check digit for start position.
178 ///
179 /// Finds a check digit X such that processing [X, chars...] results in 0.
180 ///
181 /// # Arguments
182 /// * `chars` - The characters (check digit will be prepended)
183 /// * `alphabet` - The alphabet used for character→index mapping
184 ///
185 /// # Returns
186 /// The check digit character from the alphabet.
187 pub fn calculate_for_start(&self, chars: &[char], alphabet: &Alphabet) -> char {
188 // We need to find X such that process([X] + chars) = 0
189 // Try each possible check digit
190 for check_idx in 0..self.size {
191 let mut interim = self.table[0][check_idx] as usize;
192
193 for &c in chars {
194 if let Some(index) = alphabet.index_of(c) {
195 interim = self.table[interim][index] as usize;
196 }
197 }
198
199 if interim == 0 {
200 return alphabet.char_at(check_idx);
201 }
202 }
203
204 // Fallback (should never happen with valid quasi-group)
205 alphabet.char_at(0)
206 }
207
208 /// Calculate check digit for arbitrary position in the code.
209 ///
210 /// Finds a check digit X such that inserting X at `position` in `chars`
211 /// results in a valid code (Damm validation returns 0).
212 ///
213 /// # Arguments
214 /// * `chars` - The characters (check digit will be inserted at position)
215 /// * `alphabet` - The alphabet used for character→index mapping
216 /// * `position` - Position where check digit will be inserted (0-based)
217 ///
218 /// # Returns
219 /// The check digit character from the alphabet.
220 pub fn calculate_for_position(
221 &self,
222 chars: &[char],
223 alphabet: &Alphabet,
224 position: usize,
225 ) -> char {
226 // We need to find X such that process(chars[0..position] + [X] + chars[position..]) = 0
227
228 // First, process chars up to the check digit position
229 let mut interim_before = 0usize;
230 for &c in chars.iter().take(position) {
231 if let Some(index) = alphabet.index_of(c) {
232 interim_before = self.table[interim_before][index] as usize;
233 }
234 }
235
236 // For each possible check digit, calculate what state we'd be in after it
237 // Then process the remaining chars and see if we end up at 0
238 for check_idx in 0..self.size {
239 let mut interim = self.table[interim_before][check_idx] as usize;
240
241 // Process remaining characters
242 for &c in chars.iter().skip(position) {
243 if let Some(index) = alphabet.index_of(c) {
244 interim = self.table[interim][index] as usize;
245 }
246 }
247
248 if interim == 0 {
249 return alphabet.char_at(check_idx);
250 }
251 }
252
253 // Fallback (should never happen with valid quasi-group)
254 alphabet.char_at(0)
255 }
256
257 /// Calculate check digit for a string.
258 pub fn calculate_str(&self, s: &str, alphabet: &Alphabet) -> char {
259 let chars: Vec<char> = s.chars().collect();
260 self.calculate(&chars, alphabet)
261 }
262
263 /// Validate a code including its check digit.
264 ///
265 /// The check digit is valid if processing the entire code (including check digit)
266 /// results in 0.
267 ///
268 /// # Arguments
269 /// * `code` - The full code including check digit
270 /// * `alphabet` - The alphabet used
271 ///
272 /// # Returns
273 /// `true` if the check digit is valid, `false` otherwise.
274 pub fn validate(&self, code: &str, alphabet: &Alphabet) -> bool {
275 let mut interim = 0usize;
276
277 for c in code.chars() {
278 match alphabet.index_of(c) {
279 Some(index) => {
280 interim = self.table[interim][index] as usize;
281 }
282 None => return false, // Invalid character
283 }
284 }
285
286 interim == 0
287 }
288
289 /// Serialize the table to bytes for storage.
290 pub fn to_bytes(&self) -> Vec<u8> {
291 let mut bytes = Vec::with_capacity(4 + self.size * self.size);
292
293 // Store size as u32
294 bytes.extend_from_slice(&(self.size as u32).to_le_bytes());
295
296 // Store table data row by row
297 for row in &self.table {
298 bytes.extend_from_slice(row);
299 }
300
301 bytes
302 }
303
304 /// Deserialize table from bytes.
305 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
306 if bytes.len() < 4 {
307 return Err(PromocryptError::InvalidFileFormatDetails(
308 "Damm table too short".to_string(),
309 ));
310 }
311
312 let size = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
313
314 if !(10..=62).contains(&size) {
315 return Err(PromocryptError::InvalidFileFormatDetails(format!(
316 "Invalid Damm table size: {}",
317 size
318 )));
319 }
320
321 let expected_len = 4 + size * size;
322 if bytes.len() < expected_len {
323 return Err(PromocryptError::InvalidFileFormatDetails(format!(
324 "Damm table data too short: expected {}, got {}",
325 expected_len,
326 bytes.len()
327 )));
328 }
329
330 let mut table = Vec::with_capacity(size);
331 let data = &bytes[4..];
332
333 for i in 0..size {
334 let row_start = i * size;
335 let row_end = row_start + size;
336 table.push(data[row_start..row_end].to_vec());
337 }
338
339 Ok(Self { table, size })
340 }
341
342 /// Serialize to base64 string for JSON storage.
343 pub fn to_base64(&self) -> String {
344 BASE64.encode(self.to_bytes())
345 }
346
347 /// Deserialize from base64 string.
348 pub fn from_base64(s: &str) -> Result<Self> {
349 let bytes = BASE64.decode(s).map_err(|e| {
350 PromocryptError::InvalidFileFormatDetails(format!("Invalid base64: {}", e))
351 })?;
352 Self::from_bytes(&bytes)
353 }
354}
355
356/// Custom serialization for DammTable using base64
357impl Serialize for DammTable {
358 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
359 where
360 S: serde::Serializer,
361 {
362 serializer.serialize_str(&self.to_base64())
363 }
364}
365
366impl<'de> Deserialize<'de> for DammTable {
367 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
368 where
369 D: serde::Deserializer<'de>,
370 {
371 let s = String::deserialize(deserializer)?;
372 DammTable::from_base64(&s).map_err(serde::de::Error::custom)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_new_table() {
382 let table = DammTable::new(26);
383 assert_eq!(table.size(), 26);
384 }
385
386 #[test]
387 fn test_diagonal_zeros() {
388 let table = DammTable::new(26);
389 for i in 0..26 {
390 assert_eq!(table.get(i, i), 0, "Diagonal at {} should be 0", i);
391 }
392 }
393
394 #[test]
395 fn test_calculate_and_validate() {
396 let alphabet = Alphabet::default_alphabet();
397 let table = DammTable::new(alphabet.len());
398
399 // Calculate check digit for "ABCD"
400 let chars: Vec<char> = "ABCD".chars().collect();
401 let check = table.calculate(&chars, &alphabet);
402
403 // Validate full code
404 let mut code = String::from("ABCD");
405 code.push(check);
406
407 assert!(table.validate(&code, &alphabet));
408 }
409
410 #[test]
411 fn test_single_error_detection() {
412 let alphabet = Alphabet::default_alphabet();
413 let table = DammTable::new(alphabet.len());
414
415 // Create valid code
416 let chars: Vec<char> = "ABCDEFGH".chars().collect();
417 let check = table.calculate(&chars, &alphabet);
418 let mut valid_code = String::from("ABCDEFGH");
419 valid_code.push(check);
420
421 assert!(table.validate(&valid_code, &alphabet));
422
423 // Modify one character - should be detected
424 let mut modified = valid_code.chars().collect::<Vec<_>>();
425 if modified[4] == 'E' {
426 modified[4] = 'F';
427 } else {
428 modified[4] = 'E';
429 }
430 let modified_code: String = modified.iter().collect();
431
432 assert!(!table.validate(&modified_code, &alphabet));
433 }
434
435 #[test]
436 fn test_serialization() {
437 let table = DammTable::new(26);
438 let bytes = table.to_bytes();
439 let restored = DammTable::from_bytes(&bytes).unwrap();
440 assert_eq!(table, restored);
441 }
442
443 #[test]
444 fn test_base64_serialization() {
445 let table = DammTable::new(26);
446 let b64 = table.to_base64();
447 let restored = DammTable::from_base64(&b64).unwrap();
448 assert_eq!(table, restored);
449 }
450
451 #[test]
452 fn test_serde() {
453 let table = DammTable::new(26);
454 let json = serde_json::to_string(&table).unwrap();
455 let restored: DammTable = serde_json::from_str(&json).unwrap();
456 assert_eq!(table, restored);
457 }
458
459 #[test]
460 fn test_various_sizes() {
461 for size in [10, 16, 26, 36, 62] {
462 let table = DammTable::new(size);
463 assert_eq!(table.size(), size);
464
465 // Verify diagonal
466 for i in 0..size {
467 assert_eq!(table.get(i, i), 0);
468 }
469 }
470 }
471
472 #[test]
473 fn test_row_permutation() {
474 let table = DammTable::new(10);
475 for i in 0..10 {
476 let mut values: Vec<u8> = (0..10).map(|j| table.get(i, j)).collect();
477 values.sort();
478 let expected: Vec<u8> = (0..10).map(|x| x as u8).collect();
479 assert_eq!(values, expected, "Row {} is not a permutation of 0..10", i);
480 }
481 }
482
483 #[test]
484 fn test_column_permutation() {
485 let table = DammTable::new(10);
486 for j in 0..10 {
487 let mut values: Vec<u8> = (0..10).map(|i| table.get(i, j)).collect();
488 values.sort();
489 let expected: Vec<u8> = (0..10).map(|x| x as u8).collect();
490 assert_eq!(
491 values, expected,
492 "Column {} is not a permutation of 0..10",
493 j
494 );
495 }
496 }
497}