acton_htmx/oauth2/
agent.rs1use acton_reactive::prelude::*;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::SystemTime;
10use tokio::sync::{oneshot, Mutex};
11
12use super::types::{OAuthProvider, OAuthState};
13
14pub type ResponseChannel<T> = Arc<Mutex<Option<oneshot::Sender<T>>>>;
16
17#[derive(Debug, Default, Clone)]
22pub struct OAuth2Agent {
23 states: HashMap<String, OAuthState>,
25}
26
27impl OAuth2Agent {
28 fn cleanup_expired(&mut self) {
30 let now = SystemTime::now();
31 self.states.retain(|_, state| state.expires_at > now);
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct GenerateState {
38 pub provider: OAuthProvider,
40 pub response_tx: ResponseChannel<OAuthState>,
42}
43
44impl GenerateState {
45 #[must_use]
47 pub fn new(provider: OAuthProvider) -> (Self, oneshot::Receiver<OAuthState>) {
48 let (tx, rx) = oneshot::channel();
49 (
50 Self {
51 provider,
52 response_tx: Arc::new(Mutex::new(Some(tx))),
53 },
54 rx,
55 )
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct ValidateState {
62 pub token: String,
64 pub response_tx: ResponseChannel<Option<OAuthState>>,
66}
67
68impl ValidateState {
69 #[must_use]
71 pub fn new(token: String) -> (Self, oneshot::Receiver<Option<OAuthState>>) {
72 let (tx, rx) = oneshot::channel();
73 (
74 Self {
75 token,
76 response_tx: Arc::new(Mutex::new(Some(tx))),
77 },
78 rx,
79 )
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct RemoveState {
86 pub token: String,
88}
89
90#[derive(Debug, Clone)]
92pub struct CleanupExpired;
93
94impl OAuth2Agent {
95 pub async fn spawn(runtime: &mut AgentRuntime) -> anyhow::Result<AgentHandle> {
101 let config = AgentConfig::new(Ern::with_root("oauth2_manager")?, None, None)?;
102
103 let mut builder = runtime.new_agent_with_config::<Self>(config).await;
104
105 builder
107 .mutate_on::<GenerateState>(|agent, envelope| {
108 let response_tx = envelope.message().response_tx.clone();
109 let provider = envelope.message().provider;
110
111 agent.model.cleanup_expired();
113
114 let state = OAuthState::generate(provider);
116 agent.model.states.insert(state.token.clone(), state.clone());
117
118 tracing::debug!(
119 provider = ?provider,
120 token = %state.token,
121 "Generated OAuth2 state token"
122 );
123
124 AgentReply::from_async(async move {
125 let mut guard = response_tx.lock().await;
126 if let Some(tx) = guard.take() {
127 let _ = tx.send(state);
128 }
129 })
130 })
131 .mutate_on::<ValidateState>(|agent, envelope| {
132 let token = envelope.message().token.clone();
133 let response_tx = envelope.message().response_tx.clone();
134
135 agent.model.cleanup_expired();
137
138 let state = agent.model.states.get(&token).and_then(|state| {
140 if state.is_expired() {
141 tracing::warn!(token = %token, "OAuth2 state token expired");
142 None
143 } else {
144 tracing::debug!(
145 token = %token,
146 provider = ?state.provider,
147 "Validated OAuth2 state token"
148 );
149 Some(state.clone())
150 }
151 });
152
153 AgentReply::from_async(async move {
154 let mut guard = response_tx.lock().await;
155 if let Some(tx) = guard.take() {
156 let _ = tx.send(state);
157 }
158 })
159 })
160 .mutate_on::<RemoveState>(|agent, envelope| {
161 let token = envelope.message().token.clone();
162
163 if agent.model.states.remove(&token).is_some() {
164 tracing::debug!(token = %token, "Removed OAuth2 state token");
165 }
166
167 AgentReply::immediate()
168 })
169 .mutate_on::<CleanupExpired>(|agent, _envelope| {
170 let before = agent.model.states.len();
171 agent.model.cleanup_expired();
172 let removed = before - agent.model.states.len();
173
174 if removed > 0 {
175 tracing::debug!(
176 removed = removed,
177 remaining = agent.model.states.len(),
178 "Cleaned up expired OAuth2 state tokens"
179 );
180 }
181
182 AgentReply::immediate()
183 })
184 .after_start(|_agent| async {
185 tracing::info!("OAuth2 manager agent started");
186 })
187 .after_stop(|agent| {
188 let token_count = agent.model.states.len();
189 async move {
190 tracing::info!(
191 tokens = token_count,
192 "OAuth2 manager agent stopped"
193 );
194 }
195 });
196
197 Ok(builder.start().await)
198 }
199}