use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
use sui_id_store::Database;
fn origin_from_uri(uri: &str) -> Option<String> {
let after_scheme = uri.split_once("://")?.1;
let host_and_port = match after_scheme.find('/') {
Some(idx) => &after_scheme[..idx],
None => after_scheme,
};
let scheme = uri.split_once("://")?.0;
Some(format!("{}://{}", scheme.to_lowercase(), host_and_port.to_lowercase()))
}
#[derive(Debug, Default)]
pub struct RedirectOriginsCache {
inner: RwLock<HashSet<String>>,
}
impl RedirectOriginsCache {
pub fn new() -> Self {
Self::default()
}
pub async fn rebuild(&self, db: &Database) -> Result<(), sui_id_store::StoreError> {
let clients = sui_id_store::repos::clients::list(db).await?;
let origins: HashSet<String> = clients
.iter()
.filter(|c| !c.is_deleted)
.flat_map(|c| c.redirect_uris.iter())
.filter_map(|uri| origin_from_uri(uri))
.collect();
*self.inner.write().await = origins;
Ok(())
}
pub async fn contains(&self, origin: &str) -> bool {
let normalised = origin.to_lowercase();
self.inner.read().await.contains(&normalised)
}
#[cfg(test)]
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
}
#[derive(Debug, Clone)]
pub struct CachedSigningKey {
pub kid: String,
pub algorithm: String,
pub public_key_bytes: Vec<u8>,
}
#[derive(Debug, Default)]
pub struct JwksCache {
inner: RwLock<Vec<CachedSigningKey>>,
}
impl JwksCache {
pub fn new() -> Self {
Self::default()
}
pub async fn rebuild(&self, db: &Database) -> Result<(), sui_id_store::StoreError> {
let keys = sui_id_store::repos::signing_keys::list_active(db).await?;
let cached: Vec<CachedSigningKey> = keys
.into_iter()
.map(|k| CachedSigningKey {
kid: k.id.to_string(),
algorithm: k.algorithm,
public_key_bytes: k.public_key,
})
.collect();
*self.inner.write().await = cached;
Ok(())
}
pub async fn snapshot(&self) -> Vec<CachedSigningKey> {
self.inner.read().await.clone()
}
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.read().await.is_empty()
}
}
#[derive(Debug, Default)]
pub struct Caches {
pub redirect_origins: RedirectOriginsCache,
pub jwks: JwksCache,
}
impl Caches {
pub fn new() -> Self {
Self::default()
}
pub async fn build(db: &Database) -> Result<Arc<Self>, sui_id_store::StoreError> {
let this = Arc::new(Self::new());
this.redirect_origins.rebuild(db).await?;
this.jwks.rebuild(db).await?;
Ok(this)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn origin_extraction_works() {
assert_eq!(
origin_from_uri("https://app.example.com/callback"),
Some("https://app.example.com".into())
);
assert_eq!(
origin_from_uri("http://localhost:3000/callback"),
Some("http://localhost:3000".into())
);
assert_eq!(
origin_from_uri("HTTPS://App.Example.Com/cb"),
Some("https://app.example.com".into())
);
assert_eq!(origin_from_uri("not-a-url"), None);
}
#[tokio::test]
async fn redirect_origins_cache_contains() {
let cache = RedirectOriginsCache::new();
{
let mut guard = cache.inner.write().await;
guard.insert("https://app.example.com".into());
guard.insert("http://localhost:3000".into());
}
assert!(cache.contains("https://app.example.com").await);
assert!(cache.contains("HTTPS://APP.EXAMPLE.COM").await); assert!(!cache.contains("https://evil.com").await);
}
#[tokio::test]
async fn jwks_cache_snapshot_is_cloned() {
let cache = JwksCache::new();
{
let mut guard = cache.inner.write().await;
guard.push(CachedSigningKey {
kid: "k1".into(),
algorithm: "EdDSA".into(),
public_key_bytes: vec![0u8; 32],
});
}
let snap = cache.snapshot().await;
assert_eq!(snap.len(), 1);
assert_eq!(snap[0].kid, "k1");
}
}