Skip to main content

adk_server/rest/
mod.rs

1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5    A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6    SessionController,
7};
8
9use crate::{
10    ServerConfig,
11    auth_bridge::{RequestContext, RequestContextError, RequestContextExtractor},
12    web_ui,
13};
14use axum::{
15    Json, Router,
16    body::Body,
17    extract::{DefaultBodyLimit, State},
18    http::{HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode, header},
19    middleware::{self, Next},
20    response::{IntoResponse, Response},
21    routing::{get, post},
22};
23use serde::Serialize;
24use std::sync::Arc;
25use tokio_util::sync::CancellationToken;
26use tower::ServiceBuilder;
27use tower_http::{
28    cors::{AllowOrigin, CorsLayer},
29    set_header::SetResponseHeaderLayer,
30    timeout::TimeoutLayer,
31    trace::TraceLayer,
32};
33
34const REQUEST_ID_HEADER: &str = "x-request-id";
35
36#[derive(Clone)]
37struct HealthController {
38    session_service: Arc<dyn adk_session::SessionService>,
39    artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
40    memory_service: Option<Arc<dyn adk_core::Memory>>,
41}
42
43impl HealthController {
44    fn new(config: &ServerConfig) -> Self {
45        Self {
46            session_service: config.session_service.clone(),
47            artifact_service: config.artifact_service.clone(),
48            memory_service: config.memory_service.clone(),
49        }
50    }
51}
52
53#[derive(Clone, Debug)]
54struct RequestId(String);
55
56impl RequestId {
57    fn as_str(&self) -> &str {
58        &self.0
59    }
60}
61
62#[derive(Serialize)]
63#[serde(rename_all = "camelCase")]
64struct HealthResponse {
65    status: &'static str,
66    components: HealthComponents,
67}
68
69#[derive(Serialize)]
70#[serde(rename_all = "camelCase")]
71struct HealthComponents {
72    session: ComponentHealth,
73    memory: ComponentHealth,
74    artifact: ComponentHealth,
75}
76
77#[derive(Serialize)]
78#[serde(rename_all = "camelCase")]
79struct ComponentHealth {
80    status: &'static str,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    error: Option<String>,
83}
84
85impl ComponentHealth {
86    fn healthy() -> Self {
87        Self { status: "healthy", error: None }
88    }
89
90    fn unhealthy(error: impl Into<String>) -> Self {
91        Self { status: "unhealthy", error: Some(error.into()) }
92    }
93
94    fn not_configured() -> Self {
95        Self { status: "not_configured", error: None }
96    }
97}
98
99/// Build CORS layer based on security configuration
100fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
101    let cors = CorsLayer::new()
102        .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
103        .allow_headers([
104            header::CONTENT_TYPE,
105            header::AUTHORIZATION,
106            HeaderName::from_static(REQUEST_ID_HEADER),
107            HeaderName::from_static("x-adk-ui-protocol"),
108            HeaderName::from_static("x-adk-ui-transport"),
109        ]);
110
111    if config.security.allowed_origins.is_empty() {
112        cors.allow_origin(AllowOrigin::any())
113    } else {
114        let origins: Vec<HeaderValue> = config
115            .security
116            .allowed_origins
117            .iter()
118            .filter_map(|origin| origin.parse().ok())
119            .collect();
120        cors.allow_origin(origins)
121    }
122}
123
124fn validate_request_id(headers: &HeaderMap) -> Option<String> {
125    let value = headers.get(REQUEST_ID_HEADER)?;
126    let raw = value.to_str().ok()?;
127    if raw.len() > 128 {
128        return None;
129    }
130    uuid::Uuid::parse_str(raw).ok()?;
131    Some(raw.to_string())
132}
133
134async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
135    let request_id =
136        validate_request_id(request.headers()).unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
137
138    request.extensions_mut().insert(RequestId(request_id.clone()));
139
140    let mut response = next.run(request).await;
141    if let Ok(value) = HeaderValue::from_str(&request_id) {
142        response.headers_mut().insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
143    }
144    response
145}
146
147async fn auth_middleware(
148    request: Request<Body>,
149    next: Next,
150    extractor: Option<Arc<dyn RequestContextExtractor>>,
151) -> Response {
152    let (mut parts, body) = request.into_parts();
153
154    let request_context = match extractor {
155        Some(extractor) => match extractor.extract(&parts).await {
156            Ok(context) => Some(context),
157            Err(RequestContextError::MissingAuth) => {
158                return (
159                    StatusCode::UNAUTHORIZED,
160                    Json(serde_json::json!({ "error": "missing authorization" })),
161                )
162                    .into_response();
163            }
164            Err(RequestContextError::InvalidToken(message)) => {
165                return (
166                    StatusCode::UNAUTHORIZED,
167                    Json(serde_json::json!({ "error": format!("invalid token: {message}") })),
168                )
169                    .into_response();
170            }
171            Err(RequestContextError::ExtractionFailed(message)) => {
172                return (
173                    StatusCode::INTERNAL_SERVER_ERROR,
174                    Json(serde_json::json!({
175                        "error": format!("auth extraction failed: {message}")
176                    })),
177                )
178                    .into_response();
179            }
180        },
181        None => None,
182    };
183
184    parts.extensions.insert::<Option<RequestContext>>(request_context);
185    next.run(Request::from_parts(parts, body)).await
186}
187
188async fn health_check(State(controller): State<HealthController>) -> impl IntoResponse {
189    let session = match controller.session_service.health_check().await {
190        Ok(()) => ComponentHealth::healthy(),
191        Err(error) => ComponentHealth::unhealthy(error.to_string()),
192    };
193
194    let memory = match controller.memory_service.as_ref() {
195        Some(service) => match service.health_check().await {
196            Ok(()) => ComponentHealth::healthy(),
197            Err(error) => ComponentHealth::unhealthy(error.to_string()),
198        },
199        None => ComponentHealth::not_configured(),
200    };
201
202    let artifact = match controller.artifact_service.as_ref() {
203        Some(service) => match service.health_check().await {
204            Ok(()) => ComponentHealth::healthy(),
205            Err(error) => ComponentHealth::unhealthy(error.to_string()),
206        },
207        None => ComponentHealth::not_configured(),
208    };
209
210    let healthy = session.status == "healthy"
211        && memory.status != "unhealthy"
212        && artifact.status != "unhealthy";
213
214    (
215        if healthy { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE },
216        Json(HealthResponse {
217            status: if healthy { "healthy" } else { "unhealthy" },
218            components: HealthComponents { session, memory, artifact },
219        }),
220    )
221}
222
223/// Create the server application with optional A2A support
224pub fn create_app(config: ServerConfig) -> Router {
225    create_app_with_a2a(config, None)
226}
227
228/// Start hot reload watchers for configured YAML agent directories.
229///
230/// For each directory in `config.yaml_agent_dirs`, creates an
231/// [`AgentConfigLoader`](crate::yaml_agent::AgentConfigLoader) and
232/// [`HotReloadWatcher`](crate::yaml_agent::HotReloadWatcher), performs
233/// the initial load, and spawns a background watcher task.
234///
235/// Returns the list of active watchers so route handlers can look up
236/// YAML-defined agents.
237#[cfg(feature = "yaml-agent")]
238async fn start_yaml_agent_watchers(
239    dirs: &[std::path::PathBuf],
240) -> Vec<Arc<crate::yaml_agent::HotReloadWatcher>> {
241    use crate::yaml_agent::{AgentConfigLoader, HotReloadWatcher};
242
243    let mut watchers = Vec::new();
244
245    for dir in dirs {
246        // Create a minimal tool registry (no pre-registered tools) and a
247        // placeholder model factory. Real deployments should configure these
248        // via ServerConfig extensions; for now we use empty defaults so the
249        // watcher can start and load YAML definitions.
250        let registry: Arc<dyn adk_core::ToolRegistry> = Arc::new(EmptyToolRegistry);
251        let factory: Arc<dyn crate::yaml_agent::ModelFactory> = Arc::new(NoOpModelFactory);
252        let loader = Arc::new(AgentConfigLoader::new(registry, factory));
253        let watcher = Arc::new(HotReloadWatcher::new(loader));
254
255        match watcher.watch(dir).await {
256            Ok(handle) => {
257                tracing::info!("started YAML agent hot reload watcher for {}", dir.display());
258                // Detach the watcher task — it runs until the server shuts down.
259                drop(handle);
260                watchers.push(watcher);
261            }
262            Err(e) => {
263                tracing::warn!("failed to start YAML agent watcher for {}: {e}", dir.display());
264            }
265        }
266    }
267
268    watchers
269}
270
271/// Empty tool registry used as default when no tools are pre-registered.
272#[cfg(feature = "yaml-agent")]
273struct EmptyToolRegistry;
274
275#[cfg(feature = "yaml-agent")]
276impl adk_core::ToolRegistry for EmptyToolRegistry {
277    fn resolve(&self, _tool_name: &str) -> Option<Arc<dyn adk_core::Tool>> {
278        None
279    }
280
281    fn available_tools(&self) -> Vec<String> {
282        vec![]
283    }
284}
285
286/// Placeholder model factory that returns an error for any provider.
287///
288/// Real deployments should provide a proper `ModelFactory` via
289/// `ServerConfig` extensions. This exists so the watcher can start
290/// even when no model factory is explicitly configured.
291#[cfg(feature = "yaml-agent")]
292struct NoOpModelFactory;
293
294#[cfg(feature = "yaml-agent")]
295#[async_trait::async_trait]
296impl crate::yaml_agent::ModelFactory for NoOpModelFactory {
297    async fn create_model(
298        &self,
299        provider: &str,
300        model_id: &str,
301    ) -> adk_core::Result<Arc<dyn adk_core::Llm>> {
302        Err(adk_core::AdkError::config(format!(
303            "no model factory configured for YAML agent loading \
304             (requested provider='{provider}', model_id='{model_id}'). \
305             Configure a ModelFactory on the server to enable YAML agent model creation."
306        )))
307    }
308}
309
310/// Create the server application with A2A support at the specified base URL
311pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
312    let session_controller = SessionController::new(config.session_service.clone());
313    let runtime_controller = RuntimeController::new(config.clone());
314    let apps_controller = AppsController::new(config.clone());
315    let artifacts_controller = ArtifactsController::new(config.clone());
316    let debug_controller = DebugController::new(config.clone());
317    let health_controller = HealthController::new(&config);
318
319    // Start YAML agent hot reload watchers if configured.
320    #[cfg(feature = "yaml-agent")]
321    {
322        let dirs = config.yaml_agent_dirs.clone();
323        if !dirs.is_empty() {
324            tokio::spawn(async move {
325                let _watchers = start_yaml_agent_watchers(&dirs).await;
326                // Watchers are kept alive for the lifetime of this task.
327                // They run background filesystem watch loops internally.
328                // We hold them here so they aren't dropped.
329                std::future::pending::<()>().await;
330            });
331        }
332    }
333
334    let auth_layer = middleware::from_fn({
335        let extractor = config.request_context_extractor.clone();
336        move |request: Request<Body>, next: Next| {
337            let extractor = extractor.clone();
338            async move { auth_middleware(request, next, extractor).await }
339        }
340    });
341
342    let health_router =
343        Router::new().route("/health", get(health_check)).with_state(health_controller);
344
345    let ui_api_router = Router::new()
346        .route("/apps", get(controllers::apps::list_apps))
347        .route("/list-apps", get(controllers::apps::list_apps_compat))
348        .with_state(apps_controller)
349        .route("/ui/capabilities", get(controllers::ui::ui_capabilities))
350        .route("/ui/initialize", post(controllers::ui::ui_initialize))
351        .route("/ui/message", post(controllers::ui::ui_message))
352        .route("/ui/update-model-context", post(controllers::ui::ui_update_model_context))
353        .route("/ui/notifications/poll", post(controllers::ui::ui_poll_notifications))
354        .route(
355            "/ui/notifications/resources-list-changed",
356            post(controllers::ui::ui_notify_resources_list_changed),
357        )
358        .route(
359            "/ui/notifications/tools-list-changed",
360            post(controllers::ui::ui_notify_tools_list_changed),
361        )
362        .route("/ui/resources", get(controllers::ui::list_ui_resources))
363        .route("/ui/resources/read", get(controllers::ui::read_ui_resource))
364        .route("/ui/resources/register", post(controllers::ui::register_ui_resource));
365
366    let session_router = Router::new()
367        .route("/sessions", post(controllers::session::create_session))
368        .route(
369            "/sessions/{app_name}/{user_id}/{session_id}",
370            get(controllers::session::get_session).delete(controllers::session::delete_session),
371        )
372        .route(
373            "/apps/{app_name}/users/{user_id}/sessions",
374            get(controllers::session::list_sessions)
375                .post(controllers::session::create_session_from_path),
376        )
377        .route(
378            "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
379            get(controllers::session::get_session_from_path)
380                .post(controllers::session::create_session_from_path)
381                .delete(controllers::session::delete_session_from_path),
382        )
383        .with_state(session_controller)
384        .layer(auth_layer.clone());
385
386    let runtime_router = Router::new()
387        .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
388        .route("/run_sse", post(controllers::runtime::run_sse_compat))
389        .with_state(runtime_controller);
390
391    let artifacts_router = Router::new()
392        .route(
393            "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
394            get(controllers::artifacts::list_artifacts),
395        )
396        .route(
397            "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
398            get(controllers::artifacts::get_artifact),
399        )
400        .with_state(artifacts_controller)
401        .layer(auth_layer.clone());
402
403    let mut debug_router = Router::new()
404        .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
405        .route(
406            "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
407            get(controllers::debug::get_graph),
408        )
409        .route(
410            "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
411            get(controllers::debug::get_graph),
412        )
413        .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
414        .route(
415            "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
416            get(controllers::debug::get_event),
417        );
418
419    if config.request_context_extractor.is_none() || config.security.expose_admin_debug {
420        debug_router = debug_router
421            .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id));
422    }
423
424    let debug_router = debug_router.with_state(debug_controller.clone()).layer(auth_layer.clone());
425
426    let api_router = Router::new()
427        .merge(health_router)
428        .merge(ui_api_router)
429        .merge(session_router)
430        .merge(runtime_router)
431        .merge(artifacts_router)
432        .merge(debug_router);
433
434    let ui_router = Router::new()
435        .route("/", get(web_ui::root_redirect))
436        .route("/ui/", get(web_ui::serve_ui_index))
437        .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
438        .with_state(config.clone())
439        .route("/ui/{*path}", get(web_ui::serve_ui_assets));
440
441    let mut app = Router::new().nest("/api", api_router).merge(ui_router);
442
443    if let Some(base_url) = a2a_base_url {
444        let a2a_controller = A2aController::new(config.clone(), base_url);
445        let a2a_router = Router::new()
446            .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
447            .route("/a2a", post(controllers::a2a::handle_jsonrpc))
448            .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
449            .with_state(a2a_controller);
450        app = app.merge(a2a_router);
451    }
452
453    let cors_layer = build_cors_layer(&config);
454    let trace_layer = TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
455        let request_id =
456            request.extensions().get::<RequestId>().map(RequestId::as_str).unwrap_or("");
457        tracing::info_span!(
458            "http.request",
459            request.id = %request_id,
460            http.method = %request.method(),
461            http.path = %request.uri().path()
462        )
463    });
464
465    app.layer(
466        ServiceBuilder::new()
467            .layer(middleware::from_fn(request_id_middleware))
468            .layer(trace_layer)
469            .layer(TimeoutLayer::with_status_code(
470                StatusCode::REQUEST_TIMEOUT,
471                config.security.request_timeout,
472            ))
473            .layer(DefaultBodyLimit::max(config.security.max_body_size))
474            .layer(cors_layer)
475            .layer(SetResponseHeaderLayer::if_not_present(
476                header::X_CONTENT_TYPE_OPTIONS,
477                HeaderValue::from_static("nosniff"),
478            ))
479            .layer(SetResponseHeaderLayer::if_not_present(
480                header::X_FRAME_OPTIONS,
481                HeaderValue::from_static("DENY"),
482            ))
483            .layer(SetResponseHeaderLayer::if_not_present(
484                header::X_XSS_PROTECTION,
485                HeaderValue::from_static("1; mode=block"),
486            )),
487    )
488}
489
490// ---------------------------------------------------------------------------
491// ServerBuilder — extensible server construction with custom routes
492// ---------------------------------------------------------------------------
493
494/// Builder for constructing an ADK server with custom routes.
495///
496/// `ServerBuilder` allows registering additional Axum routers alongside the
497/// built-in REST, A2A, and UI routes. Custom routes benefit from the same
498/// middleware stack (auth, CORS, tracing, timeout, security headers) as the
499/// built-in routes.
500///
501/// # Example
502///
503/// ```rust,ignore
504/// use adk_server::{ServerBuilder, ServerConfig};
505/// use axum::{Router, routing::get};
506///
507/// let config = ServerConfig::new(agent, session_service);
508///
509/// let app = ServerBuilder::new(config)
510///     .add_api_routes(
511///         Router::new()
512///             .route("/projects", get(list_projects))
513///             .route("/projects/{id}", get(get_project))
514///     )
515///     .add_api_routes(
516///         Router::new()
517///             .route("/automations", get(list_automations))
518///     )
519///     .with_a2a("http://localhost:8080")
520///     .build();
521///
522/// let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?;
523/// axum::serve(listener, app).await?;
524/// ```
525pub struct ServerBuilder {
526    config: ServerConfig,
527    a2a_base_url: Option<String>,
528    api_routes: Vec<Router>,
529    root_routes: Vec<Router>,
530    shutdown_endpoint: bool,
531}
532
533impl ServerBuilder {
534    /// Create a new server builder with the given configuration.
535    pub fn new(config: ServerConfig) -> Self {
536        Self {
537            config,
538            a2a_base_url: None,
539            api_routes: Vec::new(),
540            root_routes: Vec::new(),
541            shutdown_endpoint: false,
542        }
543    }
544
545    /// Add custom routes nested under `/api`.
546    ///
547    /// These routes are merged into the API router and benefit from the auth
548    /// middleware layer. Multiple calls accumulate routes.
549    ///
550    /// # Example
551    ///
552    /// ```rust,ignore
553    /// builder.add_api_routes(
554    ///     Router::new()
555    ///         .route("/projects", get(list_projects))
556    ///         .route("/projects/{id}", get(get_project))
557    /// )
558    /// ```
559    pub fn add_api_routes(mut self, routes: Router) -> Self {
560        self.api_routes.push(routes);
561        self
562    }
563
564    /// Add custom routes at the root level (not nested under `/api`).
565    ///
566    /// These routes are merged at the top level of the application, alongside
567    /// the UI and A2A routes. They receive the full middleware stack (CORS,
568    /// tracing, timeout, security headers) but NOT the auth middleware.
569    ///
570    /// Use this for routes that need their own auth handling or public endpoints.
571    pub fn add_root_routes(mut self, routes: Router) -> Self {
572        self.root_routes.push(routes);
573        self
574    }
575
576    /// Enable A2A protocol support at the specified base URL.
577    ///
578    /// The base URL is used to construct the agent card's endpoint URL.
579    pub fn with_a2a(mut self, base_url: impl Into<String>) -> Self {
580        self.a2a_base_url = Some(base_url.into());
581        self
582    }
583
584    /// Enable the `POST /api/shutdown` endpoint for graceful shutdown.
585    ///
586    /// When enabled, the server exposes a shutdown endpoint that triggers
587    /// graceful shutdown: stops accepting new connections, completes in-flight
588    /// requests, and then exits. Use [`build_with_shutdown`](Self::build_with_shutdown)
589    /// to get the [`ShutdownHandle`] for wiring into `axum::serve().with_graceful_shutdown()`.
590    ///
591    /// The endpoint is protected by the auth middleware when a
592    /// `RequestContextExtractor` is configured.
593    pub fn enable_shutdown_endpoint(mut self) -> Self {
594        self.shutdown_endpoint = true;
595        self
596    }
597
598    /// Build the final Axum router with all routes and middleware applied.
599    pub fn build(self) -> Router {
600        self.build_inner().0
601    }
602
603    /// Build the final Axum router and return a [`ShutdownHandle`].
604    ///
605    /// Use this when [`enable_shutdown_endpoint()`](Self::enable_shutdown_endpoint) is set.
606    /// Pass the handle's signal to `axum::serve().with_graceful_shutdown()`.
607    ///
608    /// # Example
609    ///
610    /// ```rust,ignore
611    /// let (app, shutdown_handle) = ServerBuilder::new(config)
612    ///     .enable_shutdown_endpoint()
613    ///     .build_with_shutdown();
614    ///
615    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?;
616    /// axum::serve(listener, app)
617    ///     .with_graceful_shutdown(shutdown_handle.signal())
618    ///     .await?;
619    /// ```
620    pub fn build_with_shutdown(self) -> (Router, ShutdownHandle) {
621        let (router, handle) = self.build_inner();
622        (router, handle.expect("build_with_shutdown requires enable_shutdown_endpoint()"))
623    }
624
625    fn build_inner(self) -> (Router, Option<ShutdownHandle>) {
626        let config = &self.config;
627        let session_controller = SessionController::new(config.session_service.clone());
628        let runtime_controller = RuntimeController::new(config.clone());
629        let apps_controller = AppsController::new(config.clone());
630        let artifacts_controller = ArtifactsController::new(config.clone());
631        let debug_controller = DebugController::new(config.clone());
632        let health_controller = HealthController::new(config);
633
634        // Start YAML agent hot reload watchers if configured.
635        #[cfg(feature = "yaml-agent")]
636        {
637            let dirs = config.yaml_agent_dirs.clone();
638            if !dirs.is_empty() {
639                tokio::spawn(async move {
640                    let _watchers = start_yaml_agent_watchers(&dirs).await;
641                    std::future::pending::<()>().await;
642                });
643            }
644        }
645
646        let auth_layer = middleware::from_fn({
647            let extractor = config.request_context_extractor.clone();
648            move |request: Request<Body>, next: Next| {
649                let extractor = extractor.clone();
650                async move { auth_middleware(request, next, extractor).await }
651            }
652        });
653
654        let health_router =
655            Router::new().route("/health", get(health_check)).with_state(health_controller);
656
657        let ui_api_router = Router::new()
658            .route("/apps", get(controllers::apps::list_apps))
659            .route("/list-apps", get(controllers::apps::list_apps_compat))
660            .with_state(apps_controller)
661            .route("/ui/capabilities", get(controllers::ui::ui_capabilities))
662            .route("/ui/initialize", post(controllers::ui::ui_initialize))
663            .route("/ui/message", post(controllers::ui::ui_message))
664            .route("/ui/update-model-context", post(controllers::ui::ui_update_model_context))
665            .route("/ui/notifications/poll", post(controllers::ui::ui_poll_notifications))
666            .route(
667                "/ui/notifications/resources-list-changed",
668                post(controllers::ui::ui_notify_resources_list_changed),
669            )
670            .route(
671                "/ui/notifications/tools-list-changed",
672                post(controllers::ui::ui_notify_tools_list_changed),
673            )
674            .route("/ui/resources", get(controllers::ui::list_ui_resources))
675            .route("/ui/resources/read", get(controllers::ui::read_ui_resource))
676            .route("/ui/resources/register", post(controllers::ui::register_ui_resource));
677
678        let session_router = Router::new()
679            .route("/sessions", post(controllers::session::create_session))
680            .route(
681                "/sessions/{app_name}/{user_id}/{session_id}",
682                get(controllers::session::get_session).delete(controllers::session::delete_session),
683            )
684            .route(
685                "/apps/{app_name}/users/{user_id}/sessions",
686                get(controllers::session::list_sessions)
687                    .post(controllers::session::create_session_from_path),
688            )
689            .route(
690                "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
691                get(controllers::session::get_session_from_path)
692                    .post(controllers::session::create_session_from_path)
693                    .delete(controllers::session::delete_session_from_path),
694            )
695            .with_state(session_controller)
696            .layer(auth_layer.clone());
697
698        let runtime_router = Router::new()
699            .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
700            .route("/run_sse", post(controllers::runtime::run_sse_compat))
701            .with_state(runtime_controller);
702
703        let artifacts_router = Router::new()
704            .route(
705                "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
706                get(controllers::artifacts::list_artifacts),
707            )
708            .route(
709                "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
710                get(controllers::artifacts::get_artifact),
711            )
712            .with_state(artifacts_controller)
713            .layer(auth_layer.clone());
714
715        let mut debug_router = Router::new()
716            .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
717            .route(
718                "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
719                get(controllers::debug::get_graph),
720            )
721            .route(
722                "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
723                get(controllers::debug::get_graph),
724            )
725            .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
726            .route(
727                "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
728                get(controllers::debug::get_event),
729            );
730
731        if config.request_context_extractor.is_none() || config.security.expose_admin_debug {
732            debug_router = debug_router
733                .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id));
734        }
735
736        let debug_router =
737            debug_router.with_state(debug_controller.clone()).layer(auth_layer.clone());
738
739        // Assemble the API router with built-in + custom routes
740        let mut api_router = Router::new()
741            .merge(health_router)
742            .merge(ui_api_router)
743            .merge(session_router)
744            .merge(runtime_router)
745            .merge(artifacts_router)
746            .merge(debug_router);
747
748        // Merge custom API routes — these get the same /api prefix and auth middleware
749        for custom_routes in self.api_routes {
750            api_router = api_router.merge(custom_routes.layer(auth_layer.clone()));
751        }
752
753        // Add shutdown endpoint if enabled
754        let shutdown_handle = if self.shutdown_endpoint {
755            let handle = ShutdownHandle::new();
756            let shutdown_router = Router::new()
757                .route("/shutdown", post(handle_shutdown))
758                .with_state(handle.token.clone())
759                .layer(auth_layer);
760            api_router = api_router.merge(shutdown_router);
761            Some(handle)
762        } else {
763            None
764        };
765
766        let ui_router = Router::new()
767            .route("/", get(web_ui::root_redirect))
768            .route("/ui/", get(web_ui::serve_ui_index))
769            .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
770            .with_state(config.clone())
771            .route("/ui/{*path}", get(web_ui::serve_ui_assets));
772
773        let mut app = Router::new().nest("/api", api_router).merge(ui_router);
774
775        // Merge custom root routes
776        for custom_routes in self.root_routes {
777            app = app.merge(custom_routes);
778        }
779
780        if let Some(base_url) = &self.a2a_base_url {
781            let a2a_controller = A2aController::new(config.clone(), base_url);
782            let a2a_router = Router::new()
783                .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
784                .route("/a2a", post(controllers::a2a::handle_jsonrpc))
785                .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
786                .with_state(a2a_controller);
787            app = app.merge(a2a_router);
788        }
789
790        let cors_layer = build_cors_layer(config);
791        let trace_layer = TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
792            let request_id =
793                request.extensions().get::<RequestId>().map(RequestId::as_str).unwrap_or("");
794            tracing::info_span!(
795                "http.request",
796                request.id = %request_id,
797                http.method = %request.method(),
798                http.path = %request.uri().path()
799            )
800        });
801
802        (
803            app.layer(
804                ServiceBuilder::new()
805                    .layer(middleware::from_fn(request_id_middleware))
806                    .layer(trace_layer)
807                    .layer(TimeoutLayer::with_status_code(
808                        StatusCode::REQUEST_TIMEOUT,
809                        config.security.request_timeout,
810                    ))
811                    .layer(DefaultBodyLimit::max(config.security.max_body_size))
812                    .layer(cors_layer)
813                    .layer(SetResponseHeaderLayer::if_not_present(
814                        header::X_CONTENT_TYPE_OPTIONS,
815                        HeaderValue::from_static("nosniff"),
816                    ))
817                    .layer(SetResponseHeaderLayer::if_not_present(
818                        header::X_FRAME_OPTIONS,
819                        HeaderValue::from_static("DENY"),
820                    ))
821                    .layer(SetResponseHeaderLayer::if_not_present(
822                        header::X_XSS_PROTECTION,
823                        HeaderValue::from_static("1; mode=block"),
824                    )),
825            ),
826            shutdown_handle,
827        )
828    }
829}
830
831/// Wait for a process shutdown signal.
832pub async fn shutdown_signal() {
833    let ctrl_c = async {
834        let _ = tokio::signal::ctrl_c().await;
835    };
836
837    #[cfg(unix)]
838    let terminate = async {
839        if let Ok(mut signal) =
840            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
841        {
842            let _ = signal.recv().await;
843        }
844    };
845
846    #[cfg(not(unix))]
847    let terminate = std::future::pending::<()>();
848
849    tokio::select! {
850        _ = ctrl_c => {}
851        _ = terminate => {}
852    }
853}
854
855// ---------------------------------------------------------------------------
856// ShutdownHandle — programmatic graceful shutdown trigger
857// ---------------------------------------------------------------------------
858
859/// Handle for triggering graceful server shutdown.
860///
861/// Returned by [`ServerBuilder::build_with_shutdown`]. Pass the future from
862/// [`ShutdownHandle::signal()`] to `axum::serve(...).with_graceful_shutdown()`
863/// to enable both OS signal-based and HTTP endpoint-based shutdown.
864///
865/// # Example
866///
867/// ```rust,ignore
868/// use adk_server::{ServerBuilder, ServerConfig};
869///
870/// let (app, shutdown_handle) = ServerBuilder::new(config)
871///     .enable_shutdown_endpoint()
872///     .build_with_shutdown();
873///
874/// let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?;
875/// axum::serve(listener, app)
876///     .with_graceful_shutdown(shutdown_handle.signal())
877///     .await?;
878/// ```
879#[derive(Clone)]
880pub struct ShutdownHandle {
881    token: CancellationToken,
882}
883
884impl ShutdownHandle {
885    /// Create a new shutdown handle.
886    fn new() -> Self {
887        Self { token: CancellationToken::new() }
888    }
889
890    /// Trigger graceful shutdown programmatically.
891    ///
892    /// This has the same effect as calling `POST /api/shutdown` — the server
893    /// stops accepting new connections and completes in-flight requests.
894    pub fn shutdown(&self) {
895        tracing::info!("graceful shutdown triggered programmatically");
896        self.token.cancel();
897    }
898
899    /// Returns a future that resolves when shutdown is triggered.
900    ///
901    /// Combines OS signals (Ctrl+C, SIGTERM) with the programmatic/HTTP trigger.
902    /// Pass this to `axum::serve(...).with_graceful_shutdown()`.
903    pub async fn signal(self) {
904        let token = self.token.clone();
905
906        let ctrl_c = async {
907            let _ = tokio::signal::ctrl_c().await;
908        };
909
910        #[cfg(unix)]
911        let terminate = async {
912            if let Ok(mut signal) =
913                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
914            {
915                let _ = signal.recv().await;
916            }
917        };
918
919        #[cfg(not(unix))]
920        let terminate = std::future::pending::<()>();
921
922        tokio::select! {
923            _ = ctrl_c => {
924                tracing::info!("received Ctrl+C, initiating graceful shutdown");
925            }
926            _ = terminate => {
927                tracing::info!("received SIGTERM, initiating graceful shutdown");
928            }
929            _ = token.cancelled() => {
930                // Shutdown triggered via POST /api/shutdown or programmatic call
931            }
932        }
933    }
934
935    /// Returns whether shutdown has been triggered.
936    pub fn is_shutdown(&self) -> bool {
937        self.token.is_cancelled()
938    }
939}
940
941/// Handler for `POST /api/shutdown`.
942///
943/// Triggers graceful shutdown: the server stops accepting new connections,
944/// completes in-flight requests, and then exits.
945async fn handle_shutdown(State(token): State<CancellationToken>) -> impl IntoResponse {
946    tracing::info!("POST /api/shutdown received, initiating graceful shutdown");
947    token.cancel();
948    (StatusCode::OK, Json(serde_json::json!({ "status": "shutting_down" })))
949}