Skip to main content

durable_streams_server/
router.rs

1//! Axum router construction for the Durable Streams HTTP surface.
2//!
3//! [`build_router`] is the main embedding entry point for library consumers.
4
5use crate::config::Config;
6use crate::middleware::proxy_trust::ProxyTrustState;
7use crate::protocol::stream_name::StreamNameLimits;
8use crate::{handlers, middleware, storage::Storage};
9use axum::http::HeaderValue;
10use axum::{Extension, Router, middleware as axum_middleware, routing::get};
11use std::sync::Arc;
12use std::sync::atomic::AtomicBool;
13use tokio_util::sync::CancellationToken;
14use tower_http::cors::{AllowOrigin, CorsLayer};
15
16/// Wrapper around [`CancellationToken`] for axum `Extension` extraction.
17///
18/// Long-poll and SSE handlers observe this token so they can drain cleanly
19/// when the server begins a graceful shutdown.
20#[derive(Clone)]
21pub struct ShutdownToken(pub CancellationToken);
22
23/// Combined read-stream configuration extracted as a single axum `Extension`.
24///
25/// Groups the long-poll timeout, SSE reconnect interval, and shutdown token
26/// so handlers that need all three only consume one extractor slot.
27#[derive(Clone)]
28pub(crate) struct ReadStreamConfig {
29    pub(crate) long_poll_timeout: std::time::Duration,
30    pub(crate) sse_reconnect_interval_secs: u64,
31    pub(crate) shutdown: CancellationToken,
32}
33
34/// Default mount path for the Durable Streams protocol routes.
35pub const DEFAULT_STREAM_BASE_PATH: &str = "/v1/stream";
36
37/// Wrapper around the configured stream route mount path.
38#[derive(Clone)]
39pub(crate) struct StreamBasePath(pub Arc<str>);
40
41/// Build the application router with storage state.
42///
43/// Routes:
44/// - `GET /healthz`        – Liveness probe (always 200)
45/// - `GET/PUT/... <path>`  – Protocol routes mounted at the configured
46///   `config.stream_base_path` (default [`DEFAULT_STREAM_BASE_PATH`])
47///
48/// Uses a no-op cancellation token (never cancelled). For production
49/// use with graceful shutdown, prefer [`build_router_with_ready`].
50pub fn build_router<S: Storage + 'static>(storage: Arc<S>, config: &Config) -> Router {
51    build_router_with_ready(storage, config, None, CancellationToken::new())
52}
53
54/// Build the router with readiness flag and shutdown token.
55///
56/// When `ready` is `Some`, the `/readyz` endpoint is registered and returns
57/// 200 only after the flag is set to `true`. When `None`, the endpoint is
58/// not registered (backwards-compatible).
59///
60/// The `shutdown` token is propagated to long-poll and SSE handlers so they
61/// can observe server shutdown and drain in-flight connections cleanly.
62pub fn build_router_with_ready<S: Storage + 'static>(
63    storage: Arc<S>,
64    config: &Config,
65    ready: Option<Arc<AtomicBool>>,
66    shutdown: CancellationToken,
67) -> Router {
68    let stream_base_path = Arc::<str>::from(config.http.stream_base_path.as_str());
69    let mut app = Router::new()
70        .route("/healthz", get(handlers::health::health_check))
71        .nest(
72            stream_base_path.as_ref(),
73            protocol_routes(storage, config, shutdown, Arc::clone(&stream_base_path)),
74        );
75
76    if let Some(flag) = ready {
77        app = app
78            .route("/readyz", get(handlers::health::readiness_check))
79            .layer(Extension(flag));
80    }
81
82    let proxy_trust_state = Arc::new(ProxyTrustState::from_config(config));
83
84    app.layer(axum_middleware::from_fn(
85        middleware::telemetry::track_requests,
86    ))
87    .layer(cors_layer(&config.http.cors_origins))
88    .layer(axum_middleware::from_fn(move |request, next| {
89        middleware::proxy_trust::enforce_proxy_trust(proxy_trust_state.clone(), request, next)
90    }))
91}
92
93/// Build a CORS layer from the configured origins string.
94///
95/// Accepts `"*"` for permissive (any origin) or a comma-separated list of
96/// allowed origins (e.g. `"http://localhost:3000,https://app.example.com"`).
97fn cors_layer(origins: &str) -> CorsLayer {
98    let allow_origin = if origins == "*" {
99        AllowOrigin::any()
100    } else {
101        let values: Vec<HeaderValue> = origins
102            .split(',')
103            .filter_map(|s| s.trim().parse().ok())
104            .collect();
105        AllowOrigin::list(values)
106    };
107
108    CorsLayer::new()
109        .allow_origin(allow_origin)
110        .allow_methods(tower_http::cors::Any)
111        .allow_headers(tower_http::cors::Any)
112        .expose_headers(tower_http::cors::Any)
113}
114
115/// Protocol routes under /v1/stream
116///
117/// All protocol routes have security headers applied via middleware.
118fn protocol_routes<S: Storage + 'static>(
119    storage: Arc<S>,
120    config: &Config,
121    shutdown: CancellationToken,
122    stream_base_path: Arc<str>,
123) -> Router {
124    Router::new()
125        .route(
126            "/{*name}",
127            get(handlers::get::read_stream::<S>)
128                .put(handlers::put::create_stream::<S>)
129                .head(handlers::head::stream_metadata::<S>)
130                .post(handlers::post::append_data::<S>)
131                .delete(handlers::delete::delete_stream::<S>),
132        )
133        .layer(Extension(StreamNameLimits {
134            max_bytes: config.limits.max_stream_name_bytes,
135            max_segments: config.limits.max_stream_name_segments,
136        }))
137        .layer(Extension(ReadStreamConfig {
138            long_poll_timeout: config.long_poll_timeout(),
139            sse_reconnect_interval_secs: config.transport.connection.sse_reconnect_interval_secs,
140            shutdown,
141        }))
142        .layer(Extension(StreamBasePath(stream_base_path)))
143        .layer(axum_middleware::from_fn(
144            middleware::security::add_security_headers,
145        ))
146        .with_state(storage)
147}