acton_htmx/middleware/
csrf.rs

1//! CSRF middleware for protection against Cross-Site Request Forgery attacks
2//!
3//! Provides middleware that validates CSRF tokens on state-changing requests
4//! (POST, PUT, DELETE, PATCH). Integrates with the `CsrfManagerAgent` for
5//! token storage and validation.
6//!
7//! # Security Features
8//!
9//! - Automatic token validation on non-idempotent methods
10//! - Token rotation after successful validation
11//! - 403 Forbidden response on validation failure
12//! - Support for both form data and custom headers
13//! - Session-based token storage
14
15use crate::agents::{CsrfToken, ValidateToken};
16use crate::auth::session::SessionId;
17use crate::state::ActonHtmxState;
18use acton_reactive::prelude::{AgentHandle, AgentHandleInterface};
19use axum::{
20    body::Body,
21    extract::Request,
22    http::{Method, StatusCode},
23    response::{IntoResponse, Response},
24};
25use std::sync::Arc;
26use std::task::{Context, Poll};
27use std::time::Duration;
28use tower::{Layer, Service};
29
30/// CSRF token header name
31pub const CSRF_HEADER_NAME: &str = "x-csrf-token";
32
33/// CSRF token form field name
34pub const CSRF_FORM_FIELD: &str = "_csrf_token";
35
36/// CSRF configuration for middleware
37#[derive(Clone, Debug)]
38pub struct CsrfConfig {
39    /// Header name for CSRF token (default: "x-csrf-token")
40    pub header_name: String,
41    /// Form field name for CSRF token (default: "_csrf_token")
42    pub form_field: String,
43    /// Timeout for agent communication in milliseconds
44    pub agent_timeout_ms: u64,
45    /// Skip CSRF validation for these paths (e.g., webhooks, health checks)
46    pub skip_paths: Vec<String>,
47}
48
49impl Default for CsrfConfig {
50    fn default() -> Self {
51        Self {
52            header_name: CSRF_HEADER_NAME.to_string(),
53            form_field: CSRF_FORM_FIELD.to_string(),
54            agent_timeout_ms: 100,
55            skip_paths: vec![],
56        }
57    }
58}
59
60impl CsrfConfig {
61    /// Create new CSRF config with default values
62    #[must_use]
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Add a path to skip CSRF validation
68    #[must_use]
69    pub fn skip_path(mut self, path: impl Into<String>) -> Self {
70        self.skip_paths.push(path.into());
71        self
72    }
73
74    /// Add multiple paths to skip CSRF validation
75    #[must_use]
76    pub fn skip_paths(mut self, paths: Vec<String>) -> Self {
77        self.skip_paths.extend(paths);
78        self
79    }
80}
81
82/// Layer for CSRF middleware
83///
84/// Requires both `SessionId` and CSRF manager to be available.
85#[derive(Clone)]
86pub struct CsrfLayer {
87    config: CsrfConfig,
88    csrf_manager: AgentHandle,
89}
90
91impl std::fmt::Debug for CsrfLayer {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("CsrfLayer")
94            .field("config", &self.config)
95            .field("csrf_manager", &"AgentHandle")
96            .finish()
97    }
98}
99
100impl CsrfLayer {
101    /// Create new CSRF layer with CSRF manager from state
102    #[must_use]
103    pub fn new(state: &ActonHtmxState) -> Self {
104        Self {
105            config: CsrfConfig::default(),
106            csrf_manager: state.csrf_manager().clone(),
107        }
108    }
109
110    /// Create CSRF layer with custom configuration
111    #[must_use]
112    pub fn with_config(state: &ActonHtmxState, config: CsrfConfig) -> Self {
113        Self {
114            config,
115            csrf_manager: state.csrf_manager().clone(),
116        }
117    }
118
119    /// Create CSRF layer from an existing agent handle
120    #[must_use]
121    pub fn from_handle(csrf_manager: AgentHandle) -> Self {
122        Self {
123            config: CsrfConfig::default(),
124            csrf_manager,
125        }
126    }
127
128    /// Create CSRF layer from handle with custom configuration
129    #[must_use]
130    pub const fn from_handle_with_config(csrf_manager: AgentHandle, config: CsrfConfig) -> Self {
131        Self {
132            config,
133            csrf_manager,
134        }
135    }
136}
137
138impl<S> Layer<S> for CsrfLayer {
139    type Service = CsrfMiddleware<S>;
140
141    fn layer(&self, inner: S) -> Self::Service {
142        CsrfMiddleware {
143            inner,
144            config: Arc::new(self.config.clone()),
145            csrf_manager: self.csrf_manager.clone(),
146        }
147    }
148}
149
150/// CSRF middleware that validates tokens on state-changing requests
151///
152/// Automatically validates CSRF tokens from the `CsrfManagerAgent` on
153/// POST, PUT, DELETE, and PATCH requests.
154#[derive(Clone)]
155pub struct CsrfMiddleware<S> {
156    inner: S,
157    config: Arc<CsrfConfig>,
158    csrf_manager: AgentHandle,
159}
160
161impl<S: std::fmt::Debug> std::fmt::Debug for CsrfMiddleware<S> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("CsrfMiddleware")
164            .field("inner", &self.inner)
165            .field("config", &self.config)
166            .field("csrf_manager", &"AgentHandle")
167            .finish()
168    }
169}
170
171impl<S> Service<Request> for CsrfMiddleware<S>
172where
173    S: Service<Request, Response = Response<Body>> + Clone + Send + 'static,
174    S::Future: Send + 'static,
175{
176    type Response = Response<Body>;
177    type Error = S::Error;
178    type Future = std::pin::Pin<
179        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
180    >;
181
182    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183        self.inner.poll_ready(cx)
184    }
185
186    fn call(&mut self, req: Request) -> Self::Future {
187        let config = self.config.clone();
188        let csrf_manager = self.csrf_manager.clone();
189        let mut inner = self.inner.clone();
190        let timeout = Duration::from_millis(config.agent_timeout_ms);
191
192        // Skip CSRF validation for idempotent methods
193        if is_method_safe(req.method()) {
194            return Box::pin(inner.call(req));
195        }
196
197        // Skip CSRF validation for configured paths
198        let path = req.uri().path().to_string();
199        if config.skip_paths.iter().any(|skip| skip == &path) {
200            return Box::pin(inner.call(req));
201        }
202
203        // Get session ID from request extensions (set by SessionMiddleware)
204        let Some(session_id) = req.extensions().get::<SessionId>().cloned() else {
205            tracing::warn!("CSRF middleware requires SessionMiddleware to be applied first");
206            return Box::pin(async move {
207                Ok(csrf_validation_error(
208                    "Session not found - ensure SessionMiddleware is applied",
209                ))
210            });
211        };
212
213        // Extract CSRF token from request header
214        let Some(token) = extract_csrf_token(&req, &config) else {
215            let method = req.method().clone();
216            tracing::warn!("CSRF token missing for {} {}", method, path);
217            return Box::pin(async move { Ok(csrf_validation_error("CSRF token missing")) });
218        };
219
220        Box::pin(async move {
221            // Validate token with CSRF manager
222            let (validate_request, rx) = ValidateToken::new(session_id, token);
223            csrf_manager.send(validate_request).await;
224
225            let is_valid = match tokio::time::timeout(timeout, rx).await {
226                Ok(Ok(valid)) => valid,
227                Ok(Err(_)) => {
228                    tracing::error!("CSRF validation channel error");
229                    false
230                }
231                Err(_) => {
232                    tracing::error!("CSRF validation timeout");
233                    false
234                }
235            };
236
237            if !is_valid {
238                tracing::warn!("CSRF token validation failed");
239                return Ok(csrf_validation_error("CSRF token validation failed"));
240            }
241
242            // Token validated - proceed with request
243            inner.call(req).await
244        })
245    }
246}
247
248/// Check if HTTP method is considered safe (doesn't modify state)
249const fn is_method_safe(method: &Method) -> bool {
250    matches!(
251        *method,
252        Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
253    )
254}
255
256/// Extract CSRF token from request (header or form data)
257fn extract_csrf_token(req: &Request, config: &CsrfConfig) -> Option<CsrfToken> {
258    // First, try to get token from header
259    if let Some(token_value) = req.headers().get(&config.header_name) {
260        if let Ok(token_str) = token_value.to_str() {
261            return Some(CsrfToken::from_string(token_str.to_string()));
262        }
263    }
264
265    // If not in header, check if it's form data
266    // Note: This is a simplified implementation. In production, you'd want to
267    // properly parse form data without consuming the body.
268    // For now, we'll just check the header.
269
270    None
271}
272
273/// Create a 403 Forbidden response for CSRF validation failure
274fn csrf_validation_error(message: &str) -> Response<Body> {
275    let body = if cfg!(debug_assertions) {
276        // In development, provide detailed error message
277        format!("CSRF validation failed: {message}")
278    } else {
279        // In production, use generic error message
280        "Forbidden".to_string()
281    };
282
283    (StatusCode::FORBIDDEN, body).into_response()
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_csrf_config_default() {
292        let config = CsrfConfig::default();
293        assert_eq!(config.header_name, CSRF_HEADER_NAME);
294        assert_eq!(config.form_field, CSRF_FORM_FIELD);
295        assert_eq!(config.agent_timeout_ms, 100);
296        assert!(config.skip_paths.is_empty());
297    }
298
299    #[test]
300    fn test_csrf_config_skip_path() {
301        let config = CsrfConfig::new().skip_path("/webhooks/github");
302        assert_eq!(config.skip_paths.len(), 1);
303        assert_eq!(config.skip_paths[0], "/webhooks/github");
304    }
305
306    #[test]
307    fn test_csrf_config_skip_paths() {
308        let config = CsrfConfig::new().skip_paths(vec![
309            "/health".to_string(),
310            "/metrics".to_string(),
311        ]);
312        assert_eq!(config.skip_paths.len(), 2);
313        assert!(config.skip_paths.contains(&"/health".to_string()));
314        assert!(config.skip_paths.contains(&"/metrics".to_string()));
315    }
316
317    #[test]
318    fn test_is_method_safe() {
319        assert!(is_method_safe(&Method::GET));
320        assert!(is_method_safe(&Method::HEAD));
321        assert!(is_method_safe(&Method::OPTIONS));
322        assert!(is_method_safe(&Method::TRACE));
323
324        assert!(!is_method_safe(&Method::POST));
325        assert!(!is_method_safe(&Method::PUT));
326        assert!(!is_method_safe(&Method::DELETE));
327        assert!(!is_method_safe(&Method::PATCH));
328    }
329}