wsi_streamer/server/
routes.rs1use std::time::Duration;
37
38use axum::{middleware, routing::get, Router};
39use http::header::{AUTHORIZATION, CONTENT_TYPE};
40use http::Method;
41use tower_http::cors::{Any, CorsLayer};
42use tower_http::trace::TraceLayer;
43
44use super::auth::SignedUrlAuth;
45use super::handlers::{
46 dzi_descriptor_handler, health_handler, slide_metadata_handler, slides_handler,
47 thumbnail_handler, tile_handler, viewer_handler, AppState,
48};
49use crate::slide::SlideSource;
50use crate::tile::TileService;
51
52#[derive(Clone)]
58pub struct RouterConfig {
59 pub auth_secret: String,
61
62 pub auth_enabled: bool,
64
65 pub cors_origins: Option<Vec<String>>,
67
68 pub cache_max_age: u32,
70
71 pub enable_tracing: bool,
73}
74
75impl RouterConfig {
76 pub fn new(auth_secret: impl Into<String>) -> Self {
84 Self {
85 auth_secret: auth_secret.into(),
86 auth_enabled: true,
87 cors_origins: None, cache_max_age: 3600,
89 enable_tracing: true,
90 }
91 }
92
93 pub fn without_auth() -> Self {
97 Self {
98 auth_secret: String::new(),
99 auth_enabled: false,
100 cors_origins: None,
101 cache_max_age: 3600,
102 enable_tracing: true,
103 }
104 }
105
106 pub fn with_cors_origins(mut self, origins: Vec<String>) -> Self {
111 self.cors_origins = Some(origins);
112 self
113 }
114
115 pub fn with_cors_any_origin(mut self) -> Self {
117 self.cors_origins = None;
118 self
119 }
120
121 pub fn with_cache_max_age(mut self, seconds: u32) -> Self {
123 self.cache_max_age = seconds;
124 self
125 }
126
127 pub fn with_auth_enabled(mut self, enabled: bool) -> Self {
129 self.auth_enabled = enabled;
130 self
131 }
132
133 pub fn with_tracing(mut self, enabled: bool) -> Self {
135 self.enable_tracing = enabled;
136 self
137 }
138}
139
140pub fn create_router<S>(tile_service: TileService<S>, config: RouterConfig) -> Router
161where
162 S: SlideSource + 'static,
163{
164 let app_state = if config.auth_enabled {
166 let auth = SignedUrlAuth::new(&config.auth_secret);
167 AppState::with_cache_max_age(tile_service, config.cache_max_age).with_auth(auth.clone())
168 } else {
169 AppState::with_cache_max_age(tile_service, config.cache_max_age)
170 };
171
172 let auth = SignedUrlAuth::new(&config.auth_secret);
174
175 let cors = build_cors_layer(&config);
177
178 let router = if config.auth_enabled {
180 build_protected_router(app_state, auth, cors)
181 } else {
182 build_public_router(app_state, cors)
183 };
184
185 if config.enable_tracing {
187 router.layer(TraceLayer::new_for_http())
188 } else {
189 router
190 }
191}
192
193fn build_protected_router<S>(app_state: AppState<S>, auth: SignedUrlAuth, cors: CorsLayer) -> Router
195where
196 S: SlideSource + 'static,
197{
198 let tile_routes = Router::new()
202 .route("/{slide_id}/{level}/{x}/{filename}", get(tile_handler::<S>))
203 .with_state(app_state.clone());
204
205 let slides_routes = Router::new()
207 .route("/", get(slides_handler::<S>))
208 .route("/{slide_id}", get(slide_metadata_handler::<S>))
209 .route("/{slide_id}/dzi", get(dzi_descriptor_handler::<S>))
210 .route("/{slide_id}/thumbnail", get(thumbnail_handler::<S>))
211 .with_state(app_state.clone());
212
213 let protected_routes = Router::new()
215 .nest("/tiles", tile_routes)
216 .nest("/slides", slides_routes)
217 .layer(middleware::from_fn_with_state(
218 auth,
219 super::auth::auth_middleware,
220 ));
221
222 let public_routes = Router::new()
225 .route("/health", get(health_handler))
226 .route("/view/{slide_id}", get(viewer_handler::<S>))
227 .with_state(app_state);
228
229 Router::new()
231 .merge(protected_routes)
232 .merge(public_routes)
233 .layer(cors)
234}
235
236fn build_public_router<S>(app_state: AppState<S>, cors: CorsLayer) -> Router
238where
239 S: SlideSource + 'static,
240{
241 Router::new()
244 .route("/health", get(health_handler))
245 .route(
246 "/tiles/{slide_id}/{level}/{x}/{filename}",
247 get(tile_handler::<S>),
248 )
249 .route("/slides", get(slides_handler::<S>))
250 .route("/slides/{slide_id}", get(slide_metadata_handler::<S>))
251 .route("/slides/{slide_id}/dzi", get(dzi_descriptor_handler::<S>))
252 .route("/slides/{slide_id}/thumbnail", get(thumbnail_handler::<S>))
253 .route("/view/{slide_id}", get(viewer_handler::<S>))
254 .with_state(app_state)
255 .layer(cors)
256}
257
258fn build_cors_layer(config: &RouterConfig) -> CorsLayer {
260 let cors = CorsLayer::new()
261 .allow_methods([Method::GET, Method::HEAD, Method::OPTIONS])
262 .allow_headers([AUTHORIZATION, CONTENT_TYPE])
263 .max_age(Duration::from_secs(86400)); match &config.cors_origins {
266 None => cors.allow_origin(Any),
267 Some(origins) if origins.is_empty() => {
268 cors
270 }
271 Some(origins) => {
272 let parsed_origins: Vec<_> = origins.iter().filter_map(|o| o.parse().ok()).collect();
274 cors.allow_origin(parsed_origins)
275 }
276 }
277}
278
279pub fn create_dev_router<S>(tile_service: TileService<S>) -> Router
288where
289 S: SlideSource + 'static,
290{
291 create_router(tile_service, RouterConfig::without_auth())
292}
293
294pub fn create_production_router<S>(
302 tile_service: TileService<S>,
303 auth_secret: impl Into<String>,
304) -> Router
305where
306 S: SlideSource + 'static,
307{
308 create_router(tile_service, RouterConfig::new(auth_secret))
309}
310
311#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_router_config_defaults() {
321 let config = RouterConfig::new("secret");
322 assert_eq!(config.auth_secret, "secret");
323 assert!(config.auth_enabled);
324 assert!(config.cors_origins.is_none());
325 assert_eq!(config.cache_max_age, 3600);
326 assert!(config.enable_tracing);
327 }
328
329 #[test]
330 fn test_router_config_without_auth() {
331 let config = RouterConfig::without_auth();
332 assert!(!config.auth_enabled);
333 assert!(config.auth_secret.is_empty());
334 }
335
336 #[test]
337 fn test_router_config_builder() {
338 let config = RouterConfig::new("secret")
339 .with_cors_origins(vec!["https://example.com".to_string()])
340 .with_cache_max_age(7200)
341 .with_auth_enabled(false)
342 .with_tracing(false);
343
344 assert_eq!(config.auth_secret, "secret");
345 assert!(!config.auth_enabled);
346 assert_eq!(
347 config.cors_origins,
348 Some(vec!["https://example.com".to_string()])
349 );
350 assert_eq!(config.cache_max_age, 7200);
351 assert!(!config.enable_tracing);
352 }
353
354 #[test]
355 fn test_router_config_cors_any() {
356 let config = RouterConfig::new("secret")
357 .with_cors_origins(vec!["https://example.com".to_string()])
358 .with_cors_any_origin();
359
360 assert!(config.cors_origins.is_none());
361 }
362
363 #[test]
364 fn test_build_cors_layer_any_origin() {
365 let config = RouterConfig::new("secret");
366 let _cors = build_cors_layer(&config);
367 }
369
370 #[test]
371 fn test_build_cors_layer_specific_origins() {
372 let config = RouterConfig::new("secret").with_cors_origins(vec![
373 "https://example.com".to_string(),
374 "https://other.com".to_string(),
375 ]);
376 let _cors = build_cors_layer(&config);
377 }
379
380 #[test]
381 fn test_build_cors_layer_empty_origins() {
382 let config = RouterConfig::new("secret").with_cors_origins(vec![]);
383 let _cors = build_cors_layer(&config);
384 }
386}