#![allow(deprecated)]
use crate::{AuthenticationBackend, AuthenticationError, SimpleUser, User};
use reinhardt_http::Request;
use std::collections::HashMap;
use std::sync::Arc;
use subtle::ConstantTimeEq;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::USER_ID_NAMESPACE;
pub struct MFAAuthentication {
issuer: String,
secrets: Arc<Mutex<HashMap<String, String>>>,
time_window: u64,
}
impl MFAAuthentication {
pub fn new(issuer: impl Into<String>) -> Self {
Self {
issuer: issuer.into(),
secrets: Arc::new(Mutex::new(HashMap::new())),
time_window: 30,
}
}
pub fn time_window(mut self, seconds: u64) -> Self {
self.time_window = seconds;
self
}
pub async fn register_user(&self, username: impl Into<String>, secret: impl Into<String>) {
let mut secrets = self.secrets.lock().await;
secrets.insert(username.into(), secret.into());
}
pub fn generate_totp_url(&self, username: &str, secret: &str) -> String {
format!(
"otpauth://totp/{}:{}?secret={}&issuer={}",
self.issuer, username, secret, self.issuer
)
}
pub async fn verify_totp(
&self,
username: &str,
code: &str,
) -> Result<bool, AuthenticationError> {
let secrets = self.secrets.lock().await;
if let Some(secret) = secrets.get(username) {
let secret_bytes = match data_encoding::BASE32_NOPAD.decode(secret.as_bytes()) {
Ok(bytes) => bytes,
Err(_) => return Err(AuthenticationError::InvalidCredentials),
};
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_step = current_time / self.time_window;
for offset in [-1i64, 0, 1] {
let adjusted_step = match (time_step as i64).checked_add(offset) {
Some(s) if s >= 0 => s as u64,
_ => continue,
};
let expected = totp_lite::totp_custom::<totp_lite::Sha256>(
self.time_window,
6,
&secret_bytes,
adjusted_step,
);
if expected.as_bytes().ct_eq(code.as_bytes()).into() {
return Ok(true);
}
}
Ok(false)
} else {
Err(AuthenticationError::UserNotFound)
}
}
pub async fn get_secret(&self, username: &str) -> Option<String> {
let secrets = self.secrets.lock().await;
secrets.get(username).cloned()
}
}
impl Default for MFAAuthentication {
fn default() -> Self {
Self::new("Reinhardt")
}
}
#[async_trait::async_trait]
impl AuthenticationBackend for MFAAuthentication {
async fn authenticate(
&self,
request: &Request,
) -> Result<Option<Box<dyn User>>, AuthenticationError> {
let username = request
.headers
.get("X-Username")
.and_then(|v| v.to_str().ok());
let code = request
.headers
.get("X-MFA-Code")
.and_then(|v| v.to_str().ok());
match (username, code) {
(Some(user), Some(mfa_code)) => {
if self.verify_totp(user, mfa_code).await? {
Ok(Some(Box::new(SimpleUser {
id: Uuid::new_v5(&USER_ID_NAMESPACE, user.as_bytes()),
username: user.to_string(),
email: String::new(),
is_active: true,
is_admin: false,
is_staff: false,
is_superuser: false,
})))
} else {
Err(AuthenticationError::InvalidCredentials)
}
}
_ => Ok(None),
}
}
async fn get_user(&self, user_id: &str) -> Result<Option<Box<dyn User>>, AuthenticationError> {
let secrets = self.secrets.lock().await;
if secrets.contains_key(user_id) {
Ok(Some(Box::new(SimpleUser {
id: Uuid::new_v5(&USER_ID_NAMESPACE, user_id.as_bytes()),
username: user_id.to_string(),
email: String::new(),
is_active: true,
is_admin: false,
is_staff: false,
is_superuser: false,
})))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method};
use rstest::rstest;
#[rstest]
#[tokio::test]
async fn test_mfa_registration() {
let mfa = MFAAuthentication::new("TestApp");
mfa.register_user("alice", "JBSWY3DPEHPK3PXP").await;
let secrets = mfa.secrets.lock().await;
assert!(secrets.contains_key("alice"));
}
#[rstest]
fn test_generate_totp_url() {
let mfa = MFAAuthentication::new("TestApp");
let url = mfa.generate_totp_url("alice", "SECRET");
assert!(url.contains("otpauth://totp/"));
assert!(url.contains("alice"));
assert!(url.contains("SECRET"));
assert!(url.contains("TestApp"));
}
#[rstest]
#[tokio::test]
async fn test_verify_totp_uses_sha256() {
let mfa = MFAAuthentication::new("TestApp");
let secret = "JBSWY3DPEHPK3PXP";
mfa.register_user("alice", secret).await;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_step = current_time / 30;
let secret_bytes = data_encoding::BASE32_NOPAD
.decode(secret.as_bytes())
.unwrap();
let totp_sha256 =
totp_lite::totp_custom::<totp_lite::Sha256>(30, 6, &secret_bytes, time_step);
let result = mfa.verify_totp("alice", &totp_sha256).await;
assert!(result.is_ok());
assert!(result.unwrap());
}
#[rstest]
#[tokio::test]
async fn test_verify_totp_rejects_sha1_code() {
let mfa = MFAAuthentication::new("TestApp");
let secret = "JBSWY3DPEHPK3PXP";
mfa.register_user("alice", secret).await;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_step = current_time / 30;
let secret_bytes = data_encoding::BASE32_NOPAD
.decode(secret.as_bytes())
.unwrap();
let totp_sha1 = totp_lite::totp_custom::<totp_lite::Sha1>(30, 6, &secret_bytes, time_step);
let totp_sha256 =
totp_lite::totp_custom::<totp_lite::Sha256>(30, 6, &secret_bytes, time_step);
if totp_sha1 != totp_sha256 {
let result = mfa.verify_totp("alice", &totp_sha1).await;
assert!(result.is_ok());
assert!(!result.unwrap());
}
}
#[rstest]
#[tokio::test]
async fn test_verify_totp_time_skew_tolerance() {
let mfa = MFAAuthentication::new("TestApp");
let secret = "JBSWY3DPEHPK3PXP";
mfa.register_user("alice", secret).await;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_step = current_time / 30;
let secret_bytes = data_encoding::BASE32_NOPAD
.decode(secret.as_bytes())
.unwrap();
if time_step > 0 {
let totp_prev =
totp_lite::totp_custom::<totp_lite::Sha256>(30, 6, &secret_bytes, time_step - 1);
let result = mfa.verify_totp("alice", &totp_prev).await;
assert!(result.is_ok());
assert!(
result.unwrap(),
"Previous time step TOTP should be accepted"
);
}
let totp_next =
totp_lite::totp_custom::<totp_lite::Sha256>(30, 6, &secret_bytes, time_step + 1);
let result = mfa.verify_totp("alice", &totp_next).await;
assert!(result.is_ok());
assert!(result.unwrap(), "Next time step TOTP should be accepted");
}
#[rstest]
#[tokio::test]
async fn test_verify_totp_invalid_code() {
let mfa = MFAAuthentication::new("TestApp");
let secret = "JBSWY3DPEHPK3PXP";
mfa.register_user("alice", secret).await;
let result = mfa.verify_totp("alice", "000000").await;
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[rstest]
#[tokio::test]
async fn test_verify_totp_unregistered_user() {
let mfa = MFAAuthentication::new("TestApp");
let result = mfa.verify_totp("alice", "123456").await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_mfa_authentication_with_valid_code() {
let mfa = MFAAuthentication::new("TestApp");
let secret = "JBSWY3DPEHPK3PXP";
mfa.register_user("alice", secret).await;
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_step = current_time / 30;
let secret_bytes = data_encoding::BASE32_NOPAD
.decode(secret.as_bytes())
.unwrap();
let totp = totp_lite::totp_custom::<totp_lite::Sha256>(30, 6, &secret_bytes, time_step);
let mut headers = HeaderMap::new();
headers.insert("X-Username", "alice".parse().unwrap());
headers.insert("X-MFA-Code", totp.parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/")
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let result = mfa.authenticate(&request).await.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().get_username(), "alice");
}
#[rstest]
#[tokio::test]
async fn test_mfa_authentication_without_headers() {
let mfa = MFAAuthentication::new("TestApp");
let request = Request::builder()
.method(Method::GET)
.uri("/")
.body(Bytes::new())
.build()
.unwrap();
let result = mfa.authenticate(&request).await.unwrap();
assert!(result.is_none());
}
#[rstest]
fn test_time_window_configuration() {
let mfa = MFAAuthentication::new("TestApp").time_window(60);
assert_eq!(mfa.time_window, 60);
}
#[rstest]
#[tokio::test]
async fn test_get_user_same_username_produces_same_id() {
let mfa = MFAAuthentication::new("TestApp");
mfa.register_user("alice", "JBSWY3DPEHPK3PXP").await;
let user1 = mfa.get_user("alice").await.unwrap().unwrap();
let user2 = mfa.get_user("alice").await.unwrap().unwrap();
assert_eq!(
user1.id(),
user2.id(),
"same username must produce the same UUID"
);
}
#[rstest]
#[tokio::test]
async fn test_user_id_is_deterministic_uuidv5() {
let mfa = MFAAuthentication::new("TestApp");
mfa.register_user("alice", "JBSWY3DPEHPK3PXP").await;
let user = mfa.get_user("alice").await.unwrap().unwrap();
let id = Uuid::parse_str(&user.id()).unwrap();
assert_eq!(id.get_version_num(), 5, "user ID must be UUIDv5");
assert_eq!(
id.get_variant(),
uuid::Variant::RFC4122,
"user ID must use RFC 4122 variant"
);
}
#[rstest]
#[tokio::test]
async fn test_get_user_unregistered_returns_none() {
let mfa = MFAAuthentication::new("TestApp");
let result = mfa.get_user("nonexistent_user").await.unwrap();
assert!(result.is_none());
}
#[rstest]
#[tokio::test]
async fn test_get_user_different_usernames_produce_different_ids() {
let mfa = MFAAuthentication::new("TestApp");
mfa.register_user("alice", "JBSWY3DPEHPK3PXP").await;
mfa.register_user("bob", "KRSXG5CTMVRXEZLUKN").await;
let user_a = mfa.get_user("alice").await.unwrap().unwrap();
let user_b = mfa.get_user("bob").await.unwrap().unwrap();
assert_ne!(
user_a.id(),
user_b.id(),
"different usernames must produce different UUIDs"
);
}
}