Skip to main content

ironflow_api/
middleware.rs

1//! Middleware for internal route protection and HTTP security hardening.
2
3use axum::Json;
4use axum::extract::Request;
5use axum::http::header::{
6    CONTENT_SECURITY_POLICY, STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS,
7    X_XSS_PROTECTION,
8};
9use axum::http::{HeaderValue, StatusCode};
10use axum::middleware::Next;
11use axum::response::{IntoResponse, Response};
12use serde_json::json;
13use subtle::ConstantTimeEq;
14
15/// Axum middleware that validates a static worker token.
16///
17/// Extracts `Authorization: Bearer {token}` and compares against the
18/// expected token. Returns 401 if missing or invalid.
19pub async fn worker_token_auth(req: Request, next: Next) -> Response {
20    let expected = req.extensions().get::<WorkerToken>().map(|t| t.0.clone());
21
22    let provided = req
23        .headers()
24        .get("authorization")
25        .and_then(|v| v.to_str().ok())
26        .and_then(|v| v.strip_prefix("Bearer "))
27        .map(|t| t.to_string());
28
29    match (expected, provided) {
30        (Some(expected), Some(provided))
31            if expected.as_bytes().ct_eq(provided.as_bytes()).into() =>
32        {
33            next.run(req).await
34        }
35        _ => (
36            StatusCode::UNAUTHORIZED,
37            Json(json!({
38                "error": {
39                    "code": "INVALID_WORKER_TOKEN",
40                    "message": "Invalid or missing worker token",
41                }
42            })),
43        )
44            .into_response(),
45    }
46}
47
48/// Newtype wrapper for the static worker token, stored in request extensions.
49#[derive(Clone)]
50pub struct WorkerToken(pub String);
51
52/// Middleware that records API request metrics (counter + duration histogram).
53///
54/// Emits `ironflow_api_requests_total` and `ironflow_api_request_duration_seconds`
55/// for every request. Only compiled when the `prometheus` feature is enabled.
56#[cfg(feature = "prometheus")]
57pub async fn request_metrics(req: Request, next: Next) -> Response {
58    use std::time::Instant;
59
60    use ironflow_core::metric_names::{API_REQUEST_DURATION_SECONDS, API_REQUESTS_TOTAL};
61    use metrics::{counter, histogram};
62
63    let method = req.method().to_string();
64    let path = req.uri().path().to_string();
65    let start = Instant::now();
66
67    let resp = next.run(req).await;
68
69    let status = resp.status().as_u16().to_string();
70    let duration = start.elapsed().as_secs_f64();
71
72    counter!(API_REQUESTS_TOTAL, "method" => method.clone(), "path" => path.clone(), "status" => status).increment(1);
73    histogram!(API_REQUEST_DURATION_SECONDS, "method" => method, "path" => path).record(duration);
74
75    resp
76}
77
78/// Middleware that injects standard HTTP security headers on every response.
79///
80/// Headers set:
81/// - `X-Content-Type-Options: nosniff` — prevents MIME-type sniffing
82/// - `X-Frame-Options: DENY` — blocks clickjacking via iframes
83/// - `X-XSS-Protection: 1; mode=block` — legacy XSS filter hint
84/// - `Strict-Transport-Security: max-age=63072000; includeSubDomains` — enforces HTTPS for 2 years
85/// - `Content-Security-Policy: default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'`
86pub async fn security_headers(req: Request, next: Next) -> Response {
87    let mut resp = next.run(req).await;
88    let headers = resp.headers_mut();
89
90    headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
91    headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
92    headers.insert(X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block"));
93    headers.insert(
94        STRICT_TRANSPORT_SECURITY,
95        HeaderValue::from_static("max-age=63072000; includeSubDomains"),
96    );
97    headers.insert(
98        CONTENT_SECURITY_POLICY,
99        HeaderValue::from_static(
100            "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'",
101        ),
102    );
103
104    resp
105}
106
107#[cfg(test)]
108mod tests {
109
110    use axum::body::Body;
111    use axum::http::{Request, StatusCode};
112    use http_body_util::BodyExt;
113    use ironflow_core::providers::claude::ClaudeCodeProvider;
114    use ironflow_engine::engine::Engine;
115    use ironflow_engine::notify::Event;
116    use ironflow_store::api_key_store::ApiKeyStore;
117    use ironflow_store::memory::InMemoryStore;
118    use ironflow_store::user_store::UserStore;
119    use serde_json::Value as JsonValue;
120    use std::sync::Arc;
121    use tokio::sync::broadcast;
122    use tower::ServiceExt;
123
124    use crate::routes::{RouterConfig, create_router};
125    use crate::state::AppState;
126
127    fn test_state() -> AppState {
128        let store = Arc::new(InMemoryStore::new());
129        let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
130        let api_key_store: Arc<dyn ApiKeyStore> = Arc::new(InMemoryStore::new());
131        let provider = Arc::new(ClaudeCodeProvider::new());
132        let engine = Arc::new(Engine::new(store.clone(), provider));
133        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
134            secret: "test-secret".to_string(),
135            access_token_ttl_secs: 900,
136            refresh_token_ttl_secs: 604800,
137            cookie_domain: None,
138            cookie_secure: false,
139        });
140        let (event_sender, _) = broadcast::channel::<Event>(1);
141        AppState::new(
142            store,
143            user_store,
144            api_key_store,
145            engine,
146            jwt_config,
147            "test-worker-token".to_string(),
148            event_sender,
149        )
150    }
151
152    #[tokio::test]
153    async fn worker_token_valid() {
154        let state = test_state();
155        let app = create_router(state.clone(), RouterConfig::default());
156
157        let req = Request::builder()
158            .uri("/api/v1/internal/runs/next")
159            .header("authorization", "Bearer test-worker-token")
160            .body(Body::empty())
161            .unwrap();
162
163        let resp = app.oneshot(req).await.unwrap();
164        assert_eq!(resp.status(), StatusCode::OK);
165    }
166
167    #[tokio::test]
168    async fn worker_token_missing() {
169        let state = test_state();
170        let app = create_router(state, RouterConfig::default());
171
172        let req = Request::builder()
173            .uri("/api/v1/internal/runs/next")
174            .body(Body::empty())
175            .unwrap();
176
177        let resp = app.oneshot(req).await.unwrap();
178        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
179
180        let body = resp.into_body().collect().await.unwrap().to_bytes();
181        let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
182        assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
183    }
184
185    #[tokio::test]
186    async fn worker_token_invalid() {
187        let state = test_state();
188        let app = create_router(state, RouterConfig::default());
189
190        let req = Request::builder()
191            .uri("/api/v1/internal/runs/next")
192            .header("authorization", "Bearer wrong-token")
193            .body(Body::empty())
194            .unwrap();
195
196        let resp = app.oneshot(req).await.unwrap();
197        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
198
199        let body = resp.into_body().collect().await.unwrap().to_bytes();
200        let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
201        assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
202    }
203
204    #[tokio::test]
205    async fn security_headers_present() {
206        let state = test_state();
207        let app = create_router(state, RouterConfig::default());
208
209        let req = Request::builder()
210            .uri("/api/v1/health-check")
211            .body(Body::empty())
212            .unwrap();
213
214        let resp = app.oneshot(req).await.unwrap();
215
216        assert_eq!(
217            resp.headers().get("x-content-type-options").unwrap(),
218            "nosniff"
219        );
220        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
221        assert_eq!(
222            resp.headers().get("x-xss-protection").unwrap(),
223            "1; mode=block"
224        );
225        assert_eq!(
226            resp.headers().get("strict-transport-security").unwrap(),
227            "max-age=63072000; includeSubDomains"
228        );
229        assert!(
230            resp.headers()
231                .get("content-security-policy")
232                .unwrap()
233                .to_str()
234                .unwrap()
235                .contains("default-src 'self'")
236        );
237    }
238}