ironflow_api/
middleware.rs1use axum::Json;
4use axum::extract::Request;
5use axum::http::header::{
6 CONTENT_SECURITY_POLICY, STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS,
7 X_XSS_PROTECTION,
8};
9use axum::http::{HeaderValue, StatusCode};
10use axum::middleware::Next;
11use axum::response::{IntoResponse, Response};
12use serde_json::json;
13use subtle::ConstantTimeEq;
14
15pub async fn worker_token_auth(req: Request, next: Next) -> Response {
20 let expected = req.extensions().get::<WorkerToken>().map(|t| t.0.clone());
21
22 let provided = req
23 .headers()
24 .get("authorization")
25 .and_then(|v| v.to_str().ok())
26 .and_then(|v| v.strip_prefix("Bearer "))
27 .map(|t| t.to_string());
28
29 match (expected, provided) {
30 (Some(expected), Some(provided))
31 if expected.as_bytes().ct_eq(provided.as_bytes()).into() =>
32 {
33 next.run(req).await
34 }
35 _ => (
36 StatusCode::UNAUTHORIZED,
37 Json(json!({
38 "error": {
39 "code": "INVALID_WORKER_TOKEN",
40 "message": "Invalid or missing worker token",
41 }
42 })),
43 )
44 .into_response(),
45 }
46}
47
48#[derive(Clone)]
50pub struct WorkerToken(pub String);
51
52#[cfg(feature = "prometheus")]
57pub async fn request_metrics(req: Request, next: Next) -> Response {
58 use std::time::Instant;
59
60 use ironflow_core::metric_names::{API_REQUEST_DURATION_SECONDS, API_REQUESTS_TOTAL};
61 use metrics::{counter, histogram};
62
63 let method = req.method().to_string();
64 let path = req.uri().path().to_string();
65 let start = Instant::now();
66
67 let resp = next.run(req).await;
68
69 let status = resp.status().as_u16().to_string();
70 let duration = start.elapsed().as_secs_f64();
71
72 counter!(API_REQUESTS_TOTAL, "method" => method.clone(), "path" => path.clone(), "status" => status).increment(1);
73 histogram!(API_REQUEST_DURATION_SECONDS, "method" => method, "path" => path).record(duration);
74
75 resp
76}
77
78pub async fn security_headers(req: Request, next: Next) -> Response {
87 let mut resp = next.run(req).await;
88 let headers = resp.headers_mut();
89
90 headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
91 headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
92 headers.insert(X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block"));
93 headers.insert(
94 STRICT_TRANSPORT_SECURITY,
95 HeaderValue::from_static("max-age=63072000; includeSubDomains"),
96 );
97 headers.insert(
98 CONTENT_SECURITY_POLICY,
99 HeaderValue::from_static(
100 "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'",
101 ),
102 );
103
104 resp
105}
106
107#[cfg(test)]
108mod tests {
109
110 use axum::body::Body;
111 use axum::http::{Request, StatusCode};
112 use http_body_util::BodyExt;
113 use ironflow_core::providers::claude::ClaudeCodeProvider;
114 use ironflow_engine::engine::Engine;
115 use ironflow_engine::notify::Event;
116 use ironflow_store::api_key_store::ApiKeyStore;
117 use ironflow_store::memory::InMemoryStore;
118 use ironflow_store::user_store::UserStore;
119 use serde_json::Value as JsonValue;
120 use std::sync::Arc;
121 use tokio::sync::broadcast;
122 use tower::ServiceExt;
123
124 use crate::routes::{RouterConfig, create_router};
125 use crate::state::AppState;
126
127 fn test_state() -> AppState {
128 let store = Arc::new(InMemoryStore::new());
129 let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
130 let api_key_store: Arc<dyn ApiKeyStore> = Arc::new(InMemoryStore::new());
131 let provider = Arc::new(ClaudeCodeProvider::new());
132 let engine = Arc::new(Engine::new(store.clone(), provider));
133 let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
134 secret: "test-secret".to_string(),
135 access_token_ttl_secs: 900,
136 refresh_token_ttl_secs: 604800,
137 cookie_domain: None,
138 cookie_secure: false,
139 });
140 let (event_sender, _) = broadcast::channel::<Event>(1);
141 AppState::new(
142 store,
143 user_store,
144 api_key_store,
145 engine,
146 jwt_config,
147 "test-worker-token".to_string(),
148 event_sender,
149 )
150 }
151
152 #[tokio::test]
153 async fn worker_token_valid() {
154 let state = test_state();
155 let app = create_router(state.clone(), RouterConfig::default());
156
157 let req = Request::builder()
158 .uri("/api/v1/internal/runs/next")
159 .header("authorization", "Bearer test-worker-token")
160 .body(Body::empty())
161 .unwrap();
162
163 let resp = app.oneshot(req).await.unwrap();
164 assert_eq!(resp.status(), StatusCode::OK);
165 }
166
167 #[tokio::test]
168 async fn worker_token_missing() {
169 let state = test_state();
170 let app = create_router(state, RouterConfig::default());
171
172 let req = Request::builder()
173 .uri("/api/v1/internal/runs/next")
174 .body(Body::empty())
175 .unwrap();
176
177 let resp = app.oneshot(req).await.unwrap();
178 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
179
180 let body = resp.into_body().collect().await.unwrap().to_bytes();
181 let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
182 assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
183 }
184
185 #[tokio::test]
186 async fn worker_token_invalid() {
187 let state = test_state();
188 let app = create_router(state, RouterConfig::default());
189
190 let req = Request::builder()
191 .uri("/api/v1/internal/runs/next")
192 .header("authorization", "Bearer wrong-token")
193 .body(Body::empty())
194 .unwrap();
195
196 let resp = app.oneshot(req).await.unwrap();
197 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
198
199 let body = resp.into_body().collect().await.unwrap().to_bytes();
200 let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
201 assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
202 }
203
204 #[tokio::test]
205 async fn security_headers_present() {
206 let state = test_state();
207 let app = create_router(state, RouterConfig::default());
208
209 let req = Request::builder()
210 .uri("/api/v1/health-check")
211 .body(Body::empty())
212 .unwrap();
213
214 let resp = app.oneshot(req).await.unwrap();
215
216 assert_eq!(
217 resp.headers().get("x-content-type-options").unwrap(),
218 "nosniff"
219 );
220 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
221 assert_eq!(
222 resp.headers().get("x-xss-protection").unwrap(),
223 "1; mode=block"
224 );
225 assert_eq!(
226 resp.headers().get("strict-transport-security").unwrap(),
227 "max-age=63072000; includeSubDomains"
228 );
229 assert!(
230 resp.headers()
231 .get("content-security-policy")
232 .unwrap()
233 .to_str()
234 .unwrap()
235 .contains("default-src 'self'")
236 );
237 }
238}