Skip to main content

actix_security_core/http/security/
csrf.rs

1//! CSRF (Cross-Site Request Forgery) Protection.
2//!
3//! # Spring Security Equivalent
4//! Similar to Spring Security's CSRF protection with `CsrfFilter`.
5//!
6//! # Features
7//! - Token-based CSRF protection
8//! - Session or cookie-based token storage
9//! - Configurable ignored paths and methods
10//! - Integration with form submissions and AJAX requests
11//!
12//! # Example
13//! ```rust,ignore
14//! use actix_security_core::http::security::csrf::{CsrfProtection, CsrfConfig};
15//!
16//! // Create CSRF protection middleware
17//! let csrf = CsrfProtection::new(CsrfConfig::default());
18//!
19//! App::new()
20//!     .wrap(session_middleware)
21//!     .wrap(csrf)  // Add CSRF protection
22//!     .wrap(security_transform)
23//!
24//! // In templates, include the CSRF token
25//! // <input type="hidden" name="_csrf" value="{{csrf_token}}">
26//!
27//! // For AJAX, send the token in a header
28//! // X-CSRF-TOKEN: {{csrf_token}}
29//! ```
30
31use actix_session::SessionExt;
32use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
33use actix_web::http::Method;
34use actix_web::{body::EitherBody, Error, HttpMessage, HttpResponse};
35use futures_util::future::{ok, LocalBoxFuture, Ready};
36use rand::Rng;
37use regex::Regex;
38use std::rc::Rc;
39use std::sync::Arc;
40
41// =============================================================================
42// CSRF Token
43// =============================================================================
44
45/// CSRF Token.
46///
47/// # Spring Security Equivalent
48/// Similar to `CsrfToken` in Spring Security.
49#[derive(Debug, Clone)]
50pub struct CsrfToken {
51    /// The token value
52    pub token: String,
53    /// Header name for AJAX requests
54    pub header_name: String,
55    /// Parameter name for form submissions
56    pub parameter_name: String,
57}
58
59impl CsrfToken {
60    /// Create a new CSRF token with the given value.
61    pub fn new(token: String) -> Self {
62        Self {
63            token,
64            header_name: "X-CSRF-TOKEN".to_string(),
65            parameter_name: "_csrf".to_string(),
66        }
67    }
68
69    /// Create with custom header and parameter names.
70    pub fn with_names(token: String, header_name: &str, parameter_name: &str) -> Self {
71        Self {
72            token,
73            header_name: header_name.to_string(),
74            parameter_name: parameter_name.to_string(),
75        }
76    }
77
78    /// Get the token value.
79    pub fn value(&self) -> &str {
80        &self.token
81    }
82
83    /// Get the header name.
84    pub fn header_name(&self) -> &str {
85        &self.header_name
86    }
87
88    /// Get the parameter name.
89    pub fn parameter_name(&self) -> &str {
90        &self.parameter_name
91    }
92}
93
94// =============================================================================
95// CSRF Token Repository Trait
96// =============================================================================
97
98/// Trait for storing and retrieving CSRF tokens.
99///
100/// # Spring Security Equivalent
101/// Similar to `CsrfTokenRepository` in Spring Security.
102pub trait CsrfTokenRepository: Send + Sync {
103    /// Generate a new CSRF token.
104    fn generate_token(&self) -> CsrfToken;
105
106    /// Save token to storage.
107    fn save_token(&self, req: &ServiceRequest, token: &CsrfToken) -> Result<(), CsrfError>;
108
109    /// Load token from storage.
110    fn load_token(&self, req: &ServiceRequest) -> Option<CsrfToken>;
111}
112
113// =============================================================================
114// Session CSRF Token Repository
115// =============================================================================
116
117/// Session-based CSRF token repository.
118///
119/// Stores the CSRF token in the user's session.
120///
121/// # Spring Security Equivalent
122/// Similar to `HttpSessionCsrfTokenRepository` in Spring Security.
123#[derive(Clone)]
124pub struct SessionCsrfTokenRepository {
125    /// Session key for storing the token
126    session_key: String,
127    /// Header name for the token
128    header_name: String,
129    /// Parameter name for the token
130    parameter_name: String,
131}
132
133impl Default for SessionCsrfTokenRepository {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl SessionCsrfTokenRepository {
140    /// Create a new session-based repository.
141    pub fn new() -> Self {
142        Self {
143            session_key: "CSRF_TOKEN".to_string(),
144            header_name: "X-CSRF-TOKEN".to_string(),
145            parameter_name: "_csrf".to_string(),
146        }
147    }
148
149    /// Set the session key.
150    pub fn session_key(mut self, key: &str) -> Self {
151        self.session_key = key.to_string();
152        self
153    }
154
155    /// Set the header name.
156    pub fn header_name(mut self, name: &str) -> Self {
157        self.header_name = name.to_string();
158        self
159    }
160
161    /// Set the parameter name.
162    pub fn parameter_name(mut self, name: &str) -> Self {
163        self.parameter_name = name.to_string();
164        self
165    }
166
167    /// Generate a random token value.
168    fn generate_token_value(&self) -> String {
169        let mut rng = rand::thread_rng();
170        let bytes: [u8; 32] = rng.gen();
171        hex::encode(&bytes)
172    }
173}
174
175impl CsrfTokenRepository for SessionCsrfTokenRepository {
176    fn generate_token(&self) -> CsrfToken {
177        CsrfToken::with_names(
178            self.generate_token_value(),
179            &self.header_name,
180            &self.parameter_name,
181        )
182    }
183
184    fn save_token(&self, req: &ServiceRequest, token: &CsrfToken) -> Result<(), CsrfError> {
185        let session = req.get_session();
186        session
187            .insert(&self.session_key, &token.token)
188            .map_err(|e| CsrfError::StorageError(e.to_string()))
189    }
190
191    fn load_token(&self, req: &ServiceRequest) -> Option<CsrfToken> {
192        let session = req.get_session();
193        session
194            .get::<String>(&self.session_key)
195            .ok()
196            .flatten()
197            .map(|token| CsrfToken::with_names(token, &self.header_name, &self.parameter_name))
198    }
199}
200
201// =============================================================================
202// CSRF Configuration
203// =============================================================================
204
205/// CSRF protection configuration.
206///
207/// # Spring Security Equivalent
208/// Similar to `CsrfConfigurer` in Spring Security.
209#[derive(Clone)]
210pub struct CsrfConfig {
211    /// Token repository
212    repository: Arc<dyn CsrfTokenRepository>,
213    /// Methods that require CSRF protection
214    protected_methods: Vec<Method>,
215    /// Paths to ignore (regex patterns)
216    ignored_paths: Vec<Regex>,
217    /// Header name for the token
218    header_name: String,
219    /// Parameter name for the token
220    parameter_name: String,
221}
222
223impl Default for CsrfConfig {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl CsrfConfig {
230    /// Create a new CSRF configuration with default settings.
231    ///
232    /// By default:
233    /// - Uses session-based token storage
234    /// - Protects POST, PUT, DELETE, PATCH methods
235    /// - Token header: X-CSRF-TOKEN
236    /// - Token parameter: _csrf
237    pub fn new() -> Self {
238        Self {
239            repository: Arc::new(SessionCsrfTokenRepository::new()),
240            protected_methods: vec![Method::POST, Method::PUT, Method::DELETE, Method::PATCH],
241            ignored_paths: Vec::new(),
242            header_name: "X-CSRF-TOKEN".to_string(),
243            parameter_name: "_csrf".to_string(),
244        }
245    }
246
247    /// Set a custom token repository.
248    pub fn repository<R: CsrfTokenRepository + 'static>(mut self, repository: R) -> Self {
249        self.repository = Arc::new(repository);
250        self
251    }
252
253    /// Set the methods that require CSRF protection.
254    pub fn protected_methods(mut self, methods: Vec<Method>) -> Self {
255        self.protected_methods = methods;
256        self
257    }
258
259    /// Add a path pattern to ignore.
260    ///
261    /// # Example
262    /// ```rust,ignore
263    /// let config = CsrfConfig::new()
264    ///     .ignore_path("/api/.*")  // Ignore all API paths
265    ///     .ignore_path("/webhook");
266    /// ```
267    pub fn ignore_path(mut self, pattern: &str) -> Self {
268        if let Ok(regex) = Regex::new(pattern) {
269            self.ignored_paths.push(regex);
270        }
271        self
272    }
273
274    /// Set the header name for the token.
275    pub fn header_name(mut self, name: &str) -> Self {
276        self.header_name = name.to_string();
277        self
278    }
279
280    /// Set the parameter name for the token.
281    pub fn parameter_name(mut self, name: &str) -> Self {
282        self.parameter_name = name.to_string();
283        self
284    }
285
286    /// Check if a path should be ignored.
287    fn is_path_ignored(&self, path: &str) -> bool {
288        self.ignored_paths.iter().any(|regex| regex.is_match(path))
289    }
290
291    /// Check if a method requires CSRF protection.
292    fn requires_protection(&self, method: &Method) -> bool {
293        self.protected_methods.contains(method)
294    }
295}
296
297// =============================================================================
298// CSRF Protection Middleware
299// =============================================================================
300
301/// CSRF protection middleware.
302///
303/// # Spring Security Equivalent
304/// Similar to `CsrfFilter` in Spring Security.
305///
306/// # Behavior
307/// 1. For safe methods (GET, HEAD, OPTIONS, TRACE): Generate and store token
308/// 2. For state-changing methods (POST, PUT, DELETE, PATCH): Validate token
309/// 3. Token is available in request extensions as `CsrfToken`
310///
311/// # Example
312/// ```rust,ignore
313/// use actix_security_core::http::security::csrf::{CsrfProtection, CsrfConfig};
314///
315/// App::new()
316///     .wrap(CsrfProtection::new(CsrfConfig::default()))
317/// ```
318#[derive(Clone)]
319pub struct CsrfProtection {
320    config: CsrfConfig,
321}
322
323impl CsrfProtection {
324    /// Create new CSRF protection with the given configuration.
325    pub fn new(config: CsrfConfig) -> Self {
326        Self { config }
327    }
328
329    /// Create with default configuration.
330    pub fn default_config() -> Self {
331        Self::new(CsrfConfig::default())
332    }
333}
334
335impl<S, B> Transform<S, ServiceRequest> for CsrfProtection
336where
337    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
338    S::Future: 'static,
339    B: 'static,
340{
341    type Response = ServiceResponse<EitherBody<B>>;
342    type Error = Error;
343    type Transform = CsrfMiddleware<S>;
344    type InitError = ();
345    type Future = Ready<Result<Self::Transform, Self::InitError>>;
346
347    fn new_transform(&self, service: S) -> Self::Future {
348        ok(CsrfMiddleware {
349            service: Rc::new(service),
350            config: self.config.clone(),
351        })
352    }
353}
354
355/// CSRF middleware service.
356pub struct CsrfMiddleware<S> {
357    service: Rc<S>,
358    config: CsrfConfig,
359}
360
361impl<S, B> Service<ServiceRequest> for CsrfMiddleware<S>
362where
363    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
364    S::Future: 'static,
365    B: 'static,
366{
367    type Response = ServiceResponse<EitherBody<B>>;
368    type Error = Error;
369    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
370
371    forward_ready!(service);
372
373    fn call(&self, req: ServiceRequest) -> Self::Future {
374        let service = self.service.clone();
375        let config = self.config.clone();
376
377        Box::pin(async move {
378            let path = req.path().to_string();
379            let method = req.method().clone();
380
381            // Check if path is ignored
382            if config.is_path_ignored(&path) {
383                let res = service.call(req).await?;
384                return Ok(res.map_into_left_body());
385            }
386
387            // Load or generate token
388            let token = match config.repository.load_token(&req) {
389                Some(token) => token,
390                None => {
391                    let token = config.repository.generate_token();
392                    let _ = config.repository.save_token(&req, &token);
393                    token
394                }
395            };
396
397            // Store token in request extensions for handlers to use
398            req.extensions_mut().insert(token.clone());
399
400            // Check if method requires CSRF validation
401            if config.requires_protection(&method) {
402                // Get token from request (header or parameter)
403                let request_token = get_token_from_request(&req, &config);
404
405                match request_token {
406                    Some(submitted_token) if submitted_token == token.token => {
407                        // Token valid, proceed
408                        let res = service.call(req).await?;
409                        Ok(res.map_into_left_body())
410                    }
411                    Some(_) => {
412                        // Token mismatch
413                        let response = HttpResponse::Forbidden()
414                            .body("CSRF token mismatch")
415                            .map_into_right_body();
416                        Ok(req.into_response(response))
417                    }
418                    None => {
419                        // No token provided
420                        let response = HttpResponse::Forbidden()
421                            .body("CSRF token missing")
422                            .map_into_right_body();
423                        Ok(req.into_response(response))
424                    }
425                }
426            } else {
427                // Safe method, no validation needed
428                let res = service.call(req).await?;
429                Ok(res.map_into_left_body())
430            }
431        })
432    }
433}
434
435/// Extract CSRF token from request (header or query parameter).
436fn get_token_from_request(req: &ServiceRequest, config: &CsrfConfig) -> Option<String> {
437    // Try header first
438    if let Some(header_value) = req.headers().get(&config.header_name) {
439        if let Ok(token) = header_value.to_str() {
440            return Some(token.to_string());
441        }
442    }
443
444    // Try query string
445    let query_string = req.query_string();
446    let param_prefix = format!("{}=", config.parameter_name);
447    for pair in query_string.split('&') {
448        if pair.starts_with(&param_prefix) {
449            return Some(pair[param_prefix.len()..].to_string());
450        }
451    }
452
453    None
454}
455
456// =============================================================================
457// CSRF Error
458// =============================================================================
459
460/// CSRF-related errors.
461#[derive(Debug)]
462pub enum CsrfError {
463    /// Missing CSRF token
464    MissingToken,
465    /// Invalid CSRF token
466    InvalidToken,
467    /// Token mismatch
468    TokenMismatch,
469    /// Storage error
470    StorageError(String),
471}
472
473impl std::fmt::Display for CsrfError {
474    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475        match self {
476            CsrfError::MissingToken => write!(f, "CSRF token missing"),
477            CsrfError::InvalidToken => write!(f, "Invalid CSRF token"),
478            CsrfError::TokenMismatch => write!(f, "CSRF token mismatch"),
479            CsrfError::StorageError(e) => write!(f, "CSRF storage error: {}", e),
480        }
481    }
482}
483
484impl std::error::Error for CsrfError {}
485
486// =============================================================================
487// Helper for hex encoding (simple implementation)
488// =============================================================================
489
490mod hex {
491    const HEX_CHARS: &[u8; 16] = b"0123456789abcdef";
492
493    pub fn encode(bytes: &[u8]) -> String {
494        let mut result = String::with_capacity(bytes.len() * 2);
495        for byte in bytes {
496            result.push(HEX_CHARS[(byte >> 4) as usize] as char);
497            result.push(HEX_CHARS[(byte & 0x0f) as usize] as char);
498        }
499        result
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_csrf_token() {
509        let token = CsrfToken::new("test-token".to_string());
510        assert_eq!(token.value(), "test-token");
511        assert_eq!(token.header_name(), "X-CSRF-TOKEN");
512        assert_eq!(token.parameter_name(), "_csrf");
513    }
514
515    #[test]
516    fn test_csrf_token_custom_names() {
517        let token = CsrfToken::with_names("test-token".to_string(), "X-Custom-CSRF", "csrf_token");
518        assert_eq!(token.header_name(), "X-Custom-CSRF");
519        assert_eq!(token.parameter_name(), "csrf_token");
520    }
521
522    #[test]
523    fn test_csrf_config_default() {
524        let config = CsrfConfig::default();
525        assert_eq!(config.header_name, "X-CSRF-TOKEN");
526        assert_eq!(config.parameter_name, "_csrf");
527        assert!(config.protected_methods.contains(&Method::POST));
528        assert!(config.protected_methods.contains(&Method::PUT));
529        assert!(config.protected_methods.contains(&Method::DELETE));
530        assert!(config.protected_methods.contains(&Method::PATCH));
531        assert!(!config.protected_methods.contains(&Method::GET));
532    }
533
534    #[test]
535    fn test_csrf_config_ignore_path() {
536        let config = CsrfConfig::new()
537            .ignore_path("/api/.*")
538            .ignore_path("/webhook");
539
540        assert!(config.is_path_ignored("/api/users"));
541        assert!(config.is_path_ignored("/api/posts/123"));
542        assert!(config.is_path_ignored("/webhook"));
543        assert!(!config.is_path_ignored("/admin"));
544    }
545
546    #[test]
547    fn test_csrf_config_protected_methods() {
548        let config = CsrfConfig::new().protected_methods(vec![Method::POST]);
549
550        assert!(config.requires_protection(&Method::POST));
551        assert!(!config.requires_protection(&Method::PUT));
552        assert!(!config.requires_protection(&Method::GET));
553    }
554
555    #[test]
556    fn test_session_csrf_repository() {
557        let repo = SessionCsrfTokenRepository::new()
558            .session_key("MY_CSRF")
559            .header_name("X-My-CSRF")
560            .parameter_name("my_csrf");
561
562        let token = repo.generate_token();
563        assert_eq!(token.header_name(), "X-My-CSRF");
564        assert_eq!(token.parameter_name(), "my_csrf");
565        assert_eq!(token.token.len(), 64); // 32 bytes = 64 hex chars
566    }
567
568    #[test]
569    fn test_hex_encode() {
570        assert_eq!(hex::encode(&[0x00]), "00");
571        assert_eq!(hex::encode(&[0xff]), "ff");
572        assert_eq!(hex::encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
573    }
574}