use acton_reactive::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::{oneshot, Mutex};
use super::types::{OAuthProvider, OAuthState};
pub type ResponseChannel<T> = Arc<Mutex<Option<oneshot::Sender<T>>>>;
#[derive(Debug, Default, Clone)]
pub struct OAuth2Agent {
states: HashMap<String, OAuthState>,
}
impl OAuth2Agent {
fn cleanup_expired(&mut self) {
let now = SystemTime::now();
self.states.retain(|_, state| state.expires_at > now);
}
}
#[derive(Debug, Clone)]
pub struct GenerateState {
pub provider: OAuthProvider,
pub response_tx: ResponseChannel<OAuthState>,
}
impl GenerateState {
#[must_use]
pub fn new(provider: OAuthProvider) -> (Self, oneshot::Receiver<OAuthState>) {
let (tx, rx) = oneshot::channel();
(
Self {
provider,
response_tx: Arc::new(Mutex::new(Some(tx))),
},
rx,
)
}
}
#[derive(Debug, Clone)]
pub struct ValidateState {
pub token: String,
pub response_tx: ResponseChannel<Option<OAuthState>>,
}
impl ValidateState {
#[must_use]
pub fn new(token: String) -> (Self, oneshot::Receiver<Option<OAuthState>>) {
let (tx, rx) = oneshot::channel();
(
Self {
token,
response_tx: Arc::new(Mutex::new(Some(tx))),
},
rx,
)
}
}
#[derive(Debug, Clone)]
pub struct RemoveState {
pub token: String,
}
#[derive(Debug, Clone)]
pub struct CleanupExpired;
impl OAuth2Agent {
pub async fn spawn(runtime: &mut AgentRuntime) -> anyhow::Result<AgentHandle> {
let config = AgentConfig::new(Ern::with_root("oauth2_manager")?, None, None)?;
let mut builder = runtime.new_agent_with_config::<Self>(config).await;
builder
.mutate_on::<GenerateState>(|agent, envelope| {
let response_tx = envelope.message().response_tx.clone();
let provider = envelope.message().provider;
agent.model.cleanup_expired();
let state = OAuthState::generate(provider);
agent.model.states.insert(state.token.clone(), state.clone());
tracing::debug!(
provider = ?provider,
token = %state.token,
"Generated OAuth2 state token"
);
AgentReply::from_async(async move {
let mut guard = response_tx.lock().await;
if let Some(tx) = guard.take() {
let _ = tx.send(state);
}
})
})
.mutate_on::<ValidateState>(|agent, envelope| {
let token = envelope.message().token.clone();
let response_tx = envelope.message().response_tx.clone();
agent.model.cleanup_expired();
let state = agent.model.states.get(&token).and_then(|state| {
if state.is_expired() {
tracing::warn!(token = %token, "OAuth2 state token expired");
None
} else {
tracing::debug!(
token = %token,
provider = ?state.provider,
"Validated OAuth2 state token"
);
Some(state.clone())
}
});
AgentReply::from_async(async move {
let mut guard = response_tx.lock().await;
if let Some(tx) = guard.take() {
let _ = tx.send(state);
}
})
})
.mutate_on::<RemoveState>(|agent, envelope| {
let token = envelope.message().token.clone();
if agent.model.states.remove(&token).is_some() {
tracing::debug!(token = %token, "Removed OAuth2 state token");
}
AgentReply::immediate()
})
.mutate_on::<CleanupExpired>(|agent, _envelope| {
let before = agent.model.states.len();
agent.model.cleanup_expired();
let removed = before - agent.model.states.len();
if removed > 0 {
tracing::debug!(
removed = removed,
remaining = agent.model.states.len(),
"Cleaned up expired OAuth2 state tokens"
);
}
AgentReply::immediate()
})
.after_start(|_agent| async {
tracing::info!("OAuth2 manager agent started");
})
.after_stop(|agent| {
let token_count = agent.model.states.len();
async move {
tracing::info!(
tokens = token_count,
"OAuth2 manager agent stopped"
);
}
});
Ok(builder.start().await)
}
}