acton_htmx/middleware/
csrf.rs1use crate::agents::{CsrfToken, ValidateToken};
16use crate::auth::session::SessionId;
17use crate::state::ActonHtmxState;
18use acton_reactive::prelude::{AgentHandle, AgentHandleInterface};
19use axum::{
20 body::Body,
21 extract::Request,
22 http::{Method, StatusCode},
23 response::{IntoResponse, Response},
24};
25use std::sync::Arc;
26use std::task::{Context, Poll};
27use std::time::Duration;
28use tower::{Layer, Service};
29
30pub const CSRF_HEADER_NAME: &str = "x-csrf-token";
32
33pub const CSRF_FORM_FIELD: &str = "_csrf_token";
35
36#[derive(Clone, Debug)]
38pub struct CsrfConfig {
39 pub header_name: String,
41 pub form_field: String,
43 pub agent_timeout_ms: u64,
45 pub skip_paths: Vec<String>,
47}
48
49impl Default for CsrfConfig {
50 fn default() -> Self {
51 Self {
52 header_name: CSRF_HEADER_NAME.to_string(),
53 form_field: CSRF_FORM_FIELD.to_string(),
54 agent_timeout_ms: 100,
55 skip_paths: vec![],
56 }
57 }
58}
59
60impl CsrfConfig {
61 #[must_use]
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
69 pub fn skip_path(mut self, path: impl Into<String>) -> Self {
70 self.skip_paths.push(path.into());
71 self
72 }
73
74 #[must_use]
76 pub fn skip_paths(mut self, paths: Vec<String>) -> Self {
77 self.skip_paths.extend(paths);
78 self
79 }
80}
81
82#[derive(Clone)]
86pub struct CsrfLayer {
87 config: CsrfConfig,
88 csrf_manager: AgentHandle,
89}
90
91impl std::fmt::Debug for CsrfLayer {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("CsrfLayer")
94 .field("config", &self.config)
95 .field("csrf_manager", &"AgentHandle")
96 .finish()
97 }
98}
99
100impl CsrfLayer {
101 #[must_use]
103 pub fn new(state: &ActonHtmxState) -> Self {
104 Self {
105 config: CsrfConfig::default(),
106 csrf_manager: state.csrf_manager().clone(),
107 }
108 }
109
110 #[must_use]
112 pub fn with_config(state: &ActonHtmxState, config: CsrfConfig) -> Self {
113 Self {
114 config,
115 csrf_manager: state.csrf_manager().clone(),
116 }
117 }
118
119 #[must_use]
121 pub fn from_handle(csrf_manager: AgentHandle) -> Self {
122 Self {
123 config: CsrfConfig::default(),
124 csrf_manager,
125 }
126 }
127
128 #[must_use]
130 pub const fn from_handle_with_config(csrf_manager: AgentHandle, config: CsrfConfig) -> Self {
131 Self {
132 config,
133 csrf_manager,
134 }
135 }
136}
137
138impl<S> Layer<S> for CsrfLayer {
139 type Service = CsrfMiddleware<S>;
140
141 fn layer(&self, inner: S) -> Self::Service {
142 CsrfMiddleware {
143 inner,
144 config: Arc::new(self.config.clone()),
145 csrf_manager: self.csrf_manager.clone(),
146 }
147 }
148}
149
150#[derive(Clone)]
155pub struct CsrfMiddleware<S> {
156 inner: S,
157 config: Arc<CsrfConfig>,
158 csrf_manager: AgentHandle,
159}
160
161impl<S: std::fmt::Debug> std::fmt::Debug for CsrfMiddleware<S> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("CsrfMiddleware")
164 .field("inner", &self.inner)
165 .field("config", &self.config)
166 .field("csrf_manager", &"AgentHandle")
167 .finish()
168 }
169}
170
171impl<S> Service<Request> for CsrfMiddleware<S>
172where
173 S: Service<Request, Response = Response<Body>> + Clone + Send + 'static,
174 S::Future: Send + 'static,
175{
176 type Response = Response<Body>;
177 type Error = S::Error;
178 type Future = std::pin::Pin<
179 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
180 >;
181
182 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183 self.inner.poll_ready(cx)
184 }
185
186 fn call(&mut self, req: Request) -> Self::Future {
187 let config = self.config.clone();
188 let csrf_manager = self.csrf_manager.clone();
189 let mut inner = self.inner.clone();
190 let timeout = Duration::from_millis(config.agent_timeout_ms);
191
192 if is_method_safe(req.method()) {
194 return Box::pin(inner.call(req));
195 }
196
197 let path = req.uri().path().to_string();
199 if config.skip_paths.iter().any(|skip| skip == &path) {
200 return Box::pin(inner.call(req));
201 }
202
203 let Some(session_id) = req.extensions().get::<SessionId>().cloned() else {
205 tracing::warn!("CSRF middleware requires SessionMiddleware to be applied first");
206 return Box::pin(async move {
207 Ok(csrf_validation_error(
208 "Session not found - ensure SessionMiddleware is applied",
209 ))
210 });
211 };
212
213 let Some(token) = extract_csrf_token(&req, &config) else {
215 let method = req.method().clone();
216 tracing::warn!("CSRF token missing for {} {}", method, path);
217 return Box::pin(async move { Ok(csrf_validation_error("CSRF token missing")) });
218 };
219
220 Box::pin(async move {
221 let (validate_request, rx) = ValidateToken::new(session_id, token);
223 csrf_manager.send(validate_request).await;
224
225 let is_valid = match tokio::time::timeout(timeout, rx).await {
226 Ok(Ok(valid)) => valid,
227 Ok(Err(_)) => {
228 tracing::error!("CSRF validation channel error");
229 false
230 }
231 Err(_) => {
232 tracing::error!("CSRF validation timeout");
233 false
234 }
235 };
236
237 if !is_valid {
238 tracing::warn!("CSRF token validation failed");
239 return Ok(csrf_validation_error("CSRF token validation failed"));
240 }
241
242 inner.call(req).await
244 })
245 }
246}
247
248const fn is_method_safe(method: &Method) -> bool {
250 matches!(
251 *method,
252 Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
253 )
254}
255
256fn extract_csrf_token(req: &Request, config: &CsrfConfig) -> Option<CsrfToken> {
258 if let Some(token_value) = req.headers().get(&config.header_name) {
260 if let Ok(token_str) = token_value.to_str() {
261 return Some(CsrfToken::from_string(token_str.to_string()));
262 }
263 }
264
265 None
271}
272
273fn csrf_validation_error(message: &str) -> Response<Body> {
275 let body = if cfg!(debug_assertions) {
276 format!("CSRF validation failed: {message}")
278 } else {
279 "Forbidden".to_string()
281 };
282
283 (StatusCode::FORBIDDEN, body).into_response()
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_csrf_config_default() {
292 let config = CsrfConfig::default();
293 assert_eq!(config.header_name, CSRF_HEADER_NAME);
294 assert_eq!(config.form_field, CSRF_FORM_FIELD);
295 assert_eq!(config.agent_timeout_ms, 100);
296 assert!(config.skip_paths.is_empty());
297 }
298
299 #[test]
300 fn test_csrf_config_skip_path() {
301 let config = CsrfConfig::new().skip_path("/webhooks/github");
302 assert_eq!(config.skip_paths.len(), 1);
303 assert_eq!(config.skip_paths[0], "/webhooks/github");
304 }
305
306 #[test]
307 fn test_csrf_config_skip_paths() {
308 let config = CsrfConfig::new().skip_paths(vec![
309 "/health".to_string(),
310 "/metrics".to_string(),
311 ]);
312 assert_eq!(config.skip_paths.len(), 2);
313 assert!(config.skip_paths.contains(&"/health".to_string()));
314 assert!(config.skip_paths.contains(&"/metrics".to_string()));
315 }
316
317 #[test]
318 fn test_is_method_safe() {
319 assert!(is_method_safe(&Method::GET));
320 assert!(is_method_safe(&Method::HEAD));
321 assert!(is_method_safe(&Method::OPTIONS));
322 assert!(is_method_safe(&Method::TRACE));
323
324 assert!(!is_method_safe(&Method::POST));
325 assert!(!is_method_safe(&Method::PUT));
326 assert!(!is_method_safe(&Method::DELETE));
327 assert!(!is_method_safe(&Method::PATCH));
328 }
329}