use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use super::OAuth2Provider;
#[derive(Clone, Default)]
pub struct OAuth2Registry {
inner: Arc<RwLock<HashMap<(String, String), Arc<OAuth2Provider>>>>,
}
impl OAuth2Registry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, tenant: impl Into<String>, provider: OAuth2Provider) {
let name = provider.name.clone();
self.inner
.write()
.expect("registry poisoned")
.insert((tenant.into(), name), Arc::new(provider));
}
pub fn deregister(&self, tenant: &str, name: &str) -> bool {
self.inner
.write()
.expect("registry poisoned")
.remove(&(tenant.to_owned(), name.to_owned()))
.is_some()
}
#[must_use]
pub fn get(&self, tenant: &str, name: &str) -> Option<Arc<OAuth2Provider>> {
self.inner
.read()
.expect("registry poisoned")
.get(&(tenant.to_owned(), name.to_owned()))
.cloned()
}
#[must_use]
pub fn list(&self) -> Vec<(String, String)> {
self.inner
.read()
.expect("registry poisoned")
.keys()
.cloned()
.collect()
}
#[must_use]
pub fn list_for_tenant(&self, tenant: &str) -> Vec<String> {
self.inner
.read()
.expect("registry poisoned")
.iter()
.filter(|((t, _), _)| t == tenant)
.map(|((_, n), _)| n.clone())
.collect()
}
pub fn clear(&self) {
self.inner.write().expect("registry poisoned").clear();
}
}
#[async_trait::async_trait]
pub trait ProviderLoader: Send + Sync + 'static {
async fn load(
&self,
tenant: &str,
provider_name: &str,
) -> Result<Option<OAuth2Provider>, super::OAuthError>;
}
#[derive(Clone)]
pub struct CachedRegistry<L: ProviderLoader> {
pub registry: OAuth2Registry,
pub loader: Arc<L>,
}
impl<L: ProviderLoader> CachedRegistry<L> {
pub fn new(loader: L) -> Self {
Self {
registry: OAuth2Registry::new(),
loader: Arc::new(loader),
}
}
pub async fn get(
&self,
tenant: &str,
name: &str,
) -> Result<Option<Arc<OAuth2Provider>>, super::OAuthError> {
if let Some(p) = self.registry.get(tenant, name) {
return Ok(Some(p));
}
match self.loader.load(tenant, name).await? {
Some(p) => {
self.registry.register(tenant.to_owned(), p);
Ok(self.registry.get(tenant, name))
}
None => Ok(None),
}
}
pub fn invalidate(&self, tenant: &str, name: &str) {
self.registry.deregister(tenant, name);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oauth2::providers;
#[test]
fn register_and_lookup_in_memory() {
let reg = OAuth2Registry::new();
reg.register("acme", providers::google("cid", "csec", "https://acme/cb"));
let p = reg.get("acme", "google").unwrap();
assert_eq!(p.name, "google");
assert_eq!(p.client_id, "cid");
}
#[test]
fn unknown_tenant_or_provider_returns_none() {
let reg = OAuth2Registry::new();
reg.register("acme", providers::google("cid", "csec", "https://acme/cb"));
assert!(reg.get("other", "google").is_none());
assert!(reg.get("acme", "github").is_none());
}
#[test]
fn list_for_tenant_returns_only_that_tenant() {
let reg = OAuth2Registry::new();
reg.register("a", providers::google("cid", "csec", "https://a/cb"));
reg.register("a", providers::github("cid", "csec", "https://a/cb"));
reg.register("b", providers::google("cid", "csec", "https://b/cb"));
let mut names = reg.list_for_tenant("a");
names.sort();
assert_eq!(names, vec!["github".to_owned(), "google".to_owned()]);
assert_eq!(reg.list_for_tenant("b"), vec!["google".to_owned()]);
}
#[test]
fn deregister_removes_entry() {
let reg = OAuth2Registry::new();
reg.register("acme", providers::google("cid", "csec", "https://acme/cb"));
assert!(reg.deregister("acme", "google"));
assert!(reg.get("acme", "google").is_none());
assert!(!reg.deregister("acme", "google"));
}
#[test]
fn clear_empties_registry() {
let reg = OAuth2Registry::new();
reg.register("acme", providers::google("cid", "csec", "https://acme/cb"));
reg.register("acme", providers::github("cid", "csec", "https://acme/cb"));
reg.clear();
assert!(reg.list().is_empty());
}
#[test]
fn registry_clone_shares_state() {
let reg = OAuth2Registry::new();
let clone = reg.clone();
reg.register("acme", providers::google("cid", "csec", "https://acme/cb"));
assert!(clone.get("acme", "google").is_some());
}
struct StaticLoader {
provider: OAuth2Provider,
}
#[async_trait::async_trait]
impl ProviderLoader for StaticLoader {
async fn load(
&self,
tenant: &str,
name: &str,
) -> Result<Option<OAuth2Provider>, super::super::OAuthError> {
if tenant == "acme" && name == "google" {
Ok(Some(providers::google(
self.provider.client_id.clone(),
self.provider.client_secret.clone(),
self.provider.redirect_uri.clone(),
)))
} else {
Ok(None)
}
}
}
#[tokio::test]
async fn cached_registry_loads_on_miss() {
let cached = CachedRegistry::new(StaticLoader {
provider: providers::google("cid", "csec", "https://acme/cb"),
});
let p = cached.get("acme", "google").await.unwrap().unwrap();
assert_eq!(p.client_id, "cid");
assert_eq!(cached.registry.list().len(), 1);
}
#[tokio::test]
async fn cached_registry_returns_none_for_unknown() {
let cached = CachedRegistry::new(StaticLoader {
provider: providers::google("cid", "csec", "https://acme/cb"),
});
assert!(cached.get("other", "google").await.unwrap().is_none());
}
#[tokio::test]
async fn invalidate_drops_cache_entry() {
let cached = CachedRegistry::new(StaticLoader {
provider: providers::google("cid", "csec", "https://acme/cb"),
});
cached.get("acme", "google").await.unwrap();
cached.invalidate("acme", "google");
assert!(cached.registry.get("acme", "google").is_none());
}
}