1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5 A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6 SessionController,
7};
8
9use crate::{
10 ServerConfig,
11 auth_bridge::{RequestContext, RequestContextError, RequestContextExtractor},
12 web_ui,
13};
14use axum::{
15 Json, Router,
16 body::Body,
17 extract::{DefaultBodyLimit, State},
18 http::{HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode, header},
19 middleware::{self, Next},
20 response::{IntoResponse, Response},
21 routing::{get, post},
22};
23use serde::Serialize;
24use std::sync::Arc;
25use tokio_util::sync::CancellationToken;
26use tower::ServiceBuilder;
27use tower_http::{
28 cors::{AllowOrigin, CorsLayer},
29 set_header::SetResponseHeaderLayer,
30 timeout::TimeoutLayer,
31 trace::TraceLayer,
32};
33
34const REQUEST_ID_HEADER: &str = "x-request-id";
35
36#[derive(Clone)]
37struct HealthController {
38 session_service: Arc<dyn adk_session::SessionService>,
39 artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
40 memory_service: Option<Arc<dyn adk_core::Memory>>,
41}
42
43impl HealthController {
44 fn new(config: &ServerConfig) -> Self {
45 Self {
46 session_service: config.session_service.clone(),
47 artifact_service: config.artifact_service.clone(),
48 memory_service: config.memory_service.clone(),
49 }
50 }
51}
52
53#[derive(Clone, Debug)]
54struct RequestId(String);
55
56impl RequestId {
57 fn as_str(&self) -> &str {
58 &self.0
59 }
60}
61
62#[derive(Serialize)]
63#[serde(rename_all = "camelCase")]
64struct HealthResponse {
65 status: &'static str,
66 components: HealthComponents,
67}
68
69#[derive(Serialize)]
70#[serde(rename_all = "camelCase")]
71struct HealthComponents {
72 session: ComponentHealth,
73 memory: ComponentHealth,
74 artifact: ComponentHealth,
75}
76
77#[derive(Serialize)]
78#[serde(rename_all = "camelCase")]
79struct ComponentHealth {
80 status: &'static str,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 error: Option<String>,
83}
84
85impl ComponentHealth {
86 fn healthy() -> Self {
87 Self { status: "healthy", error: None }
88 }
89
90 fn unhealthy(error: impl Into<String>) -> Self {
91 Self { status: "unhealthy", error: Some(error.into()) }
92 }
93
94 fn not_configured() -> Self {
95 Self { status: "not_configured", error: None }
96 }
97}
98
99fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
101 let cors = CorsLayer::new()
102 .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
103 .allow_headers([
104 header::CONTENT_TYPE,
105 header::AUTHORIZATION,
106 HeaderName::from_static(REQUEST_ID_HEADER),
107 HeaderName::from_static("x-adk-ui-protocol"),
108 HeaderName::from_static("x-adk-ui-transport"),
109 ]);
110
111 if config.security.allowed_origins.is_empty() {
112 cors.allow_origin(AllowOrigin::any())
113 } else {
114 let origins: Vec<HeaderValue> = config
115 .security
116 .allowed_origins
117 .iter()
118 .filter_map(|origin| origin.parse().ok())
119 .collect();
120 cors.allow_origin(origins)
121 }
122}
123
124fn validate_request_id(headers: &HeaderMap) -> Option<String> {
125 let value = headers.get(REQUEST_ID_HEADER)?;
126 let raw = value.to_str().ok()?;
127 if raw.len() > 128 {
128 return None;
129 }
130 uuid::Uuid::parse_str(raw).ok()?;
131 Some(raw.to_string())
132}
133
134async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
135 let request_id =
136 validate_request_id(request.headers()).unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
137
138 request.extensions_mut().insert(RequestId(request_id.clone()));
139
140 let mut response = next.run(request).await;
141 if let Ok(value) = HeaderValue::from_str(&request_id) {
142 response.headers_mut().insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
143 }
144 response
145}
146
147async fn auth_middleware(
148 request: Request<Body>,
149 next: Next,
150 extractor: Option<Arc<dyn RequestContextExtractor>>,
151) -> Response {
152 let (mut parts, body) = request.into_parts();
153
154 let request_context = match extractor {
155 Some(extractor) => match extractor.extract(&parts).await {
156 Ok(context) => Some(context),
157 Err(RequestContextError::MissingAuth) => {
158 return (
159 StatusCode::UNAUTHORIZED,
160 Json(serde_json::json!({ "error": "missing authorization" })),
161 )
162 .into_response();
163 }
164 Err(RequestContextError::InvalidToken(message)) => {
165 return (
166 StatusCode::UNAUTHORIZED,
167 Json(serde_json::json!({ "error": format!("invalid token: {message}") })),
168 )
169 .into_response();
170 }
171 Err(RequestContextError::ExtractionFailed(message)) => {
172 return (
173 StatusCode::INTERNAL_SERVER_ERROR,
174 Json(serde_json::json!({
175 "error": format!("auth extraction failed: {message}")
176 })),
177 )
178 .into_response();
179 }
180 },
181 None => None,
182 };
183
184 parts.extensions.insert::<Option<RequestContext>>(request_context);
185 next.run(Request::from_parts(parts, body)).await
186}
187
188async fn health_check(State(controller): State<HealthController>) -> impl IntoResponse {
189 let session = match controller.session_service.health_check().await {
190 Ok(()) => ComponentHealth::healthy(),
191 Err(error) => ComponentHealth::unhealthy(error.to_string()),
192 };
193
194 let memory = match controller.memory_service.as_ref() {
195 Some(service) => match service.health_check().await {
196 Ok(()) => ComponentHealth::healthy(),
197 Err(error) => ComponentHealth::unhealthy(error.to_string()),
198 },
199 None => ComponentHealth::not_configured(),
200 };
201
202 let artifact = match controller.artifact_service.as_ref() {
203 Some(service) => match service.health_check().await {
204 Ok(()) => ComponentHealth::healthy(),
205 Err(error) => ComponentHealth::unhealthy(error.to_string()),
206 },
207 None => ComponentHealth::not_configured(),
208 };
209
210 let healthy = session.status == "healthy"
211 && memory.status != "unhealthy"
212 && artifact.status != "unhealthy";
213
214 (
215 if healthy { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE },
216 Json(HealthResponse {
217 status: if healthy { "healthy" } else { "unhealthy" },
218 components: HealthComponents { session, memory, artifact },
219 }),
220 )
221}
222
223pub fn create_app(config: ServerConfig) -> Router {
225 create_app_with_a2a(config, None)
226}
227
228#[cfg(feature = "yaml-agent")]
238async fn start_yaml_agent_watchers(
239 dirs: &[std::path::PathBuf],
240) -> Vec<Arc<crate::yaml_agent::HotReloadWatcher>> {
241 use crate::yaml_agent::{AgentConfigLoader, HotReloadWatcher};
242
243 let mut watchers = Vec::new();
244
245 for dir in dirs {
246 let registry: Arc<dyn adk_core::ToolRegistry> = Arc::new(EmptyToolRegistry);
251 let factory: Arc<dyn crate::yaml_agent::ModelFactory> = Arc::new(NoOpModelFactory);
252 let loader = Arc::new(AgentConfigLoader::new(registry, factory));
253 let watcher = Arc::new(HotReloadWatcher::new(loader));
254
255 match watcher.watch(dir).await {
256 Ok(handle) => {
257 tracing::info!("started YAML agent hot reload watcher for {}", dir.display());
258 drop(handle);
260 watchers.push(watcher);
261 }
262 Err(e) => {
263 tracing::warn!("failed to start YAML agent watcher for {}: {e}", dir.display());
264 }
265 }
266 }
267
268 watchers
269}
270
271#[cfg(feature = "yaml-agent")]
273struct EmptyToolRegistry;
274
275#[cfg(feature = "yaml-agent")]
276impl adk_core::ToolRegistry for EmptyToolRegistry {
277 fn resolve(&self, _tool_name: &str) -> Option<Arc<dyn adk_core::Tool>> {
278 None
279 }
280
281 fn available_tools(&self) -> Vec<String> {
282 vec![]
283 }
284}
285
286#[cfg(feature = "yaml-agent")]
292struct NoOpModelFactory;
293
294#[cfg(feature = "yaml-agent")]
295#[async_trait::async_trait]
296impl crate::yaml_agent::ModelFactory for NoOpModelFactory {
297 async fn create_model(
298 &self,
299 provider: &str,
300 model_id: &str,
301 ) -> adk_core::Result<Arc<dyn adk_core::Llm>> {
302 Err(adk_core::AdkError::config(format!(
303 "no model factory configured for YAML agent loading \
304 (requested provider='{provider}', model_id='{model_id}'). \
305 Configure a ModelFactory on the server to enable YAML agent model creation."
306 )))
307 }
308}
309
310pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
312 let session_controller = SessionController::new(config.session_service.clone());
313 let runtime_controller = RuntimeController::new(config.clone());
314 let apps_controller = AppsController::new(config.clone());
315 let artifacts_controller = ArtifactsController::new(config.clone());
316 let debug_controller = DebugController::new(config.clone());
317 let health_controller = HealthController::new(&config);
318
319 #[cfg(feature = "yaml-agent")]
321 {
322 let dirs = config.yaml_agent_dirs.clone();
323 if !dirs.is_empty() {
324 tokio::spawn(async move {
325 let _watchers = start_yaml_agent_watchers(&dirs).await;
326 std::future::pending::<()>().await;
330 });
331 }
332 }
333
334 let auth_layer = middleware::from_fn({
335 let extractor = config.request_context_extractor.clone();
336 move |request: Request<Body>, next: Next| {
337 let extractor = extractor.clone();
338 async move { auth_middleware(request, next, extractor).await }
339 }
340 });
341
342 let health_router =
343 Router::new().route("/health", get(health_check)).with_state(health_controller);
344
345 let ui_api_router = Router::new()
346 .route("/apps", get(controllers::apps::list_apps))
347 .route("/list-apps", get(controllers::apps::list_apps_compat))
348 .with_state(apps_controller)
349 .route("/ui/capabilities", get(controllers::ui::ui_capabilities))
350 .route("/ui/initialize", post(controllers::ui::ui_initialize))
351 .route("/ui/message", post(controllers::ui::ui_message))
352 .route("/ui/update-model-context", post(controllers::ui::ui_update_model_context))
353 .route("/ui/notifications/poll", post(controllers::ui::ui_poll_notifications))
354 .route(
355 "/ui/notifications/resources-list-changed",
356 post(controllers::ui::ui_notify_resources_list_changed),
357 )
358 .route(
359 "/ui/notifications/tools-list-changed",
360 post(controllers::ui::ui_notify_tools_list_changed),
361 )
362 .route("/ui/resources", get(controllers::ui::list_ui_resources))
363 .route("/ui/resources/read", get(controllers::ui::read_ui_resource))
364 .route("/ui/resources/register", post(controllers::ui::register_ui_resource));
365
366 let session_router = Router::new()
367 .route("/sessions", post(controllers::session::create_session))
368 .route(
369 "/sessions/{app_name}/{user_id}/{session_id}",
370 get(controllers::session::get_session).delete(controllers::session::delete_session),
371 )
372 .route(
373 "/apps/{app_name}/users/{user_id}/sessions",
374 get(controllers::session::list_sessions)
375 .post(controllers::session::create_session_from_path),
376 )
377 .route(
378 "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
379 get(controllers::session::get_session_from_path)
380 .post(controllers::session::create_session_from_path)
381 .delete(controllers::session::delete_session_from_path),
382 )
383 .with_state(session_controller)
384 .layer(auth_layer.clone());
385
386 let runtime_router = Router::new()
387 .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
388 .route("/run_sse", post(controllers::runtime::run_sse_compat))
389 .with_state(runtime_controller);
390
391 let artifacts_router = Router::new()
392 .route(
393 "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
394 get(controllers::artifacts::list_artifacts),
395 )
396 .route(
397 "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
398 get(controllers::artifacts::get_artifact),
399 )
400 .with_state(artifacts_controller)
401 .layer(auth_layer.clone());
402
403 let mut debug_router = Router::new()
404 .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
405 .route(
406 "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
407 get(controllers::debug::get_graph),
408 )
409 .route(
410 "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
411 get(controllers::debug::get_graph),
412 )
413 .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
414 .route(
415 "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
416 get(controllers::debug::get_event),
417 );
418
419 if config.request_context_extractor.is_none() || config.security.expose_admin_debug {
420 debug_router = debug_router
421 .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id));
422 }
423
424 let debug_router = debug_router.with_state(debug_controller.clone()).layer(auth_layer.clone());
425
426 let api_router = Router::new()
427 .merge(health_router)
428 .merge(ui_api_router)
429 .merge(session_router)
430 .merge(runtime_router)
431 .merge(artifacts_router)
432 .merge(debug_router);
433
434 let ui_router = Router::new()
435 .route("/", get(web_ui::root_redirect))
436 .route("/ui/", get(web_ui::serve_ui_index))
437 .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
438 .with_state(config.clone())
439 .route("/ui/{*path}", get(web_ui::serve_ui_assets));
440
441 let mut app = Router::new().nest("/api", api_router).merge(ui_router);
442
443 if let Some(base_url) = a2a_base_url {
444 let a2a_controller = A2aController::new(config.clone(), base_url);
445 let a2a_router = Router::new()
446 .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
447 .route("/a2a", post(controllers::a2a::handle_jsonrpc))
448 .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
449 .with_state(a2a_controller);
450 app = app.merge(a2a_router);
451 }
452
453 let cors_layer = build_cors_layer(&config);
454 let trace_layer = TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
455 let request_id =
456 request.extensions().get::<RequestId>().map(RequestId::as_str).unwrap_or("");
457 tracing::info_span!(
458 "http.request",
459 request.id = %request_id,
460 http.method = %request.method(),
461 http.path = %request.uri().path()
462 )
463 });
464
465 app.layer(
466 ServiceBuilder::new()
467 .layer(middleware::from_fn(request_id_middleware))
468 .layer(trace_layer)
469 .layer(TimeoutLayer::with_status_code(
470 StatusCode::REQUEST_TIMEOUT,
471 config.security.request_timeout,
472 ))
473 .layer(DefaultBodyLimit::max(config.security.max_body_size))
474 .layer(cors_layer)
475 .layer(SetResponseHeaderLayer::if_not_present(
476 header::X_CONTENT_TYPE_OPTIONS,
477 HeaderValue::from_static("nosniff"),
478 ))
479 .layer(SetResponseHeaderLayer::if_not_present(
480 header::X_FRAME_OPTIONS,
481 HeaderValue::from_static("DENY"),
482 ))
483 .layer(SetResponseHeaderLayer::if_not_present(
484 header::X_XSS_PROTECTION,
485 HeaderValue::from_static("1; mode=block"),
486 )),
487 )
488}
489
490pub struct ServerBuilder {
526 config: ServerConfig,
527 a2a_base_url: Option<String>,
528 api_routes: Vec<Router>,
529 root_routes: Vec<Router>,
530 shutdown_endpoint: bool,
531}
532
533impl ServerBuilder {
534 pub fn new(config: ServerConfig) -> Self {
536 Self {
537 config,
538 a2a_base_url: None,
539 api_routes: Vec::new(),
540 root_routes: Vec::new(),
541 shutdown_endpoint: false,
542 }
543 }
544
545 pub fn add_api_routes(mut self, routes: Router) -> Self {
560 self.api_routes.push(routes);
561 self
562 }
563
564 pub fn add_root_routes(mut self, routes: Router) -> Self {
572 self.root_routes.push(routes);
573 self
574 }
575
576 pub fn with_a2a(mut self, base_url: impl Into<String>) -> Self {
580 self.a2a_base_url = Some(base_url.into());
581 self
582 }
583
584 pub fn enable_shutdown_endpoint(mut self) -> Self {
594 self.shutdown_endpoint = true;
595 self
596 }
597
598 pub fn build(self) -> Router {
600 self.build_inner().0
601 }
602
603 pub fn build_with_shutdown(self) -> (Router, ShutdownHandle) {
621 let (router, handle) = self.build_inner();
622 (router, handle.expect("build_with_shutdown requires enable_shutdown_endpoint()"))
623 }
624
625 fn build_inner(self) -> (Router, Option<ShutdownHandle>) {
626 let config = &self.config;
627 let session_controller = SessionController::new(config.session_service.clone());
628 let runtime_controller = RuntimeController::new(config.clone());
629 let apps_controller = AppsController::new(config.clone());
630 let artifacts_controller = ArtifactsController::new(config.clone());
631 let debug_controller = DebugController::new(config.clone());
632 let health_controller = HealthController::new(config);
633
634 #[cfg(feature = "yaml-agent")]
636 {
637 let dirs = config.yaml_agent_dirs.clone();
638 if !dirs.is_empty() {
639 tokio::spawn(async move {
640 let _watchers = start_yaml_agent_watchers(&dirs).await;
641 std::future::pending::<()>().await;
642 });
643 }
644 }
645
646 let auth_layer = middleware::from_fn({
647 let extractor = config.request_context_extractor.clone();
648 move |request: Request<Body>, next: Next| {
649 let extractor = extractor.clone();
650 async move { auth_middleware(request, next, extractor).await }
651 }
652 });
653
654 let health_router =
655 Router::new().route("/health", get(health_check)).with_state(health_controller);
656
657 let ui_api_router = Router::new()
658 .route("/apps", get(controllers::apps::list_apps))
659 .route("/list-apps", get(controllers::apps::list_apps_compat))
660 .with_state(apps_controller)
661 .route("/ui/capabilities", get(controllers::ui::ui_capabilities))
662 .route("/ui/initialize", post(controllers::ui::ui_initialize))
663 .route("/ui/message", post(controllers::ui::ui_message))
664 .route("/ui/update-model-context", post(controllers::ui::ui_update_model_context))
665 .route("/ui/notifications/poll", post(controllers::ui::ui_poll_notifications))
666 .route(
667 "/ui/notifications/resources-list-changed",
668 post(controllers::ui::ui_notify_resources_list_changed),
669 )
670 .route(
671 "/ui/notifications/tools-list-changed",
672 post(controllers::ui::ui_notify_tools_list_changed),
673 )
674 .route("/ui/resources", get(controllers::ui::list_ui_resources))
675 .route("/ui/resources/read", get(controllers::ui::read_ui_resource))
676 .route("/ui/resources/register", post(controllers::ui::register_ui_resource));
677
678 let session_router = Router::new()
679 .route("/sessions", post(controllers::session::create_session))
680 .route(
681 "/sessions/{app_name}/{user_id}/{session_id}",
682 get(controllers::session::get_session).delete(controllers::session::delete_session),
683 )
684 .route(
685 "/apps/{app_name}/users/{user_id}/sessions",
686 get(controllers::session::list_sessions)
687 .post(controllers::session::create_session_from_path),
688 )
689 .route(
690 "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
691 get(controllers::session::get_session_from_path)
692 .post(controllers::session::create_session_from_path)
693 .delete(controllers::session::delete_session_from_path),
694 )
695 .with_state(session_controller)
696 .layer(auth_layer.clone());
697
698 let runtime_router = Router::new()
699 .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
700 .route("/run_sse", post(controllers::runtime::run_sse_compat))
701 .with_state(runtime_controller);
702
703 let artifacts_router = Router::new()
704 .route(
705 "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
706 get(controllers::artifacts::list_artifacts),
707 )
708 .route(
709 "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
710 get(controllers::artifacts::get_artifact),
711 )
712 .with_state(artifacts_controller)
713 .layer(auth_layer.clone());
714
715 let mut debug_router = Router::new()
716 .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
717 .route(
718 "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
719 get(controllers::debug::get_graph),
720 )
721 .route(
722 "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
723 get(controllers::debug::get_graph),
724 )
725 .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
726 .route(
727 "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
728 get(controllers::debug::get_event),
729 );
730
731 if config.request_context_extractor.is_none() || config.security.expose_admin_debug {
732 debug_router = debug_router
733 .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id));
734 }
735
736 let debug_router =
737 debug_router.with_state(debug_controller.clone()).layer(auth_layer.clone());
738
739 let mut api_router = Router::new()
741 .merge(health_router)
742 .merge(ui_api_router)
743 .merge(session_router)
744 .merge(runtime_router)
745 .merge(artifacts_router)
746 .merge(debug_router);
747
748 for custom_routes in self.api_routes {
750 api_router = api_router.merge(custom_routes.layer(auth_layer.clone()));
751 }
752
753 let shutdown_handle = if self.shutdown_endpoint {
755 let handle = ShutdownHandle::new();
756 let shutdown_router = Router::new()
757 .route("/shutdown", post(handle_shutdown))
758 .with_state(handle.token.clone())
759 .layer(auth_layer);
760 api_router = api_router.merge(shutdown_router);
761 Some(handle)
762 } else {
763 None
764 };
765
766 let ui_router = Router::new()
767 .route("/", get(web_ui::root_redirect))
768 .route("/ui/", get(web_ui::serve_ui_index))
769 .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
770 .with_state(config.clone())
771 .route("/ui/{*path}", get(web_ui::serve_ui_assets));
772
773 let mut app = Router::new().nest("/api", api_router).merge(ui_router);
774
775 for custom_routes in self.root_routes {
777 app = app.merge(custom_routes);
778 }
779
780 if let Some(base_url) = &self.a2a_base_url {
781 let a2a_controller = A2aController::new(config.clone(), base_url);
782 let a2a_router = Router::new()
783 .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
784 .route("/a2a", post(controllers::a2a::handle_jsonrpc))
785 .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
786 .with_state(a2a_controller);
787 app = app.merge(a2a_router);
788 }
789
790 let cors_layer = build_cors_layer(config);
791 let trace_layer = TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
792 let request_id =
793 request.extensions().get::<RequestId>().map(RequestId::as_str).unwrap_or("");
794 tracing::info_span!(
795 "http.request",
796 request.id = %request_id,
797 http.method = %request.method(),
798 http.path = %request.uri().path()
799 )
800 });
801
802 (
803 app.layer(
804 ServiceBuilder::new()
805 .layer(middleware::from_fn(request_id_middleware))
806 .layer(trace_layer)
807 .layer(TimeoutLayer::with_status_code(
808 StatusCode::REQUEST_TIMEOUT,
809 config.security.request_timeout,
810 ))
811 .layer(DefaultBodyLimit::max(config.security.max_body_size))
812 .layer(cors_layer)
813 .layer(SetResponseHeaderLayer::if_not_present(
814 header::X_CONTENT_TYPE_OPTIONS,
815 HeaderValue::from_static("nosniff"),
816 ))
817 .layer(SetResponseHeaderLayer::if_not_present(
818 header::X_FRAME_OPTIONS,
819 HeaderValue::from_static("DENY"),
820 ))
821 .layer(SetResponseHeaderLayer::if_not_present(
822 header::X_XSS_PROTECTION,
823 HeaderValue::from_static("1; mode=block"),
824 )),
825 ),
826 shutdown_handle,
827 )
828 }
829}
830
831pub async fn shutdown_signal() {
833 let ctrl_c = async {
834 let _ = tokio::signal::ctrl_c().await;
835 };
836
837 #[cfg(unix)]
838 let terminate = async {
839 if let Ok(mut signal) =
840 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
841 {
842 let _ = signal.recv().await;
843 }
844 };
845
846 #[cfg(not(unix))]
847 let terminate = std::future::pending::<()>();
848
849 tokio::select! {
850 _ = ctrl_c => {}
851 _ = terminate => {}
852 }
853}
854
855#[derive(Clone)]
880pub struct ShutdownHandle {
881 token: CancellationToken,
882}
883
884impl ShutdownHandle {
885 fn new() -> Self {
887 Self { token: CancellationToken::new() }
888 }
889
890 pub fn shutdown(&self) {
895 tracing::info!("graceful shutdown triggered programmatically");
896 self.token.cancel();
897 }
898
899 pub async fn signal(self) {
904 let token = self.token.clone();
905
906 let ctrl_c = async {
907 let _ = tokio::signal::ctrl_c().await;
908 };
909
910 #[cfg(unix)]
911 let terminate = async {
912 if let Ok(mut signal) =
913 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
914 {
915 let _ = signal.recv().await;
916 }
917 };
918
919 #[cfg(not(unix))]
920 let terminate = std::future::pending::<()>();
921
922 tokio::select! {
923 _ = ctrl_c => {
924 tracing::info!("received Ctrl+C, initiating graceful shutdown");
925 }
926 _ = terminate => {
927 tracing::info!("received SIGTERM, initiating graceful shutdown");
928 }
929 _ = token.cancelled() => {
930 }
932 }
933 }
934
935 pub fn is_shutdown(&self) -> bool {
937 self.token.is_cancelled()
938 }
939}
940
941async fn handle_shutdown(State(token): State<CancellationToken>) -> impl IntoResponse {
946 tracing::info!("POST /api/shutdown received, initiating graceful shutdown");
947 token.cancel();
948 (StatusCode::OK, Json(serde_json::json!({ "status": "shutting_down" })))
949}