durable_streams_server/
router.rs1use crate::config::Config;
6use crate::middleware::proxy_trust::ProxyTrustState;
7use crate::protocol::stream_name::StreamNameLimits;
8use crate::{handlers, middleware, storage::Storage};
9use axum::http::HeaderValue;
10use axum::{Extension, Router, middleware as axum_middleware, routing::get};
11use std::sync::Arc;
12use std::sync::atomic::AtomicBool;
13use tokio_util::sync::CancellationToken;
14use tower_http::cors::{AllowOrigin, CorsLayer};
15
16#[derive(Clone)]
21pub struct ShutdownToken(pub CancellationToken);
22
23#[derive(Clone)]
28pub(crate) struct ReadStreamConfig {
29 pub(crate) long_poll_timeout: std::time::Duration,
30 pub(crate) sse_reconnect_interval_secs: u64,
31 pub(crate) shutdown: CancellationToken,
32}
33
34pub const DEFAULT_STREAM_BASE_PATH: &str = "/v1/stream";
36
37#[derive(Clone)]
39pub(crate) struct StreamBasePath(pub Arc<str>);
40
41pub fn build_router<S: Storage + 'static>(storage: Arc<S>, config: &Config) -> Router {
51 build_router_with_ready(storage, config, None, CancellationToken::new())
52}
53
54pub fn build_router_with_ready<S: Storage + 'static>(
63 storage: Arc<S>,
64 config: &Config,
65 ready: Option<Arc<AtomicBool>>,
66 shutdown: CancellationToken,
67) -> Router {
68 let stream_base_path = Arc::<str>::from(config.http.stream_base_path.as_str());
69 let mut app = Router::new()
70 .route("/healthz", get(handlers::health::health_check))
71 .nest(
72 stream_base_path.as_ref(),
73 protocol_routes(storage, config, shutdown, Arc::clone(&stream_base_path)),
74 );
75
76 if let Some(flag) = ready {
77 app = app
78 .route("/readyz", get(handlers::health::readiness_check))
79 .layer(Extension(flag));
80 }
81
82 let proxy_trust_state = Arc::new(ProxyTrustState::from_config(config));
83
84 app.layer(axum_middleware::from_fn(
85 middleware::telemetry::track_requests,
86 ))
87 .layer(cors_layer(&config.http.cors_origins))
88 .layer(axum_middleware::from_fn(move |request, next| {
89 middleware::proxy_trust::enforce_proxy_trust(proxy_trust_state.clone(), request, next)
90 }))
91}
92
93fn cors_layer(origins: &str) -> CorsLayer {
98 let allow_origin = if origins == "*" {
99 AllowOrigin::any()
100 } else {
101 let values: Vec<HeaderValue> = origins
102 .split(',')
103 .filter_map(|s| s.trim().parse().ok())
104 .collect();
105 AllowOrigin::list(values)
106 };
107
108 CorsLayer::new()
109 .allow_origin(allow_origin)
110 .allow_methods(tower_http::cors::Any)
111 .allow_headers(tower_http::cors::Any)
112 .expose_headers(tower_http::cors::Any)
113}
114
115fn protocol_routes<S: Storage + 'static>(
119 storage: Arc<S>,
120 config: &Config,
121 shutdown: CancellationToken,
122 stream_base_path: Arc<str>,
123) -> Router {
124 Router::new()
125 .route(
126 "/{*name}",
127 get(handlers::get::read_stream::<S>)
128 .put(handlers::put::create_stream::<S>)
129 .head(handlers::head::stream_metadata::<S>)
130 .post(handlers::post::append_data::<S>)
131 .delete(handlers::delete::delete_stream::<S>),
132 )
133 .layer(Extension(StreamNameLimits {
134 max_bytes: config.limits.max_stream_name_bytes,
135 max_segments: config.limits.max_stream_name_segments,
136 }))
137 .layer(Extension(ReadStreamConfig {
138 long_poll_timeout: config.long_poll_timeout(),
139 sse_reconnect_interval_secs: config.transport.connection.sse_reconnect_interval_secs,
140 shutdown,
141 }))
142 .layer(Extension(StreamBasePath(stream_base_path)))
143 .layer(axum_middleware::from_fn(
144 middleware::security::add_security_headers,
145 ))
146 .with_state(storage)
147}