durable_streams_server/
router.rs1use crate::config::{Config, LongPollTimeout, SseReconnectInterval};
6use crate::protocol::stream_name::StreamNameLimits;
7use crate::{handlers, middleware, storage::Storage};
8use axum::http::HeaderValue;
9use axum::{Extension, Router, middleware as axum_middleware, routing::get};
10use std::sync::Arc;
11use std::sync::atomic::AtomicBool;
12use tokio_util::sync::CancellationToken;
13use tower_http::cors::{AllowOrigin, CorsLayer};
14
15#[derive(Clone)]
20pub struct ShutdownToken(pub CancellationToken);
21
22pub const DEFAULT_STREAM_BASE_PATH: &str = "/v1/stream";
24
25#[derive(Clone)]
27pub(crate) struct StreamBasePath(pub Arc<str>);
28
29pub fn build_router<S: Storage + 'static>(storage: Arc<S>, config: &Config) -> Router {
39 build_router_with_ready(storage, config, None, CancellationToken::new())
40}
41
42pub fn build_router_with_ready<S: Storage + 'static>(
51 storage: Arc<S>,
52 config: &Config,
53 ready: Option<Arc<AtomicBool>>,
54 shutdown: CancellationToken,
55) -> Router {
56 let stream_base_path = Arc::<str>::from(config.stream_base_path.as_str());
57 let mut app = Router::new()
58 .route("/healthz", get(handlers::health::health_check))
59 .nest(
60 stream_base_path.as_ref(),
61 protocol_routes(storage, config, shutdown, Arc::clone(&stream_base_path)),
62 );
63
64 if let Some(flag) = ready {
65 app = app
66 .route("/readyz", get(handlers::health::readiness_check))
67 .layer(Extension(flag));
68 }
69
70 app.layer(axum_middleware::from_fn(
71 middleware::telemetry::track_requests,
72 ))
73 .layer(cors_layer(&config.cors_origins))
74}
75
76fn cors_layer(origins: &str) -> CorsLayer {
81 let allow_origin = if origins == "*" {
82 AllowOrigin::any()
83 } else {
84 let values: Vec<HeaderValue> = origins
85 .split(',')
86 .filter_map(|s| s.trim().parse().ok())
87 .collect();
88 AllowOrigin::list(values)
89 };
90
91 CorsLayer::new()
92 .allow_origin(allow_origin)
93 .allow_methods(tower_http::cors::Any)
94 .allow_headers(tower_http::cors::Any)
95 .expose_headers(tower_http::cors::Any)
96}
97
98fn protocol_routes<S: Storage + 'static>(
102 storage: Arc<S>,
103 config: &Config,
104 shutdown: CancellationToken,
105 stream_base_path: Arc<str>,
106) -> Router {
107 Router::new()
108 .route(
109 "/{*name}",
110 get(handlers::get::read_stream::<S>)
111 .put(handlers::put::create_stream::<S>)
112 .head(handlers::head::stream_metadata::<S>)
113 .post(handlers::post::append_data::<S>)
114 .delete(handlers::delete::delete_stream::<S>),
115 )
116 .layer(Extension(StreamNameLimits {
117 max_bytes: config.max_stream_name_bytes,
118 max_segments: config.max_stream_name_segments,
119 }))
120 .layer(Extension(ShutdownToken(shutdown)))
121 .layer(Extension(SseReconnectInterval(
122 config.sse_reconnect_interval_secs,
123 )))
124 .layer(Extension(LongPollTimeout(config.long_poll_timeout)))
125 .layer(Extension(StreamBasePath(stream_base_path)))
126 .layer(axum_middleware::from_fn(
127 middleware::security::add_security_headers,
128 ))
129 .with_state(storage)
130}