Skip to main content

ferro_rs/csrf/
middleware.rs

1//! CSRF protection middleware
2
3use crate::http::{HttpResponse, Response};
4use crate::middleware::{Middleware, Next};
5use crate::session::get_csrf_token;
6use crate::Request;
7use async_trait::async_trait;
8
9/// CSRF protection middleware
10///
11/// Validates CSRF tokens on state-changing requests (POST, PUT, PATCH, DELETE).
12///
13/// # Token Sources
14///
15/// The middleware looks for the CSRF token in the following order:
16/// 1. `X-CSRF-TOKEN` header (used by Inertia.js)
17/// 2. `X-XSRF-TOKEN` header (Laravel convention)
18/// 3. `_token` form field (traditional forms)
19///
20/// # Usage
21///
22/// ```rust,ignore
23/// use ferro_rs::{global_middleware, CsrfMiddleware};
24///
25/// global_middleware!(CsrfMiddleware::new());
26/// ```
27pub struct CsrfMiddleware {
28    /// HTTP methods that require CSRF validation
29    protected_methods: Vec<&'static str>,
30    /// Paths to exclude from CSRF validation (e.g., webhooks)
31    except: Vec<String>,
32}
33
34impl CsrfMiddleware {
35    /// Create a new CSRF middleware with default settings
36    ///
37    /// Protects: POST, PUT, PATCH, DELETE
38    pub fn new() -> Self {
39        Self {
40            protected_methods: vec!["POST", "PUT", "PATCH", "DELETE"],
41            except: Vec::new(),
42        }
43    }
44
45    /// Add paths to exclude from CSRF validation
46    ///
47    /// Useful for webhooks or API endpoints that use other authentication.
48    ///
49    /// # Example
50    ///
51    /// ```rust,ignore
52    /// let csrf = CsrfMiddleware::new()
53    ///     .except(vec!["/webhooks/*", "/api/external/*"]);
54    /// ```
55    pub fn except(mut self, paths: Vec<impl Into<String>>) -> Self {
56        self.except = paths.into_iter().map(|p| p.into()).collect();
57        self
58    }
59
60    /// Check if a path should be excluded from CSRF validation
61    fn is_excluded(&self, path: &str) -> bool {
62        for pattern in &self.except {
63            if pattern.ends_with('*') {
64                let prefix = &pattern[..pattern.len() - 1];
65                if path.starts_with(prefix) {
66                    return true;
67                }
68            } else if pattern == path {
69                return true;
70            }
71        }
72        false
73    }
74}
75
76impl Default for CsrfMiddleware {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82#[async_trait]
83impl Middleware for CsrfMiddleware {
84    async fn handle(&self, request: Request, next: Next) -> Response {
85        let method = request.method().as_str();
86
87        // Only validate state-changing requests
88        if !self.protected_methods.contains(&method) {
89            return next(request).await;
90        }
91
92        // Check if path is excluded
93        if self.is_excluded(request.path()) {
94            return next(request).await;
95        }
96
97        // Get expected token from session
98        let expected_token = match get_csrf_token() {
99            Some(token) => token,
100            None => {
101                return Err(HttpResponse::json(serde_json::json!({
102                    "message": "Session not found. CSRF validation failed."
103                }))
104                .status(500));
105            }
106        };
107
108        // Get provided token from request
109        // Check headers first (Inertia.js and AJAX)
110        let provided_token = request
111            .header("X-CSRF-TOKEN")
112            .or_else(|| request.header("X-XSRF-TOKEN"))
113            .map(|s| s.to_string());
114
115        match provided_token {
116            Some(token) if constant_time_compare(&token, &expected_token) => {
117                // Token is valid
118                next(request).await
119            }
120            _ => {
121                // Token mismatch or missing
122                // Return 419 status (Laravel convention)
123                Err(HttpResponse::json(serde_json::json!({
124                    "message": "CSRF token mismatch."
125                }))
126                .status(419))
127            }
128        }
129    }
130}
131
132/// Constant-time string comparison to prevent timing attacks
133///
134/// This ensures an attacker can't determine how much of the token is correct
135/// based on response time.
136fn constant_time_compare(a: &str, b: &str) -> bool {
137    if a.len() != b.len() {
138        return false;
139    }
140
141    a.bytes()
142        .zip(b.bytes())
143        .fold(0, |acc, (x, y)| acc | (x ^ y))
144        == 0
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_constant_time_compare() {
153        assert!(constant_time_compare("abc123", "abc123"));
154        assert!(!constant_time_compare("abc123", "abc124"));
155        assert!(!constant_time_compare("abc123", "abc12"));
156        assert!(!constant_time_compare("", "a"));
157    }
158
159    #[test]
160    fn test_is_excluded() {
161        let csrf = CsrfMiddleware::new().except(vec!["/webhooks/*", "/api/public"]);
162
163        assert!(csrf.is_excluded("/webhooks/stripe"));
164        assert!(csrf.is_excluded("/webhooks/github/events"));
165        assert!(csrf.is_excluded("/api/public"));
166        assert!(!csrf.is_excluded("/api/private"));
167        assert!(!csrf.is_excluded("/login"));
168    }
169}