Skip to main content

ironflow_api/routes/
mod.rs

1//! Router assembly — one module per route.
2
3mod approve_run;
4mod auth;
5mod cancel_run;
6mod create_run;
7mod get_run;
8mod get_stats;
9mod get_workflow;
10mod health_check;
11mod internal;
12mod list_runs;
13mod list_workflows;
14#[cfg(feature = "prometheus")]
15pub mod metrics;
16mod retry_run;
17
18use std::path::PathBuf;
19
20use axum::Extension;
21use axum::Router;
22use axum::middleware as axum_mw;
23use axum::routing::{get, post, put};
24use tower_http::limit::RequestBodyLimitLayer;
25use tower_http::services::{ServeDir, ServeFile};
26
27use crate::middleware::{WorkerToken, security_headers, worker_token_auth};
28use crate::rate_limit::{per_minute, rate_limit};
29use crate::state::AppState;
30
31/// Maximum request body size: 2 MiB.
32const MAX_BODY_SIZE: usize = 2 * 1024 * 1024;
33
34/// Router-level configuration with sensible defaults.
35///
36/// Controls dashboard serving, rate limiting, and other router behaviors.
37/// Use [`Default::default()`] for production-ready defaults, then override
38/// individual fields as needed.
39///
40/// # Examples
41///
42/// ```
43/// use ironflow_api::routes::RouterConfig;
44///
45/// // All defaults: rate limiting enabled, no custom dashboard dir
46/// let config = RouterConfig::default();
47/// assert_eq!(config.rate_limit_auth, Some(10));
48///
49/// // Disable auth rate limiting, custom dashboard
50/// let config = RouterConfig {
51///     rate_limit_auth: None,
52///     ..RouterConfig::default()
53/// };
54/// ```
55#[derive(Debug, Clone)]
56pub struct RouterConfig {
57    /// Filesystem path to dashboard assets. When set, serves the SPA
58    /// from this directory instead of the embedded build.
59    pub dashboard_dir: Option<PathBuf>,
60    /// Rate limit for auth credential routes (sign-in, sign-up) in
61    /// requests per minute per IP. `None` disables the limiter.
62    pub rate_limit_auth: Option<u32>,
63    /// Rate limit for general public API routes in requests per minute
64    /// per IP. `None` disables the limiter.
65    pub rate_limit_general: Option<u32>,
66}
67
68impl Default for RouterConfig {
69    fn default() -> Self {
70        Self {
71            dashboard_dir: None,
72            rate_limit_auth: Some(10),
73            rate_limit_general: Some(60),
74        }
75    }
76}
77
78/// Create the main application router.
79///
80/// # Examples
81///
82/// ```no_run
83/// use ironflow_api::routes::{RouterConfig, create_router};
84/// use ironflow_api::state::AppState;
85/// use ironflow_auth::jwt::JwtConfig;
86/// use ironflow_store::prelude::*;
87/// use ironflow_engine::engine::Engine;
88/// use ironflow_core::providers::claude::ClaudeCodeProvider;
89/// use std::sync::Arc;
90///
91/// # async fn example() {
92/// let store = Arc::new(InMemoryStore::new());
93/// let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
94/// let provider = Arc::new(ClaudeCodeProvider::new());
95/// let engine = Arc::new(Engine::new(store.clone(), provider));
96/// let jwt_config = Arc::new(JwtConfig {
97///     secret: "secret".to_string(),
98///     access_token_ttl_secs: 900,
99///     refresh_token_ttl_secs: 604800,
100///     cookie_domain: None,
101///     cookie_secure: false,
102/// });
103/// let state = AppState::new(store, user_store, engine, jwt_config, "token".to_string());
104/// let router = create_router(state, RouterConfig::default());
105/// # }
106/// ```
107pub fn create_router(state: AppState, config: RouterConfig) -> Router {
108    // Internal routes (worker-to-API, protected by WORKER_TOKEN)
109    let internal_routes = Router::new()
110        .route("/runs", post(internal::create_run::create_run))
111        .route("/runs/next", get(internal::pick_next_run::pick_next_run))
112        .route(
113            "/runs/{id}",
114            get(internal::get_run::get_run).put(internal::update_run::update_run),
115        )
116        .route(
117            "/runs/{id}/status",
118            put(internal::update_run_status::update_run_status),
119        )
120        .route("/steps", post(internal::create_step::create_step))
121        .route("/steps/{id}", put(internal::update_step::update_step))
122        .route(
123            "/step-dependencies",
124            post(internal::create_step_dependencies::create_step_dependencies),
125        )
126        .layer(axum_mw::from_fn(worker_token_auth))
127        .layer(Extension(WorkerToken(state.worker_token.clone())))
128        .with_state(state.clone());
129
130    // Auth credential routes (rate-limited when configured)
131    #[allow(unused_mut)]
132    let mut auth_credential_routes = Router::new();
133
134    #[cfg(feature = "sign-up")]
135    {
136        auth_credential_routes =
137            auth_credential_routes.route("/sign-up", post(auth::sign_up::sign_up));
138    }
139
140    let mut auth_credential_routes =
141        auth_credential_routes.route("/sign-in", post(auth::sign_in::sign_in));
142
143    if let Some(rpm) = config.rate_limit_auth {
144        auth_credential_routes = auth_credential_routes
145            .layer(axum_mw::from_fn(rate_limit))
146            .layer(Extension(per_minute(rpm)));
147    }
148
149    // Auth session routes (no strict rate limiting, covered by general limiter)
150    let auth_session_routes = Router::new()
151        .route("/refresh", post(auth::refresh::refresh))
152        .route("/sign-out", post(auth::sign_out::sign_out))
153        .route("/me", get(auth::me::me));
154
155    // Public + user-authenticated routes (rate-limited when configured)
156    #[allow(unused_mut)]
157    let mut api_v1 = Router::new()
158        .route("/health-check", get(health_check::health_check))
159        .route(
160            "/runs",
161            get(list_runs::list_runs).post(create_run::create_run),
162        )
163        .route("/runs/{id}", get(get_run::get_run))
164        .route("/runs/{id}/cancel", post(cancel_run::cancel_run))
165        .route("/runs/{id}/approve", post(approve_run::approve_run))
166        .route("/runs/{id}/reject", post(approve_run::reject_run))
167        .route("/runs/{id}/retry", post(retry_run::retry_run))
168        .route("/workflows", get(list_workflows::list_workflows))
169        .route("/workflows/{name}", get(get_workflow::get_workflow))
170        .route("/stats", get(get_stats::get_stats));
171
172    #[cfg(feature = "prometheus")]
173    {
174        api_v1 = api_v1.route("/metrics", get(metrics::metrics));
175    }
176
177    let mut api_v1 = api_v1
178        .nest("/auth", auth_credential_routes)
179        .nest("/auth", auth_session_routes);
180
181    if let Some(rpm) = config.rate_limit_general {
182        api_v1 = api_v1
183            .layer(axum_mw::from_fn(rate_limit))
184            .layer(Extension(per_minute(rpm)));
185    }
186
187    let api_v1 = api_v1.with_state(state.clone());
188
189    #[allow(unused_mut)]
190    let mut app = Router::new()
191        .nest("/api/v1/internal", internal_routes)
192        .nest("/api/v1", api_v1)
193        .with_state(state)
194        .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
195        .layer(axum_mw::from_fn(security_headers));
196
197    #[cfg(feature = "prometheus")]
198    {
199        app = app.layer(axum_mw::from_fn(crate::middleware::request_metrics));
200    }
201
202    match config.dashboard_dir {
203        Some(dir) => {
204            let index = dir.join("index.html");
205            let serve = ServeDir::new(dir).fallback(ServeFile::new(index));
206            app.fallback_service(serve)
207        }
208        #[cfg(feature = "dashboard")]
209        None => app.fallback_service(crate::dashboard::EmbeddedDashboard),
210        #[cfg(not(feature = "dashboard"))]
211        None => app,
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use axum::body::Body;
219    use axum::http::{Request, StatusCode};
220    use http_body_util::BodyExt;
221    use ironflow_core::providers::claude::ClaudeCodeProvider;
222    use ironflow_engine::engine::Engine;
223    use ironflow_store::memory::InMemoryStore;
224    use ironflow_store::user_store::UserStore;
225    use std::sync::Arc;
226    use tower::ServiceExt;
227    fn test_state() -> AppState {
228        let store = Arc::new(InMemoryStore::new());
229        let user_store: Arc<dyn UserStore> = Arc::new(InMemoryStore::new());
230        let provider = Arc::new(ClaudeCodeProvider::new());
231        let engine = Arc::new(Engine::new(store.clone(), provider));
232        let jwt_config = Arc::new(ironflow_auth::jwt::JwtConfig {
233            secret: "test-secret".to_string(),
234            access_token_ttl_secs: 900,
235            refresh_token_ttl_secs: 604800,
236            cookie_domain: None,
237            cookie_secure: false,
238        });
239        AppState::new(
240            store,
241            user_store,
242            engine,
243            jwt_config,
244            "test-worker-token".to_string(),
245        )
246    }
247
248    #[tokio::test]
249    async fn health_check_route() {
250        let state = test_state();
251        let app = create_router(state, RouterConfig::default());
252
253        let req = Request::builder()
254            .uri("/api/v1/health-check")
255            .body(Body::empty())
256            .unwrap();
257
258        let resp = app.oneshot(req).await.unwrap();
259        assert_eq!(resp.status(), StatusCode::OK);
260
261        let body = resp.into_body().collect().await.unwrap().to_bytes();
262        assert_eq!(&body[..], b"OK");
263    }
264
265    fn make_auth_header(state: &AppState) -> String {
266        use ironflow_auth::jwt::AccessToken;
267        use uuid::Uuid;
268
269        let user_id = Uuid::now_v7();
270        let token = AccessToken::for_user(user_id, "testuser", false, &state.jwt_config).unwrap();
271        format!("Bearer {}", token.0)
272    }
273
274    #[tokio::test]
275    async fn runs_route_exists() {
276        let state = test_state();
277        let app = create_router(state.clone(), RouterConfig::default());
278        let auth_header = make_auth_header(&state);
279
280        let req = Request::builder()
281            .uri("/api/v1/runs?page=1&per_page=20")
282            .header("authorization", auth_header)
283            .body(Body::empty())
284            .unwrap();
285
286        let resp = app.oneshot(req).await.unwrap();
287        assert_eq!(resp.status(), StatusCode::OK);
288    }
289
290    #[tokio::test]
291    async fn stats_route_exists() {
292        let state = test_state();
293        let app = create_router(state.clone(), RouterConfig::default());
294        let auth_header = make_auth_header(&state);
295
296        let req = Request::builder()
297            .uri("/api/v1/stats")
298            .header("authorization", auth_header)
299            .body(Body::empty())
300            .unwrap();
301
302        let resp = app.oneshot(req).await.unwrap();
303        assert_eq!(resp.status(), StatusCode::OK);
304    }
305
306    #[tokio::test]
307    async fn responses_include_security_headers() {
308        let state = test_state();
309        let app = create_router(state, RouterConfig::default());
310
311        let req = Request::builder()
312            .uri("/api/v1/health-check")
313            .body(Body::empty())
314            .unwrap();
315
316        let resp = app.oneshot(req).await.unwrap();
317
318        assert_eq!(
319            resp.headers().get("x-content-type-options").unwrap(),
320            "nosniff"
321        );
322        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
323        assert_eq!(
324            resp.headers().get("x-xss-protection").unwrap(),
325            "1; mode=block"
326        );
327        assert_eq!(
328            resp.headers().get("strict-transport-security").unwrap(),
329            "max-age=63072000; includeSubDomains"
330        );
331        assert!(
332            resp.headers()
333                .get("content-security-policy")
334                .unwrap()
335                .to_str()
336                .unwrap()
337                .contains("default-src 'self'")
338        );
339    }
340
341    #[tokio::test]
342    async fn body_size_limit_rejects_oversized_payload() {
343        let state = test_state();
344        let app = create_router(state.clone(), RouterConfig::default());
345        let auth_header = make_auth_header(&state);
346
347        // 3 MiB payload — exceeds the 2 MiB limit
348        let oversized = vec![0u8; 3 * 1024 * 1024];
349
350        let req = Request::builder()
351            .method("POST")
352            .uri("/api/v1/runs")
353            .header("content-type", "application/json")
354            .header("authorization", auth_header)
355            .body(Body::from(oversized))
356            .unwrap();
357
358        let resp = app.oneshot(req).await.unwrap();
359        assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
360    }
361}