actix_security_core/http/security/
csrf.rs1use actix_session::SessionExt;
32use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
33use actix_web::http::Method;
34use actix_web::{body::EitherBody, Error, HttpMessage, HttpResponse};
35use futures_util::future::{ok, LocalBoxFuture, Ready};
36use rand::Rng;
37use regex::Regex;
38use std::rc::Rc;
39use std::sync::Arc;
40
41#[derive(Debug, Clone)]
50pub struct CsrfToken {
51 pub token: String,
53 pub header_name: String,
55 pub parameter_name: String,
57}
58
59impl CsrfToken {
60 pub fn new(token: String) -> Self {
62 Self {
63 token,
64 header_name: "X-CSRF-TOKEN".to_string(),
65 parameter_name: "_csrf".to_string(),
66 }
67 }
68
69 pub fn with_names(token: String, header_name: &str, parameter_name: &str) -> Self {
71 Self {
72 token,
73 header_name: header_name.to_string(),
74 parameter_name: parameter_name.to_string(),
75 }
76 }
77
78 pub fn value(&self) -> &str {
80 &self.token
81 }
82
83 pub fn header_name(&self) -> &str {
85 &self.header_name
86 }
87
88 pub fn parameter_name(&self) -> &str {
90 &self.parameter_name
91 }
92}
93
94pub trait CsrfTokenRepository: Send + Sync {
103 fn generate_token(&self) -> CsrfToken;
105
106 fn save_token(&self, req: &ServiceRequest, token: &CsrfToken) -> Result<(), CsrfError>;
108
109 fn load_token(&self, req: &ServiceRequest) -> Option<CsrfToken>;
111}
112
113#[derive(Clone)]
124pub struct SessionCsrfTokenRepository {
125 session_key: String,
127 header_name: String,
129 parameter_name: String,
131}
132
133impl Default for SessionCsrfTokenRepository {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl SessionCsrfTokenRepository {
140 pub fn new() -> Self {
142 Self {
143 session_key: "CSRF_TOKEN".to_string(),
144 header_name: "X-CSRF-TOKEN".to_string(),
145 parameter_name: "_csrf".to_string(),
146 }
147 }
148
149 pub fn session_key(mut self, key: &str) -> Self {
151 self.session_key = key.to_string();
152 self
153 }
154
155 pub fn header_name(mut self, name: &str) -> Self {
157 self.header_name = name.to_string();
158 self
159 }
160
161 pub fn parameter_name(mut self, name: &str) -> Self {
163 self.parameter_name = name.to_string();
164 self
165 }
166
167 fn generate_token_value(&self) -> String {
169 let mut rng = rand::thread_rng();
170 let bytes: [u8; 32] = rng.gen();
171 hex::encode(&bytes)
172 }
173}
174
175impl CsrfTokenRepository for SessionCsrfTokenRepository {
176 fn generate_token(&self) -> CsrfToken {
177 CsrfToken::with_names(
178 self.generate_token_value(),
179 &self.header_name,
180 &self.parameter_name,
181 )
182 }
183
184 fn save_token(&self, req: &ServiceRequest, token: &CsrfToken) -> Result<(), CsrfError> {
185 let session = req.get_session();
186 session
187 .insert(&self.session_key, &token.token)
188 .map_err(|e| CsrfError::StorageError(e.to_string()))
189 }
190
191 fn load_token(&self, req: &ServiceRequest) -> Option<CsrfToken> {
192 let session = req.get_session();
193 session
194 .get::<String>(&self.session_key)
195 .ok()
196 .flatten()
197 .map(|token| CsrfToken::with_names(token, &self.header_name, &self.parameter_name))
198 }
199}
200
201#[derive(Clone)]
210pub struct CsrfConfig {
211 repository: Arc<dyn CsrfTokenRepository>,
213 protected_methods: Vec<Method>,
215 ignored_paths: Vec<Regex>,
217 header_name: String,
219 parameter_name: String,
221}
222
223impl Default for CsrfConfig {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229impl CsrfConfig {
230 pub fn new() -> Self {
238 Self {
239 repository: Arc::new(SessionCsrfTokenRepository::new()),
240 protected_methods: vec![Method::POST, Method::PUT, Method::DELETE, Method::PATCH],
241 ignored_paths: Vec::new(),
242 header_name: "X-CSRF-TOKEN".to_string(),
243 parameter_name: "_csrf".to_string(),
244 }
245 }
246
247 pub fn repository<R: CsrfTokenRepository + 'static>(mut self, repository: R) -> Self {
249 self.repository = Arc::new(repository);
250 self
251 }
252
253 pub fn protected_methods(mut self, methods: Vec<Method>) -> Self {
255 self.protected_methods = methods;
256 self
257 }
258
259 pub fn ignore_path(mut self, pattern: &str) -> Self {
268 if let Ok(regex) = Regex::new(pattern) {
269 self.ignored_paths.push(regex);
270 }
271 self
272 }
273
274 pub fn header_name(mut self, name: &str) -> Self {
276 self.header_name = name.to_string();
277 self
278 }
279
280 pub fn parameter_name(mut self, name: &str) -> Self {
282 self.parameter_name = name.to_string();
283 self
284 }
285
286 fn is_path_ignored(&self, path: &str) -> bool {
288 self.ignored_paths.iter().any(|regex| regex.is_match(path))
289 }
290
291 fn requires_protection(&self, method: &Method) -> bool {
293 self.protected_methods.contains(method)
294 }
295}
296
297#[derive(Clone)]
319pub struct CsrfProtection {
320 config: CsrfConfig,
321}
322
323impl CsrfProtection {
324 pub fn new(config: CsrfConfig) -> Self {
326 Self { config }
327 }
328
329 pub fn default_config() -> Self {
331 Self::new(CsrfConfig::default())
332 }
333}
334
335impl<S, B> Transform<S, ServiceRequest> for CsrfProtection
336where
337 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
338 S::Future: 'static,
339 B: 'static,
340{
341 type Response = ServiceResponse<EitherBody<B>>;
342 type Error = Error;
343 type Transform = CsrfMiddleware<S>;
344 type InitError = ();
345 type Future = Ready<Result<Self::Transform, Self::InitError>>;
346
347 fn new_transform(&self, service: S) -> Self::Future {
348 ok(CsrfMiddleware {
349 service: Rc::new(service),
350 config: self.config.clone(),
351 })
352 }
353}
354
355pub struct CsrfMiddleware<S> {
357 service: Rc<S>,
358 config: CsrfConfig,
359}
360
361impl<S, B> Service<ServiceRequest> for CsrfMiddleware<S>
362where
363 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
364 S::Future: 'static,
365 B: 'static,
366{
367 type Response = ServiceResponse<EitherBody<B>>;
368 type Error = Error;
369 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
370
371 forward_ready!(service);
372
373 fn call(&self, req: ServiceRequest) -> Self::Future {
374 let service = self.service.clone();
375 let config = self.config.clone();
376
377 Box::pin(async move {
378 let path = req.path().to_string();
379 let method = req.method().clone();
380
381 if config.is_path_ignored(&path) {
383 let res = service.call(req).await?;
384 return Ok(res.map_into_left_body());
385 }
386
387 let token = match config.repository.load_token(&req) {
389 Some(token) => token,
390 None => {
391 let token = config.repository.generate_token();
392 let _ = config.repository.save_token(&req, &token);
393 token
394 }
395 };
396
397 req.extensions_mut().insert(token.clone());
399
400 if config.requires_protection(&method) {
402 let request_token = get_token_from_request(&req, &config);
404
405 match request_token {
406 Some(submitted_token) if submitted_token == token.token => {
407 let res = service.call(req).await?;
409 Ok(res.map_into_left_body())
410 }
411 Some(_) => {
412 let response = HttpResponse::Forbidden()
414 .body("CSRF token mismatch")
415 .map_into_right_body();
416 Ok(req.into_response(response))
417 }
418 None => {
419 let response = HttpResponse::Forbidden()
421 .body("CSRF token missing")
422 .map_into_right_body();
423 Ok(req.into_response(response))
424 }
425 }
426 } else {
427 let res = service.call(req).await?;
429 Ok(res.map_into_left_body())
430 }
431 })
432 }
433}
434
435fn get_token_from_request(req: &ServiceRequest, config: &CsrfConfig) -> Option<String> {
437 if let Some(header_value) = req.headers().get(&config.header_name) {
439 if let Ok(token) = header_value.to_str() {
440 return Some(token.to_string());
441 }
442 }
443
444 let query_string = req.query_string();
446 let param_prefix = format!("{}=", config.parameter_name);
447 for pair in query_string.split('&') {
448 if pair.starts_with(¶m_prefix) {
449 return Some(pair[param_prefix.len()..].to_string());
450 }
451 }
452
453 None
454}
455
456#[derive(Debug)]
462pub enum CsrfError {
463 MissingToken,
465 InvalidToken,
467 TokenMismatch,
469 StorageError(String),
471}
472
473impl std::fmt::Display for CsrfError {
474 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475 match self {
476 CsrfError::MissingToken => write!(f, "CSRF token missing"),
477 CsrfError::InvalidToken => write!(f, "Invalid CSRF token"),
478 CsrfError::TokenMismatch => write!(f, "CSRF token mismatch"),
479 CsrfError::StorageError(e) => write!(f, "CSRF storage error: {}", e),
480 }
481 }
482}
483
484impl std::error::Error for CsrfError {}
485
486mod hex {
491 const HEX_CHARS: &[u8; 16] = b"0123456789abcdef";
492
493 pub fn encode(bytes: &[u8]) -> String {
494 let mut result = String::with_capacity(bytes.len() * 2);
495 for byte in bytes {
496 result.push(HEX_CHARS[(byte >> 4) as usize] as char);
497 result.push(HEX_CHARS[(byte & 0x0f) as usize] as char);
498 }
499 result
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_csrf_token() {
509 let token = CsrfToken::new("test-token".to_string());
510 assert_eq!(token.value(), "test-token");
511 assert_eq!(token.header_name(), "X-CSRF-TOKEN");
512 assert_eq!(token.parameter_name(), "_csrf");
513 }
514
515 #[test]
516 fn test_csrf_token_custom_names() {
517 let token = CsrfToken::with_names("test-token".to_string(), "X-Custom-CSRF", "csrf_token");
518 assert_eq!(token.header_name(), "X-Custom-CSRF");
519 assert_eq!(token.parameter_name(), "csrf_token");
520 }
521
522 #[test]
523 fn test_csrf_config_default() {
524 let config = CsrfConfig::default();
525 assert_eq!(config.header_name, "X-CSRF-TOKEN");
526 assert_eq!(config.parameter_name, "_csrf");
527 assert!(config.protected_methods.contains(&Method::POST));
528 assert!(config.protected_methods.contains(&Method::PUT));
529 assert!(config.protected_methods.contains(&Method::DELETE));
530 assert!(config.protected_methods.contains(&Method::PATCH));
531 assert!(!config.protected_methods.contains(&Method::GET));
532 }
533
534 #[test]
535 fn test_csrf_config_ignore_path() {
536 let config = CsrfConfig::new()
537 .ignore_path("/api/.*")
538 .ignore_path("/webhook");
539
540 assert!(config.is_path_ignored("/api/users"));
541 assert!(config.is_path_ignored("/api/posts/123"));
542 assert!(config.is_path_ignored("/webhook"));
543 assert!(!config.is_path_ignored("/admin"));
544 }
545
546 #[test]
547 fn test_csrf_config_protected_methods() {
548 let config = CsrfConfig::new().protected_methods(vec![Method::POST]);
549
550 assert!(config.requires_protection(&Method::POST));
551 assert!(!config.requires_protection(&Method::PUT));
552 assert!(!config.requires_protection(&Method::GET));
553 }
554
555 #[test]
556 fn test_session_csrf_repository() {
557 let repo = SessionCsrfTokenRepository::new()
558 .session_key("MY_CSRF")
559 .header_name("X-My-CSRF")
560 .parameter_name("my_csrf");
561
562 let token = repo.generate_token();
563 assert_eq!(token.header_name(), "X-My-CSRF");
564 assert_eq!(token.parameter_name(), "my_csrf");
565 assert_eq!(token.token.len(), 64); }
567
568 #[test]
569 fn test_hex_encode() {
570 assert_eq!(hex::encode(&[0x00]), "00");
571 assert_eq!(hex::encode(&[0xff]), "ff");
572 assert_eq!(hex::encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
573 }
574}