acton_htmx/oauth2/
agent.rs

1//! OAuth2 state management agent
2//!
3//! This module provides an acton-reactive agent for managing OAuth2 state tokens
4//! and preventing CSRF attacks during the OAuth2 flow.
5
6use acton_reactive::prelude::*;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::SystemTime;
10use tokio::sync::{oneshot, Mutex};
11
12use super::types::{OAuthProvider, OAuthState};
13
14/// Type alias for response channels (web handler pattern)
15pub type ResponseChannel<T> = Arc<Mutex<Option<oneshot::Sender<T>>>>;
16
17/// OAuth2 state management agent
18///
19/// This agent stores and validates OAuth2 state tokens to prevent CSRF attacks.
20/// State tokens are ephemeral and expire after 10 minutes.
21#[derive(Debug, Default, Clone)]
22pub struct OAuth2Agent {
23    /// Map of state tokens to their metadata
24    states: HashMap<String, OAuthState>,
25}
26
27impl OAuth2Agent {
28    /// Clean up expired state tokens
29    fn cleanup_expired(&mut self) {
30        let now = SystemTime::now();
31        self.states.retain(|_, state| state.expires_at > now);
32    }
33}
34
35/// Message to generate a new OAuth2 state token (web handler)
36#[derive(Debug, Clone)]
37pub struct GenerateState {
38    /// Provider for this state
39    pub provider: OAuthProvider,
40    /// Response channel
41    pub response_tx: ResponseChannel<OAuthState>,
42}
43
44impl GenerateState {
45    /// Create a new generate state request with response channel
46    #[must_use]
47    pub fn new(provider: OAuthProvider) -> (Self, oneshot::Receiver<OAuthState>) {
48        let (tx, rx) = oneshot::channel();
49        (
50            Self {
51                provider,
52                response_tx: Arc::new(Mutex::new(Some(tx))),
53            },
54            rx,
55        )
56    }
57}
58
59/// Message to validate an OAuth2 state token (web handler)
60#[derive(Debug, Clone)]
61pub struct ValidateState {
62    /// State token to validate
63    pub token: String,
64    /// Response channel
65    pub response_tx: ResponseChannel<Option<OAuthState>>,
66}
67
68impl ValidateState {
69    /// Create a new validate state request with response channel
70    #[must_use]
71    pub fn new(token: String) -> (Self, oneshot::Receiver<Option<OAuthState>>) {
72        let (tx, rx) = oneshot::channel();
73        (
74            Self {
75                token,
76                response_tx: Arc::new(Mutex::new(Some(tx))),
77            },
78            rx,
79        )
80    }
81}
82
83/// Message to remove a state token (after successful use)
84#[derive(Debug, Clone)]
85pub struct RemoveState {
86    /// State token to remove
87    pub token: String,
88}
89
90/// Message to clean up expired state tokens
91#[derive(Debug, Clone)]
92pub struct CleanupExpired;
93
94impl OAuth2Agent {
95    /// Spawn OAuth2 manager agent
96    ///
97    /// # Errors
98    ///
99    /// Returns error if agent configuration or spawning fails
100    pub async fn spawn(runtime: &mut AgentRuntime) -> anyhow::Result<AgentHandle> {
101        let config = AgentConfig::new(Ern::with_root("oauth2_manager")?, None, None)?;
102
103        let mut builder = runtime.new_agent_with_config::<Self>(config).await;
104
105        // Configure handlers using mutate_on (all operations mutate state)
106        builder
107            .mutate_on::<GenerateState>(|agent, envelope| {
108                let response_tx = envelope.message().response_tx.clone();
109                let provider = envelope.message().provider;
110
111                // Clean up expired tokens periodically
112                agent.model.cleanup_expired();
113
114                // Generate and store state token
115                let state = OAuthState::generate(provider);
116                agent.model.states.insert(state.token.clone(), state.clone());
117
118                tracing::debug!(
119                    provider = ?provider,
120                    token = %state.token,
121                    "Generated OAuth2 state token"
122                );
123
124                AgentReply::from_async(async move {
125                    let mut guard = response_tx.lock().await;
126                    if let Some(tx) = guard.take() {
127                        let _ = tx.send(state);
128                    }
129                })
130            })
131            .mutate_on::<ValidateState>(|agent, envelope| {
132                let token = envelope.message().token.clone();
133                let response_tx = envelope.message().response_tx.clone();
134
135                // Clean up expired tokens
136                agent.model.cleanup_expired();
137
138                // Validate state token
139                let state = agent.model.states.get(&token).and_then(|state| {
140                    if state.is_expired() {
141                        tracing::warn!(token = %token, "OAuth2 state token expired");
142                        None
143                    } else {
144                        tracing::debug!(
145                            token = %token,
146                            provider = ?state.provider,
147                            "Validated OAuth2 state token"
148                        );
149                        Some(state.clone())
150                    }
151                });
152
153                AgentReply::from_async(async move {
154                    let mut guard = response_tx.lock().await;
155                    if let Some(tx) = guard.take() {
156                        let _ = tx.send(state);
157                    }
158                })
159            })
160            .mutate_on::<RemoveState>(|agent, envelope| {
161                let token = envelope.message().token.clone();
162
163                if agent.model.states.remove(&token).is_some() {
164                    tracing::debug!(token = %token, "Removed OAuth2 state token");
165                }
166
167                AgentReply::immediate()
168            })
169            .mutate_on::<CleanupExpired>(|agent, _envelope| {
170                let before = agent.model.states.len();
171                agent.model.cleanup_expired();
172                let removed = before - agent.model.states.len();
173
174                if removed > 0 {
175                    tracing::debug!(
176                        removed = removed,
177                        remaining = agent.model.states.len(),
178                        "Cleaned up expired OAuth2 state tokens"
179                    );
180                }
181
182                AgentReply::immediate()
183            })
184            .after_start(|_agent| async {
185                tracing::info!("OAuth2 manager agent started");
186            })
187            .after_stop(|agent| {
188                let token_count = agent.model.states.len();
189                async move {
190                    tracing::info!(
191                        tokens = token_count,
192                        "OAuth2 manager agent stopped"
193                    );
194                }
195            });
196
197        Ok(builder.start().await)
198    }
199}