use std::collections::HashMap;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use axum::extract::{ConnectInfo, State};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::{DateTime, Duration, Utc};
use rand::distributions::Alphanumeric;
use rand::Rng;
use webauthn_rs::prelude::{
AuthenticationResult, Passkey, PasskeyAuthentication, PasskeyRegistration, Url,
};
use webauthn_rs::{Webauthn, WebauthnBuilder};
pub fn generate_otp() -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(|c| (c as char).to_ascii_uppercase())
.collect()
}
pub fn generate_token() -> String {
let bytes: [u8; 32] = rand::thread_rng().gen();
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
struct OtpEntry {
otp: String,
created_at: DateTime<Utc>,
}
#[derive(Clone)]
pub struct OtpStore(Arc<Mutex<HashMap<String, OtpEntry>>>);
pub enum OtpError {
NotFound,
Expired,
Invalid,
}
impl OtpStore {
pub fn new() -> Self {
OtpStore(Arc::new(Mutex::new(HashMap::new())))
}
pub fn insert(&self, username: &str, otp: String) {
let mut map = self.0.lock().unwrap();
map.insert(
username.to_string(),
OtpEntry {
otp,
created_at: Utc::now(),
},
);
}
pub fn validate(&self, username: &str, otp: &str) -> Result<(), OtpError> {
let mut map = self.0.lock().unwrap();
let entry = map.get(username).ok_or(OtpError::NotFound)?;
if Utc::now() - entry.created_at > Duration::minutes(5) {
map.remove(username);
return Err(OtpError::Expired);
}
if entry.otp != otp {
return Err(OtpError::Invalid);
}
map.remove(username);
Ok(())
}
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]
struct SessionEntry {
token: String,
username: String,
#[serde(default)]
device_hint: Option<String>,
#[serde(default = "Utc::now")]
last_seen: DateTime<Utc>,
expires_at: DateTime<Utc>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct SessionsFile {
sessions: Vec<SessionEntry>,
}
#[derive(Clone)]
pub struct SessionStore {
inner: Arc<Mutex<HashMap<String, SessionEntry>>>,
path: PathBuf,
}
impl SessionStore {
pub fn load(path: PathBuf) -> Self {
let inner = if path.as_os_str().is_empty() || !path.exists() {
HashMap::new()
} else {
match std::fs::read_to_string(&path)
.ok()
.and_then(|s| serde_json::from_str::<SessionsFile>(&s).ok())
{
Some(file) => file.sessions.into_iter().map(|e| (e.token.clone(), e)).collect(),
None => {
eprintln!("warn: could not parse {}", path.display());
HashMap::new()
}
}
};
SessionStore {
inner: Arc::new(Mutex::new(inner)),
path,
}
}
pub fn insert(&self, token: String, username: String, device_hint: Option<String>) {
let now = Utc::now();
let expires_at = now + Duration::days(7);
let entry = SessionEntry {
token: token.clone(),
username,
device_hint,
last_seen: now,
expires_at,
};
self.inner.lock().unwrap().insert(token, entry);
self.save();
}
pub fn lookup(&self, token: &str) -> Option<String> {
let mut map = self.inner.lock().unwrap();
match map.get(token) {
None => None,
Some(entry) => {
if Utc::now() > entry.expires_at {
map.remove(token);
None
} else {
Some(entry.username.clone())
}
}
}
}
fn save(&self) {
if self.path.as_os_str().is_empty() {
return;
}
let sessions: Vec<SessionEntry> = self.inner.lock().unwrap().values().cloned().collect();
let file = SessionsFile { sessions };
let tmp_path = self.path.with_extension("json.tmp");
if let Ok(json) = serde_json::to_string_pretty(&file) {
if std::fs::write(&tmp_path, &json).is_ok() {
let _ = std::fs::rename(&tmp_path, &self.path);
}
}
}
}
#[derive(serde::Serialize, Clone)]
pub struct SessionInfo {
pub username: String,
pub device_hint: Option<String>,
pub last_seen: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(serde::Deserialize)]
pub struct RevokeRequest {
pub username: Option<String>,
pub device: Option<String>,
#[serde(default)]
pub all: bool,
}
#[derive(serde::Serialize)]
pub struct RevokeResponse {
pub revoked: usize,
}
impl SessionStore {
pub fn list_active(&self) -> Vec<SessionInfo> {
let map = self.inner.lock().unwrap();
let now = Utc::now();
map.values()
.filter(|e| e.expires_at > now)
.map(|e| SessionInfo {
username: e.username.clone(),
device_hint: e.device_hint.clone(),
last_seen: e.last_seen,
expires_at: e.expires_at,
})
.collect()
}
pub fn revoke(&self, req: &RevokeRequest) -> usize {
let mut map = self.inner.lock().unwrap();
let before = map.len();
if req.all {
map.clear();
} else {
map.retain(|_, e| {
let username_match = req.username.as_deref().map(|u| e.username != u).unwrap_or(false);
let device_match = req.device.as_deref().map(|d| e.device_hint.as_deref() != Some(d)).unwrap_or(false);
if req.device.is_some() {
username_match || device_match
} else {
username_match
}
});
}
let revoked = before - map.len();
if revoked > 0 {
drop(map);
self.save();
}
revoked
}
}
pub struct RegistrationSession {
pub username: String,
pub passkey_reg: PasskeyRegistration,
}
pub struct AuthenticationSession {
pub username: String,
pub passkey_auth: PasskeyAuthentication,
pub created_at: Instant,
}
pub struct WebauthnState {
pub webauthn: Webauthn,
pub pending: Arc<Mutex<HashMap<String, RegistrationSession>>>,
pub pending_auth: Arc<Mutex<HashMap<String, AuthenticationSession>>>,
}
impl WebauthnState {
pub fn new(origin: &str) -> anyhow::Result<Self> {
let origin_url = Url::parse(origin)
.map_err(|e| anyhow::anyhow!("invalid origin URL: {e}"))?;
let rp_id = origin_url
.host_str()
.ok_or_else(|| anyhow::anyhow!("origin URL has no host"))?
.to_string();
let webauthn = WebauthnBuilder::new(&rp_id, &origin_url)
.map_err(|e| anyhow::anyhow!("WebauthnBuilder::new failed: {e:?}"))?
.build()
.map_err(|e| anyhow::anyhow!("WebauthnBuilder::build failed: {e:?}"))?;
Ok(WebauthnState {
webauthn,
pending: Arc::new(Mutex::new(HashMap::new())),
pending_auth: Arc::new(Mutex::new(HashMap::new())),
})
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CredentialsFile {
credentials: HashMap<String, Vec<Passkey>>,
}
#[derive(Clone)]
pub struct CredentialStore {
inner: Arc<Mutex<HashMap<String, Vec<Passkey>>>>,
path: PathBuf,
}
impl CredentialStore {
pub fn load(path: PathBuf) -> Self {
let inner = if !path.exists() {
HashMap::new()
} else {
match std::fs::read_to_string(&path)
.ok()
.and_then(|s| serde_json::from_str::<CredentialsFile>(&s).ok())
{
Some(file) => file.credentials,
None => {
eprintln!("warn: could not parse {}", path.display());
HashMap::new()
}
}
};
CredentialStore {
inner: Arc::new(Mutex::new(inner)),
path,
}
}
pub fn insert(&self, username: &str, passkey: Passkey) {
{
let mut map = self.inner.lock().unwrap();
map.entry(username.to_string()).or_default().push(passkey);
}
self.save();
}
pub fn get(&self, username: &str) -> Option<Vec<Passkey>> {
let map = self.inner.lock().unwrap();
let v = map.get(username)?;
if v.is_empty() { None } else { Some(v.clone()) }
}
pub fn update_credential(&self, username: &str, auth_result: &AuthenticationResult) {
{
let mut map = self.inner.lock().unwrap();
if let Some(passkeys) = map.get_mut(username) {
for passkey in passkeys.iter_mut() {
passkey.update_credential(auth_result);
}
}
}
self.save();
}
fn save(&self) {
let map = self.inner.lock().unwrap();
let file = CredentialsFile {
credentials: map.clone(),
};
drop(map);
let tmp = self.path.with_extension("json.tmp");
if let Ok(json) = serde_json::to_string_pretty(&file) {
if std::fs::write(&tmp, &json).is_ok() {
let _ = std::fs::rename(&tmp, &self.path);
}
}
}
#[cfg(test)]
pub fn credential_count(&self, username: &str) -> usize {
self.inner
.lock()
.unwrap()
.get(username)
.map(|v| v.len())
.unwrap_or(0)
}
}
pub fn find_session_username(
headers: &axum::http::HeaderMap,
session_store: &SessionStore,
) -> Option<String> {
let cookie_header = headers.get(axum::http::header::COOKIE)?.to_str().ok()?;
for part in cookie_header.split(';') {
if let Ok(c) = cookie::Cookie::parse(part.trim().to_owned()) {
if c.name() == "__Host-apm-session" {
return session_store.lookup(c.value());
}
}
}
None
}
pub async fn require_auth(
State(state): State<Arc<crate::AppState>>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let connect_info = req.extensions().get::<ConnectInfo<SocketAddr>>().copied();
if connect_info.is_none() || crate::is_localhost(connect_info) {
return next.run(req).await;
}
if find_session_username(req.headers(), &state.session_store).is_some() {
return next.run(req).await;
}
if req.uri().path().starts_with("/api/") {
(
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({"error": "unauthorized"})),
)
.into_response()
} else {
axum::response::Redirect::temporary("/login").into_response()
}
}
pub async fn otp_handler(
State(state): State<Arc<crate::AppState>>,
connect_info: Option<ConnectInfo<SocketAddr>>,
body: axum::body::Bytes,
) -> Response {
if !crate::is_localhost(connect_info) {
return StatusCode::FORBIDDEN.into_response();
}
let parsed: serde_json::Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
let username = match parsed.get("username").and_then(|v| v.as_str()) {
Some(u) if !u.is_empty() => u.to_string(),
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let otp = generate_otp();
state.otp_store.insert(&username, otp.clone());
Json(serde_json::json!({"otp": otp})).into_response()
}
pub async fn register_page_handler() -> Response {
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
include_str!("register.html"),
)
.into_response()
}
pub async fn register_challenge_handler(
State(state): State<Arc<crate::AppState>>,
body: axum::body::Bytes,
) -> Response {
let req: crate::models::RegisterChallengeRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(_) => return (StatusCode::BAD_REQUEST, "missing or invalid fields").into_response(),
};
if req.username.is_empty() || req.otp.is_empty() {
return (StatusCode::BAD_REQUEST, "missing or invalid fields").into_response();
}
match state.otp_store.validate(&req.username, &req.otp) {
Ok(()) => {}
Err(OtpError::NotFound) => {
return (StatusCode::BAD_REQUEST, "invalid OTP").into_response()
}
Err(OtpError::Expired) => {
return (StatusCode::BAD_REQUEST, "OTP expired").into_response()
}
Err(OtpError::Invalid) => {
return (StatusCode::BAD_REQUEST, "invalid OTP").into_response()
}
}
let user_id = uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, req.username.as_bytes());
let (ccr, passkey_reg) = match state
.webauthn_state
.webauthn
.start_passkey_registration(user_id, &req.username, &req.username, None)
{
Ok(pair) => pair,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("webauthn error: {e:?}"),
)
.into_response()
}
};
let reg_id = {
let bytes: [u8; 16] = rand::Rng::gen(&mut rand::thread_rng());
bytes.iter().map(|b| format!("{b:02x}")).collect::<String>()
};
{
let mut pending = state.webauthn_state.pending.lock().unwrap();
pending.insert(
reg_id.clone(),
RegistrationSession {
username: req.username,
passkey_reg,
},
);
}
let public_key = match serde_json::to_value(&ccr.public_key) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("serialization error: {e}"),
)
.into_response()
}
};
Json(crate::models::RegisterChallengeResponse { reg_id, public_key }).into_response()
}
pub async fn register_complete_handler(
State(state): State<Arc<crate::AppState>>,
Json(req): Json<crate::models::RegisterCompleteRequest>,
) -> Response {
let session = {
let mut pending = state.webauthn_state.pending.lock().unwrap();
pending.remove(&req.reg_id)
};
let session = match session {
Some(s) => s,
None => return (StatusCode::BAD_REQUEST, "unknown reg_id").into_response(),
};
let passkey = match state
.webauthn_state
.webauthn
.finish_passkey_registration(&req.response, &session.passkey_reg)
{
Ok(p) => p,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid WebAuthn response").into_response(),
};
state.credential_store.insert(&session.username, passkey);
let token = generate_token();
state.session_store.insert(token.clone(), session.username, None);
let cookie = format!(
"__Host-apm-session={token}; HttpOnly; Secure; SameSite=Lax; Path=/; Max-Age=604800"
);
(
StatusCode::OK,
[(axum::http::header::SET_COOKIE, cookie)],
Json(serde_json::json!({"status": "ok"})),
)
.into_response()
}
pub async fn login_page_handler() -> Response {
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
include_str!("login.html"),
)
.into_response()
}
pub async fn login_challenge_handler(
State(state): State<Arc<crate::AppState>>,
body: axum::body::Bytes,
) -> Response {
let req: crate::models::LoginChallengeRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(_) => return (StatusCode::BAD_REQUEST, "missing or invalid fields").into_response(),
};
if req.username.is_empty() {
return (StatusCode::BAD_REQUEST, "missing or invalid fields").into_response();
}
let credentials = match state.credential_store.get(&req.username) {
Some(c) => c,
None => return (StatusCode::BAD_REQUEST, "no credentials for user").into_response(),
};
let (rcr, passkey_auth) = match state
.webauthn_state
.webauthn
.start_passkey_authentication(&credentials)
{
Ok(pair) => pair,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("webauthn error: {e:?}"),
)
.into_response()
}
};
let login_id = generate_token();
{
let mut pending = state.webauthn_state.pending_auth.lock().unwrap();
pending.insert(
login_id.clone(),
AuthenticationSession {
username: req.username,
passkey_auth,
created_at: Instant::now(),
},
);
}
let public_key = match serde_json::to_value(&rcr.public_key) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("serialization error: {e}"),
)
.into_response()
}
};
Json(crate::models::LoginChallengeResponse { login_id, public_key }).into_response()
}
pub async fn login_complete_handler(
State(state): State<Arc<crate::AppState>>,
Json(req): Json<crate::models::LoginCompleteRequest>,
) -> Response {
let session = {
let mut pending = state.webauthn_state.pending_auth.lock().unwrap();
pending.remove(&req.login_id)
};
let session = match session {
Some(s) => s,
None => return (StatusCode::BAD_REQUEST, "unknown login_id").into_response(),
};
if session.created_at.elapsed() >= std::time::Duration::from_secs(300) {
return (StatusCode::BAD_REQUEST, "login session expired").into_response();
}
let credentials = match state.credential_store.get(&session.username) {
Some(c) => c,
None => return (StatusCode::BAD_REQUEST, "no credentials for user").into_response(),
};
let _ = credentials;
let auth_result = match state
.webauthn_state
.webauthn
.finish_passkey_authentication(&req.response, &session.passkey_auth)
{
Ok(r) => r,
Err(_) => return (StatusCode::BAD_REQUEST, "invalid WebAuthn response").into_response(),
};
state.credential_store.update_credential(&session.username, &auth_result);
let token = generate_token();
state.session_store.insert(token.clone(), session.username, None);
let cookie = format!(
"__Host-apm-session={token}; HttpOnly; Secure; SameSite=Lax; Path=/; Max-Age=604800"
);
(
StatusCode::OK,
[(axum::http::header::SET_COOKIE, cookie)],
Json(serde_json::json!({"status": "ok"})),
)
.into_response()
}
pub async fn list_sessions_handler(
State(state): State<Arc<crate::AppState>>,
connect_info: Option<ConnectInfo<SocketAddr>>,
) -> Response {
if !crate::is_localhost(connect_info) {
return StatusCode::FORBIDDEN.into_response();
}
Json(state.session_store.list_active()).into_response()
}
pub async fn revoke_sessions_handler(
State(state): State<Arc<crate::AppState>>,
connect_info: Option<ConnectInfo<SocketAddr>>,
Json(req): Json<RevokeRequest>,
) -> Response {
if !crate::is_localhost(connect_info) {
return StatusCode::FORBIDDEN.into_response();
}
let revoked = state.session_store.revoke(&req);
Json(RevokeResponse { revoked }).into_response()
}
#[cfg(test)]
impl OtpStore {
pub fn insert_at(&self, username: &str, otp: String, created_at: DateTime<Utc>) {
self.0
.lock()
.unwrap()
.insert(username.to_string(), OtpEntry { otp, created_at });
}
}
#[cfg(test)]
impl SessionStore {
pub fn insert_expiring_at(&self, token: String, username: String, expires_at: DateTime<Utc>) {
let entry = SessionEntry {
token: token.clone(),
username,
device_hint: None,
last_seen: Utc::now(),
expires_at,
};
self.inner.lock().unwrap().insert(token, entry);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn otp_insert_and_validate_happy_path() {
let store = OtpStore::new();
store.insert("alice", "ABCD1234".to_string());
assert!(store.validate("alice", "ABCD1234").is_ok());
}
#[test]
fn otp_validate_expired() {
let store = OtpStore::new();
let old = Utc::now() - Duration::minutes(6);
store.insert_at("alice", "ABCD1234".to_string(), old);
assert!(matches!(store.validate("alice", "ABCD1234"), Err(OtpError::Expired)));
}
#[test]
fn otp_validate_wrong_value() {
let store = OtpStore::new();
store.insert("alice", "ABCD1234".to_string());
assert!(matches!(store.validate("alice", "WRONGVAL"), Err(OtpError::Invalid)));
}
#[test]
fn otp_validate_twice_second_fails() {
let store = OtpStore::new();
store.insert("alice", "ABCD1234".to_string());
assert!(store.validate("alice", "ABCD1234").is_ok());
assert!(matches!(store.validate("alice", "ABCD1234"), Err(OtpError::NotFound)));
}
#[test]
fn session_insert_and_lookup() {
let store = SessionStore::load(PathBuf::new());
store.insert("tok1".to_string(), "alice".to_string(), None);
assert_eq!(store.lookup("tok1"), Some("alice".to_string()));
}
#[test]
fn session_lookup_expired() {
let store = SessionStore::load(PathBuf::new());
let expired = Utc::now() - Duration::days(8);
store.insert_expiring_at("tok2".to_string(), "alice".to_string(), expired);
assert_eq!(store.lookup("tok2"), None);
}
#[test]
fn session_survives_reload() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("sessions.json");
let store = SessionStore::load(path.clone());
store.insert("tok3".to_string(), "bob".to_string(), None);
drop(store);
let store2 = SessionStore::load(path);
assert_eq!(store2.lookup("tok3"), Some("bob".to_string()));
}
#[test]
fn otp_replace_existing() {
let store = OtpStore::new();
store.insert("alice", "FIRST123".to_string());
store.insert("alice", "SECOND45".to_string());
assert!(matches!(store.validate("alice", "FIRST123"), Err(OtpError::Invalid)));
assert!(store.validate("alice", "SECOND45").is_ok());
}
#[test]
fn list_active_excludes_expired() {
let store = SessionStore::load(PathBuf::new());
let expired = Utc::now() - Duration::days(8);
store.insert_expiring_at("tok-exp".to_string(), "eve".to_string(), expired);
store.insert("tok-ok".to_string(), "alice".to_string(), None);
let active = store.list_active();
assert_eq!(active.len(), 1);
assert_eq!(active[0].username, "alice");
}
#[test]
fn revoke_all_clears_store() {
let store = SessionStore::load(PathBuf::new());
store.insert("t1".to_string(), "alice".to_string(), None);
store.insert("t2".to_string(), "bob".to_string(), None);
let req = RevokeRequest { username: None, device: None, all: true };
let count = store.revoke(&req);
assert_eq!(count, 2);
assert!(store.list_active().is_empty());
}
#[test]
fn revoke_by_username_filters_correctly() {
let store = SessionStore::load(PathBuf::new());
store.insert("t1".to_string(), "alice".to_string(), None);
store.insert("t2".to_string(), "alice".to_string(), None);
store.insert("t3".to_string(), "bob".to_string(), None);
let req = RevokeRequest { username: Some("alice".to_string()), device: None, all: false };
let count = store.revoke(&req);
assert_eq!(count, 2);
let remaining = store.list_active();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].username, "bob");
}
#[test]
fn revoke_by_username_and_device() {
let store = SessionStore::load(PathBuf::new());
store.insert("t1".to_string(), "alice".to_string(), Some("MacBook".to_string()));
store.insert("t2".to_string(), "alice".to_string(), Some("iPhone".to_string()));
let req = RevokeRequest {
username: Some("alice".to_string()),
device: Some("MacBook".to_string()),
all: false,
};
let count = store.revoke(&req);
assert_eq!(count, 1);
let remaining = store.list_active();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].device_hint.as_deref(), Some("iPhone"));
}
#[test]
fn new_with_localhost_origin_succeeds() {
let state = WebauthnState::new("http://localhost:3000");
assert!(state.is_ok(), "expected Ok, got {:?}", state.err().map(|e| e.to_string()));
}
#[test]
fn new_with_invalid_origin_fails() {
assert!(WebauthnState::new("not-a-url").is_err());
}
#[test]
fn authentication_session_ttl_check() {
let old = Instant::now() - std::time::Duration::from_secs(360);
assert!(old.elapsed() >= std::time::Duration::from_secs(300), "session 6 min old should be expired");
}
#[test]
fn load_absent_file_returns_empty() {
let store = CredentialStore::load(PathBuf::from("/tmp/nonexistent_credentials_test.json"));
assert_eq!(store.credential_count("alice"), 0);
}
#[test]
fn insert_and_reload_round_trip() {
use webauthn_rs::WebauthnBuilder;
use webauthn_rs::prelude::Url;
use uuid::Uuid;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("credentials.json");
let origin = Url::parse("http://localhost:3000").unwrap();
let webauthn = WebauthnBuilder::new("localhost", &origin)
.unwrap()
.build()
.unwrap();
let user_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, b"alice");
let (ccr, reg_state) = webauthn
.start_passkey_registration(user_id, "alice", "alice", None)
.unwrap();
let _ = ccr;
let store = CredentialStore::load(path.clone());
assert_eq!(store.credential_count("alice"), 0);
drop(store);
let _ = reg_state;
let store2 = CredentialStore::load(path);
assert_eq!(store2.credential_count("alice"), 0);
}
}