use axum::Json;
use axum::extract::Request;
use axum::http::header::{
CONTENT_SECURITY_POLICY, STRICT_TRANSPORT_SECURITY, X_CONTENT_TYPE_OPTIONS, X_FRAME_OPTIONS,
X_XSS_PROTECTION,
};
use axum::http::{HeaderValue, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use subtle::ConstantTimeEq;
pub async fn worker_token_auth(req: Request, next: Next) -> Response {
let expected = req.extensions().get::<WorkerToken>().map(|t| t.0.clone());
let provided = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|t| t.to_string());
match (expected, provided) {
(Some(expected), Some(provided))
if expected.as_bytes().ct_eq(provided.as_bytes()).into() =>
{
next.run(req).await
}
_ => (
StatusCode::UNAUTHORIZED,
Json(json!({
"error": {
"code": "INVALID_WORKER_TOKEN",
"message": "Invalid or missing worker token",
}
})),
)
.into_response(),
}
}
#[derive(Clone)]
pub struct WorkerToken(pub String);
#[cfg(feature = "prometheus")]
pub async fn request_metrics(req: Request, next: Next) -> Response {
use std::time::Instant;
use ironflow_core::metric_names::{API_REQUEST_DURATION_SECONDS, API_REQUESTS_TOTAL};
use metrics::{counter, histogram};
let method = req.method().to_string();
let path = req.uri().path().to_string();
let start = Instant::now();
let resp = next.run(req).await;
let status = resp.status().as_u16().to_string();
let duration = start.elapsed().as_secs_f64();
counter!(API_REQUESTS_TOTAL, "method" => method.clone(), "path" => path.clone(), "status" => status).increment(1);
histogram!(API_REQUEST_DURATION_SECONDS, "method" => method, "path" => path).record(duration);
resp
}
pub async fn security_headers(req: Request, next: Next) -> Response {
let mut resp = next.run(req).await;
let headers = resp.headers_mut();
headers.insert(X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
headers.insert(X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
headers.insert(X_XSS_PROTECTION, HeaderValue::from_static("1; mode=block"));
headers.insert(
STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=63072000; includeSubDomains"),
);
headers.insert(
CONTENT_SECURITY_POLICY,
HeaderValue::from_static(
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self' data:; connect-src 'self'",
),
);
resp
}
#[cfg(test)]
mod tests {
use axum::body::Body;
use axum::http::{Request, StatusCode};
use http_body_util::BodyExt;
use ironflow_core::providers::claude::ClaudeCodeProvider;
use ironflow_engine::engine::Engine;
use ironflow_engine::notify::Event;
use ironflow_store::api_key_store::ApiKeyStore;
use ironflow_store::memory::InMemoryStore;
use ironflow_store::user_store::UserStore;
use serde_json::Value as JsonValue;
use std::sync::Arc;
use tokio::sync::broadcast;
use tower::ServiceExt;
use crate::routes::{RouterConfig, create_router};
use crate::state::AppState;
fn test_state() -> AppState {
let store = Arc::new(InMemoryStore::new());
let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
let api_key_store: Arc<dyn ApiKeyStore> = Arc::new(InMemoryStore::new());
let provider = Arc::new(ClaudeCodeProvider::new());
let engine = Arc::new(Engine::new(store.clone(), provider));
let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
secret: "test-secret".to_string(),
access_token_ttl_secs: 900,
refresh_token_ttl_secs: 604800,
cookie_domain: None,
cookie_secure: false,
});
let (event_sender, _) = broadcast::channel::<Event>(1);
AppState::new(
store,
user_store,
api_key_store,
engine,
jwt_config,
"test-worker-token".to_string(),
event_sender,
)
}
#[tokio::test]
async fn worker_token_valid() {
let state = test_state();
let app = create_router(state.clone(), RouterConfig::default());
let req = Request::builder()
.uri("/api/v1/internal/runs/next")
.header("authorization", "Bearer test-worker-token")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn worker_token_missing() {
let state = test_state();
let app = create_router(state, RouterConfig::default());
let req = Request::builder()
.uri("/api/v1/internal/runs/next")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
}
#[tokio::test]
async fn worker_token_invalid() {
let state = test_state();
let app = create_router(state, RouterConfig::default());
let req = Request::builder()
.uri("/api/v1/internal/runs/next")
.header("authorization", "Bearer wrong-token")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json_val: JsonValue = serde_json::from_slice(&body).unwrap();
assert_eq!(json_val["error"]["code"], "INVALID_WORKER_TOKEN");
}
#[tokio::test]
async fn security_headers_present() {
let state = test_state();
let app = create_router(state, RouterConfig::default());
let req = Request::builder()
.uri("/api/v1/health-check")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(
resp.headers().get("x-content-type-options").unwrap(),
"nosniff"
);
assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
assert_eq!(
resp.headers().get("x-xss-protection").unwrap(),
"1; mode=block"
);
assert_eq!(
resp.headers().get("strict-transport-security").unwrap(),
"max-age=63072000; includeSubDomains"
);
assert!(
resp.headers()
.get("content-security-policy")
.unwrap()
.to_str()
.unwrap()
.contains("default-src 'self'")
);
}
}