use std::collections::HashMap;
use std::sync::RwLock;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use webauthn_rs::prelude::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCredential {
pub credential_id: String,
pub user_id: String,
pub passkey: Passkey,
pub name: String,
pub created_at: DateTime<Utc>,
pub last_used_at: Option<DateTime<Utc>>,
pub use_count: u64,
pub revoked: bool,
pub authenticator_type: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl StoredCredential {
pub fn new(user_id: impl Into<String>, passkey: Passkey, name: impl Into<String>) -> Self {
let cred_id = passkey.cred_id();
let credential_id = base64_url_encode(cred_id.as_ref());
Self {
credential_id,
user_id: user_id.into(),
passkey,
name: name.into(),
created_at: Utc::now(),
last_used_at: None,
use_count: 0,
revoked: false,
authenticator_type: None,
metadata: HashMap::new(),
}
}
pub fn with_authenticator_type(mut self, auth_type: impl Into<String>) -> Self {
self.authenticator_type = Some(auth_type.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn record_use(&mut self) {
self.last_used_at = Some(Utc::now());
self.use_count += 1;
}
pub fn revoke(&mut self) {
self.revoked = true;
}
pub fn is_valid(&self) -> bool {
!self.revoked
}
pub fn update_passkey(&mut self, passkey: Passkey) {
self.passkey = passkey;
}
}
#[async_trait]
pub trait CredentialStore: Send + Sync {
async fn save(&self, credential: StoredCredential) -> Result<(), CredentialStoreError>;
async fn find_by_id(&self, credential_id: &str) -> Option<StoredCredential>;
async fn find_by_user(&self, user_id: &str) -> Vec<StoredCredential>;
async fn get_passkeys_for_user(&self, user_id: &str) -> Vec<Passkey> {
self.find_by_user(user_id)
.await
.into_iter()
.filter(|c| c.is_valid())
.map(|c| c.passkey)
.collect()
}
async fn update(&self, credential: StoredCredential) -> Result<(), CredentialStoreError>;
async fn delete(&self, credential_id: &str) -> Result<bool, CredentialStoreError>;
async fn delete_by_user(&self, user_id: &str) -> Result<usize, CredentialStoreError>;
async fn list(&self) -> Vec<StoredCredential>;
async fn count_by_user(&self, user_id: &str) -> usize {
self.find_by_user(user_id).await.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CredentialStoreError {
AlreadyExists,
NotFound,
StorageError(String),
}
impl std::fmt::Display for CredentialStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AlreadyExists => write!(f, "凭证已存在"),
Self::NotFound => write!(f, "凭证未找到"),
Self::StorageError(msg) => write!(f, "存储错误: {}", msg),
}
}
}
impl std::error::Error for CredentialStoreError {}
#[derive(Debug, Default)]
pub struct InMemoryCredentialStore {
credentials: RwLock<HashMap<String, StoredCredential>>,
}
impl InMemoryCredentialStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl CredentialStore for InMemoryCredentialStore {
async fn save(&self, credential: StoredCredential) -> Result<(), CredentialStoreError> {
let mut credentials = self
.credentials
.write()
.map_err(|e| CredentialStoreError::StorageError(e.to_string()))?;
if credentials.contains_key(&credential.credential_id) {
return Err(CredentialStoreError::AlreadyExists);
}
credentials.insert(credential.credential_id.clone(), credential);
Ok(())
}
async fn find_by_id(&self, credential_id: &str) -> Option<StoredCredential> {
self.credentials
.read()
.ok()
.and_then(|creds| creds.get(credential_id).cloned())
}
async fn find_by_user(&self, user_id: &str) -> Vec<StoredCredential> {
self.credentials
.read()
.map(|creds| {
creds
.values()
.filter(|c| c.user_id == user_id)
.cloned()
.collect()
})
.unwrap_or_default()
}
async fn update(&self, credential: StoredCredential) -> Result<(), CredentialStoreError> {
let mut credentials = self
.credentials
.write()
.map_err(|e| CredentialStoreError::StorageError(e.to_string()))?;
if !credentials.contains_key(&credential.credential_id) {
return Err(CredentialStoreError::NotFound);
}
credentials.insert(credential.credential_id.clone(), credential);
Ok(())
}
async fn delete(&self, credential_id: &str) -> Result<bool, CredentialStoreError> {
let mut credentials = self
.credentials
.write()
.map_err(|e| CredentialStoreError::StorageError(e.to_string()))?;
Ok(credentials.remove(credential_id).is_some())
}
async fn delete_by_user(&self, user_id: &str) -> Result<usize, CredentialStoreError> {
let mut credentials = self
.credentials
.write()
.map_err(|e| CredentialStoreError::StorageError(e.to_string()))?;
let to_remove: Vec<_> = credentials
.iter()
.filter(|(_, c)| c.user_id == user_id)
.map(|(k, _)| k.clone())
.collect();
let count = to_remove.len();
for key in to_remove {
credentials.remove(&key);
}
Ok(count)
}
async fn list(&self) -> Vec<StoredCredential> {
self.credentials
.read()
.map(|creds| creds.values().cloned().collect())
.unwrap_or_default()
}
}
fn base64_url_encode(data: &[u8]) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.encode(data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_credential_store_error_display() {
assert_eq!(
CredentialStoreError::AlreadyExists.to_string(),
"凭证已存在"
);
assert_eq!(CredentialStoreError::NotFound.to_string(), "凭证未找到");
assert_eq!(
CredentialStoreError::StorageError("test".to_string()).to_string(),
"存储错误: test"
);
}
#[test]
fn test_base64_url_encode() {
let data = b"hello world";
let encoded = base64_url_encode(data);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
#[tokio::test]
async fn test_in_memory_store_basic() {
let store = InMemoryCredentialStore::new();
assert!(store.list().await.is_empty());
}
}