1pub mod api_keys;
4pub mod approve_run;
5pub mod audit_logs;
6pub mod auth;
7pub mod cancel_run;
8pub mod create_run;
9pub mod events;
10pub mod get_run;
11pub mod get_stats;
12pub mod get_workflow;
13pub mod health_check;
14mod internal;
15pub mod list_runs;
16pub mod list_workflows;
17#[cfg(feature = "prometheus")]
18pub mod metrics;
19pub mod openapi_spec;
20pub mod retry_run;
21pub mod secrets;
22pub mod users;
23
24use std::path::PathBuf;
25
26use axum::Extension;
27use axum::Router;
28use axum::middleware as axum_mw;
29use axum::routing::{delete, get, patch, post, put};
30use tower_http::limit::RequestBodyLimitLayer;
31use tower_http::services::{ServeDir, ServeFile};
32
33use crate::middleware::{WorkerToken, security_headers, worker_token_auth};
34use crate::rate_limit::{per_minute, rate_limit};
35use crate::state::AppState;
36
37const MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
39
40#[derive(Debug, Clone)]
62pub struct RouterConfig {
63 pub dashboard_dir: Option<PathBuf>,
66 pub rate_limit_auth: Option<u32>,
69 pub rate_limit_general: Option<u32>,
72}
73
74impl Default for RouterConfig {
75 fn default() -> Self {
76 Self {
77 dashboard_dir: None,
78 rate_limit_auth: Some(10),
79 rate_limit_general: Some(60),
80 }
81 }
82}
83
84#[cfg(not(feature = "sign-up"))]
86async fn sign_up_disabled() -> impl axum::response::IntoResponse {
87 crate::error::ApiError::BadRequest("sign-up is disabled".to_string())
88}
89
90pub fn create_router(state: AppState, config: RouterConfig) -> Router {
122 let internal_routes = Router::new()
124 .route("/runs", post(internal::create_run::create_run))
125 .route("/runs/next", get(internal::pick_next_run::pick_next_run))
126 .route(
127 "/runs/{id}",
128 get(internal::get_run::get_run).put(internal::update_run::update_run),
129 )
130 .route(
131 "/runs/{id}/status",
132 put(internal::update_run_status::update_run_status),
133 )
134 .route("/runs/{id}/logs", post(internal::push_logs::push_logs))
135 .route("/steps", post(internal::create_step::create_step))
136 .route("/steps/{id}", put(internal::update_step::update_step))
137 .route(
138 "/step-dependencies",
139 post(internal::create_step_dependencies::create_step_dependencies),
140 )
141 .route("/secrets/{*key}", get(internal::get_secret::get_secret))
142 .layer(axum_mw::from_fn(worker_token_auth))
143 .layer(Extension(WorkerToken(state.worker_token.clone())))
144 .with_state(state.clone());
145
146 #[allow(unused_mut)]
148 let mut auth_credential_routes = Router::new();
149
150 #[cfg(feature = "sign-up")]
151 {
152 auth_credential_routes =
153 auth_credential_routes.route("/sign-up", post(auth::sign_up::sign_up));
154 }
155
156 #[cfg(not(feature = "sign-up"))]
157 {
158 auth_credential_routes = auth_credential_routes.route("/sign-up", post(sign_up_disabled));
159 }
160
161 let mut auth_credential_routes =
162 auth_credential_routes.route("/sign-in", post(auth::sign_in::sign_in));
163
164 if let Some(rpm) = config.rate_limit_auth {
165 auth_credential_routes = auth_credential_routes
166 .layer(axum_mw::from_fn(rate_limit))
167 .layer(Extension(per_minute(rpm)));
168 }
169
170 let auth_session_routes = Router::new()
172 .route("/refresh", post(auth::refresh::refresh))
173 .route("/sign-out", post(auth::sign_out::sign_out))
174 .route("/me", get(auth::me::me));
175
176 #[allow(unused_mut)]
178 let mut api_v1 = Router::new()
179 .route("/health-check", get(health_check::health_check))
180 .route("/openapi.json", get(openapi_spec::openapi_spec))
181 .route(
182 "/runs",
183 get(list_runs::list_runs).post(create_run::create_run),
184 )
185 .route("/runs/{id}", get(get_run::get_run))
186 .route("/runs/{id}/cancel", post(cancel_run::cancel_run))
187 .route("/runs/{id}/approve", post(approve_run::approve_run))
188 .route("/runs/{id}/reject", post(approve_run::reject_run))
189 .route("/runs/{id}/retry", post(retry_run::retry_run))
190 .route("/workflows", get(list_workflows::list_workflows))
191 .route("/workflows/{name}", get(get_workflow::get_workflow))
192 .route("/stats", get(get_stats::get_stats))
193 .route("/audit-logs", get(audit_logs::list_audit_logs))
194 .route("/events", get(events::events))
195 .route(
196 "/api-keys",
197 get(api_keys::list::list_api_keys).post(api_keys::create::create_api_key),
198 )
199 .route(
200 "/api-keys/scopes",
201 get(api_keys::available_scopes::available_scopes),
202 )
203 .route("/api-keys/{id}", delete(api_keys::delete::delete_api_key))
204 .route(
205 "/users",
206 get(users::list::list_users).post(users::create::create_user),
207 )
208 .route("/users/{id}", delete(users::delete::delete_user))
209 .route("/users/{id}/role", patch(users::update_role::update_role))
210 .route(
211 "/secrets",
212 get(secrets::list::list_secrets).post(secrets::create::create_secret),
213 )
214 .route(
215 "/secrets/{*key}",
216 put(secrets::update::update_secret).delete(secrets::delete::delete_secret),
217 );
218
219 #[cfg(feature = "prometheus")]
220 {
221 api_v1 = api_v1.route("/metrics", get(metrics::metrics));
222 }
223
224 let mut api_v1 = api_v1
225 .nest("/auth", auth_credential_routes)
226 .nest("/auth", auth_session_routes);
227
228 if let Some(rpm) = config.rate_limit_general {
229 api_v1 = api_v1
230 .layer(axum_mw::from_fn(rate_limit))
231 .layer(Extension(per_minute(rpm)));
232 }
233
234 let api_v1 = api_v1.with_state(state.clone());
235
236 #[allow(unused_mut)]
237 let mut app = Router::new()
238 .nest("/api/v1/internal", internal_routes)
239 .nest("/api/v1", api_v1)
240 .with_state(state)
241 .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
242 .layer(axum_mw::from_fn(security_headers));
243
244 #[cfg(feature = "prometheus")]
245 {
246 app = app.layer(axum_mw::from_fn(crate::middleware::request_metrics));
247 }
248
249 match config.dashboard_dir {
250 Some(dir) => {
251 let index = dir.join("index.html");
252 let serve = ServeDir::new(dir).fallback(ServeFile::new(index));
253 app.fallback_service(serve)
254 }
255 #[cfg(feature = "dashboard")]
256 None => app.fallback_service(crate::dashboard::EmbeddedDashboard),
257 #[cfg(not(feature = "dashboard"))]
258 None => app,
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use axum::body::Body;
266 use axum::http::{Request, StatusCode};
267 use http_body_util::BodyExt;
268 use ironflow_core::providers::claude::ClaudeCodeProvider;
269 use ironflow_engine::engine::Engine;
270 use ironflow_engine::notify::Event;
271 use ironflow_store::memory::InMemoryStore;
272 use std::sync::Arc;
273 use tokio::sync::broadcast;
274 use tower::ServiceExt;
275 fn test_state() -> AppState {
276 let store = Arc::new(InMemoryStore::new());
277 let provider = Arc::new(ClaudeCodeProvider::new());
278 let engine = Arc::new(Engine::new(store.clone(), provider));
279 let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
280 secret: "test-secret".to_string(),
281 access_token_ttl_secs: 900,
282 refresh_token_ttl_secs: 604800,
283 cookie_domain: None,
284 cookie_secure: false,
285 });
286 let (event_sender, _) = broadcast::channel::<Event>(1);
287 AppState::new(
288 store,
289 engine,
290 jwt_config,
291 "test-worker-token".to_string(),
292 event_sender,
293 )
294 }
295
296 #[tokio::test]
297 async fn health_check_route() {
298 let state = test_state();
299 let app = create_router(state, RouterConfig::default());
300
301 let req = Request::builder()
302 .uri("/api/v1/health-check")
303 .body(Body::empty())
304 .unwrap();
305
306 let resp = app.oneshot(req).await.unwrap();
307 assert_eq!(resp.status(), StatusCode::OK);
308
309 let body = resp.into_body().collect().await.unwrap().to_bytes();
310 assert_eq!(&body[..], b"OK");
311 }
312
313 fn make_auth_header(state: &AppState) -> String {
314 use ironflow_auth::jwt::AccessToken;
315 use uuid::Uuid;
316
317 let user_id = Uuid::now_v7();
318 let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
319 format!("Bearer {}", token.0)
320 }
321
322 #[tokio::test]
323 async fn runs_route_exists() {
324 let state = test_state();
325 let app = create_router(state.clone(), RouterConfig::default());
326 let auth_header = make_auth_header(&state);
327
328 let req = Request::builder()
329 .uri("/api/v1/runs?page=1&per_page=20")
330 .header("authorization", auth_header)
331 .body(Body::empty())
332 .unwrap();
333
334 let resp = app.oneshot(req).await.unwrap();
335 assert_eq!(resp.status(), StatusCode::OK);
336 }
337
338 #[tokio::test]
339 async fn stats_route_exists() {
340 let state = test_state();
341 let app = create_router(state.clone(), RouterConfig::default());
342 let auth_header = make_auth_header(&state);
343
344 let req = Request::builder()
345 .uri("/api/v1/stats")
346 .header("authorization", auth_header)
347 .body(Body::empty())
348 .unwrap();
349
350 let resp = app.oneshot(req).await.unwrap();
351 assert_eq!(resp.status(), StatusCode::OK);
352 }
353
354 #[tokio::test]
355 async fn responses_include_security_headers() {
356 let state = test_state();
357 let app = create_router(state, RouterConfig::default());
358
359 let req = Request::builder()
360 .uri("/api/v1/health-check")
361 .body(Body::empty())
362 .unwrap();
363
364 let resp = app.oneshot(req).await.unwrap();
365
366 assert_eq!(
367 resp.headers().get("x-content-type-options").unwrap(),
368 "nosniff"
369 );
370 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
371 assert_eq!(
372 resp.headers().get("x-xss-protection").unwrap(),
373 "1; mode=block"
374 );
375 assert_eq!(
376 resp.headers().get("strict-transport-security").unwrap(),
377 "max-age=63072000; includeSubDomains"
378 );
379 assert!(
380 resp.headers()
381 .get("content-security-policy")
382 .unwrap()
383 .to_str()
384 .unwrap()
385 .contains("default-src 'self'")
386 );
387 }
388
389 #[tokio::test]
390 async fn body_size_limit_rejects_oversized_payload() {
391 let state = test_state();
392 let app = create_router(state.clone(), RouterConfig::default());
393 let auth_header = make_auth_header(&state);
394
395 let oversized = vec![0u8; 3 * 1024 * 1024];
397
398 let req = Request::builder()
399 .method("POST")
400 .uri("/api/v1/runs")
401 .header("content-type", "application/json")
402 .header("authorization", auth_header)
403 .body(Body::from(oversized))
404 .unwrap();
405
406 let resp = app.oneshot(req).await.unwrap();
407 assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
408 }
409}