use async_trait::async_trait;
use chrono::Utc;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::errors::AppError;
use crate::models::sso::{SsoAuthState, SsoProvider};
#[async_trait]
pub trait SsoRepository: Send + Sync {
async fn create_provider(&self, provider: SsoProvider) -> Result<SsoProvider, AppError>;
async fn find_provider_by_id(&self, id: Uuid) -> Result<Option<SsoProvider>, AppError>;
async fn find_providers_by_org(&self, org_id: Uuid) -> Result<Vec<SsoProvider>, AppError>;
async fn find_enabled_provider_for_org(
&self,
org_id: Uuid,
) -> Result<Option<SsoProvider>, AppError>;
async fn list_all_providers(&self) -> Result<Vec<SsoProvider>, AppError>;
async fn list_providers_for_org(&self, org_id: Uuid) -> Result<Vec<SsoProvider>, AppError>;
async fn list_all_providers_paged(
&self,
limit: u32,
offset: u32,
) -> Result<Vec<SsoProvider>, AppError>;
async fn list_providers_for_org_paged(
&self,
org_id: Uuid,
limit: u32,
offset: u32,
) -> Result<Vec<SsoProvider>, AppError>;
async fn count_all_providers(&self) -> Result<u64, AppError>;
async fn count_providers_for_org(&self, org_id: Uuid) -> Result<u64, AppError>;
async fn list_providers_for_orgs_paged(
&self,
org_ids: &[Uuid],
limit: u32,
offset: u32,
) -> Result<Vec<SsoProvider>, AppError>;
async fn count_providers_for_orgs(&self, org_ids: &[Uuid]) -> Result<u64, AppError>;
async fn update_provider(&self, provider: SsoProvider) -> Result<SsoProvider, AppError>;
async fn delete_provider(&self, id: Uuid) -> Result<(), AppError>;
async fn store_auth_state(&self, state: SsoAuthState) -> Result<(), AppError>;
async fn get_auth_state(&self, state_id: Uuid) -> Result<Option<SsoAuthState>, AppError>;
async fn consume_auth_state(&self, state_id: Uuid) -> Result<Option<SsoAuthState>, AppError>;
async fn delete_expired_states(&self) -> Result<u64, AppError>;
}
pub struct InMemorySsoRepository {
providers: RwLock<HashMap<Uuid, SsoProvider>>,
auth_states: RwLock<HashMap<Uuid, SsoAuthState>>,
}
impl InMemorySsoRepository {
pub fn new() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
auth_states: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemorySsoRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SsoRepository for InMemorySsoRepository {
async fn create_provider(&self, provider: SsoProvider) -> Result<SsoProvider, AppError> {
let mut providers = self.providers.write().await;
providers.insert(provider.id, provider.clone());
Ok(provider)
}
async fn find_provider_by_id(&self, id: Uuid) -> Result<Option<SsoProvider>, AppError> {
let providers = self.providers.read().await;
Ok(providers.get(&id).cloned())
}
async fn find_providers_by_org(&self, org_id: Uuid) -> Result<Vec<SsoProvider>, AppError> {
let providers = self.providers.read().await;
let result: Vec<_> = providers
.values()
.filter(|p| p.org_id == org_id)
.cloned()
.collect();
Ok(result)
}
async fn find_enabled_provider_for_org(
&self,
org_id: Uuid,
) -> Result<Option<SsoProvider>, AppError> {
let providers = self.providers.read().await;
Ok(providers
.values()
.find(|p| p.org_id == org_id && p.enabled)
.cloned())
}
async fn list_all_providers(&self) -> Result<Vec<SsoProvider>, AppError> {
let providers = self.providers.read().await;
Ok(providers.values().cloned().collect())
}
async fn list_providers_for_org(&self, org_id: Uuid) -> Result<Vec<SsoProvider>, AppError> {
self.find_providers_by_org(org_id).await
}
async fn list_all_providers_paged(
&self,
limit: u32,
offset: u32,
) -> Result<Vec<SsoProvider>, AppError> {
let providers = self.providers.read().await;
let mut result: Vec<_> = providers.values().cloned().collect();
result.sort_by(|a, b| (a.org_id, a.name.clone()).cmp(&(b.org_id, b.name.clone())));
let start = offset as usize;
let end = start.saturating_add(limit as usize);
Ok(result.into_iter().skip(start).take(end - start).collect())
}
async fn list_providers_for_org_paged(
&self,
org_id: Uuid,
limit: u32,
offset: u32,
) -> Result<Vec<SsoProvider>, AppError> {
let providers = self.providers.read().await;
let mut result: Vec<_> = providers
.values()
.filter(|p| p.org_id == org_id)
.cloned()
.collect();
result.sort_by(|a, b| a.name.cmp(&b.name));
let start = offset as usize;
let end = start.saturating_add(limit as usize);
Ok(result.into_iter().skip(start).take(end - start).collect())
}
async fn count_all_providers(&self) -> Result<u64, AppError> {
let providers = self.providers.read().await;
Ok(providers.len() as u64)
}
async fn count_providers_for_org(&self, org_id: Uuid) -> Result<u64, AppError> {
let providers = self.providers.read().await;
Ok(providers.values().filter(|p| p.org_id == org_id).count() as u64)
}
async fn list_providers_for_orgs_paged(
&self,
org_ids: &[Uuid],
limit: u32,
offset: u32,
) -> Result<Vec<SsoProvider>, AppError> {
if org_ids.is_empty() {
return Ok(Vec::new());
}
let org_set: HashSet<Uuid> = org_ids.iter().copied().collect();
let providers = self.providers.read().await;
let mut result: Vec<_> = providers
.values()
.filter(|p| org_set.contains(&p.org_id))
.cloned()
.collect();
result.sort_by(|a, b| (a.org_id, a.name.clone()).cmp(&(b.org_id, b.name.clone())));
let start = offset as usize;
let end = start.saturating_add(limit as usize);
Ok(result.into_iter().skip(start).take(end - start).collect())
}
async fn count_providers_for_orgs(&self, org_ids: &[Uuid]) -> Result<u64, AppError> {
if org_ids.is_empty() {
return Ok(0);
}
let org_set: HashSet<Uuid> = org_ids.iter().copied().collect();
let providers = self.providers.read().await;
Ok(providers
.values()
.filter(|p| org_set.contains(&p.org_id))
.count() as u64)
}
async fn update_provider(&self, provider: SsoProvider) -> Result<SsoProvider, AppError> {
let mut providers = self.providers.write().await;
match providers.entry(provider.id) {
Entry::Occupied(mut entry) => {
entry.insert(provider.clone());
Ok(provider)
}
Entry::Vacant(_) => Err(AppError::NotFound("Provider not found".into())),
}
}
async fn delete_provider(&self, id: Uuid) -> Result<(), AppError> {
let mut providers = self.providers.write().await;
providers.remove(&id);
Ok(())
}
async fn store_auth_state(&self, state: SsoAuthState) -> Result<(), AppError> {
let mut states = self.auth_states.write().await;
states.insert(state.state_id, state);
Ok(())
}
async fn get_auth_state(&self, state_id: Uuid) -> Result<Option<SsoAuthState>, AppError> {
let states = self.auth_states.read().await;
let state = states.get(&state_id).cloned();
if let Some(ref s) = state {
if s.expires_at < Utc::now() {
return Ok(None);
}
}
Ok(state)
}
async fn consume_auth_state(&self, state_id: Uuid) -> Result<Option<SsoAuthState>, AppError> {
let mut states = self.auth_states.write().await;
let state = states.remove(&state_id);
if let Some(ref s) = state {
if s.expires_at < Utc::now() {
return Ok(None);
}
}
Ok(state)
}
async fn delete_expired_states(&self) -> Result<u64, AppError> {
let mut states = self.auth_states.write().await;
let now = Utc::now();
let to_remove: Vec<Uuid> = states
.values()
.filter(|s| s.expires_at < now)
.map(|s| s.state_id)
.collect();
let count = to_remove.len() as u64;
for id in to_remove {
states.remove(&id);
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_and_find_provider() {
let repo = InMemorySsoRepository::new();
let org_id = Uuid::new_v4();
let provider = SsoProvider::new(
org_id,
"Okta".into(),
"https://dev.okta.com".into(),
"client-id".into(),
"secret".into(),
);
let provider_id = provider.id;
repo.create_provider(provider).await.unwrap();
let found = repo.find_provider_by_id(provider_id).await.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().name, "Okta");
}
#[tokio::test]
async fn test_find_providers_by_org() {
let repo = InMemorySsoRepository::new();
let org_id = Uuid::new_v4();
let p1 = SsoProvider::new(
org_id,
"Okta".into(),
"https://okta.com".into(),
"id1".into(),
"s1".into(),
);
let p2 = SsoProvider::new(
org_id,
"Azure AD".into(),
"https://azure.com".into(),
"id2".into(),
"s2".into(),
);
repo.create_provider(p1).await.unwrap();
repo.create_provider(p2).await.unwrap();
let providers = repo.find_providers_by_org(org_id).await.unwrap();
assert_eq!(providers.len(), 2);
}
#[tokio::test]
async fn test_list_providers_paged_and_count() {
let repo = InMemorySsoRepository::new();
let org_a = Uuid::new_v4();
let org_b = Uuid::new_v4();
let provider_a1 = SsoProvider::new(
org_a,
"Alpha".into(),
"https://alpha.example.com".into(),
"client-a1".into(),
"secret".into(),
);
let provider_a2 = SsoProvider::new(
org_a,
"Beta".into(),
"https://beta.example.com".into(),
"client-a2".into(),
"secret".into(),
);
let provider_b1 = SsoProvider::new(
org_b,
"Gamma".into(),
"https://gamma.example.com".into(),
"client-b1".into(),
"secret".into(),
);
repo.create_provider(provider_a1).await.unwrap();
repo.create_provider(provider_a2).await.unwrap();
repo.create_provider(provider_b1).await.unwrap();
let total = repo.count_all_providers().await.unwrap();
assert_eq!(total, 3);
let page = repo.list_all_providers_paged(2, 0).await.unwrap();
assert_eq!(page.len(), 2);
let org_total = repo.count_providers_for_org(org_a).await.unwrap();
assert_eq!(org_total, 2);
let org_page = repo
.list_providers_for_org_paged(org_a, 1, 1)
.await
.unwrap();
assert_eq!(org_page.len(), 1);
}
#[tokio::test]
async fn test_list_and_count_providers_for_orgs() {
let repo = InMemorySsoRepository::new();
let org_a = Uuid::new_v4();
let org_b = Uuid::new_v4();
let org_c = Uuid::new_v4();
for (org, name) in [
(org_a, "Alpha"),
(org_a, "Bravo"),
(org_b, "Charlie"),
(org_c, "Delta"),
] {
repo.create_provider(SsoProvider::new(
org,
name.into(),
format!("https://{}.example.com", name.to_lowercase()),
format!("client-{}", name.to_lowercase()),
"secret".into(),
))
.await
.unwrap();
}
let owned_orgs = vec![org_a, org_b];
let total = repo.count_providers_for_orgs(&owned_orgs).await.unwrap();
assert_eq!(total, 3);
let page = repo
.list_providers_for_orgs_paged(&owned_orgs, 10, 0)
.await
.unwrap();
assert_eq!(page.len(), 3);
assert!(page.iter().all(|p| p.org_id == org_a || p.org_id == org_b));
}
#[tokio::test]
async fn test_auth_state_expiration() {
let repo = InMemorySsoRepository::new();
let mut state = SsoAuthState::new(
Uuid::new_v4(),
Uuid::new_v4(),
"verifier".into(),
"nonce".into(),
None,
-10, );
state.expires_at = Utc::now() - chrono::Duration::seconds(10);
let state_id = state.state_id;
repo.store_auth_state(state).await.unwrap();
let result = repo.consume_auth_state(state_id).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_get_auth_state_does_not_consume() {
let repo = InMemorySsoRepository::new();
let mut state = SsoAuthState::new(
Uuid::new_v4(),
Uuid::new_v4(),
"verifier".into(),
"nonce".into(),
None,
300,
);
state.access_code = Some("ACCESS123".into());
state.referral = Some("REFERRAL1".into());
let state_id = state.state_id;
repo.store_auth_state(state).await.unwrap();
let fetched = repo.get_auth_state(state_id).await.unwrap();
assert!(fetched.is_some());
let fetched = fetched.unwrap();
assert_eq!(fetched.access_code.as_deref(), Some("ACCESS123"));
assert_eq!(fetched.referral.as_deref(), Some("REFERRAL1"));
let consumed = repo.consume_auth_state(state_id).await.unwrap();
assert!(consumed.is_some());
let missing = repo.get_auth_state(state_id).await.unwrap();
assert!(missing.is_none());
}
}