use async_trait::async_trait;
use dashmap::DashMap;
use parking_lot::Mutex;
use std::sync::Arc;
use crate::auth::credential::AuthCredential;
use crate::error::Result;
#[async_trait]
pub trait CredentialService: Send + Sync + std::fmt::Debug + 'static {
async fn load(
&self,
app_name: &str,
user_id: &str,
key: &str,
) -> Result<Option<AuthCredential>>;
async fn save(
&self,
app_name: &str,
user_id: &str,
key: &str,
value: &AuthCredential,
) -> Result<()>;
async fn delete(&self, _app_name: &str, _user_id: &str, _key: &str) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Default)]
pub struct InMemoryCredentialService {
by_key: DashMap<(String, String, String), AuthCredential>,
}
impl InMemoryCredentialService {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl CredentialService for InMemoryCredentialService {
async fn load(
&self,
app_name: &str,
user_id: &str,
key: &str,
) -> Result<Option<AuthCredential>> {
let k = (app_name.to_string(), user_id.to_string(), key.to_string());
Ok(self.by_key.get(&k).map(|v| v.value().clone()))
}
async fn save(
&self,
app_name: &str,
user_id: &str,
key: &str,
value: &AuthCredential,
) -> Result<()> {
self.by_key.insert(
(app_name.to_string(), user_id.to_string(), key.to_string()),
value.clone(),
);
Ok(())
}
async fn delete(&self, app_name: &str, user_id: &str, key: &str) -> Result<()> {
let k = (app_name.to_string(), user_id.to_string(), key.to_string());
self.by_key.remove(&k);
Ok(())
}
}
#[derive(Debug)]
pub struct SessionStateCredentialService {
overlay: Arc<Mutex<DashMap<String, AuthCredential>>>,
}
impl Default for SessionStateCredentialService {
fn default() -> Self {
Self {
overlay: Arc::new(Mutex::new(DashMap::new())),
}
}
}
impl SessionStateCredentialService {
#[must_use]
pub fn new() -> Self {
Self::default()
}
fn state_key(app: &str, user: &str, key: &str) -> String {
format!("temp:adk_auth:{app}:{user}:{key}")
}
}
#[async_trait]
impl CredentialService for SessionStateCredentialService {
async fn load(
&self,
app_name: &str,
user_id: &str,
key: &str,
) -> Result<Option<AuthCredential>> {
let map = self.overlay.lock();
Ok(map
.get(&Self::state_key(app_name, user_id, key))
.map(|r| r.clone()))
}
async fn save(
&self,
app_name: &str,
user_id: &str,
key: &str,
value: &AuthCredential,
) -> Result<()> {
let map = self.overlay.lock();
map.insert(Self::state_key(app_name, user_id, key), value.clone());
Ok(())
}
async fn delete(&self, app_name: &str, user_id: &str, key: &str) -> Result<()> {
let map = self.overlay.lock();
map.remove(&Self::state_key(app_name, user_id, key));
Ok(())
}
}
#[must_use]
pub fn session_state_key(key: &str) -> String {
format!("temp:adk_auth:{key}")
}
#[must_use]
pub fn render_to_state(cred: &AuthCredential) -> serde_json::Value {
serde_json::to_value(cred).unwrap_or(serde_json::Value::Null)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn in_memory_save_load_delete() {
let svc = InMemoryCredentialService::new();
let c = AuthCredential::api_key("hello");
svc.save("app", "u", "k", &c).await.unwrap();
assert_eq!(svc.load("app", "u", "k").await.unwrap(), Some(c.clone()));
svc.delete("app", "u", "k").await.unwrap();
assert_eq!(svc.load("app", "u", "k").await.unwrap(), None);
}
#[tokio::test]
async fn session_state_save_load() {
let svc = SessionStateCredentialService::new();
let c = AuthCredential::bearer("tok");
svc.save("app", "u", "k", &c).await.unwrap();
assert_eq!(svc.load("app", "u", "k").await.unwrap(), Some(c));
}
}