use std::collections::HashMap;
use std::sync::RwLock;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use webauthn_rs::prelude::*;
use super::credential::{CredentialStore, StoredCredential};
fn extract_credential_ids(passkeys: &[Passkey]) -> Vec<CredentialID> {
passkeys.iter().map(|p| p.cred_id().clone()).collect()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum UserVerification {
#[default]
Preferred,
Required,
Discouraged,
}
#[derive(Debug, Clone)]
pub struct RegistrationConfig {
pub user_verification: UserVerification,
pub authenticator_attachment: Option<AuthenticatorAttachment>,
pub require_resident_key: bool,
pub exclude_credentials: bool,
pub timeout_ms: u32,
pub max_credentials_per_user: Option<usize>,
}
impl Default for RegistrationConfig {
fn default() -> Self {
Self {
user_verification: UserVerification::Preferred,
authenticator_attachment: None,
require_resident_key: false,
exclude_credentials: true,
timeout_ms: 60000, max_credentials_per_user: Some(10),
}
}
}
impl RegistrationConfig {
pub fn platform_only() -> Self {
Self {
authenticator_attachment: Some(AuthenticatorAttachment::Platform),
require_resident_key: true,
..Default::default()
}
}
pub fn cross_platform_only() -> Self {
Self {
authenticator_attachment: Some(AuthenticatorAttachment::CrossPlatform),
..Default::default()
}
}
pub fn high_security() -> Self {
Self {
user_verification: UserVerification::Required,
require_resident_key: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistrationState {
pub user_id: String,
pub username: String,
pub display_name: String,
pub credential_name: String,
pub passkey_registration: PasskeyRegistration,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl RegistrationState {
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
}
pub struct RegistrationManager<'a> {
webauthn: &'a Webauthn,
config: RegistrationConfig,
}
impl<'a> RegistrationManager<'a> {
pub fn new(webauthn: &'a Webauthn) -> Self {
Self {
webauthn,
config: RegistrationConfig::default(),
}
}
pub fn with_config(webauthn: &'a Webauthn, config: RegistrationConfig) -> Self {
Self { webauthn, config }
}
pub fn start_registration(
&self,
user_id: impl Into<String>,
username: impl Into<String>,
display_name: impl Into<String>,
credential_name: impl Into<String>,
existing_credentials: Option<Vec<Passkey>>,
) -> Result<(CreationChallengeResponse, RegistrationState), RegistrationError> {
let user_id = user_id.into();
let username = username.into();
let display_name = display_name.into();
let credential_name = credential_name.into();
if let Some(max) = self.config.max_credentials_per_user
&& let Some(ref creds) = existing_credentials
&& creds.len() >= max
{
return Err(RegistrationError::MaxCredentialsReached(max));
}
let user_uuid = parse_or_generate_uuid(&user_id);
let exclude_creds: Option<Vec<CredentialID>> = if self.config.exclude_credentials {
existing_credentials
.as_ref()
.map(|c| extract_credential_ids(c))
} else {
None
};
let (ccr, passkey_registration) = self
.webauthn
.start_passkey_registration(user_uuid, &username, &display_name, exclude_creds)
.map_err(|e| RegistrationError::WebAuthnError(e.to_string()))?;
let now = Utc::now();
let expires_at = now + chrono::Duration::milliseconds(i64::from(self.config.timeout_ms));
let state = RegistrationState {
user_id,
username,
display_name,
credential_name,
passkey_registration,
created_at: now,
expires_at,
};
Ok((ccr, state))
}
pub fn finish_registration(
&self,
state: &RegistrationState,
response: &RegisterPublicKeyCredential,
) -> Result<StoredCredential, RegistrationError> {
if state.is_expired() {
return Err(RegistrationError::SessionExpired);
}
let passkey = self
.webauthn
.finish_passkey_registration(response, &state.passkey_registration)
.map_err(|e| RegistrationError::WebAuthnError(e.to_string()))?;
let credential = StoredCredential::new(&state.user_id, passkey, &state.credential_name);
Ok(credential)
}
pub async fn start_registration_with_store<S: CredentialStore>(
&self,
user_id: impl Into<String>,
username: impl Into<String>,
display_name: impl Into<String>,
credential_name: impl Into<String>,
store: &S,
) -> Result<(CreationChallengeResponse, RegistrationState), RegistrationError> {
let user_id = user_id.into();
let existing = store.get_passkeys_for_user(&user_id).await;
let existing = if existing.is_empty() {
None
} else {
Some(existing)
};
self.start_registration(user_id, username, display_name, credential_name, existing)
}
pub async fn finish_registration_and_save<S: CredentialStore>(
&self,
state: &RegistrationState,
response: &RegisterPublicKeyCredential,
store: &S,
) -> Result<StoredCredential, RegistrationError> {
let credential = self.finish_registration(state, response)?;
store
.save(credential.clone())
.await
.map_err(|e| RegistrationError::StorageError(e.to_string()))?;
Ok(credential)
}
}
#[async_trait]
pub trait RegistrationStateStore: Send + Sync {
async fn save_state(
&self,
session_id: &str,
state: RegistrationState,
) -> Result<(), RegistrationError>;
async fn take_state(&self, session_id: &str) -> Option<RegistrationState>;
async fn cleanup_expired(&self);
}
#[derive(Debug, Default)]
pub struct InMemoryRegistrationStateStore {
states: RwLock<HashMap<String, RegistrationState>>,
}
impl InMemoryRegistrationStateStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl RegistrationStateStore for InMemoryRegistrationStateStore {
async fn save_state(
&self,
session_id: &str,
state: RegistrationState,
) -> Result<(), RegistrationError> {
if let Ok(mut states) = self.states.write() {
states.insert(session_id.to_string(), state);
}
Ok(())
}
async fn take_state(&self, session_id: &str) -> Option<RegistrationState> {
self.states
.write()
.ok()
.and_then(|mut states| states.remove(session_id))
}
async fn cleanup_expired(&self) {
if let Ok(mut states) = self.states.write() {
states.retain(|_, state| !state.is_expired());
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RegistrationError {
WebAuthnError(String),
SessionExpired,
MaxCredentialsReached(usize),
StorageError(String),
InvalidUserId(String),
CredentialExists,
}
impl std::fmt::Display for RegistrationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WebAuthnError(e) => write!(f, "WebAuthn 错误: {}", e),
Self::SessionExpired => write!(f, "注册会话已过期"),
Self::MaxCredentialsReached(max) => {
write!(f, "已达到最大凭证数量限制 ({})", max)
}
Self::StorageError(e) => write!(f, "存储错误: {}", e),
Self::InvalidUserId(id) => write!(f, "无效的用户 ID: {}", id),
Self::CredentialExists => write!(f, "凭证已存在"),
}
}
}
impl std::error::Error for RegistrationError {}
fn parse_or_generate_uuid(input: &str) -> Uuid {
if let Ok(uuid) = Uuid::parse_str(input) {
return uuid;
}
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(input.as_bytes());
let hash = hasher.finalize();
let mut bytes = [0u8; 16];
bytes.copy_from_slice(&hash[..16]);
bytes[6] = (bytes[6] & 0x0f) | 0x40; bytes[8] = (bytes[8] & 0x3f) | 0x80;
Uuid::from_bytes(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registration_config_default() {
let config = RegistrationConfig::default();
assert_eq!(config.timeout_ms, 60000);
assert_eq!(config.max_credentials_per_user, Some(10));
assert!(config.exclude_credentials);
}
#[test]
fn test_registration_config_platform_only() {
let config = RegistrationConfig::platform_only();
assert_eq!(
config.authenticator_attachment,
Some(AuthenticatorAttachment::Platform)
);
assert!(config.require_resident_key);
}
#[test]
fn test_registration_config_high_security() {
let config = RegistrationConfig::high_security();
assert_eq!(config.user_verification, UserVerification::Required);
}
#[test]
fn test_parse_or_generate_uuid() {
let valid_uuid = "550e8400-e29b-41d4-a716-446655440000";
let parsed = parse_or_generate_uuid(valid_uuid);
assert_eq!(parsed.to_string(), valid_uuid);
let user_id = "user@example.com";
let generated1 = parse_or_generate_uuid(user_id);
let generated2 = parse_or_generate_uuid(user_id);
assert_eq!(generated1, generated2);
let other_id = "other@example.com";
let other_uuid = parse_or_generate_uuid(other_id);
assert_ne!(generated1, other_uuid);
}
#[test]
fn test_registration_error_display() {
assert_eq!(
RegistrationError::SessionExpired.to_string(),
"注册会话已过期"
);
assert_eq!(
RegistrationError::MaxCredentialsReached(5).to_string(),
"已达到最大凭证数量限制 (5)"
);
}
#[tokio::test]
async fn test_in_memory_registration_state_store() {
let store = InMemoryRegistrationStateStore::new();
assert!(store.take_state("nonexistent").await.is_none());
}
#[test]
fn test_user_verification_default() {
let uv = UserVerification::default();
assert_eq!(uv, UserVerification::Preferred);
}
}