use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use parking_lot::RwLock;
use super::hash_token;
use super::principal::{AuthOutcome, Principal};
use super::scopes::ScopeSet;
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("auth backend unavailable: {0}")]
BackendUnavailable(String),
#[error("auth backend error: {0}")]
Backend(String),
}
#[async_trait::async_trait]
pub trait AuthProvider: Send + Sync {
async fn authenticate(&self, presented: &str) -> Result<AuthOutcome, AuthError>;
}
#[derive(Debug, Clone, PartialEq)]
struct TokenRecord {
principal: Principal,
hash: String,
revoked: bool,
expires_at: Option<SystemTime>,
}
#[derive(Default, Clone)]
pub struct InMemoryAuthProvider {
inner: Arc<RwLock<InMemoryState>>,
}
#[derive(Default)]
struct InMemoryState {
by_hash: HashMap<String, TokenRecord>,
}
impl InMemoryAuthProvider {
pub fn new() -> Self {
Self::default()
}
pub fn issue_token(&self, principal: Principal) -> String {
let raw = super::generate_token();
let hash = hash_token(&raw);
let rec = TokenRecord {
principal,
hash: hash.clone(),
revoked: false,
expires_at: None,
};
self.inner.write().by_hash.insert(hash, rec);
raw
}
pub fn issue_token_with_expiry(&self, principal: Principal, expires_at: SystemTime) -> String {
let raw = super::generate_token();
let hash = hash_token(&raw);
let rec = TokenRecord {
principal,
hash: hash.clone(),
revoked: false,
expires_at: Some(expires_at),
};
self.inner.write().by_hash.insert(hash, rec);
raw
}
pub fn revoke(&self, raw: &str) -> bool {
let hash = hash_token(raw);
let mut guard = self.inner.write();
match guard.by_hash.get_mut(&hash) {
Some(rec) => {
rec.revoked = true;
true
}
None => false,
}
}
pub fn token_count(&self) -> usize {
self.inner.read().by_hash.len()
}
}
#[async_trait::async_trait]
impl AuthProvider for InMemoryAuthProvider {
async fn authenticate(&self, presented: &str) -> Result<AuthOutcome, AuthError> {
let hash = hash_token(presented);
let guard = self.inner.read();
let Some(rec) = guard.by_hash.get(&hash) else {
return Ok(AuthOutcome::Unauthenticated);
};
if rec.revoked {
return Ok(AuthOutcome::Revoked {
id: rec.principal.id.clone(),
});
}
if let Some(exp) = rec.expires_at {
if SystemTime::now() >= exp {
return Ok(AuthOutcome::Expired {
id: rec.principal.id.clone(),
});
}
}
Ok(AuthOutcome::Authenticated(rec.principal.clone()))
}
}
#[cfg(test)]
mod tests {
use super::super::scopes::Scope;
use super::*;
use std::time::Duration;
fn p(id: &str) -> Principal {
Principal::new(id).with_scopes(ScopeSet::from_iter([Scope::Read]))
}
#[tokio::test]
async fn authenticate_known_token_returns_principal() {
let prov = InMemoryAuthProvider::new();
let raw = prov.issue_token(p("alice"));
let out = prov.authenticate(&raw).await.unwrap();
match out {
AuthOutcome::Authenticated(pp) => assert_eq!(pp.id, "alice"),
other => panic!("expected Authenticated, got {:?}", other),
}
}
#[tokio::test]
async fn authenticate_unknown_token_returns_unauthenticated() {
let prov = InMemoryAuthProvider::new();
prov.issue_token(p("alice"));
let out = prov.authenticate("ydb_nonsense").await.unwrap();
assert!(matches!(out, AuthOutcome::Unauthenticated));
}
#[tokio::test]
async fn revoked_token_is_distinguished_from_unknown() {
let prov = InMemoryAuthProvider::new();
let raw = prov.issue_token(p("alice"));
assert!(prov.revoke(&raw));
let out = prov.authenticate(&raw).await.unwrap();
assert!(matches!(out, AuthOutcome::Revoked { id } if id == "alice"));
}
#[tokio::test]
async fn expired_token_returns_expired_outcome() {
let prov = InMemoryAuthProvider::new();
let past = SystemTime::now() - Duration::from_secs(60);
let raw = prov.issue_token_with_expiry(p("alice"), past);
let out = prov.authenticate(&raw).await.unwrap();
assert!(matches!(out, AuthOutcome::Expired { id } if id == "alice"));
}
#[tokio::test]
async fn future_expiry_still_authenticates() {
let prov = InMemoryAuthProvider::new();
let future = SystemTime::now() + Duration::from_secs(60);
let raw = prov.issue_token_with_expiry(p("alice"), future);
let out = prov.authenticate(&raw).await.unwrap();
assert!(matches!(out, AuthOutcome::Authenticated(_)));
}
#[tokio::test]
async fn revoke_unknown_token_returns_false() {
let prov = InMemoryAuthProvider::new();
assert!(!prov.revoke("ydb_nonsense"));
}
#[test]
fn token_count_reflects_inserts() {
let prov = InMemoryAuthProvider::new();
assert_eq!(prov.token_count(), 0);
prov.issue_token(p("a"));
prov.issue_token(p("b"));
assert_eq!(prov.token_count(), 2);
}
#[tokio::test]
async fn raw_token_is_never_stored() {
let prov = InMemoryAuthProvider::new();
let raw = prov.issue_token(p("alice"));
let ok = prov.authenticate(&raw).await.unwrap();
assert!(matches!(ok, AuthOutcome::Authenticated(_)));
let no = prov.authenticate("ydb_garbage").await.unwrap();
assert!(matches!(no, AuthOutcome::Unauthenticated));
}
#[tokio::test]
async fn provider_is_dyn_dispatchable() {
let prov: Arc<dyn AuthProvider> = Arc::new(InMemoryAuthProvider::new());
let _ = prov.authenticate("x").await;
}
}