use std::collections::HashMap;
use std::sync::RwLock;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use webauthn_rs::prelude::*;
use super::credential::CredentialStore;
use super::registration::UserVerification;
#[derive(Debug, Clone)]
pub struct AuthenticationConfig {
pub user_verification: UserVerification,
pub timeout_ms: u32,
pub allow_empty_credentials: bool,
}
impl Default for AuthenticationConfig {
fn default() -> Self {
Self {
user_verification: UserVerification::Preferred,
timeout_ms: 60000, allow_empty_credentials: false,
}
}
}
impl AuthenticationConfig {
pub fn high_security() -> Self {
Self {
user_verification: UserVerification::Required,
..Default::default()
}
}
pub fn discoverable() -> Self {
Self {
allow_empty_credentials: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationState {
pub user_id: Option<String>,
pub passkey_authentication: PasskeyAuthentication,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl AuthenticationState {
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
}
#[derive(Debug, Clone)]
pub struct WebAuthnAuthenticationResult {
pub credential_id: String,
pub user_id: String,
pub user_verified: bool,
pub counter: u32,
pub authenticated_at: DateTime<Utc>,
}
pub struct AuthenticationManager<'a> {
webauthn: &'a Webauthn,
config: AuthenticationConfig,
}
impl<'a> AuthenticationManager<'a> {
pub fn new(webauthn: &'a Webauthn) -> Self {
Self {
webauthn,
config: AuthenticationConfig::default(),
}
}
pub fn with_config(webauthn: &'a Webauthn, config: AuthenticationConfig) -> Self {
Self { webauthn, config }
}
pub fn start_authentication(
&self,
user_id: Option<String>,
credentials: Vec<Passkey>,
) -> Result<(RequestChallengeResponse, AuthenticationState), AuthenticationError> {
if credentials.is_empty() && !self.config.allow_empty_credentials {
return Err(AuthenticationError::NoCredentials);
}
let (rcr, passkey_authentication) = self
.webauthn
.start_passkey_authentication(&credentials)
.map_err(|e| AuthenticationError::WebAuthnError(e.to_string()))?;
let now = Utc::now();
let expires_at = now + chrono::Duration::milliseconds(i64::from(self.config.timeout_ms));
let state = AuthenticationState {
user_id,
passkey_authentication,
created_at: now,
expires_at,
};
Ok((rcr, state))
}
pub fn finish_authentication(
&self,
state: &AuthenticationState,
response: &PublicKeyCredential,
credentials: &[Passkey],
) -> Result<(WebAuthnAuthenticationResult, Option<Passkey>), AuthenticationError> {
if state.is_expired() {
return Err(AuthenticationError::SessionExpired);
}
let auth_result = self
.webauthn
.finish_passkey_authentication(response, &state.passkey_authentication)
.map_err(|e| AuthenticationError::WebAuthnError(e.to_string()))?;
let cred_id_bytes = auth_result.cred_id();
let credential_id = base64_url_encode(cred_id_bytes.as_ref());
let updated_passkey = credentials
.iter()
.find(|c| c.cred_id() == cred_id_bytes)
.cloned()
.map(|mut pk| {
pk.update_credential(&auth_result);
pk
});
let user_id = state
.user_id
.clone()
.unwrap_or_else(|| credential_id.clone());
let result = WebAuthnAuthenticationResult {
credential_id,
user_id,
user_verified: auth_result.user_verified(),
counter: auth_result.counter(),
authenticated_at: Utc::now(),
};
Ok((result, updated_passkey))
}
pub async fn start_authentication_with_store<S: CredentialStore>(
&self,
user_id: impl Into<String>,
store: &S,
) -> Result<(RequestChallengeResponse, AuthenticationState), AuthenticationError> {
let user_id = user_id.into();
let credentials = store.get_passkeys_for_user(&user_id).await;
if credentials.is_empty() {
return Err(AuthenticationError::NoCredentials);
}
self.start_authentication(Some(user_id), credentials)
}
pub async fn finish_authentication_and_update<S: CredentialStore>(
&self,
state: &AuthenticationState,
response: &PublicKeyCredential,
store: &S,
) -> Result<WebAuthnAuthenticationResult, AuthenticationError> {
let user_id = state
.user_id
.as_ref()
.ok_or(AuthenticationError::MissingUserId)?;
let credentials = store.get_passkeys_for_user(user_id).await;
let (result, updated_passkey) =
self.finish_authentication(state, response, &credentials)?;
if let Some(passkey) = updated_passkey {
if let Some(mut stored) = store.find_by_id(&result.credential_id).await {
stored.update_passkey(passkey);
stored.record_use();
store
.update(stored)
.await
.map_err(|e| AuthenticationError::StorageError(e.to_string()))?;
}
}
Ok(result)
}
}
#[async_trait]
pub trait AuthenticationStateStore: Send + Sync {
async fn save_state(
&self,
session_id: &str,
state: AuthenticationState,
) -> Result<(), AuthenticationError>;
async fn take_state(&self, session_id: &str) -> Option<AuthenticationState>;
async fn cleanup_expired(&self);
}
#[derive(Debug, Default)]
pub struct InMemoryAuthenticationStateStore {
states: RwLock<HashMap<String, AuthenticationState>>,
}
impl InMemoryAuthenticationStateStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl AuthenticationStateStore for InMemoryAuthenticationStateStore {
async fn save_state(
&self,
session_id: &str,
state: AuthenticationState,
) -> Result<(), AuthenticationError> {
if let Ok(mut states) = self.states.write() {
states.insert(session_id.to_string(), state);
}
Ok(())
}
async fn take_state(&self, session_id: &str) -> Option<AuthenticationState> {
self.states
.write()
.ok()
.and_then(|mut states| states.remove(session_id))
}
async fn cleanup_expired(&self) {
if let Ok(mut states) = self.states.write() {
states.retain(|_, state| !state.is_expired());
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthenticationError {
WebAuthnError(String),
SessionExpired,
NoCredentials,
CredentialNotFound,
MissingUserId,
StorageError(String),
CredentialRevoked,
CounterRollback {
stored: u32,
received: u32,
},
}
impl std::fmt::Display for AuthenticationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WebAuthnError(e) => write!(f, "WebAuthn 错误: {}", e),
Self::SessionExpired => write!(f, "认证会话已过期"),
Self::NoCredentials => write!(f, "没有可用的凭证"),
Self::CredentialNotFound => write!(f, "凭证未找到"),
Self::MissingUserId => write!(f, "缺少用户 ID"),
Self::StorageError(e) => write!(f, "存储错误: {}", e),
Self::CredentialRevoked => write!(f, "凭证已被撤销"),
Self::CounterRollback { stored, received } => {
write!(
f,
"检测到计数器回滚(可能的克隆攻击):存储值={}, 收到值={}",
stored, received
)
}
}
}
}
impl std::error::Error for AuthenticationError {}
fn base64_url_encode(data: &[u8]) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_authentication_config_default() {
let config = AuthenticationConfig::default();
assert_eq!(config.timeout_ms, 60000);
assert!(!config.allow_empty_credentials);
}
#[test]
fn test_authentication_config_high_security() {
let config = AuthenticationConfig::high_security();
assert_eq!(config.user_verification, UserVerification::Required);
}
#[test]
fn test_authentication_config_discoverable() {
let config = AuthenticationConfig::discoverable();
assert!(config.allow_empty_credentials);
}
#[test]
fn test_authentication_error_display() {
assert_eq!(
AuthenticationError::SessionExpired.to_string(),
"认证会话已过期"
);
assert_eq!(
AuthenticationError::NoCredentials.to_string(),
"没有可用的凭证"
);
assert_eq!(
AuthenticationError::CounterRollback {
stored: 10,
received: 5
}
.to_string(),
"检测到计数器回滚(可能的克隆攻击):存储值=10, 收到值=5"
);
}
#[tokio::test]
async fn test_in_memory_authentication_state_store() {
let store = InMemoryAuthenticationStateStore::new();
assert!(store.take_state("nonexistent").await.is_none());
}
#[test]
fn test_base64_url_encode() {
let data = b"hello world";
let encoded = base64_url_encode(data);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
}