use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uvb_core::TenantId;
use uvb_storage_api::{EnrollmentError, EnrollmentRecord, EnrollmentStore};
pub struct InMemoryEnrollmentStore {
enrollments: Arc<RwLock<HashMap<String, EnrollmentRecord>>>,
}
impl InMemoryEnrollmentStore {
pub fn new() -> Self {
Self {
enrollments: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryEnrollmentStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl EnrollmentStore for InMemoryEnrollmentStore {
async fn create(&self, record: EnrollmentRecord) -> Result<String, EnrollmentError> {
let id = record.id.clone();
tracing::info!(
enrollment_id = %id,
user_id = %record.user_id,
tenant_id = %record.tenant_id,
factor_id = %record.factor_id,
"Storing enrollment in memory"
);
self.enrollments
.write()
.await
.insert(id.clone(), record.clone());
Ok(id)
}
async fn get(&self, id: &str) -> Result<Option<EnrollmentRecord>, EnrollmentError> {
Ok(self.enrollments.read().await.get(id).cloned())
}
async fn is_enrolled(
&self,
user_id: &str,
tenant_id: &TenantId,
factor_id: &str,
) -> Result<bool, EnrollmentError> {
let enrollments = self.enrollments.read().await;
let count = enrollments.len();
let enrolled = enrollments.values().any(|e| {
e.user_id == user_id
&& &e.tenant_id == tenant_id
&& e.factor_id == factor_id
&& e.status == uvb_storage_api::EnrollmentStatus::Active
});
tracing::info!(
user_id = %user_id,
tenant_id = %tenant_id,
factor_id = %factor_id,
enrolled = enrolled,
total_enrollments = count,
"Checking enrollment status"
);
Ok(enrolled)
}
async fn list_by_user(
&self,
user_id: &str,
tenant_id: &TenantId,
) -> Result<Vec<EnrollmentRecord>, EnrollmentError> {
let enrollments = self.enrollments.read().await;
Ok(enrollments
.values()
.filter(|e| e.user_id == user_id && &e.tenant_id == tenant_id)
.cloned()
.collect())
}
async fn update(&self, record: EnrollmentRecord) -> Result<(), EnrollmentError> {
let mut enrollments = self.enrollments.write().await;
if !enrollments.contains_key(&record.id) {
return Err(EnrollmentError::NotFound);
}
enrollments.insert(record.id.clone(), record);
Ok(())
}
async fn delete(&self, id: &str) -> Result<(), EnrollmentError> {
self.enrollments
.write()
.await
.remove(id)
.ok_or(EnrollmentError::NotFound)?;
Ok(())
}
async fn record_usage(&self, id: &str) -> Result<(), EnrollmentError> {
let mut enrollments = self.enrollments.write().await;
if let Some(enrollment) = enrollments.get_mut(id) {
enrollment.use_count += 1;
enrollment.last_used_at = Some(std::time::SystemTime::now());
Ok(())
} else {
Err(EnrollmentError::NotFound)
}
}
}