Skip to main content

ironflow_api/
middleware.rs

1//! Middleware for internal route protection and HTTP security hardening.
2
3use axum::Json;
4use axum::extract::Request;
5use axum::http::header::{
6    CONTENT_SECURITY_POLICY, STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS,
7    X_XSS_PROTECTION,
8};
9use axum::http::{HeaderValue, StatusCode};
10use axum::middleware::Next;
11use axum::response::{IntoResponse, Response};
12use subtle::ConstantTimeEq;
13
14/// Axum middleware that validates a static worker token.
15///
16/// Extracts `Authorization: Bearer {token}` and compares against the
17/// expected token. Returns 401 if missing or invalid.
18pub async fn worker_token_auth(req: Request, next: Next) -> Response {
19    let expected = req.extensions().get::<WorkerToken>().map(|t| t.0.clone());
20
21    let provided = req
22        .headers()
23        .get("authorization")
24        .and_then(|v| v.to_str().ok())
25        .and_then(|v| v.strip_prefix("Bearer "))
26        .map(|t| t.to_string());
27
28    match (expected, provided) {
29        (Some(expected), Some(provided))
30            if expected.as_bytes().ct_eq(provided.as_bytes()).into() =>
31        {
32            next.run(req).await
33        }
34        _ => (
35            StatusCode::UNAUTHORIZED,
36            Json(serde_json::json!({
37                "error": {
38                    "code": "INVALID_WORKER_TOKEN",
39                    "message": "Invalid or missing worker token",
40                }
41            })),
42        )
43            .into_response(),
44    }
45}
46
47/// Newtype wrapper for the static worker token, stored in request extensions.
48#[derive(Clone)]
49pub struct WorkerToken(pub String);
50
51/// Middleware that injects standard HTTP security headers on every response.
52///
53/// Headers set:
54/// - `X-Content-Type-Options: nosniff` — prevents MIME-type sniffing
55/// - `X-Frame-Options: DENY` — blocks clickjacking via iframes
56/// - `X-XSS-Protection: 1; mode=block` — legacy XSS filter hint
57/// - `Strict-Transport-Security: max-age=63072000; includeSubDomains` — enforces HTTPS for 2 years
58/// - `Content-Security-Policy: default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'`
59pub async fn security_headers(req: Request, next: Next) -> Response {
60    let mut resp = next.run(req).await;
61    let headers = resp.headers_mut();
62
63    headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
64    headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
65    headers.insert(X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block"));
66    headers.insert(
67        STRICT_TRANSPORT_SECURITY,
68        HeaderValue::from_static("max-age=63072000; includeSubDomains"),
69    );
70    headers.insert(
71        CONTENT_SECURITY_POLICY,
72        HeaderValue::from_static(
73            "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'",
74        ),
75    );
76
77    resp
78}
79
80#[cfg(test)]
81mod tests {
82
83    use axum::body::Body;
84    use axum::http::{Request, StatusCode};
85    use http_body_util::BodyExt;
86    use ironflow_core::providers::claude::ClaudeCodeProvider;
87    use ironflow_engine::engine::Engine;
88    use ironflow_store::memory::InMemoryStore;
89    use std::sync::Arc;
90    use tower::ServiceExt;
91
92    use crate::routes::create_router;
93    use crate::state::AppState;
94
95    fn test_state() -> AppState {
96        let store = Arc::new(InMemoryStore::new());
97        let user_store = Arc::new(InMemoryStore::new());
98        let provider = Arc::new(ClaudeCodeProvider::new());
99        let engine = Arc::new(Engine::new(store.clone(), provider));
100        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
101            secret: "test-secret".to_string(),
102            access_token_ttl_secs: 900,
103            refresh_token_ttl_secs: 604800,
104            cookie_domain: None,
105            cookie_secure: false,
106        });
107        AppState {
108            store,
109            user_store,
110            engine,
111            jwt_config,
112            worker_token: "test-worker-token".to_string(),
113        }
114    }
115
116    #[tokio::test]
117    async fn worker_token_valid() {
118        let state = test_state();
119        let app = create_router(state.clone(), None);
120
121        let req = Request::builder()
122            .uri("/api/v1/internal/runs/next")
123            .header("authorization", "Bearer test-worker-token")
124            .body(Body::empty())
125            .unwrap();
126
127        let resp = app.oneshot(req).await.unwrap();
128        assert_eq!(resp.status(), StatusCode::OK);
129    }
130
131    #[tokio::test]
132    async fn worker_token_missing() {
133        let state = test_state();
134        let app = create_router(state, None);
135
136        let req = Request::builder()
137            .uri("/api/v1/internal/runs/next")
138            .body(Body::empty())
139            .unwrap();
140
141        let resp = app.oneshot(req).await.unwrap();
142        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
143
144        let body = resp.into_body().collect().await.unwrap().to_bytes();
145        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
146        assert_eq!(json["error"]["code"], "INVALID_WORKER_TOKEN");
147    }
148
149    #[tokio::test]
150    async fn worker_token_invalid() {
151        let state = test_state();
152        let app = create_router(state, None);
153
154        let req = Request::builder()
155            .uri("/api/v1/internal/runs/next")
156            .header("authorization", "Bearer wrong-token")
157            .body(Body::empty())
158            .unwrap();
159
160        let resp = app.oneshot(req).await.unwrap();
161        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
162
163        let body = resp.into_body().collect().await.unwrap().to_bytes();
164        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
165        assert_eq!(json["error"]["code"], "INVALID_WORKER_TOKEN");
166    }
167
168    #[tokio::test]
169    async fn security_headers_present() {
170        let state = test_state();
171        let app = create_router(state, None);
172
173        let req = Request::builder()
174            .uri("/api/v1/health-check")
175            .body(Body::empty())
176            .unwrap();
177
178        let resp = app.oneshot(req).await.unwrap();
179
180        assert_eq!(
181            resp.headers().get("x-content-type-options").unwrap(),
182            "nosniff"
183        );
184        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
185        assert_eq!(
186            resp.headers().get("x-xss-protection").unwrap(),
187            "1; mode=block"
188        );
189        assert_eq!(
190            resp.headers().get("strict-transport-security").unwrap(),
191            "max-age=63072000; includeSubDomains"
192        );
193        assert!(
194            resp.headers()
195                .get("content-security-policy")
196                .unwrap()
197                .to_str()
198                .unwrap()
199                .contains("default-src 'self'")
200        );
201    }
202}