1use std::fmt;
2
3use aes::Aes256;
4use fpe::ff1::{BinaryNumeralString, FF1};
5use hkdf::Hkdf;
6use hmac::{Hmac, Mac};
7use sha2::Sha256;
8use uuid::Uuid;
9
10use crate::Config;
11
12type HmacSha256 = Hmac<Sha256>;
13
14#[derive(Debug, PartialEq)]
16pub enum Error {
17 DecodingFailed,
18 DecryptionFailed,
19 EncryptionFailed,
20 IncorrectMAC,
21 InvalidDataLength,
22 InvalidPrefix { received: String, expected: String },
23 SentinelMismatch { received: u8, expected: u8 },
24}
25
26impl fmt::Display for Error {
27 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
28 match self {
29 Error::DecodingFailed => {
30 write!(f, "Decoding string failed")
31 }
32 Error::DecryptionFailed => {
33 write!(f, "FF1 decryption failed")
34 }
35 Error::EncryptionFailed => {
36 write!(f, "FF1 encryption failed")
37 }
38 Error::IncorrectMAC => {
39 write!(f, "Incorrect MAC")
40 }
41 Error::InvalidDataLength => {
42 write!(f, "Invalid data length")
43 }
44 Error::SentinelMismatch { received, expected } => {
45 write!(f, "Sentinel byte was {received}, expected {expected}")
46 }
47 Error::InvalidPrefix { received, expected } => {
48 write!(f, "Prefix was {received}, expected {expected}")
49 }
50 }
51 }
52}
53
54impl From<base62::DecodeError> for Error {
55 fn from(_: base62::DecodeError) -> Error {
56 Error::DecodingFailed
57 }
58}
59
60impl std::error::Error for Error {}
61
62const MAX_BUFFER: usize = 16;
64
65const SENTINEL: u8 = 1;
67
68pub struct Codec {
70 ff1: FF1<Aes256>,
71 hmac: HmacSha256,
72 hmac_length: usize,
73 prefix: String,
74 zero_pad_length: usize,
75}
76
77impl Codec {
78 pub fn new(name: &str, config: &Config) -> Codec {
103 let hkdf = Hkdf::<Sha256>::new(None, &config.key);
104 let mut ff1_key = [0u8; 32];
105 let mut hmac_key = [0u8; 32];
106 hkdf.expand(format!("{name}/ff1").as_bytes(), &mut ff1_key)
107 .expect("Length 32 should be valid");
108 hkdf.expand(format!("{name}/hmac").as_bytes(), &mut hmac_key)
109 .expect("Length 32 should be valid");
110 Codec {
111 ff1: FF1::<Aes256>::new(&ff1_key, 2).expect("Radix 2 should be valid"),
112 hmac: HmacSha256::new_from_slice(&hmac_key).expect("Key length 32 should be valid"),
113 hmac_length: config.hmac_length as usize,
114 prefix: format!("{name}_"),
115 zero_pad_length: config.zero_pad_length as usize,
116 }
117 }
118
119 pub fn encode(&self, num: u64) -> String {
145 let encoded = base62::encode(self.encode_u128(num));
146 format!("{}{}", self.prefix, encoded)
147 }
148
149 fn encode_u128(&self, num: u64) -> u128 {
152 let bytes = encrypt_number(
153 &self.ff1,
154 &self.hmac,
155 self.hmac_length,
156 self.zero_pad_length,
157 num,
158 );
159 let mut num_array = [0u8; MAX_BUFFER];
160 num_array[..bytes.len()].copy_from_slice(&bytes);
161 if bytes.len() < num_array.len() {
162 num_array[bytes.len()] = SENTINEL;
163 }
164 u128::from_le_bytes(num_array)
165 }
166
167 pub fn encode_uuid(&self, num: u64) -> Uuid {
170 let vec = encrypt_number(&self.ff1, &self.hmac, 8, 8, num);
172 let num = u128::from_le_bytes(vec.try_into().expect("Should have exactly 16 bytes"));
173 Uuid::from_u128_le(num)
174 }
175
176 pub fn decode_uuid(&self, uuid: Uuid) -> Result<u64, Error> {
200 let bytes = uuid.to_u128_le().to_le_bytes();
201 decrypt_number(&self.ff1, &self.hmac, 8, 8, &bytes)
202 }
203
204 pub fn decode(&self, encoded: &str) -> Result<u64, Error> {
230 let received = match encoded.rfind('_') {
232 None => "".to_string(),
233 Some(i) => encoded[..i + 1].to_string(),
234 };
235 if received != self.prefix {
236 let expected = self.prefix.clone();
237 return Err(Error::InvalidPrefix { received, expected });
238 }
239
240 let tail = &encoded[self.prefix.len()..];
241 let num = base62::decode(tail).map_err(Error::from)?;
242 let num_array = num.to_le_bytes();
243
244 let length;
245 if self.hmac_length + self.zero_pad_length < MAX_BUFFER {
246 length = last_nonzero(&num_array);
247 if num_array[length] != SENTINEL {
248 return Err(Error::SentinelMismatch {
249 received: num_array[length],
250 expected: SENTINEL,
251 });
252 }
253 } else {
254 length = MAX_BUFFER;
255 }
256
257 decrypt_number(
258 &self.ff1,
259 &self.hmac,
260 self.hmac_length,
261 self.zero_pad_length,
262 &num_array[..length],
263 )
264 }
265}
266
267fn last_nonzero(bytes: &[u8]) -> usize {
268 bytes.iter().rposition(|&b| b != 0).unwrap_or(0)
269}
270
271fn num_to_le_vec(num: u64, min_length: usize) -> Vec<u8> {
274 let bytes = num.to_le_bytes();
275 let prefix_length = (last_nonzero(&bytes) + 1).max(min_length);
276 bytes[..prefix_length].to_vec()
277}
278
279fn le_vec_to_num(bytes: &[u8]) -> u64 {
280 let mut arr = [0; 8];
281 arr[..bytes.len()].copy_from_slice(bytes);
282 u64::from_le_bytes(arr)
283}
284
285fn encrypt_number(
286 ff1: &FF1<Aes256>,
287 hmac: &HmacSha256,
288 hmac_length: usize,
289 zero_pad_length: usize,
290 num: u64,
291) -> Vec<u8> {
292 let pt = num_to_le_vec(num, zero_pad_length);
294 let encrypted_num = ff1
295 .encrypt(&[], &BinaryNumeralString::from_bytes_le(&pt))
296 .expect("Radix 2 should be valid")
297 .to_bytes_le();
298
299 let mut hmac: HmacSha256 = hmac.clone();
301 hmac.update(&encrypted_num);
302 let truncated_mac = &hmac.finalize().into_bytes()[..hmac_length];
303
304 let mut result = encrypted_num.to_vec();
306 result.extend_from_slice(truncated_mac);
307
308 result
309}
310
311fn decrypt_number(
312 ff1: &FF1<Aes256>,
313 hmac: &HmacSha256,
314 hmac_length: usize,
315 zero_pad_length: usize,
316 encrypted_data: &[u8],
317) -> Result<u64, Error> {
318 if encrypted_data.len() < hmac_length + zero_pad_length {
319 return Err(Error::InvalidDataLength);
320 }
321 let (encrypted_num, received_mac) = encrypted_data.split_at(encrypted_data.len() - hmac_length);
322
323 let mut hmac_clone: HmacSha256 = hmac.clone();
325 hmac_clone.update(encrypted_num);
326 let truncated_mac = &hmac_clone.finalize().into_bytes()[..hmac_length];
327 if truncated_mac != received_mac {
328 return Err(Error::IncorrectMAC);
329 }
330
331 let decrypted_num = ff1
333 .decrypt(&[], &BinaryNumeralString::from_bytes_le(encrypted_num))
334 .map_err(|_| Error::DecryptionFailed)?;
335
336 let num: u64 = le_vec_to_num(&decrypted_num.to_bytes_le());
338 Ok(num)
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use rand::Rng;
345
346 #[test]
347 fn test_defaults() {
348 let codec = Codec::new("test", &Config::new(b"Test key here"));
349 let test_cases = vec![
350 (0, "test_g1HdsEGpXp5"),
351 (1, "test_bTPc8uxHEwv"),
352 (2, "test_dZ0iJdcLBgB"),
353 (123, "test_hHLBCl4rZ3u"),
354 (u64::MAX, "test_20cMzlnhTkILdJzWt"),
355 ];
356
357 for (input, expected) in test_cases {
358 assert_eq!(codec.encode(input), expected);
359 assert_eq!(codec.decode(expected).unwrap(), input);
360 }
361 }
362
363 #[test]
364 fn test_uuid() {
365 let codec = Codec::new("test", &Config::new(b"Test key here"));
366 let test_cases = [
367 (0, "59142369-adeb-8ef9-a1be-28f61c05d4d6"),
368 (1, "93196956-2d32-d8d2-54f7-9a86fc765f3a"),
369 (2, "3c10f25c-005e-6f6f-87a9-781efe02d14d"),
370 (123, "571fd9d5-e133-f7b0-b0df-f444e4dd1127"),
371 (u64::MAX, "a3b06cf5-dd4d-3f09-4000-9d3519d4d6c2"),
372 ];
373
374 for &(input, expected) in &test_cases {
375 let uuid = Uuid::parse_str(expected).unwrap();
376 assert_eq!(codec.encode_uuid(input), uuid);
377
378 let decoded = codec.decode_uuid(uuid).unwrap();
380 assert_eq!(decoded, input, "Failed to decode UUID for input: {}", input);
381 }
382 }
383
384 #[test]
385 fn test_uuid_roundtrip() {
386 let codec = Codec::new("test", &Config::new(b"Test key here"));
387 let mut rng = rand::rng();
388
389 for _ in 0..1_000 {
390 let number: u64 = rng.random();
391 let uuid = codec.encode_uuid(number);
392 let decoded = codec.decode_uuid(uuid).expect("Decoding failed");
393
394 assert_eq!(decoded, number, "Failed at number: {}", number);
395 }
396 }
397
398 #[test]
399 fn test_long() {
400 let config = Config::new(b"Test key here")
401 .hmac_length(8)
402 .unwrap()
403 .zero_pad_length(8)
404 .unwrap();
405 let codec = Codec::new("test", &config);
406 assert_eq!(codec.encode(0), "test_6XNFaHOCeuIBNvRT4pIrVZ");
407 assert_eq!(codec.encode(1), "test_1m9BJW23Jk5hSIlfPxoboZ");
408 assert_eq!(codec.encode(2), "test_2MpvWPgnp5j1dIqFnJVOjU");
409 assert_eq!(codec.encode(123), "test_1BirgT1ZJhfSsKFLgxA5gt");
410 assert_eq!(codec.encode(u64::MAX), "test_5vegfyOLrrmwtgznQByI4J");
411 assert_eq!(codec.decode("test_6XNFaHOCeuIBNvRT4pIrVZ").unwrap(), 0);
412 assert_eq!(codec.decode("test_1m9BJW23Jk5hSIlfPxoboZ").unwrap(), 1);
413 assert_eq!(codec.decode("test_2MpvWPgnp5j1dIqFnJVOjU").unwrap(), 2);
414 assert_eq!(codec.decode("test_1BirgT1ZJhfSsKFLgxA5gt").unwrap(), 123);
415 assert_eq!(
416 codec.decode("test_5vegfyOLrrmwtgznQByI4J").unwrap(),
417 u64::MAX
418 );
419 }
420
421 #[test]
422 fn test_short() {
423 let config = Config::new(b"Test key here")
424 .hmac_length(0)
425 .unwrap()
426 .zero_pad_length(3)
427 .unwrap();
428 let codec = Codec::new("test", &config);
429 assert_eq!(codec.encode(0), "test_1zG8O");
430 assert_eq!(codec.encode(1), "test_1R8PN");
431 assert_eq!(codec.encode(2), "test_1nzgo");
432 assert_eq!(codec.encode(123), "test_1YqNT");
433 assert_eq!(codec.encode(u64::MAX), "test_Mlu72Yai97j");
434 assert_eq!(codec.decode("test_1zG8O").unwrap(), 0);
435 assert_eq!(codec.decode("test_1R8PN").unwrap(), 1);
436 assert_eq!(codec.decode("test_1nzgo").unwrap(), 2);
437 assert_eq!(codec.decode("test_1YqNT").unwrap(), 123);
438 assert_eq!(codec.decode("test_Mlu72Yai97j").unwrap(), u64::MAX);
439
440 assert_eq!(codec.decode("test_1helloall").unwrap(), 20580488769766);
442 }
443
444 #[test]
445 fn test_decode_errors() {
446 let codec = Codec::new("test", &Config::new(b"Test key here"));
447
448 assert_eq!(
449 codec.decode("hHLBCl4rZ3u"),
450 Err(Error::InvalidPrefix {
451 received: "".to_string(),
452 expected: "test_".to_string()
453 })
454 );
455
456 assert_eq!(
457 codec.decode("_hHLBCl4rZ3u"),
458 Err(Error::InvalidPrefix {
459 received: "_".to_string(),
460 expected: "test_".to_string()
461 })
462 );
463
464 assert_eq!(
465 codec.decode("wrong_hHLBCl4rZ3u"),
466 Err(Error::InvalidPrefix {
467 received: "wrong_".to_string(),
468 expected: "test_".to_string()
469 })
470 );
471
472 assert_eq!(
473 codec.decode("test_iHLBCl4rZ3u"),
474 Err(Error::SentinelMismatch {
475 received: 2,
476 expected: SENTINEL,
477 })
478 );
479
480 assert_eq!(codec.decode("test_hHLBCl4rZ3v"), Err(Error::IncorrectMAC));
482 assert_eq!(codec.decode("test_hHMBCl4rZ3u"), Err(Error::IncorrectMAC));
483
484 assert_eq!(codec.decode("test_hHLBCl+rZ3u"), Err(Error::DecodingFailed));
486
487 assert_eq!(codec.decode("test_hHLBCl4rZ3u"), Ok(123));
489 }
490
491 #[test]
492 fn test_random_roundtrips() {
493 let codec = Codec::new("test", &Config::new(b"Test key here"));
494 let mut rng = rand::rng();
495
496 for _ in 0..10_000 {
497 let number: u64 = rng.random();
498 let encoded = codec.encode(number);
499 let decoded = codec.decode(&encoded).expect("Decoding failed");
500
501 assert_eq!(decoded, number, "Failed at number: {}", number);
502 }
503 }
504
505 #[test]
506 fn test_thread_local_config_isolation() {
507 use crate::{Config, Field, TypeMarker};
508 use std::thread;
509
510 #[derive(Clone, Copy, Debug)]
511 pub struct TestIdMarker;
512 impl TypeMarker for TestIdMarker {
513 fn name() -> &'static str {
514 "test"
515 }
516 }
517
518 type TestId = Field<TestIdMarker>;
519
520 Config::set_global(Config::new(b"global-key-16bytes"));
522
523 let id = TestId::from(123);
525 let global_encoded = id.to_string();
526
527 let handle = thread::spawn(|| {
529 Config::set_thread_local(Config::new(b"thread-key-16bytes"));
530
531 let id = TestId::from(123);
532 id.to_string()
533 });
534
535 let thread_result = handle.join().unwrap();
536
537 assert_ne!(
539 global_encoded, thread_result,
540 "Thread-local config should produce different encoding than global config"
541 );
542
543 let main_again = TestId::from(123).to_string();
545 assert_eq!(
546 global_encoded, main_again,
547 "Main thread should still use global config"
548 );
549 }
550}