use crate::agents::{CsrfToken, ValidateToken};
use crate::auth::session::SessionId;
use crate::state::ActonHtmxState;
use acton_reactive::prelude::{AgentHandle, AgentHandleInterface};
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
response::{IntoResponse, Response},
};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::{Layer, Service};
pub const CSRF_HEADER_NAME: &str = "x-csrf-token";
pub const CSRF_FORM_FIELD: &str = "_csrf_token";
#[derive(Clone, Debug)]
pub struct CsrfConfig {
pub header_name: String,
pub form_field: String,
pub agent_timeout_ms: u64,
pub skip_paths: Vec<String>,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
header_name: CSRF_HEADER_NAME.to_string(),
form_field: CSRF_FORM_FIELD.to_string(),
agent_timeout_ms: 100,
skip_paths: vec![],
}
}
}
impl CsrfConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn skip_path(mut self, path: impl Into<String>) -> Self {
self.skip_paths.push(path.into());
self
}
#[must_use]
pub fn skip_paths(mut self, paths: Vec<String>) -> Self {
self.skip_paths.extend(paths);
self
}
}
#[derive(Clone)]
pub struct CsrfLayer {
config: CsrfConfig,
csrf_manager: AgentHandle,
}
impl std::fmt::Debug for CsrfLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CsrfLayer")
.field("config", &self.config)
.field("csrf_manager", &"AgentHandle")
.finish()
}
}
impl CsrfLayer {
#[must_use]
pub fn new(state: &ActonHtmxState) -> Self {
Self {
config: CsrfConfig::default(),
csrf_manager: state.csrf_manager().clone(),
}
}
#[must_use]
pub fn with_config(state: &ActonHtmxState, config: CsrfConfig) -> Self {
Self {
config,
csrf_manager: state.csrf_manager().clone(),
}
}
#[must_use]
pub fn from_handle(csrf_manager: AgentHandle) -> Self {
Self {
config: CsrfConfig::default(),
csrf_manager,
}
}
#[must_use]
pub const fn from_handle_with_config(csrf_manager: AgentHandle, config: CsrfConfig) -> Self {
Self {
config,
csrf_manager,
}
}
}
impl<S> Layer<S> for CsrfLayer {
type Service = CsrfMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
CsrfMiddleware {
inner,
config: Arc::new(self.config.clone()),
csrf_manager: self.csrf_manager.clone(),
}
}
}
#[derive(Clone)]
pub struct CsrfMiddleware<S> {
inner: S,
config: Arc<CsrfConfig>,
csrf_manager: AgentHandle,
}
impl<S: std::fmt::Debug> std::fmt::Debug for CsrfMiddleware<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CsrfMiddleware")
.field("inner", &self.inner)
.field("config", &self.config)
.field("csrf_manager", &"AgentHandle")
.finish()
}
}
impl<S> Service<Request> for CsrfMiddleware<S>
where
S: Service<Request, Response = Response<Body>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let config = self.config.clone();
let csrf_manager = self.csrf_manager.clone();
let mut inner = self.inner.clone();
let timeout = Duration::from_millis(config.agent_timeout_ms);
if is_method_safe(req.method()) {
return Box::pin(inner.call(req));
}
let path = req.uri().path().to_string();
if config.skip_paths.iter().any(|skip| skip == &path) {
return Box::pin(inner.call(req));
}
let Some(session_id) = req.extensions().get::<SessionId>().cloned() else {
tracing::warn!("CSRF middleware requires SessionMiddleware to be applied first");
return Box::pin(async move {
Ok(csrf_validation_error(
"Session not found - ensure SessionMiddleware is applied",
))
});
};
let Some(token) = extract_csrf_token(&req, &config) else {
let method = req.method().clone();
tracing::warn!("CSRF token missing for {} {}", method, path);
return Box::pin(async move { Ok(csrf_validation_error("CSRF token missing")) });
};
Box::pin(async move {
let (validate_request, rx) = ValidateToken::new(session_id, token);
csrf_manager.send(validate_request).await;
let is_valid = match tokio::time::timeout(timeout, rx).await {
Ok(Ok(valid)) => valid,
Ok(Err(_)) => {
tracing::error!("CSRF validation channel error");
false
}
Err(_) => {
tracing::error!("CSRF validation timeout");
false
}
};
if !is_valid {
tracing::warn!("CSRF token validation failed");
return Ok(csrf_validation_error("CSRF token validation failed"));
}
inner.call(req).await
})
}
}
const fn is_method_safe(method: &Method) -> bool {
matches!(
*method,
Method::GET | Method::HEAD | Method::OPTIONS | Method::TRACE
)
}
fn extract_csrf_token(req: &Request, config: &CsrfConfig) -> Option<CsrfToken> {
if let Some(token_value) = req.headers().get(&config.header_name) {
if let Ok(token_str) = token_value.to_str() {
return Some(CsrfToken::from_string(token_str.to_string()));
}
}
None
}
fn csrf_validation_error(message: &str) -> Response<Body> {
let body = if cfg!(debug_assertions) {
format!("CSRF validation failed: {message}")
} else {
"Forbidden".to_string()
};
(StatusCode::FORBIDDEN, body).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csrf_config_default() {
let config = CsrfConfig::default();
assert_eq!(config.header_name, CSRF_HEADER_NAME);
assert_eq!(config.form_field, CSRF_FORM_FIELD);
assert_eq!(config.agent_timeout_ms, 100);
assert!(config.skip_paths.is_empty());
}
#[test]
fn test_csrf_config_skip_path() {
let config = CsrfConfig::new().skip_path("/webhooks/github");
assert_eq!(config.skip_paths.len(), 1);
assert_eq!(config.skip_paths[0], "/webhooks/github");
}
#[test]
fn test_csrf_config_skip_paths() {
let config = CsrfConfig::new().skip_paths(vec![
"/health".to_string(),
"/metrics".to_string(),
]);
assert_eq!(config.skip_paths.len(), 2);
assert!(config.skip_paths.contains(&"/health".to_string()));
assert!(config.skip_paths.contains(&"/metrics".to_string()));
}
#[test]
fn test_is_method_safe() {
assert!(is_method_safe(&Method::GET));
assert!(is_method_safe(&Method::HEAD));
assert!(is_method_safe(&Method::OPTIONS));
assert!(is_method_safe(&Method::TRACE));
assert!(!is_method_safe(&Method::POST));
assert!(!is_method_safe(&Method::PUT));
assert!(!is_method_safe(&Method::DELETE));
assert!(!is_method_safe(&Method::PATCH));
}
}