1use std::collections::HashSet;
7use std::convert::AsRef;
8
9
10#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub enum LuhnError {
13 NotUnique(char),
15
16 InvalidCharacter(char),
18
19 EmptyString,
21}
22
23#[derive(Debug)]
26pub struct Luhn {
27 alphabet: Vec<char>,
28}
29
30impl Luhn {
31 pub fn new<S>(alphabet: S) -> Result<Luhn, LuhnError>
34 where S: AsRef<str>
35 {
36 let mut chars = alphabet.as_ref().chars().collect::<Vec<char>>();
37 if chars.len() < 1 {
38 return Err(LuhnError::EmptyString);
39 }
40
41 chars.sort();
43
44 let mut charset = HashSet::new();
46 for ch in chars.iter() {
47 if charset.contains(ch) {
48 return Err(LuhnError::NotUnique(*ch));
49 }
50
51 charset.insert(*ch);
52 }
53
54 Ok(Luhn { alphabet: chars })
55 }
56
57 #[inline]
58 fn codepoint_from_character(&self, ch: char) -> Result<usize, LuhnError> {
59 match self.alphabet.binary_search(&ch) {
60 Ok(idx) => Ok(idx),
61 Err(_) => Err(LuhnError::InvalidCharacter(ch)),
62 }
63 }
64
65 #[inline]
66 fn character_from_codepoint(&self, cp: usize) -> char {
67 self.alphabet[cp]
68 }
69
70 pub fn generate<S>(&self, s: S) -> Result<char, LuhnError>
75 where S: AsRef<str>
76 {
77 let s = s.as_ref();
78 if s.len() == 0 {
79 return Err(LuhnError::EmptyString);
80 }
81
82 let mut factor = 1;
83 let mut sum = 0;
84 let n = self.alphabet.len();
85
86 for ch in s.chars() {
90 let codepoint = try!(self.codepoint_from_character(ch));
91
92 let mut addend = factor * codepoint;
93 factor = if factor == 2 {
94 1
95 } else {
96 2
97 };
98 addend = (addend / n) + (addend % n);
99 sum += addend;
100 }
101
102 let remainder = sum % n;
103 let check_codepoint = (n - remainder) % n;
104
105 Ok(self.character_from_codepoint(check_codepoint))
106 }
107
108 pub fn validate<S>(&self, s: S) -> Result<bool, LuhnError>
112 where S: AsRef<str>
113 {
114 let s = s.as_ref();
115 if s.len() <= 1 {
116 return Err(LuhnError::EmptyString);
117 }
118
119 let head = s.char_indices()
122 .take_while(|&(index, _)| index < s.len() - 1)
123 .map(|(_, ch)| ch)
124 .collect::<String>();
125 let luhn = s.chars().last().unwrap();
126
127 let expected = try!(self.generate(head));
128 Ok(luhn == expected)
129 }
130
131 pub fn validate_with<S>(&self, s: S, check: char) -> Result<bool, LuhnError>
135 where S: AsRef<str>
136 {
137 let s = s.as_ref();
138 if s.len() <= 1 {
139 return Err(LuhnError::EmptyString);
140 }
141
142 let expected = try!(self.generate(s));
143 Ok(check == expected)
144 }
145}
146
147
148#[cfg(test)]
149mod tests {
150 extern crate rand;
151
152 use self::rand::{Isaac64Rng, Rng, SeedableRng, sample, thread_rng};
153
154 use super::{Luhn, LuhnError};
155
156 #[test]
157 fn test_generate() {
158 let l = Luhn::new("abcdef").ok().expect("valid alphabet");
160
161 match l.generate("abcdef") {
162 Ok(ch) => assert_eq!(ch, 'e'),
163 Err(e) => panic!("unexpected generate error: {:?}", e),
164 };
165
166 let l = Luhn::new("0123456789").ok().expect("valid alphabet");
167
168 match l.generate("7992739871") {
169 Ok(ch) => assert_eq!(ch, '3'),
170 Err(e) => panic!("unexpected generate error: {:?}", e),
171 };
172 }
173
174 #[test]
175 fn test_invalid_alphabet() {
176 match Luhn::new("abcdea") {
177 Ok(_) => panic!("unexpected success"),
178 Err(e) => assert_eq!(e, LuhnError::NotUnique('a')),
179 };
180 }
181
182 #[test]
183 fn test_invalid_input() {
184 let l = Luhn::new("abcdef").ok().expect("valid alphabet");
185
186 match l.generate("012345") {
187 Ok(_) => panic!("unexpected success"),
188 Err(e) => assert_eq!(e, LuhnError::InvalidCharacter('0')),
189 };
190 }
191
192 #[test]
193 fn test_validate() {
194 let l = Luhn::new("abcdef").ok().expect("valid alphabet");
195
196 assert!(l.validate("abcdefe").unwrap());
197 assert!(!l.validate("abcdefd").unwrap());
198 }
199
200 #[test]
201 fn test_empty_strings() {
202 assert_eq!(Luhn::new("").unwrap_err(), LuhnError::EmptyString);
204
205 let l = Luhn::new("abcdef").ok().expect("valid alphabet");
206
207 assert_eq!(l.generate("").unwrap_err(), LuhnError::EmptyString);
209
210 assert_eq!(l.validate("a").unwrap_err(), LuhnError::EmptyString);
212 }
213
214 #[test]
215 fn test_validate_with() {
216 let l = Luhn::new("abcdef").ok().expect("valid alphabet");
217
218 assert!(l.validate_with("abcdef", 'e').unwrap());
219 assert!(!l.validate_with("abcdef", 'd').unwrap());
220 }
221
222 #[test]
223 fn test_longer_input() {
224 let l = Luhn::new("abcdef").ok().expect("valid alphabet");
226 let _ = l.generate("aabbccdd");
227 }
228
229 #[test]
230 fn test_random_input() {
231 const NUM_TESTS: usize = 10000;
232 const PRINTABLE: &'static str = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTU\
233 VWXYZ";
234 let printable_chars = PRINTABLE.chars().collect::<Vec<char>>();
235
236 let seed: u64 = thread_rng().gen();
238 println!("Seed for this run: {}", seed);
239
240 let mut rng = Isaac64Rng::from_seed(&[seed]);
242
243 for i in 1..NUM_TESTS {
244 let alphabet_size: u8 = rng.gen_range(1, printable_chars.len() as u8);
246
247 let chars = sample(&mut rng, &printable_chars, alphabet_size as usize)
250 .into_iter()
251 .cloned()
252 .collect::<Vec<char>>();
253 let alphabet = chars.iter().cloned().collect::<String>();
254
255 let input_length: u16 = rng.gen_range(1, 1024);
257
258 let input = (0..input_length)
260 .map(|_| *rng.choose(&*chars).unwrap())
261 .collect::<String>();
262
263 let l = Luhn::new(&alphabet).ok().expect("valid alphabet");
265 if let Err(e) = l.generate(&input) {
266 println!("Alphabet = {}", alphabet);
267 println!("Input = {}", input);
268 panic!("{}: Unexpected error: {:?}", i, e);
269 }
270 }
271 }
272}