oxidite_middleware/
csrf.rs1use oxidite_core::{OxiditeRequest, OxiditeResponse, Error as CoreError};
2use tower::{Service, Layer};
3use std::task::{Context, Poll};
4use std::future::Future;
5use std::pin::Pin;
6
7
8use base64::{Engine as _, engine::general_purpose};
9use rand::Rng;
10
11const CSRF_TOKEN_HEADER: &str = "x-csrf-token";
12const CSRF_COOKIE_NAME: &str = "csrf_token";
13
14#[derive(Clone)]
16pub struct CsrfMiddleware<S> {
17 inner: S,
18 config: CsrfConfig,
19}
20
21#[derive(Clone, Debug)]
22pub struct CsrfConfig {
23 pub token_length: usize,
24 pub exempt_paths: Vec<String>,
25}
26
27impl Default for CsrfConfig {
28 fn default() -> Self {
29 Self {
30 token_length: 32,
31 exempt_paths: vec![],
32 }
33 }
34}
35
36impl<S> CsrfMiddleware<S> {
37 pub fn new(inner: S, config: CsrfConfig) -> Self {
38 Self { inner, config }
39 }
40
41 fn is_exempt(&self, path: &str) -> bool {
42 self.config.exempt_paths.iter().any(|exempt| path.starts_with(exempt))
43 }
44
45 fn generate_token() -> String {
46 let random_bytes: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect();
47 general_purpose::STANDARD.encode(random_bytes)
48 }
49
50 fn verify_token(token: &str, cookie_token: &str) -> bool {
51 token == cookie_token
52 }
53}
54
55impl<S> Service<OxiditeRequest> for CsrfMiddleware<S>
56where
57 S: Service<OxiditeRequest, Response = OxiditeResponse, Error = CoreError> + Clone + Send + 'static,
58 S::Future: Send + 'static,
59{
60 type Response = S::Response;
61 type Error = S::Error;
62 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
63
64 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65 self.inner.poll_ready(cx)
66 }
67
68 fn call(&mut self, req: OxiditeRequest) -> Self::Future {
69 let path = req.uri().path().to_string();
70 let method = req.method().clone();
71
72 let is_exempt = self.is_exempt(&path);
74
75 let header_token = req
77 .headers()
78 .get(CSRF_TOKEN_HEADER)
79 .and_then(|h| h.to_str().ok())
80 .map(|s| s.to_string());
81
82 let cookie_token = req
84 .headers()
85 .get("cookie")
86 .and_then(|h| h.to_str().ok())
87 .and_then(|cookies| {
88 cookies.split(';')
89 .find(|c| c.trim().starts_with(CSRF_COOKIE_NAME))
90 .and_then(|c| c.split('=').nth(1))
91 .map(|s| s.trim().to_string())
92 });
93
94 let mut inner = self.inner.clone();
95
96 Box::pin(async move {
97 if !is_exempt && (method == "POST" || method == "PUT" || method == "DELETE" || method == "PATCH") {
99 match (header_token, cookie_token.clone()) {
100 (Some(h_token), Some(c_token)) => {
101 if !CsrfMiddleware::<S>::verify_token(&h_token, &c_token) {
102 return Err(CoreError::BadRequest("Invalid CSRF token".to_string()));
103 }
104 }
105 _ => {
106 return Err(CoreError::BadRequest("Missing CSRF token".to_string()));
107 }
108 }
109 }
110
111 let mut response = inner.call(req).await?;
112
113 if cookie_token.is_none() {
115 let new_token = CsrfMiddleware::<S>::generate_token();
116 let cookie_value = format!("{}={}; HttpOnly; SameSite=Strict; Path=/", CSRF_COOKIE_NAME, new_token);
117 if let Ok(value) = cookie_value.parse() {
118 response.headers_mut().insert("set-cookie", value);
119 }
120 }
121
122 Ok(response)
123 })
124 }
125}
126
127#[derive(Clone)]
129pub struct CsrfLayer {
130 config: CsrfConfig,
131}
132
133impl CsrfLayer {
134 pub fn new(config: CsrfConfig) -> Self {
135 Self { config }
136 }
137
138 pub fn with_defaults() -> Self {
139 Self {
140 config: CsrfConfig::default(),
141 }
142 }
143}
144
145impl<S> Layer<S> for CsrfLayer {
146 type Service = CsrfMiddleware<S>;
147
148 fn layer(&self, inner: S) -> Self::Service {
149 CsrfMiddleware::new(inner, self.config.clone())
150 }
151}