Skip to main content

ironflow_api/
middleware.rs

1//! Middleware for internal route protection and HTTP security hardening.
2
3use axum::Json;
4use axum::extract::Request;
5use axum::http::header::{
6    CONTENT_SECURITY_POLICY, STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS,
7    X_XSS_PROTECTION,
8};
9use axum::http::{HeaderValue, StatusCode};
10use axum::middleware::Next;
11use axum::response::{IntoResponse, Response};
12use serde_json::json;
13use subtle::ConstantTimeEq;
14
15/// Axum middleware that validates a static worker token.
16///
17/// Extracts `Authorization: Bearer {token}` and compares against the
18/// expected token. Returns 401 if missing or invalid.
19pub async fn worker_token_auth(req: Request, next: Next) -> Response {
20    let expected = req.extensions().get::<WorkerToken>().map(|t| t.0.clone());
21
22    let provided = req
23        .headers()
24        .get("authorization")
25        .and_then(|v| v.to_str().ok())
26        .and_then(|v| v.strip_prefix("Bearer "))
27        .map(|t| t.to_string());
28
29    match (expected, provided) {
30        (Some(expected), Some(provided))
31            if expected.as_bytes().ct_eq(provided.as_bytes()).into() =>
32        {
33            next.run(req).await
34        }
35        _ => (
36            StatusCode::UNAUTHORIZED,
37            Json(json!({
38                "error": {
39                    "code": "INVALID_WORKER_TOKEN",
40                    "message": "Invalid or missing worker token",
41                }
42            })),
43        )
44            .into_response(),
45    }
46}
47
48/// Newtype wrapper for the static worker token, stored in request extensions.
49#[derive(Clone)]
50pub struct WorkerToken(pub String);
51
52/// Middleware that injects standard HTTP security headers on every response.
53///
54/// Headers set:
55/// - `X-Content-Type-Options: nosniff` — prevents MIME-type sniffing
56/// - `X-Frame-Options: DENY` — blocks clickjacking via iframes
57/// - `X-XSS-Protection: 1; mode=block` — legacy XSS filter hint
58/// - `Strict-Transport-Security: max-age=63072000; includeSubDomains` — enforces HTTPS for 2 years
59/// - `Content-Security-Policy: default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'`
60pub async fn security_headers(req: Request, next: Next) -> Response {
61    let mut resp = next.run(req).await;
62    let headers = resp.headers_mut();
63
64    headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
65    headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
66    headers.insert(X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block"));
67    headers.insert(
68        STRICT_TRANSPORT_SECURITY,
69        HeaderValue::from_static("max-age=63072000; includeSubDomains"),
70    );
71    headers.insert(
72        CONTENT_SECURITY_POLICY,
73        HeaderValue::from_static(
74            "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'",
75        ),
76    );
77
78    resp
79}
80
81#[cfg(test)]
82mod tests {
83
84    use axum::body::Body;
85    use axum::http::{Request, StatusCode};
86    use http_body_util::BodyExt;
87    use ironflow_core::providers::claude::ClaudeCodeProvider;
88    use ironflow_engine::engine::Engine;
89    use ironflow_store::memory::InMemoryStore;
90    use serde_json::Value as JsonValue;
91    use std::sync::Arc;
92    use tower::ServiceExt;
93
94    use crate::routes::create_router;
95    use crate::state::AppState;
96
97    fn test_state() -> AppState {
98        let store = Arc::new(InMemoryStore::new());
99        let user_store = Arc::new(InMemoryStore::new());
100        let provider = Arc::new(ClaudeCodeProvider::new());
101        let engine = Arc::new(Engine::new(store.clone(), provider));
102        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
103            secret: "test-secret".to_string(),
104            access_token_ttl_secs: 900,
105            refresh_token_ttl_secs: 604800,
106            cookie_domain: None,
107            cookie_secure: false,
108        });
109        AppState {
110            store,
111            user_store,
112            engine,
113            jwt_config,
114            worker_token: "test-worker-token".to_string(),
115        }
116    }
117
118    #[tokio::test]
119    async fn worker_token_valid() {
120        let state = test_state();
121        let app = create_router(state.clone(), None);
122
123        let req = Request::builder()
124            .uri("/api/v1/internal/runs/next")
125            .header("authorization", "Bearer test-worker-token")
126            .body(Body::empty())
127            .unwrap();
128
129        let resp = app.oneshot(req).await.unwrap();
130        assert_eq!(resp.status(), StatusCode::OK);
131    }
132
133    #[tokio::test]
134    async fn worker_token_missing() {
135        let state = test_state();
136        let app = create_router(state, None);
137
138        let req = Request::builder()
139            .uri("/api/v1/internal/runs/next")
140            .body(Body::empty())
141            .unwrap();
142
143        let resp = app.oneshot(req).await.unwrap();
144        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
145
146        let body = resp.into_body().collect().await.unwrap().to_bytes();
147        let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
148        assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
149    }
150
151    #[tokio::test]
152    async fn worker_token_invalid() {
153        let state = test_state();
154        let app = create_router(state, None);
155
156        let req = Request::builder()
157            .uri("/api/v1/internal/runs/next")
158            .header("authorization", "Bearer wrong-token")
159            .body(Body::empty())
160            .unwrap();
161
162        let resp = app.oneshot(req).await.unwrap();
163        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
164
165        let body = resp.into_body().collect().await.unwrap().to_bytes();
166        let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
167        assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
168    }
169
170    #[tokio::test]
171    async fn security_headers_present() {
172        let state = test_state();
173        let app = create_router(state, None);
174
175        let req = Request::builder()
176            .uri("/api/v1/health-check")
177            .body(Body::empty())
178            .unwrap();
179
180        let resp = app.oneshot(req).await.unwrap();
181
182        assert_eq!(
183            resp.headers().get("x-content-type-options").unwrap(),
184            "nosniff"
185        );
186        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
187        assert_eq!(
188            resp.headers().get("x-xss-protection").unwrap(),
189            "1; mode=block"
190        );
191        assert_eq!(
192            resp.headers().get("strict-transport-security").unwrap(),
193            "max-age=63072000; includeSubDomains"
194        );
195        assert!(
196            resp.headers()
197                .get("content-security-policy")
198                .unwrap()
199                .to_str()
200                .unwrap()
201                .contains("default-src 'self'")
202        );
203    }
204}