use std::collections::HashSet;
use crate::error::Result;
use crate::password::{hash_password, verify_password};
use crate::random::generate_random_bytes;
#[derive(Debug, Clone)]
pub struct RecoveryConfig {
pub code_count: usize,
pub group_length: usize,
pub group_count: usize,
pub separator: char,
pub hash_codes: bool,
}
impl Default for RecoveryConfig {
fn default() -> Self {
Self {
code_count: 10,
group_length: 4,
group_count: 2,
separator: '-',
hash_codes: true,
}
}
}
impl RecoveryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_code_count(mut self, count: usize) -> Self {
assert!(
count > 0 && count <= 20,
"code count must be between 1 and 20"
);
self.code_count = count;
self
}
pub fn with_group_length(mut self, length: usize) -> Self {
assert!(
(4..=8).contains(&length),
"group length must be between 4 and 8"
);
self.group_length = length;
self
}
pub fn with_group_count(mut self, count: usize) -> Self {
assert!(
(1..=4).contains(&count),
"group count must be between 1 and 4"
);
self.group_count = count;
self
}
pub fn with_separator(mut self, separator: char) -> Self {
self.separator = separator;
self
}
pub fn with_hash(mut self, hash: bool) -> Self {
self.hash_codes = hash;
self
}
pub fn high_security() -> Self {
Self {
code_count: 16,
group_length: 5,
group_count: 3,
separator: '-',
hash_codes: true,
}
}
pub fn simple() -> Self {
Self {
code_count: 5,
group_length: 4,
group_count: 2,
separator: '-',
hash_codes: false,
}
}
}
#[derive(Debug, Clone)]
pub struct RecoveryCodeSet {
pub plain_codes: Vec<String>,
pub hashed_codes: Vec<String>,
pub generated_at: i64,
}
#[derive(Debug, Clone)]
pub struct RecoveryCodeStatus {
pub total: usize,
pub used: usize,
pub remaining: usize,
}
#[derive(Debug, Clone)]
pub struct RecoveryCodeManager {
config: RecoveryConfig,
}
impl RecoveryCodeManager {
pub fn new(config: RecoveryConfig) -> Self {
Self { config }
}
pub fn default_manager() -> Self {
Self::new(RecoveryConfig::default())
}
pub fn generate(&self) -> Result<RecoveryCodeSet> {
let mut plain_codes = Vec::with_capacity(self.config.code_count);
let mut seen = HashSet::new();
while plain_codes.len() < self.config.code_count {
let code = self.generate_single_code()?;
if seen.insert(code.clone()) {
plain_codes.push(code);
}
}
let hashed_codes = if self.config.hash_codes {
self.hash_codes(&plain_codes)?
} else {
plain_codes.clone()
};
Ok(RecoveryCodeSet {
plain_codes,
hashed_codes,
generated_at: chrono::Utc::now().timestamp(),
})
}
pub fn verify(&self, code: &str, stored_codes: &[String]) -> Result<Option<usize>> {
let normalized = self.normalize_code(code);
if self.config.hash_codes {
for (index, hashed) in stored_codes.iter().enumerate() {
if verify_password(&normalized, hashed)? {
return Ok(Some(index));
}
}
} else {
for (index, stored) in stored_codes.iter().enumerate() {
let stored_normalized = self.normalize_code(stored);
if constant_time_eq(normalized.as_bytes(), stored_normalized.as_bytes()) {
return Ok(Some(index));
}
}
}
Ok(None)
}
pub fn verify_and_consume(
&self,
code: &str,
stored_codes: &[String],
) -> Result<(bool, Vec<String>)> {
match self.verify(code, stored_codes)? {
Some(index) => {
let mut remaining = stored_codes.to_vec();
remaining.remove(index);
Ok((true, remaining))
}
None => Ok((false, stored_codes.to_vec())),
}
}
pub fn get_status(&self, stored_codes: &[String]) -> RecoveryCodeStatus {
let remaining = stored_codes.len();
RecoveryCodeStatus {
total: self.config.code_count,
used: self.config.code_count.saturating_sub(remaining),
remaining,
}
}
pub fn regenerate(&self) -> Result<RecoveryCodeSet> {
self.generate()
}
pub fn format_for_display(&self, codes: &[String]) -> String {
codes
.iter()
.enumerate()
.map(|(i, code)| format!("{:2}. {}", i + 1, code))
.collect::<Vec<_>>()
.join("\n")
}
pub fn config(&self) -> &RecoveryConfig {
&self.config
}
fn generate_single_code(&self) -> Result<String> {
const CHARSET: &[u8] = b"23456789ABCDEFGHJKLMNPQRSTUVWXYZ";
let total_chars = self.config.group_length * self.config.group_count;
let bytes = generate_random_bytes(total_chars)?;
let mut code = String::with_capacity(total_chars + self.config.group_count - 1);
for (i, byte) in bytes.iter().enumerate() {
if i > 0 && i % self.config.group_length == 0 {
code.push(self.config.separator);
}
let char_index = (*byte as usize) % CHARSET.len();
code.push(CHARSET[char_index] as char);
}
Ok(code)
}
fn hash_codes(&self, codes: &[String]) -> Result<Vec<String>> {
codes
.iter()
.map(|code| {
let normalized = self.normalize_code(code);
hash_password(&normalized)
})
.collect()
}
fn normalize_code(&self, code: &str) -> String {
code.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_uppercase()
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
pub fn generate_recovery_codes() -> Result<RecoveryCodeSet> {
RecoveryCodeManager::default_manager().generate()
}
pub fn generate_recovery_codes_with_count(count: usize) -> Result<RecoveryCodeSet> {
let config = RecoveryConfig::default().with_code_count(count);
RecoveryCodeManager::new(config).generate()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recovery_config_default() {
let config = RecoveryConfig::default();
assert_eq!(config.code_count, 10);
assert_eq!(config.group_length, 4);
assert_eq!(config.group_count, 2);
assert_eq!(config.separator, '-');
assert!(config.hash_codes);
}
#[test]
fn test_recovery_config_builder() {
let config = RecoveryConfig::new()
.with_code_count(8)
.with_group_length(5)
.with_group_count(3)
.with_separator('_')
.with_hash(false);
assert_eq!(config.code_count, 8);
assert_eq!(config.group_length, 5);
assert_eq!(config.group_count, 3);
assert_eq!(config.separator, '_');
assert!(!config.hash_codes);
}
#[test]
fn test_generate_recovery_codes() {
let manager = RecoveryCodeManager::default_manager();
let code_set = manager.generate().unwrap();
assert_eq!(code_set.plain_codes.len(), 10);
assert_eq!(code_set.hashed_codes.len(), 10);
for code in &code_set.plain_codes {
assert_eq!(code.len(), 9); assert!(code.contains('-'));
}
}
#[test]
fn test_codes_are_unique() {
let manager = RecoveryCodeManager::default_manager();
let code_set = manager.generate().unwrap();
let unique: HashSet<_> = code_set.plain_codes.iter().collect();
assert_eq!(unique.len(), code_set.plain_codes.len());
}
#[test]
fn test_verify_plain_code() {
let config = RecoveryConfig::default().with_hash(false);
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
let result = manager
.verify(&code_set.plain_codes[0], &code_set.hashed_codes)
.unwrap();
assert_eq!(result, Some(0));
let result = manager
.verify("INVALID-CODE", &code_set.hashed_codes)
.unwrap();
assert!(result.is_none());
}
#[test]
fn test_verify_hashed_code() {
let manager = RecoveryCodeManager::default_manager();
let code_set = manager.generate().unwrap();
let result = manager
.verify(&code_set.plain_codes[0], &code_set.hashed_codes)
.unwrap();
assert_eq!(result, Some(0));
}
#[test]
fn test_verify_case_insensitive() {
let config = RecoveryConfig::default().with_hash(false);
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
let lowercase = code_set.plain_codes[0].to_lowercase();
let result = manager.verify(&lowercase, &code_set.hashed_codes).unwrap();
assert_eq!(result, Some(0));
}
#[test]
fn test_verify_without_separator() {
let config = RecoveryConfig::default().with_hash(false);
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
let without_sep = code_set.plain_codes[0].replace('-', "");
let result = manager
.verify(&without_sep, &code_set.hashed_codes)
.unwrap();
assert_eq!(result, Some(0));
}
#[test]
fn test_verify_and_consume() {
let config = RecoveryConfig::default().with_hash(false);
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
let (valid, remaining) = manager
.verify_and_consume(&code_set.plain_codes[0], &code_set.hashed_codes)
.unwrap();
assert!(valid);
assert_eq!(remaining.len(), 9);
let (valid, _) = manager
.verify_and_consume(&code_set.plain_codes[0], &remaining)
.unwrap();
assert!(!valid);
}
#[test]
fn test_get_status() {
let manager = RecoveryCodeManager::default_manager();
let code_set = manager.generate().unwrap();
let status = manager.get_status(&code_set.hashed_codes);
assert_eq!(status.total, 10);
assert_eq!(status.used, 0);
assert_eq!(status.remaining, 10);
let partial = code_set.hashed_codes[..7].to_vec();
let status = manager.get_status(&partial);
assert_eq!(status.remaining, 7);
assert_eq!(status.used, 3);
}
#[test]
fn test_format_for_display() {
let config = RecoveryConfig::default().with_code_count(3);
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
let display = manager.format_for_display(&code_set.plain_codes);
assert!(display.contains(" 1. "));
assert!(display.contains(" 2. "));
assert!(display.contains(" 3. "));
}
#[test]
fn test_config_presets() {
let high_sec = RecoveryConfig::high_security();
assert_eq!(high_sec.code_count, 16);
assert_eq!(high_sec.group_length, 5);
assert_eq!(high_sec.group_count, 3);
let simple = RecoveryConfig::simple();
assert_eq!(simple.code_count, 5);
assert!(!simple.hash_codes);
}
#[test]
fn test_convenience_functions() {
let code_set = generate_recovery_codes().unwrap();
assert_eq!(code_set.plain_codes.len(), 10);
let code_set = generate_recovery_codes_with_count(5).unwrap();
assert_eq!(code_set.plain_codes.len(), 5);
}
#[test]
fn test_no_confusing_characters() {
let manager = RecoveryCodeManager::default_manager();
let code_set = manager.generate().unwrap();
let confusing = ['0', 'O', 'I', 'l', '1'];
for code in &code_set.plain_codes {
for ch in confusing {
assert!(
!code.contains(ch),
"Code {} contains confusing character {}",
code,
ch
);
}
}
}
#[test]
fn test_regenerate() {
let manager = RecoveryCodeManager::default_manager();
let code_set1 = manager.generate().unwrap();
let code_set2 = manager.regenerate().unwrap();
assert_ne!(code_set1.plain_codes, code_set2.plain_codes);
}
#[test]
fn test_different_group_configs() {
let config = RecoveryConfig::default()
.with_group_length(5)
.with_group_count(3);
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
for code in &code_set.plain_codes {
assert_eq!(code.len(), 17);
assert_eq!(code.matches('-').count(), 2);
}
}
#[test]
fn test_custom_separator() {
let config = RecoveryConfig::default().with_separator('_');
let manager = RecoveryCodeManager::new(config);
let code_set = manager.generate().unwrap();
for code in &code_set.plain_codes {
assert!(code.contains('_'));
assert!(!code.contains('-'));
}
}
#[test]
#[should_panic(expected = "code count must be between 1 and 20")]
fn test_invalid_code_count() {
RecoveryConfig::default().with_code_count(0);
}
#[test]
#[should_panic(expected = "group length must be between 4 and 8")]
fn test_invalid_group_length() {
RecoveryConfig::default().with_group_length(3);
}
#[test]
#[should_panic(expected = "group count must be between 1 and 4")]
fn test_invalid_group_count() {
RecoveryConfig::default().with_group_count(5);
}
}