promocrypt_core/
generator.rs

1//! Code generation using HMAC-SHA256.
2//!
3//! This module provides functions for generating promotional codes
4//! from counter values using HMAC-SHA256 and the Damm check digit algorithm.
5
6use ring::hmac;
7use serde::{Deserialize, Serialize};
8
9use crate::alphabet::Alphabet;
10use crate::damm::DammTable;
11
12/// Index-based check digit position.
13///
14/// The check digit can be placed at any position in the code using positive
15/// or negative indexing:
16/// - `0`: Start (first position)
17/// - `-1`: End (last position, default)
18/// - Positive N: Position N from start
19/// - Negative -N: Position N from end
20///
21/// # Examples
22/// ```
23/// use promocrypt_core::CheckPosition;
24///
25/// let pos = CheckPosition::End;
26/// assert_eq!(pos.to_index(10), 9);
27///
28/// let pos = CheckPosition::Start;
29/// assert_eq!(pos.to_index(10), 0);
30///
31/// let pos = CheckPosition::Index(4);
32/// assert_eq!(pos.to_index(10), 4);
33///
34/// let pos = CheckPosition::Index(-3);
35/// assert_eq!(pos.to_index(10), 7);
36/// ```
37#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(untagged)]
39pub enum CheckPosition {
40    /// Check digit at the start of the code (index 0)
41    #[serde(rename = "start")]
42    Start,
43    /// Check digit at the end of the code (index -1)
44    #[serde(rename = "end")]
45    #[default]
46    End,
47    /// Check digit at specific index (positive or negative)
48    Index(i8),
49}
50
51impl CheckPosition {
52    /// Create from raw index value.
53    pub fn new(index: i8) -> Self {
54        match index {
55            0 => CheckPosition::Start,
56            -1 => CheckPosition::End,
57            n => CheckPosition::Index(n),
58        }
59    }
60
61    /// Get actual index given total code length (including check digit).
62    pub fn to_index(&self, total_length: usize) -> usize {
63        match self {
64            CheckPosition::Start => 0,
65            CheckPosition::End => total_length.saturating_sub(1),
66            CheckPosition::Index(idx) => {
67                if *idx >= 0 {
68                    (*idx as usize).min(total_length.saturating_sub(1))
69                } else {
70                    let from_end = (-*idx) as usize;
71                    total_length.saturating_sub(from_end)
72                }
73            }
74        }
75    }
76
77    /// Get raw index value.
78    pub fn raw(&self) -> i8 {
79        match self {
80            CheckPosition::Start => 0,
81            CheckPosition::End => -1,
82            CheckPosition::Index(n) => *n,
83        }
84    }
85
86    /// Parse from string (for CLI and config).
87    pub fn parse_str(s: &str) -> Option<Self> {
88        match s.to_lowercase().as_str() {
89            "start" | "beginning" | "front" | "0" => Some(CheckPosition::Start),
90            "end" | "back" | "tail" | "-1" => Some(CheckPosition::End),
91            _ => {
92                // Try parsing as integer
93                s.parse::<i8>().ok().map(CheckPosition::new)
94            }
95        }
96    }
97}
98
99impl std::str::FromStr for CheckPosition {
100    type Err = String;
101
102    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
103        Self::parse_str(s).ok_or_else(|| format!("Invalid check position: {}", s))
104    }
105}
106
107impl std::fmt::Display for CheckPosition {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            CheckPosition::Start => write!(f, "start"),
111            CheckPosition::End => write!(f, "end"),
112            CheckPosition::Index(n) => write!(f, "{}", n),
113        }
114    }
115}
116
117/// Code format options for prefix, suffix, and separators.
118///
119/// # Examples
120/// ```
121/// use promocrypt_core::CodeFormat;
122///
123/// let format = CodeFormat::new()
124///     .with_prefix("PROMO-")
125///     .with_suffix("-2024")
126///     .with_separator('-', vec![4, 8]);
127///
128/// assert_eq!(format.format("A3KF7NP2XM"), "PROMO-A3KF-7NP2-XM-2024");
129/// ```
130#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
131pub struct CodeFormat {
132    /// Optional prefix prepended to code
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub prefix: Option<String>,
135    /// Optional suffix appended to code
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub suffix: Option<String>,
138    /// Separator character
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub separator: Option<char>,
141    /// Positions where separator is inserted (0-based, before the character at that position)
142    #[serde(default, skip_serializing_if = "Vec::is_empty")]
143    pub separator_positions: Vec<usize>,
144}
145
146impl CodeFormat {
147    /// Create a new empty code format.
148    pub fn new() -> Self {
149        Self::default()
150    }
151
152    /// Set prefix.
153    pub fn with_prefix(mut self, prefix: &str) -> Self {
154        self.prefix = Some(prefix.to_string());
155        self
156    }
157
158    /// Set suffix.
159    pub fn with_suffix(mut self, suffix: &str) -> Self {
160        self.suffix = Some(suffix.to_string());
161        self
162    }
163
164    /// Set separator and positions.
165    pub fn with_separator(mut self, sep: char, positions: Vec<usize>) -> Self {
166        self.separator = Some(sep);
167        self.separator_positions = positions;
168        self
169    }
170
171    /// Apply formatting to base code.
172    pub fn format(&self, base_code: &str) -> String {
173        let chars: Vec<char> = base_code.chars().collect();
174        let sep_count = self
175            .separator_positions
176            .iter()
177            .filter(|&&p| p > 0 && p < chars.len())
178            .count();
179        let prefix_len = self.prefix.as_ref().map(|p| p.len()).unwrap_or(0);
180        let suffix_len = self.suffix.as_ref().map(|s| s.len()).unwrap_or(0);
181
182        let mut result = String::with_capacity(prefix_len + chars.len() + sep_count + suffix_len);
183
184        // Add prefix
185        if let Some(ref prefix) = self.prefix {
186            result.push_str(prefix);
187        }
188
189        // Add base code with separators
190        if let Some(sep) = self.separator {
191            for (i, c) in chars.iter().enumerate() {
192                if i > 0 && self.separator_positions.contains(&i) {
193                    result.push(sep);
194                }
195                result.push(*c);
196            }
197        } else {
198            result.push_str(base_code);
199        }
200
201        // Add suffix
202        if let Some(ref suffix) = self.suffix {
203            result.push_str(suffix);
204        }
205
206        result
207    }
208
209    /// Strip formatting from code to get base code.
210    pub fn strip(&self, formatted_code: &str) -> Option<String> {
211        let mut code = formatted_code.to_string();
212
213        // Strip prefix
214        if let Some(ref prefix) = self.prefix {
215            code = code.strip_prefix(prefix)?.to_string();
216        }
217
218        // Strip suffix
219        if let Some(ref suffix) = self.suffix {
220            code = code.strip_suffix(suffix)?.to_string();
221        }
222
223        // Strip separators
224        if let Some(sep) = self.separator {
225            code = code.chars().filter(|&c| c != sep).collect();
226        }
227
228        Some(code)
229    }
230
231    /// Calculate expected total length of formatted code.
232    pub fn total_length(&self, base_length: usize) -> usize {
233        let prefix_len = self.prefix.as_ref().map(|p| p.len()).unwrap_or(0);
234        let suffix_len = self.suffix.as_ref().map(|s| s.len()).unwrap_or(0);
235        let sep_count = self
236            .separator_positions
237            .iter()
238            .filter(|&&p| p > 0 && p < base_length)
239            .count();
240
241        prefix_len + base_length + sep_count + suffix_len
242    }
243
244    /// Check if any formatting is applied.
245    pub fn has_formatting(&self) -> bool {
246        self.prefix.is_some() || self.suffix.is_some() || self.separator.is_some()
247    }
248}
249
250/// Generate a single promotional code from a counter value.
251///
252/// # Arguments
253/// * `secret_key` - 32-byte secret key for HMAC
254/// * `counter` - Counter value (unique per code)
255/// * `alphabet` - Character set for the code
256/// * `code_length` - Number of random characters (check digit added separately)
257/// * `check_position` - Where to place the check digit (index-based)
258/// * `damm_table` - Damm table for check digit calculation
259///
260/// # Returns
261/// Generated code string with length = code_length + 1 (including check digit)
262///
263/// # Example
264///
265/// ```
266/// use promocrypt_core::{generate_code, Alphabet, DammTable, CheckPosition};
267///
268/// let secret = [0u8; 32];
269/// let alphabet = Alphabet::default_alphabet();
270/// let damm = DammTable::new(alphabet.len());
271///
272/// let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
273/// assert_eq!(code.len(), 10);
274/// ```
275pub fn generate_code(
276    secret_key: &[u8; 32],
277    counter: u64,
278    alphabet: &Alphabet,
279    code_length: usize,
280    check_position: CheckPosition,
281    damm_table: &DammTable,
282) -> String {
283    // 1. HMAC-SHA256 of counter
284    let key = hmac::Key::new(hmac::HMAC_SHA256, secret_key);
285    let signature = hmac::sign(&key, &counter.to_le_bytes());
286
287    // 2. Convert first 8 bytes to u64
288    let hash_bytes: [u8; 8] = signature.as_ref()[0..8].try_into().unwrap();
289    let mut value = u64::from_le_bytes(hash_bytes);
290
291    // 3. Convert to base-N characters
292    let base = alphabet.len() as u64;
293    let mut chars = Vec::with_capacity(code_length);
294
295    for _ in 0..code_length {
296        let index = (value % base) as usize;
297        chars.push(alphabet.char_at(index));
298        value /= base;
299    }
300
301    // 4. Get check position index
302    let total_length = code_length + 1;
303    let check_index = check_position.to_index(total_length);
304
305    // 5. Build code with check digit placeholder to calculate check
306    // For any position, we need to calculate check so that Damm validation passes
307    let check = if check_index == 0 {
308        // Check at start: find X such that processing [X, chars...] = 0
309        damm_table.calculate_for_start(&chars, alphabet)
310    } else if check_index >= code_length {
311        // Check at end: process chars and get final check digit
312        damm_table.calculate(&chars, alphabet)
313    } else {
314        // Check in middle: need to find check digit for middle position
315        // We calculate check digit such that Damm validation passes
316        damm_table.calculate_for_position(&chars, alphabet, check_index)
317    };
318
319    // 6. Build final code with check digit at specified position
320    let mut result = String::with_capacity(total_length);
321
322    for i in 0..total_length {
323        if i == check_index {
324            result.push(check);
325        } else {
326            // Adjust index for chars array based on check position
327            let char_idx = if i < check_index { i } else { i - 1 };
328            if char_idx < chars.len() {
329                result.push(chars[char_idx]);
330            }
331        }
332    }
333
334    result
335}
336
337/// Generate a batch of promotional codes.
338///
339/// # Arguments
340/// * `secret_key` - 32-byte secret key for HMAC
341/// * `start_counter` - Starting counter value
342/// * `count` - Number of codes to generate
343/// * `alphabet` - Character set for the codes
344/// * `code_length` - Number of random characters per code
345/// * `check_position` - Where to place the check digit
346/// * `damm_table` - Damm table for check digit calculation
347///
348/// # Returns
349/// Vector of generated codes
350///
351/// # Example
352///
353/// ```
354/// use promocrypt_core::{generate_batch, Alphabet, DammTable, CheckPosition};
355///
356/// let secret = [0u8; 32];
357/// let alphabet = Alphabet::default_alphabet();
358/// let damm = DammTable::new(alphabet.len());
359///
360/// let codes = generate_batch(&secret, 0, 100, &alphabet, 9, CheckPosition::End, &damm);
361/// assert_eq!(codes.len(), 100);
362/// ```
363pub fn generate_batch(
364    secret_key: &[u8; 32],
365    start_counter: u64,
366    count: usize,
367    alphabet: &Alphabet,
368    code_length: usize,
369    check_position: CheckPosition,
370    damm_table: &DammTable,
371) -> Vec<String> {
372    (0..count)
373        .map(|i| {
374            generate_code(
375                secret_key,
376                start_counter + i as u64,
377                alphabet,
378                code_length,
379                check_position,
380                damm_table,
381            )
382        })
383        .collect()
384}
385
386/// Generate codes into a pre-allocated vector for better performance.
387///
388/// This is useful when generating very large batches where allocation
389/// overhead matters.
390#[allow(clippy::too_many_arguments)]
391pub fn generate_batch_into(
392    secret_key: &[u8; 32],
393    start_counter: u64,
394    count: usize,
395    alphabet: &Alphabet,
396    code_length: usize,
397    check_position: CheckPosition,
398    damm_table: &DammTable,
399    output: &mut Vec<String>,
400) {
401    output.clear();
402    output.reserve(count);
403
404    for i in 0..count {
405        output.push(generate_code(
406            secret_key,
407            start_counter + i as u64,
408            alphabet,
409            code_length,
410            check_position,
411            damm_table,
412        ));
413    }
414}
415
416/// Iterator for generating codes lazily.
417///
418/// Useful for streaming large batches without holding all codes in memory.
419pub struct CodeGenerator<'a> {
420    secret_key: &'a [u8; 32],
421    alphabet: &'a Alphabet,
422    code_length: usize,
423    check_position: CheckPosition,
424    damm_table: &'a DammTable,
425    current_counter: u64,
426    end_counter: u64,
427}
428
429impl<'a> CodeGenerator<'a> {
430    /// Create a new code generator iterator.
431    pub fn new(
432        secret_key: &'a [u8; 32],
433        start_counter: u64,
434        count: usize,
435        alphabet: &'a Alphabet,
436        code_length: usize,
437        check_position: CheckPosition,
438        damm_table: &'a DammTable,
439    ) -> Self {
440        Self {
441            secret_key,
442            alphabet,
443            code_length,
444            check_position,
445            damm_table,
446            current_counter: start_counter,
447            end_counter: start_counter + count as u64,
448        }
449    }
450}
451
452impl Iterator for CodeGenerator<'_> {
453    type Item = String;
454
455    fn next(&mut self) -> Option<Self::Item> {
456        if self.current_counter >= self.end_counter {
457            return None;
458        }
459
460        let code = generate_code(
461            self.secret_key,
462            self.current_counter,
463            self.alphabet,
464            self.code_length,
465            self.check_position,
466            self.damm_table,
467        );
468
469        self.current_counter += 1;
470        Some(code)
471    }
472
473    fn size_hint(&self) -> (usize, Option<usize>) {
474        let remaining = (self.end_counter - self.current_counter) as usize;
475        (remaining, Some(remaining))
476    }
477}
478
479impl ExactSizeIterator for CodeGenerator<'_> {}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    fn setup() -> (Alphabet, DammTable, [u8; 32]) {
486        let alphabet = Alphabet::default_alphabet();
487        let damm = DammTable::new(alphabet.len());
488        let secret = [42u8; 32];
489        (alphabet, damm, secret)
490    }
491
492    #[test]
493    fn test_generate_code_length() {
494        let (alphabet, damm, secret) = setup();
495
496        let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
497        assert_eq!(code.len(), 10); // 9 random + 1 check
498    }
499
500    #[test]
501    fn test_generate_code_deterministic() {
502        let (alphabet, damm, secret) = setup();
503
504        let code1 = generate_code(&secret, 12345, &alphabet, 9, CheckPosition::End, &damm);
505        let code2 = generate_code(&secret, 12345, &alphabet, 9, CheckPosition::End, &damm);
506
507        assert_eq!(code1, code2);
508    }
509
510    #[test]
511    fn test_different_counters_different_codes() {
512        let (alphabet, damm, secret) = setup();
513
514        let code1 = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
515        let code2 = generate_code(&secret, 1, &alphabet, 9, CheckPosition::End, &damm);
516
517        assert_ne!(code1, code2);
518    }
519
520    #[test]
521    fn test_different_secrets_different_codes() {
522        let (alphabet, damm, _) = setup();
523
524        let secret1 = [1u8; 32];
525        let secret2 = [2u8; 32];
526
527        let code1 = generate_code(&secret1, 0, &alphabet, 9, CheckPosition::End, &damm);
528        let code2 = generate_code(&secret2, 0, &alphabet, 9, CheckPosition::End, &damm);
529
530        assert_ne!(code1, code2);
531    }
532
533    #[test]
534    fn test_check_position_start() {
535        let (alphabet, damm, secret) = setup();
536
537        let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::Start, &damm);
538        assert_eq!(code.len(), 10);
539
540        // Verify it validates correctly
541        assert!(damm.validate(&code, &alphabet));
542    }
543
544    #[test]
545    fn test_check_position_end() {
546        let (alphabet, damm, secret) = setup();
547
548        let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
549        assert_eq!(code.len(), 10);
550
551        // Verify it validates correctly
552        assert!(damm.validate(&code, &alphabet));
553    }
554
555    #[test]
556    fn test_all_characters_in_alphabet() {
557        let (alphabet, damm, secret) = setup();
558
559        let code = generate_code(&secret, 0, &alphabet, 9, CheckPosition::End, &damm);
560
561        for c in code.chars() {
562            assert!(alphabet.contains(c), "Character '{}' not in alphabet", c);
563        }
564    }
565
566    #[test]
567    fn test_generate_batch() {
568        let (alphabet, damm, secret) = setup();
569
570        let codes = generate_batch(&secret, 0, 100, &alphabet, 9, CheckPosition::End, &damm);
571
572        assert_eq!(codes.len(), 100);
573
574        // All codes should be unique
575        let unique: std::collections::HashSet<_> = codes.iter().collect();
576        assert_eq!(unique.len(), 100);
577
578        // All codes should be valid
579        for code in &codes {
580            assert!(damm.validate(code, &alphabet));
581        }
582    }
583
584    #[test]
585    fn test_code_generator_iterator() {
586        let (alphabet, damm, secret) = setup();
587
588        let code_gen = CodeGenerator::new(&secret, 0, 10, &alphabet, 9, CheckPosition::End, &damm);
589
590        let codes: Vec<_> = code_gen.collect();
591        assert_eq!(codes.len(), 10);
592
593        // Should match batch generation
594        let batch = generate_batch(&secret, 0, 10, &alphabet, 9, CheckPosition::End, &damm);
595        assert_eq!(codes, batch);
596    }
597
598    #[test]
599    fn test_code_generator_exact_size() {
600        let (alphabet, damm, secret) = setup();
601
602        let code_gen = CodeGenerator::new(&secret, 0, 50, &alphabet, 9, CheckPosition::End, &damm);
603        assert_eq!(code_gen.len(), 50);
604    }
605
606    #[test]
607    fn test_various_code_lengths() {
608        let (alphabet, damm, secret) = setup();
609
610        for length in [4, 6, 8, 9, 12, 16, 20] {
611            let code = generate_code(&secret, 0, &alphabet, length, CheckPosition::End, &damm);
612            assert_eq!(code.len(), length + 1);
613            assert!(damm.validate(&code, &alphabet));
614        }
615    }
616
617    #[test]
618    fn test_check_position_from_str() {
619        assert_eq!(
620            CheckPosition::parse_str("start"),
621            Some(CheckPosition::Start)
622        );
623        assert_eq!(CheckPosition::parse_str("end"), Some(CheckPosition::End));
624        assert_eq!(
625            CheckPosition::parse_str("beginning"),
626            Some(CheckPosition::Start)
627        );
628        assert_eq!(CheckPosition::parse_str("invalid"), None);
629    }
630}