Skip to main content

adk_server/rest/
mod.rs

1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5    A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6    SessionController,
7};
8
9use crate::{
10    ServerConfig,
11    auth_bridge::{RequestContext, RequestContextError, RequestContextExtractor},
12    web_ui,
13};
14use axum::{
15    Json, Router,
16    body::Body,
17    extract::{DefaultBodyLimit, State},
18    http::{HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode, header},
19    middleware::{self, Next},
20    response::{IntoResponse, Response},
21    routing::{get, post},
22};
23use serde::Serialize;
24use std::sync::Arc;
25use tower::ServiceBuilder;
26use tower_http::{
27    cors::{AllowOrigin, CorsLayer},
28    set_header::SetResponseHeaderLayer,
29    timeout::TimeoutLayer,
30    trace::TraceLayer,
31};
32
33const REQUEST_ID_HEADER: &str = "x-request-id";
34
35#[derive(Clone)]
36struct HealthController {
37    session_service: Arc<dyn adk_session::SessionService>,
38    artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
39    memory_service: Option<Arc<dyn adk_core::Memory>>,
40}
41
42impl HealthController {
43    fn new(config: &ServerConfig) -> Self {
44        Self {
45            session_service: config.session_service.clone(),
46            artifact_service: config.artifact_service.clone(),
47            memory_service: config.memory_service.clone(),
48        }
49    }
50}
51
52#[derive(Clone, Debug)]
53struct RequestId(String);
54
55impl RequestId {
56    fn as_str(&self) -> &str {
57        &self.0
58    }
59}
60
61#[derive(Serialize)]
62#[serde(rename_all = "camelCase")]
63struct HealthResponse {
64    status: &'static str,
65    components: HealthComponents,
66}
67
68#[derive(Serialize)]
69#[serde(rename_all = "camelCase")]
70struct HealthComponents {
71    session: ComponentHealth,
72    memory: ComponentHealth,
73    artifact: ComponentHealth,
74}
75
76#[derive(Serialize)]
77#[serde(rename_all = "camelCase")]
78struct ComponentHealth {
79    status: &'static str,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    error: Option<String>,
82}
83
84impl ComponentHealth {
85    fn healthy() -> Self {
86        Self { status: "healthy", error: None }
87    }
88
89    fn unhealthy(error: impl Into<String>) -> Self {
90        Self { status: "unhealthy", error: Some(error.into()) }
91    }
92
93    fn not_configured() -> Self {
94        Self { status: "not_configured", error: None }
95    }
96}
97
98/// Build CORS layer based on security configuration
99fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
100    let cors = CorsLayer::new()
101        .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
102        .allow_headers([
103            header::CONTENT_TYPE,
104            header::AUTHORIZATION,
105            HeaderName::from_static(REQUEST_ID_HEADER),
106            HeaderName::from_static("x-adk-ui-protocol"),
107            HeaderName::from_static("x-adk-ui-transport"),
108        ]);
109
110    if config.security.allowed_origins.is_empty() {
111        cors.allow_origin(AllowOrigin::any())
112    } else {
113        let origins: Vec<HeaderValue> = config
114            .security
115            .allowed_origins
116            .iter()
117            .filter_map(|origin| origin.parse().ok())
118            .collect();
119        cors.allow_origin(origins)
120    }
121}
122
123fn validate_request_id(headers: &HeaderMap) -> Option<String> {
124    let value = headers.get(REQUEST_ID_HEADER)?;
125    let raw = value.to_str().ok()?;
126    if raw.len() > 128 {
127        return None;
128    }
129    uuid::Uuid::parse_str(raw).ok()?;
130    Some(raw.to_string())
131}
132
133async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
134    let request_id =
135        validate_request_id(request.headers()).unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
136
137    request.extensions_mut().insert(RequestId(request_id.clone()));
138
139    let mut response = next.run(request).await;
140    if let Ok(value) = HeaderValue::from_str(&request_id) {
141        response.headers_mut().insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
142    }
143    response
144}
145
146async fn auth_middleware(
147    request: Request<Body>,
148    next: Next,
149    extractor: Option<Arc<dyn RequestContextExtractor>>,
150) -> Response {
151    let (mut parts, body) = request.into_parts();
152
153    let request_context = match extractor {
154        Some(extractor) => match extractor.extract(&parts).await {
155            Ok(context) => Some(context),
156            Err(RequestContextError::MissingAuth) => {
157                return (
158                    StatusCode::UNAUTHORIZED,
159                    Json(serde_json::json!({ "error": "missing authorization" })),
160                )
161                    .into_response();
162            }
163            Err(RequestContextError::InvalidToken(message)) => {
164                return (
165                    StatusCode::UNAUTHORIZED,
166                    Json(serde_json::json!({ "error": format!("invalid token: {message}") })),
167                )
168                    .into_response();
169            }
170            Err(RequestContextError::ExtractionFailed(message)) => {
171                return (
172                    StatusCode::INTERNAL_SERVER_ERROR,
173                    Json(serde_json::json!({
174                        "error": format!("auth extraction failed: {message}")
175                    })),
176                )
177                    .into_response();
178            }
179        },
180        None => None,
181    };
182
183    parts.extensions.insert::<Option<RequestContext>>(request_context);
184    next.run(Request::from_parts(parts, body)).await
185}
186
187async fn health_check(State(controller): State<HealthController>) -> impl IntoResponse {
188    let session = match controller.session_service.health_check().await {
189        Ok(()) => ComponentHealth::healthy(),
190        Err(error) => ComponentHealth::unhealthy(error.to_string()),
191    };
192
193    let memory = match controller.memory_service.as_ref() {
194        Some(service) => match service.health_check().await {
195            Ok(()) => ComponentHealth::healthy(),
196            Err(error) => ComponentHealth::unhealthy(error.to_string()),
197        },
198        None => ComponentHealth::not_configured(),
199    };
200
201    let artifact = match controller.artifact_service.as_ref() {
202        Some(service) => match service.health_check().await {
203            Ok(()) => ComponentHealth::healthy(),
204            Err(error) => ComponentHealth::unhealthy(error.to_string()),
205        },
206        None => ComponentHealth::not_configured(),
207    };
208
209    let healthy = session.status == "healthy"
210        && memory.status != "unhealthy"
211        && artifact.status != "unhealthy";
212
213    (
214        if healthy { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE },
215        Json(HealthResponse {
216            status: if healthy { "healthy" } else { "unhealthy" },
217            components: HealthComponents { session, memory, artifact },
218        }),
219    )
220}
221
222/// Create the server application with optional A2A support
223pub fn create_app(config: ServerConfig) -> Router {
224    create_app_with_a2a(config, None)
225}
226
227/// Create the server application with A2A support at the specified base URL
228pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
229    let session_controller = SessionController::new(config.session_service.clone());
230    let runtime_controller = RuntimeController::new(config.clone());
231    let apps_controller = AppsController::new(config.clone());
232    let artifacts_controller = ArtifactsController::new(config.clone());
233    let debug_controller = DebugController::new(config.clone());
234    let health_controller = HealthController::new(&config);
235
236    let auth_layer = middleware::from_fn({
237        let extractor = config.request_context_extractor.clone();
238        move |request: Request<Body>, next: Next| {
239            let extractor = extractor.clone();
240            async move { auth_middleware(request, next, extractor).await }
241        }
242    });
243
244    let health_router =
245        Router::new().route("/health", get(health_check)).with_state(health_controller);
246
247    let ui_api_router = Router::new()
248        .route("/apps", get(controllers::apps::list_apps))
249        .route("/list-apps", get(controllers::apps::list_apps_compat))
250        .with_state(apps_controller)
251        .route("/ui/capabilities", get(controllers::ui::ui_capabilities))
252        .route("/ui/initialize", post(controllers::ui::ui_initialize))
253        .route("/ui/message", post(controllers::ui::ui_message))
254        .route("/ui/update-model-context", post(controllers::ui::ui_update_model_context))
255        .route("/ui/notifications/poll", post(controllers::ui::ui_poll_notifications))
256        .route(
257            "/ui/notifications/resources-list-changed",
258            post(controllers::ui::ui_notify_resources_list_changed),
259        )
260        .route(
261            "/ui/notifications/tools-list-changed",
262            post(controllers::ui::ui_notify_tools_list_changed),
263        )
264        .route("/ui/resources", get(controllers::ui::list_ui_resources))
265        .route("/ui/resources/read", get(controllers::ui::read_ui_resource))
266        .route("/ui/resources/register", post(controllers::ui::register_ui_resource));
267
268    let session_router = Router::new()
269        .route("/sessions", post(controllers::session::create_session))
270        .route(
271            "/sessions/{app_name}/{user_id}/{session_id}",
272            get(controllers::session::get_session).delete(controllers::session::delete_session),
273        )
274        .route(
275            "/apps/{app_name}/users/{user_id}/sessions",
276            get(controllers::session::list_sessions)
277                .post(controllers::session::create_session_from_path),
278        )
279        .route(
280            "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
281            get(controllers::session::get_session_from_path)
282                .post(controllers::session::create_session_from_path)
283                .delete(controllers::session::delete_session_from_path),
284        )
285        .with_state(session_controller)
286        .layer(auth_layer.clone());
287
288    let runtime_router = Router::new()
289        .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
290        .route("/run_sse", post(controllers::runtime::run_sse_compat))
291        .with_state(runtime_controller);
292
293    let artifacts_router = Router::new()
294        .route(
295            "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
296            get(controllers::artifacts::list_artifacts),
297        )
298        .route(
299            "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
300            get(controllers::artifacts::get_artifact),
301        )
302        .with_state(artifacts_controller)
303        .layer(auth_layer.clone());
304
305    let mut debug_router = Router::new()
306        .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
307        .route(
308            "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
309            get(controllers::debug::get_graph),
310        )
311        .route(
312            "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
313            get(controllers::debug::get_graph),
314        )
315        .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
316        .route(
317            "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
318            get(controllers::debug::get_event),
319        );
320
321    if config.request_context_extractor.is_none() || config.security.expose_admin_debug {
322        debug_router = debug_router
323            .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id));
324    }
325
326    let debug_router = debug_router.with_state(debug_controller.clone()).layer(auth_layer.clone());
327
328    let api_router = Router::new()
329        .merge(health_router)
330        .merge(ui_api_router)
331        .merge(session_router)
332        .merge(runtime_router)
333        .merge(artifacts_router)
334        .merge(debug_router);
335
336    let ui_router = Router::new()
337        .route("/", get(web_ui::root_redirect))
338        .route("/ui/", get(web_ui::serve_ui_index))
339        .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
340        .with_state(config.clone())
341        .route("/ui/{*path}", get(web_ui::serve_ui_assets));
342
343    let mut app = Router::new().nest("/api", api_router).merge(ui_router);
344
345    if let Some(base_url) = a2a_base_url {
346        let a2a_controller = A2aController::new(config.clone(), base_url);
347        let a2a_router = Router::new()
348            .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
349            .route("/a2a", post(controllers::a2a::handle_jsonrpc))
350            .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
351            .with_state(a2a_controller);
352        app = app.merge(a2a_router);
353    }
354
355    let cors_layer = build_cors_layer(&config);
356    let trace_layer = TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
357        let request_id =
358            request.extensions().get::<RequestId>().map(RequestId::as_str).unwrap_or("");
359        tracing::info_span!(
360            "http.request",
361            request.id = %request_id,
362            http.method = %request.method(),
363            http.path = %request.uri().path()
364        )
365    });
366
367    app.layer(
368        ServiceBuilder::new()
369            .layer(middleware::from_fn(request_id_middleware))
370            .layer(trace_layer)
371            .layer(TimeoutLayer::with_status_code(
372                StatusCode::REQUEST_TIMEOUT,
373                config.security.request_timeout,
374            ))
375            .layer(DefaultBodyLimit::max(config.security.max_body_size))
376            .layer(cors_layer)
377            .layer(SetResponseHeaderLayer::if_not_present(
378                header::X_CONTENT_TYPE_OPTIONS,
379                HeaderValue::from_static("nosniff"),
380            ))
381            .layer(SetResponseHeaderLayer::if_not_present(
382                header::X_FRAME_OPTIONS,
383                HeaderValue::from_static("DENY"),
384            ))
385            .layer(SetResponseHeaderLayer::if_not_present(
386                header::X_XSS_PROTECTION,
387                HeaderValue::from_static("1; mode=block"),
388            )),
389    )
390}
391
392/// Wait for a process shutdown signal.
393pub async fn shutdown_signal() {
394    let ctrl_c = async {
395        let _ = tokio::signal::ctrl_c().await;
396    };
397
398    #[cfg(unix)]
399    let terminate = async {
400        if let Ok(mut signal) =
401            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
402        {
403            let _ = signal.recv().await;
404        }
405    };
406
407    #[cfg(not(unix))]
408    let terminate = std::future::pending::<()>();
409
410    tokio::select! {
411        _ = ctrl_c => {}
412        _ = terminate => {}
413    }
414}