luhn/
lib.rs

1/// ## Luhn
2///
3/// This create contains an implementation of the [Luhn checksum
4/// algorithm](https://en.wikipedia.org/wiki/Luhn_mod_N_algorithm).  For more
5/// information, see the documentation on the `Luhn` type.
6use std::collections::HashSet;
7use std::convert::AsRef;
8
9
10/// The error type for this crate.
11#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub enum LuhnError {
13    /// The given alphabet has a duplicated character.
14    NotUnique(char),
15
16    /// The input string has a character that is invalid for the alphabet.
17    InvalidCharacter(char),
18
19    /// The input was the empty string or a single character.
20    EmptyString,
21}
22
23/// Luhn represents a thing that can generate or validate the Luhn character for
24/// a given input.
25#[derive(Debug)]
26pub struct Luhn {
27    alphabet: Vec<char>,
28}
29
30impl Luhn {
31    /// Create a new Luhn instance from anything that can be coerced to a
32    /// `&str`.
33    pub fn new<S>(alphabet: S) -> Result<Luhn, LuhnError>
34        where S: AsRef<str>
35    {
36        let mut chars = alphabet.as_ref().chars().collect::<Vec<char>>();
37        if chars.len() < 1 {
38            return Err(LuhnError::EmptyString);
39        }
40
41        // Need to sort so binary_search works.
42        chars.sort();
43
44        // Validate uniqueness
45        let mut charset = HashSet::new();
46        for ch in chars.iter() {
47            if charset.contains(ch) {
48                return Err(LuhnError::NotUnique(*ch));
49            }
50
51            charset.insert(*ch);
52        }
53
54        Ok(Luhn { alphabet: chars })
55    }
56
57    #[inline]
58    fn codepoint_from_character(&self, ch: char) -> Result<usize, LuhnError> {
59        match self.alphabet.binary_search(&ch) {
60            Ok(idx) => Ok(idx),
61            Err(_) => Err(LuhnError::InvalidCharacter(ch)),
62        }
63    }
64
65    #[inline]
66    fn character_from_codepoint(&self, cp: usize) -> char {
67        self.alphabet[cp]
68    }
69
70    /// Given an input string, generate the Luhn character.
71    ///
72    /// Returns an error if the input string is empty, or contains a character
73    /// that is not in the input alphabet.
74    pub fn generate<S>(&self, s: S) -> Result<char, LuhnError>
75        where S: AsRef<str>
76    {
77        let s = s.as_ref();
78        if s.len() == 0 {
79            return Err(LuhnError::EmptyString);
80        }
81
82        let mut factor = 1;
83        let mut sum = 0;
84        let n = self.alphabet.len();
85
86        // Note: this is by-and-large a transliteration of the algorithm in the
87        // Wikipedia article into Rust:
88        //   https://en.wikipedia.org/wiki/Luhn_mod_N_algorithm
89        for ch in s.chars() {
90            let codepoint = try!(self.codepoint_from_character(ch));
91
92            let mut addend = factor * codepoint;
93            factor = if factor == 2 {
94                1
95            } else {
96                2
97            };
98            addend = (addend / n) + (addend % n);
99            sum += addend;
100        }
101
102        let remainder = sum % n;
103        let check_codepoint = (n - remainder) % n;
104
105        Ok(self.character_from_codepoint(check_codepoint))
106    }
107
108    /// Validates a Luhn check character.  This assumes that the final character
109    /// of the input string is the Luhn character, and it will validate that the
110    /// remainder of the string is correct.
111    pub fn validate<S>(&self, s: S) -> Result<bool, LuhnError>
112        where S: AsRef<str>
113    {
114        let s = s.as_ref();
115        if s.len() <= 1 {
116            return Err(LuhnError::EmptyString);
117        }
118
119        // Extract the check character and remainder of the string.
120        // TODO: can we do this without allocating a new String?
121        let head = s.char_indices()
122                    .take_while(|&(index, _)| index < s.len() - 1)
123                    .map(|(_, ch)| ch)
124                    .collect::<String>();
125        let luhn = s.chars().last().unwrap();
126
127        let expected = try!(self.generate(head));
128        Ok(luhn == expected)
129    }
130
131    /// Validates a Luhn check character.  This is the same as the `validate`
132    /// method, but allows providing the Luhn check character out-of-band from
133    /// the input to validate.
134    pub fn validate_with<S>(&self, s: S, check: char) -> Result<bool, LuhnError>
135        where S: AsRef<str>
136    {
137        let s = s.as_ref();
138        if s.len() <= 1 {
139            return Err(LuhnError::EmptyString);
140        }
141
142        let expected = try!(self.generate(s));
143        Ok(check == expected)
144    }
145}
146
147
148#[cfg(test)]
149mod tests {
150    extern crate rand;
151
152    use self::rand::{Isaac64Rng, Rng, SeedableRng, sample, thread_rng};
153
154    use super::{Luhn, LuhnError};
155
156    #[test]
157    fn test_generate() {
158        // Base 6
159        let l = Luhn::new("abcdef").ok().expect("valid alphabet");
160
161        match l.generate("abcdef") {
162            Ok(ch) => assert_eq!(ch, 'e'),
163            Err(e) => panic!("unexpected generate error: {:?}", e),
164        };
165
166        let l = Luhn::new("0123456789").ok().expect("valid alphabet");
167
168        match l.generate("7992739871") {
169            Ok(ch) => assert_eq!(ch, '3'),
170            Err(e) => panic!("unexpected generate error: {:?}", e),
171        };
172    }
173
174    #[test]
175    fn test_invalid_alphabet() {
176        match Luhn::new("abcdea") {
177            Ok(_) => panic!("unexpected success"),
178            Err(e) => assert_eq!(e, LuhnError::NotUnique('a')),
179        };
180    }
181
182    #[test]
183    fn test_invalid_input() {
184        let l = Luhn::new("abcdef").ok().expect("valid alphabet");
185
186        match l.generate("012345") {
187            Ok(_) => panic!("unexpected success"),
188            Err(e) => assert_eq!(e, LuhnError::InvalidCharacter('0')),
189        };
190    }
191
192    #[test]
193    fn test_validate() {
194        let l = Luhn::new("abcdef").ok().expect("valid alphabet");
195
196        assert!(l.validate("abcdefe").unwrap());
197        assert!(!l.validate("abcdefd").unwrap());
198    }
199
200    #[test]
201    fn test_empty_strings() {
202        // Alphabet must have at least one character.
203        assert_eq!(Luhn::new("").unwrap_err(), LuhnError::EmptyString);
204
205        let l = Luhn::new("abcdef").ok().expect("valid alphabet");
206
207        // Cannot generate on an empty string.
208        assert_eq!(l.generate("").unwrap_err(), LuhnError::EmptyString);
209
210        // Cannot validate a string of length 1 (since the last character is the check digit).
211        assert_eq!(l.validate("a").unwrap_err(), LuhnError::EmptyString);
212    }
213
214    #[test]
215    fn test_validate_with() {
216        let l = Luhn::new("abcdef").ok().expect("valid alphabet");
217
218        assert!(l.validate_with("abcdef", 'e').unwrap());
219        assert!(!l.validate_with("abcdef", 'd').unwrap());
220    }
221
222    #[test]
223    fn test_longer_input() {
224        // This test caught an out-of-bounds error.
225        let l = Luhn::new("abcdef").ok().expect("valid alphabet");
226        let _ = l.generate("aabbccdd");
227    }
228
229    #[test]
230    fn test_random_input() {
231        const NUM_TESTS: usize = 10000;
232        const PRINTABLE: &'static str = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTU\
233                                         VWXYZ";
234        let printable_chars = PRINTABLE.chars().collect::<Vec<char>>();
235
236        // Generate a random seed and print it
237        let seed: u64 = thread_rng().gen();
238        println!("Seed for this run: {}", seed);
239
240        // Create the seedable RNG with this seed.
241        let mut rng = Isaac64Rng::from_seed(&[seed]);
242
243        for i in 1..NUM_TESTS {
244            // Generate a random alphabet size
245            let alphabet_size: u8 = rng.gen_range(1, printable_chars.len() as u8);
246
247            // Create the alphabet by taking this many characters from our
248            // printable characters Vec.
249            let chars = sample(&mut rng, &printable_chars, alphabet_size as usize)
250                            .into_iter()
251                            .cloned()
252                            .collect::<Vec<char>>();
253            let alphabet = chars.iter().cloned().collect::<String>();
254
255            // Generate a random input length.
256            let input_length: u16 = rng.gen_range(1, 1024);
257
258            // Generate this many random characters.
259            let input = (0..input_length)
260                            .map(|_| *rng.choose(&*chars).unwrap())
261                            .collect::<String>();
262
263            // Validate that this succeeds.
264            let l = Luhn::new(&alphabet).ok().expect("valid alphabet");
265            if let Err(e) = l.generate(&input) {
266                println!("Alphabet = {}", alphabet);
267                println!("Input = {}", input);
268                panic!("{}: Unexpected error: {:?}", i, e);
269            }
270        }
271    }
272}