Skip to main content

oxidite_middleware/
csrf.rs

1use oxidite_core::{OxiditeRequest, OxiditeResponse, Error as CoreError};
2use tower::{Service, Layer};
3use std::task::{Context, Poll};
4use std::future::Future;
5use std::pin::Pin;
6
7
8use base64::{Engine as _, engine::general_purpose};
9use rand::Rng;
10
11const CSRF_TOKEN_HEADER: &str = "x-csrf-token";
12const CSRF_COOKIE_NAME: &str = "csrf_token";
13
14/// CSRF protection middleware
15#[derive(Clone)]
16pub struct CsrfMiddleware<S> {
17    inner: S,
18    config: CsrfConfig,
19}
20
21#[derive(Clone, Debug)]
22pub struct CsrfConfig {
23    pub token_length: usize,
24    pub exempt_paths: Vec<String>,
25}
26
27impl Default for CsrfConfig {
28    fn default() -> Self {
29        Self {
30            token_length: 32,
31            exempt_paths: vec![],
32        }
33    }
34}
35
36impl<S> CsrfMiddleware<S> {
37    pub fn new(inner: S, config: CsrfConfig) -> Self {
38        Self { inner, config }
39    }
40
41    fn is_exempt(&self, path: &str) -> bool {
42        self.config.exempt_paths.iter().any(|exempt| path.starts_with(exempt))
43    }
44
45    fn generate_token() -> String {
46        let random_bytes: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect();
47        general_purpose::STANDARD.encode(random_bytes)
48    }
49
50    fn verify_token(token: &str, cookie_token: &str) -> bool {
51        token == cookie_token
52    }
53}
54
55impl<S> Service<OxiditeRequest> for CsrfMiddleware<S>
56where
57    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = CoreError> + Clone + Send + 'static,
58    S::Future: Send + 'static,
59{
60    type Response = S::Response;
61    type Error = S::Error;
62    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
63
64    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65        self.inner.poll_ready(cx)
66    }
67
68    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
69        let path = req.uri().path().to_string();
70        let method = req.method().clone();
71        
72        // Check if path is exempt
73        let is_exempt =  self.is_exempt(&path);
74        
75        // Extract CSRF token from header
76        let header_token = req
77            .headers()
78            .get(CSRF_TOKEN_HEADER)
79            .and_then(|h| h.to_str().ok())
80            .map(|s| s.to_string());
81
82        // Extract CSRF token from cookie (simplified - in production use proper cookie parsing)
83        let cookie_token = req
84            .headers()
85            .get("cookie")
86            .and_then(|h| h.to_str().ok())
87            .and_then(|cookies| {
88                cookies.split(';')
89                    .find(|c| c.trim().starts_with(CSRF_COOKIE_NAME))
90                    .and_then(|c| c.split('=').nth(1))
91                    .map(|s| s.trim().to_string())
92            });
93
94        let mut inner = self.inner.clone();
95
96        Box::pin(async move {
97            // Validate CSRF for state-changing methods
98            if !is_exempt && (method == "POST" || method == "PUT" || method == "DELETE" || method == "PATCH") {
99                match (header_token, cookie_token.clone()) {
100                    (Some(h_token), Some(c_token)) => {
101                        if !CsrfMiddleware::<S>::verify_token(&h_token, &c_token) {
102                            return Err(CoreError::BadRequest("Invalid CSRF token".to_string()));
103                        }
104                    }
105                    _ => {
106                        return Err(CoreError::BadRequest("Missing CSRF token".to_string()));
107                    }
108                }
109            }
110
111            let mut response = inner.call(req).await?;
112
113            // Set CSRF token cookie if not present
114            if cookie_token.is_none() {
115                let new_token = CsrfMiddleware::<S>::generate_token();
116                let cookie_value = format!("{}={}; HttpOnly; SameSite=Strict; Path=/", CSRF_COOKIE_NAME, new_token);
117                if let Ok(value) = cookie_value.parse() {
118                    response.headers_mut().insert("set-cookie", value);
119                }
120            }
121
122            Ok(response)
123        })
124    }
125}
126
127/// Layer for CSRF middleware
128#[derive(Clone)]
129pub struct CsrfLayer {
130    config: CsrfConfig,
131}
132
133impl CsrfLayer {
134    pub fn new(config: CsrfConfig) -> Self {
135        Self { config }
136    }
137
138    pub fn with_defaults() -> Self {
139        Self {
140            config: CsrfConfig::default(),
141        }
142    }
143}
144
145impl<S> Layer<S> for CsrfLayer {
146    type Service = CsrfMiddleware<S>;
147
148    fn layer(&self, inner: S) -> Self::Service {
149        CsrfMiddleware::new(inner, self.config.clone())
150    }
151}