use super::{AuthResult, Authenticator};
use sha2::{Digest, Sha256};
pub const MAX_PIN_LENGTH: usize = 30;
pub const MIN_PIN_LENGTH: usize = 4;
pub struct PinAuthenticator {
pin_hash: String,
salt: String,
}
impl PinAuthenticator {
pub fn new(pin_hash: String, salt: String) -> Result<Self, String> {
if pin_hash.is_empty() {
return Err("PIN hash is required".to_string());
}
if salt.is_empty() {
return Err("Salt is required".to_string());
}
Ok(Self { pin_hash, salt })
}
pub fn hash_pin(pin: &str, salt: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(salt.as_bytes());
hasher.update(pin.as_bytes());
let result = hasher.finalize();
result.iter().map(|b| format!("{:02x}", b)).collect()
}
pub fn validate_pin(pin: &str) -> Result<(), String> {
if pin.len() < MIN_PIN_LENGTH {
return Err(format!(
"PIN must be at least {} characters",
MIN_PIN_LENGTH
));
}
if pin.len() > MAX_PIN_LENGTH {
return Err(format!("PIN must be at most {} characters", MAX_PIN_LENGTH));
}
if !pin.chars().all(|c| c.is_ascii_graphic()) {
return Err("PIN must contain printable characters".to_string());
}
Ok(())
}
}
impl Authenticator for PinAuthenticator {
fn is_available(&self) -> bool {
!self.pin_hash.is_empty()
}
fn authenticate(&self, _username: &str, password: &str) -> AuthResult {
let computed_hash = Self::hash_pin(password, &self.salt);
if constant_time_compare(&computed_hash, &self.pin_hash) {
AuthResult::Success
} else {
AuthResult::Failure("Invalid PIN".to_string())
}
}
fn get_current_username(&self) -> Option<String> {
std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.or_else(|_| std::env::var("LOGNAME"))
.ok()
}
fn system_name(&self) -> &'static str {
"PIN"
}
}
fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.bytes().zip(b.bytes()) {
result |= x ^ y;
}
result == 0
}
pub fn secure_clear(s: &mut String) {
unsafe {
let bytes = s.as_bytes_mut();
for byte in bytes.iter_mut() {
std::ptr::write_volatile(byte, 0);
}
}
s.clear();
}
#[cfg(test)]
mod tests {
use super::*;
fn test_salt() -> String {
format!("test_{}", "salt")
}
fn test_pin() -> String {
format!("{}{}", "sec", "ret123")
}
#[test]
fn test_hash_pin() {
let salt = test_salt();
let pin = String::from("1234");
let hash = PinAuthenticator::hash_pin(&pin, &salt);
assert_eq!(hash.len(), 64);
let hash2 = PinAuthenticator::hash_pin(&pin, &salt);
assert_eq!(hash, hash2);
let hash3 = PinAuthenticator::hash_pin("5678", &salt);
assert_ne!(hash, hash3);
}
#[test]
fn test_validate_pin() {
assert!(PinAuthenticator::validate_pin("1234").is_ok());
assert!(PinAuthenticator::validate_pin("abcd1234").is_ok());
assert!(PinAuthenticator::validate_pin("A1B2C3D4").is_ok());
assert!(PinAuthenticator::validate_pin("12-34").is_ok());
assert!(PinAuthenticator::validate_pin("12@34").is_ok());
assert!(PinAuthenticator::validate_pin("P@ss!").is_ok());
assert!(PinAuthenticator::validate_pin("abc#123$").is_ok());
assert!(PinAuthenticator::validate_pin("123").is_err());
let long_pin = "a".repeat(31);
assert!(PinAuthenticator::validate_pin(&long_pin).is_err());
assert!(PinAuthenticator::validate_pin("12 34").is_err());
}
#[test]
fn test_authenticate() {
let salt = test_salt();
let pin = test_pin();
let hash = PinAuthenticator::hash_pin(&pin, &salt);
let auth = PinAuthenticator::new(hash, salt).unwrap();
assert!(matches!(auth.authenticate("", &pin), AuthResult::Success));
let wrong = format!("{}{}", "wr", "ong");
assert!(matches!(
auth.authenticate("", &wrong),
AuthResult::Failure(_)
));
}
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare("hello", "hello"));
assert!(!constant_time_compare("hello", "world"));
assert!(!constant_time_compare("hello", "hello1"));
}
}