1use ring::hmac;
7use serde::{Deserialize, Serialize};
8
9use crate::alphabet::Alphabet;
10use crate::damm::DammTable;
11
12#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(untagged)]
39pub enum CheckPosition {
40 #[serde(rename = "start")]
42 Start,
43 #[serde(rename = "end")]
45 #[default]
46 End,
47 Index(i8),
49}
50
51impl CheckPosition {
52 pub fn new(index: i8) -> Self {
54 match index {
55 0 => CheckPosition::Start,
56 -1 => CheckPosition::End,
57 n => CheckPosition::Index(n),
58 }
59 }
60
61 pub fn to_index(&self, total_length: usize) -> usize {
63 match self {
64 CheckPosition::Start => 0,
65 CheckPosition::End => total_length.saturating_sub(1),
66 CheckPosition::Index(idx) => {
67 if *idx >= 0 {
68 (*idx as usize).min(total_length.saturating_sub(1))
69 } else {
70 let from_end = (-*idx) as usize;
71 total_length.saturating_sub(from_end)
72 }
73 }
74 }
75 }
76
77 pub fn raw(&self) -> i8 {
79 match self {
80 CheckPosition::Start => 0,
81 CheckPosition::End => -1,
82 CheckPosition::Index(n) => *n,
83 }
84 }
85
86 pub fn parse_str(s: &str) -> Option<Self> {
88 match s.to_lowercase().as_str() {
89 "start" | "beginning" | "front" | "0" => Some(CheckPosition::Start),
90 "end" | "back" | "tail" | "-1" => Some(CheckPosition::End),
91 _ => {
92 s.parse::<i8>().ok().map(CheckPosition::new)
94 }
95 }
96 }
97}
98
99impl std::str::FromStr for CheckPosition {
100 type Err = String;
101
102 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
103 Self::parse_str(s).ok_or_else(|| format!("Invalid check position: {}", s))
104 }
105}
106
107impl std::fmt::Display for CheckPosition {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 CheckPosition::Start => write!(f, "start"),
111 CheckPosition::End => write!(f, "end"),
112 CheckPosition::Index(n) => write!(f, "{}", n),
113 }
114 }
115}
116
117#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
131pub struct CodeFormat {
132 #[serde(skip_serializing_if = "Option::is_none")]
134 pub prefix: Option<String>,
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub suffix: Option<String>,
138 #[serde(skip_serializing_if = "Option::is_none")]
140 pub separator: Option<char>,
141 #[serde(default, skip_serializing_if = "Vec::is_empty")]
143 pub separator_positions: Vec<usize>,
144}
145
146impl CodeFormat {
147 pub fn new() -> Self {
149 Self::default()
150 }
151
152 pub fn with_prefix(mut self, prefix: &str) -> Self {
154 self.prefix = Some(prefix.to_string());
155 self
156 }
157
158 pub fn with_suffix(mut self, suffix: &str) -> Self {
160 self.suffix = Some(suffix.to_string());
161 self
162 }
163
164 pub fn with_separator(mut self, sep: char, positions: Vec<usize>) -> Self {
166 self.separator = Some(sep);
167 self.separator_positions = positions;
168 self
169 }
170
171 pub fn format(&self, base_code: &str) -> String {
173 let chars: Vec<char> = base_code.chars().collect();
174 let sep_count = self
175 .separator_positions
176 .iter()
177 .filter(|&&p| p > 0 && p < chars.len())
178 .count();
179 let prefix_len = self.prefix.as_ref().map(|p| p.len()).unwrap_or(0);
180 let suffix_len = self.suffix.as_ref().map(|s| s.len()).unwrap_or(0);
181
182 let mut result = String::with_capacity(prefix_len + chars.len() + sep_count + suffix_len);
183
184 if let Some(ref prefix) = self.prefix {
186 result.push_str(prefix);
187 }
188
189 if let Some(sep) = self.separator {
191 for (i, c) in chars.iter().enumerate() {
192 if i > 0 && self.separator_positions.contains(&i) {
193 result.push(sep);
194 }
195 result.push(*c);
196 }
197 } else {
198 result.push_str(base_code);
199 }
200
201 if let Some(ref suffix) = self.suffix {
203 result.push_str(suffix);
204 }
205
206 result
207 }
208
209 pub fn strip(&self, formatted_code: &str) -> Option<String> {
211 let mut code = formatted_code.to_string();
212
213 if let Some(ref prefix) = self.prefix {
215 code = code.strip_prefix(prefix)?.to_string();
216 }
217
218 if let Some(ref suffix) = self.suffix {
220 code = code.strip_suffix(suffix)?.to_string();
221 }
222
223 if let Some(sep) = self.separator {
225 code = code.chars().filter(|&c| c != sep).collect();
226 }
227
228 Some(code)
229 }
230
231 pub fn total_length(&self, base_length: usize) -> usize {
233 let prefix_len = self.prefix.as_ref().map(|p| p.len()).unwrap_or(0);
234 let suffix_len = self.suffix.as_ref().map(|s| s.len()).unwrap_or(0);
235 let sep_count = self
236 .separator_positions
237 .iter()
238 .filter(|&&p| p > 0 && p < base_length)
239 .count();
240
241 prefix_len + base_length + sep_count + suffix_len
242 }
243
244 pub fn has_formatting(&self) -> bool {
246 self.prefix.is_some() || self.suffix.is_some() || self.separator.is_some()
247 }
248}
249
250pub fn generate_code(
276 secret_key: &[u8; 32],
277 counter: u64,
278 alphabet: &Alphabet,
279 code_length: usize,
280 check_position: CheckPosition,
281 damm_table: &DammTable,
282) -> String {
283 let key = hmac::Key::new(hmac::HMAC_SHA256, secret_key);
285 let signature = hmac::sign(&key, &counter.to_le_bytes());
286
287 let hash_bytes: [u8; 8] = signature.as_ref()[0..8].try_into().unwrap();
289 let mut value = u64::from_le_bytes(hash_bytes);
290
291 let base = alphabet.len() as u64;
293 let mut chars = Vec::with_capacity(code_length);
294
295 for _ in 0..code_length {
296 let index = (value % base) as usize;
297 chars.push(alphabet.char_at(index));
298 value /= base;
299 }
300
301 let total_length = code_length + 1;
303 let check_index = check_position.to_index(total_length);
304
305 let check = if check_index == 0 {
308 damm_table.calculate_for_start(&chars, alphabet)
310 } else if check_index >= code_length {
311 damm_table.calculate(&chars, alphabet)
313 } else {
314 damm_table.calculate_for_position(&chars, alphabet, check_index)
317 };
318
319 let mut result = String::with_capacity(total_length);
321
322 for i in 0..total_length {
323 if i == check_index {
324 result.push(check);
325 } else {
326 let char_idx = if i < check_index { i } else { i - 1 };
328 if char_idx < chars.len() {
329 result.push(chars[char_idx]);
330 }
331 }
332 }
333
334 result
335}
336
337pub fn generate_batch(
364 secret_key: &[u8; 32],
365 start_counter: u64,
366 count: usize,
367 alphabet: &Alphabet,
368 code_length: usize,
369 check_position: CheckPosition,
370 damm_table: &DammTable,
371) -> Vec<String> {
372 (0..count)
373 .map(|i| {
374 generate_code(
375 secret_key,
376 start_counter + i as u64,
377 alphabet,
378 code_length,
379 check_position,
380 damm_table,
381 )
382 })
383 .collect()
384}
385
386#[allow(clippy::too_many_arguments)]
391pub fn generate_batch_into(
392 secret_key: &[u8; 32],
393 start_counter: u64,
394 count: usize,
395 alphabet: &Alphabet,
396 code_length: usize,
397 check_position: CheckPosition,
398 damm_table: &DammTable,
399 output: &mut Vec<String>,
400) {
401 output.clear();
402 output.reserve(count);
403
404 for i in 0..count {
405 output.push(generate_code(
406 secret_key,
407 start_counter + i as u64,
408 alphabet,
409 code_length,
410 check_position,
411 damm_table,
412 ));
413 }
414}
415
416pub struct CodeGenerator<'a> {
420 secret_key: &'a [u8; 32],
421 alphabet: &'a Alphabet,
422 code_length: usize,
423 check_position: CheckPosition,
424 damm_table: &'a DammTable,
425 current_counter: u64,
426 end_counter: u64,
427}
428
429impl<'a> CodeGenerator<'a> {
430 pub fn new(
432 secret_key: &'a [u8; 32],
433 start_counter: u64,
434 count: usize,
435 alphabet: &'a Alphabet,
436 code_length: usize,
437 check_position: CheckPosition,
438 damm_table: &'a DammTable,
439 ) -> Self {
440 Self {
441 secret_key,
442 alphabet,
443 code_length,
444 check_position,
445 damm_table,
446 current_counter: start_counter,
447 end_counter: start_counter + count as u64,
448 }
449 }
450}
451
452impl Iterator for CodeGenerator<'_> {
453 type Item = String;
454
455 fn next(&mut self) -> Option<Self::Item> {
456 if self.current_counter >= self.end_counter {
457 return None;
458 }
459
460 let code = generate_code(
461 self.secret_key,
462 self.current_counter,
463 self.alphabet,
464 self.code_length,
465 self.check_position,
466 self.damm_table,
467 );
468
469 self.current_counter += 1;
470 Some(code)
471 }
472
473 fn size_hint(&self) -> (usize, Option<usize>) {
474 let remaining = (self.end_counter - self.current_counter) as usize;
475 (remaining, Some(remaining))
476 }
477}
478
479impl ExactSizeIterator for CodeGenerator<'_> {}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 fn setup() -> (Alphabet, DammTable, [u8; 32]) {
486 let alphabet = Alphabet::default_alphabet();
487 let damm = DammTable::new(alphabet.len());
488 let secret = [42u8; 32];
489 (alphabet, damm, secret)
490 }
491
492 #[test]
493 fn test_generate_code_length() {
494 let (alphabet, damm, secret) = setup();
495
496 let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
497 assert_eq!(code.len(), 10); }
499
500 #[test]
501 fn test_generate_code_deterministic() {
502 let (alphabet, damm, secret) = setup();
503
504 let code1 = generate_code(&secret, 12345, &alphabet, 9, CheckPosition::End, &damm);
505 let code2 = generate_code(&secret, 12345, &alphabet, 9, CheckPosition::End, &damm);
506
507 assert_eq!(code1, code2);
508 }
509
510 #[test]
511 fn test_different_counters_different_codes() {
512 let (alphabet, damm, secret) = setup();
513
514 let code1 = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
515 let code2 = generate_code(&secret, 1, &alphabet, 9, CheckPosition::End, &damm);
516
517 assert_ne!(code1, code2);
518 }
519
520 #[test]
521 fn test_different_secrets_different_codes() {
522 let (alphabet, damm, _) = setup();
523
524 let secret1 = [1u8; 32];
525 let secret2 = [2u8; 32];
526
527 let code1 = generate_code(&secret1, 0, &alphabet, 9, CheckPosition::End, &damm);
528 let code2 = generate_code(&secret2, 0, &alphabet, 9, CheckPosition::End, &damm);
529
530 assert_ne!(code1, code2);
531 }
532
533 #[test]
534 fn test_check_position_start() {
535 let (alphabet, damm, secret) = setup();
536
537 let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::Start, &damm);
538 assert_eq!(code.len(), 10);
539
540 assert!(damm.validate(&code, &alphabet));
542 }
543
544 #[test]
545 fn test_check_position_end() {
546 let (alphabet, damm, secret) = setup();
547
548 let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
549 assert_eq!(code.len(), 10);
550
551 assert!(damm.validate(&code, &alphabet));
553 }
554
555 #[test]
556 fn test_all_characters_in_alphabet() {
557 let (alphabet, damm, secret) = setup();
558
559 let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
560
561 for c in code.chars() {
562 assert!(alphabet.contains(c), "Character '{}' not in alphabet", c);
563 }
564 }
565
566 #[test]
567 fn test_generate_batch() {
568 let (alphabet, damm, secret) = setup();
569
570 let codes = generate_batch(&secret, 0, 100, &alphabet, 9, CheckPosition::End, &damm);
571
572 assert_eq!(codes.len(), 100);
573
574 let unique: std::collections::HashSet<_> = codes.iter().collect();
576 assert_eq!(unique.len(), 100);
577
578 for code in &codes {
580 assert!(damm.validate(code, &alphabet));
581 }
582 }
583
584 #[test]
585 fn test_code_generator_iterator() {
586 let (alphabet, damm, secret) = setup();
587
588 let code_gen = CodeGenerator::new(&secret, 0, 10, &alphabet, 9, CheckPosition::End, &damm);
589
590 let codes: Vec<_> = code_gen.collect();
591 assert_eq!(codes.len(), 10);
592
593 let batch = generate_batch(&secret, 0, 10, &alphabet, 9, CheckPosition::End, &damm);
595 assert_eq!(codes, batch);
596 }
597
598 #[test]
599 fn test_code_generator_exact_size() {
600 let (alphabet, damm, secret) = setup();
601
602 let code_gen = CodeGenerator::new(&secret, 0, 50, &alphabet, 9, CheckPosition::End, &damm);
603 assert_eq!(code_gen.len(), 50);
604 }
605
606 #[test]
607 fn test_various_code_lengths() {
608 let (alphabet, damm, secret) = setup();
609
610 for length in [4, 6, 8, 9, 12, 16, 20] {
611 let code = generate_code(&secret, 0, &alphabet, length, CheckPosition::End, &damm);
612 assert_eq!(code.len(), length + 1);
613 assert!(damm.validate(&code, &alphabet));
614 }
615 }
616
617 #[test]
618 fn test_check_position_from_str() {
619 assert_eq!(
620 CheckPosition::parse_str("start"),
621 Some(CheckPosition::Start)
622 );
623 assert_eq!(CheckPosition::parse_str("end"), Some(CheckPosition::End));
624 assert_eq!(
625 CheckPosition::parse_str("beginning"),
626 Some(CheckPosition::Start)
627 );
628 assert_eq!(CheckPosition::parse_str("invalid"), None);
629 }
630}