use anyhow::{Context, Result};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use ring::{hmac, pbkdf2, rand as ring_rand};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::path::Path;
use std::sync::Arc;
const PBKDF2_ITERATIONS: u32 = 100_000;
const SALT_LENGTH: usize = 16;
const CREDENTIAL_LENGTH: usize = 32;
const DEFAULT_TOKEN_EXPIRY_SECS: u64 = 86400;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserCredential {
pub username: String,
pub password_hash: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TokenPayload {
pub sub: String,
pub iat: u64,
pub exp: u64,
pub jti: String,
}
#[derive(Clone)]
pub struct AuthService {
credentials: Arc<HashMap<String, UserCredential>>,
signing_key: Arc<hmac::Key>,
token_expiry_secs: u64,
enabled: bool,
}
impl std::fmt::Debug for AuthService {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthService")
.field("credentials_count", &self.credentials.len())
.field("token_expiry_secs", &self.token_expiry_secs)
.field("enabled", &self.enabled)
.finish()
}
}
impl AuthService {
pub fn new(
credentials_file: Option<&Path>,
token_secret: Option<&str>,
token_expiry_secs: Option<u64>,
enabled: bool,
) -> Result<Self> {
let credentials = if let Some(path) = credentials_file {
Self::load_credentials(path)?
} else {
HashMap::new()
};
let signing_key = if let Some(secret) = token_secret {
hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes())
} else {
let rng = ring_rand::SystemRandom::new();
hmac::Key::generate(hmac::HMAC_SHA256, &rng)
.map_err(|_| anyhow::anyhow!("Failed to generate signing key"))?
};
Ok(Self {
credentials: Arc::new(credentials),
signing_key: Arc::new(signing_key),
token_expiry_secs: token_expiry_secs.unwrap_or(DEFAULT_TOKEN_EXPIRY_SECS),
enabled,
})
}
pub fn disabled() -> Self {
Self {
credentials: Arc::new(HashMap::new()),
signing_key: Arc::new(hmac::Key::new(hmac::HMAC_SHA256, b"disabled-auth-not-used")),
token_expiry_secs: DEFAULT_TOKEN_EXPIRY_SECS,
enabled: false,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
fn load_credentials(path: &Path) -> Result<HashMap<String, UserCredential>> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read credentials file: {}", path.display()))?;
let credentials: Vec<UserCredential> = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse credentials file: {}", path.display()))?;
let mut map = HashMap::new();
for cred in credentials {
map.insert(cred.username.clone(), cred);
}
tracing::info!("Loaded {} user credentials", map.len());
Ok(map)
}
pub fn verify_password(&self, username: &str, password: &str) -> bool {
let Some(credential) = self.credentials.get(username) else {
return false;
};
let Ok(stored_bytes) = URL_SAFE_NO_PAD.decode(&credential.password_hash) else {
tracing::warn!("Invalid base64 in password hash for user: {}", username);
return false;
};
if stored_bytes.len() != SALT_LENGTH + CREDENTIAL_LENGTH {
tracing::warn!("Invalid password hash length for user: {}", username);
return false;
}
let (salt, stored_hash) = stored_bytes.split_at(SALT_LENGTH);
pbkdf2::verify(
pbkdf2::PBKDF2_HMAC_SHA256,
NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
salt,
password.as_bytes(),
stored_hash,
)
.is_ok()
}
pub fn generate_token(&self, username: &str) -> Result<String> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.context("System time before Unix epoch")?
.as_secs();
let payload = TokenPayload {
sub: username.to_string(),
iat: now,
exp: now + self.token_expiry_secs,
jti: uuid::Uuid::new_v4().to_string(),
};
let payload_json = serde_json::to_string(&payload)?;
let payload_b64 = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
let signature = hmac::sign(&self.signing_key, payload_b64.as_bytes());
let signature_b64 = URL_SAFE_NO_PAD.encode(signature.as_ref());
Ok(format!("{}.{}", payload_b64, signature_b64))
}
pub fn validate_token(&self, token: &str) -> Option<String> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return None;
}
let payload_b64 = parts[0];
let signature_b64 = parts[1];
let Ok(signature_bytes) = URL_SAFE_NO_PAD.decode(signature_b64) else {
return None;
};
if hmac::verify(&self.signing_key, payload_b64.as_bytes(), &signature_bytes).is_err() {
return None;
}
let Ok(payload_json) = URL_SAFE_NO_PAD.decode(payload_b64) else {
return None;
};
let Ok(payload): Result<TokenPayload, _> = serde_json::from_slice(&payload_json) else {
return None;
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?
.as_secs();
if now > payload.exp {
return None;
}
Some(payload.sub)
}
pub fn hash_password(password: &str) -> Result<String> {
let rng = ring_rand::SystemRandom::new();
let mut salt = [0u8; SALT_LENGTH];
ring_rand::SecureRandom::fill(&rng, &mut salt)
.map_err(|_| anyhow::anyhow!("Failed to generate salt"))?;
let mut derived_key = [0u8; CREDENTIAL_LENGTH];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
&salt,
password.as_bytes(),
&mut derived_key,
);
let mut combined = Vec::with_capacity(SALT_LENGTH + CREDENTIAL_LENGTH);
combined.extend_from_slice(&salt);
combined.extend_from_slice(&derived_key);
Ok(URL_SAFE_NO_PAD.encode(&combined))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub token: String,
pub token_type: String,
pub expires_in: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_password_hashing() {
let password = "my_secret_password";
let hash = AuthService::hash_password(password).unwrap();
let decoded = URL_SAFE_NO_PAD.decode(&hash).unwrap();
assert_eq!(decoded.len(), SALT_LENGTH + CREDENTIAL_LENGTH);
}
#[test]
fn test_password_verification() {
let password = "test_password_123";
let hash = AuthService::hash_password(password).unwrap();
let credentials = vec![UserCredential {
username: "testuser".to_string(),
password_hash: hash,
}];
let mut file = NamedTempFile::new().unwrap();
write!(file, "{}", serde_json::to_string(&credentials).unwrap()).unwrap();
let auth =
AuthService::new(Some(file.path()), Some("test_secret"), Some(3600), true).unwrap();
assert!(auth.verify_password("testuser", password));
assert!(!auth.verify_password("testuser", "wrong_password"));
assert!(!auth.verify_password("unknown", password));
}
#[test]
fn test_token_generation_and_validation() {
let auth = AuthService::new(None, Some("test_secret"), Some(3600), true).unwrap();
let token = auth.generate_token("testuser").unwrap();
let username = auth.validate_token(&token);
assert_eq!(username, Some("testuser".to_string()));
assert!(auth.validate_token("invalid.token").is_none());
assert!(auth.validate_token("notavalidtoken").is_none());
}
#[test]
fn test_expired_token() {
let auth = AuthService::new(None, Some("test_secret"), Some(0), true).unwrap();
let token = auth.generate_token("testuser").unwrap();
std::thread::sleep(std::time::Duration::from_millis(1100));
assert!(auth.validate_token(&token).is_none());
}
#[test]
fn test_disabled_auth() {
let auth = AuthService::disabled();
assert!(!auth.is_enabled());
}
#[test]
fn test_token_tampering() {
let auth = AuthService::new(None, Some("test_secret"), Some(3600), true).unwrap();
let token = auth.generate_token("testuser").unwrap();
let parts: Vec<&str> = token.split('.').collect();
let tampered_payload = URL_SAFE_NO_PAD
.encode(b"{\"sub\":\"admin\",\"iat\":0,\"exp\":9999999999,\"jti\":\"fake\"}");
let tampered_token = format!("{}.{}", tampered_payload, parts[1]);
assert!(auth.validate_token(&tampered_token).is_none());
}
}