1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5 A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6 SessionController,
7};
8
9use crate::{
10 ServerConfig,
11 auth_bridge::{RequestContext, RequestContextError, RequestContextExtractor},
12 web_ui,
13};
14use axum::{
15 Json, Router,
16 body::Body,
17 extract::{DefaultBodyLimit, State},
18 http::{HeaderMap, HeaderName, HeaderValue, Method, Request, StatusCode, header},
19 middleware::{self, Next},
20 response::{IntoResponse, Response},
21 routing::{get, post},
22};
23use serde::Serialize;
24use std::sync::Arc;
25use tower::ServiceBuilder;
26use tower_http::{
27 cors::{AllowOrigin, CorsLayer},
28 set_header::SetResponseHeaderLayer,
29 timeout::TimeoutLayer,
30 trace::TraceLayer,
31};
32
33const REQUEST_ID_HEADER: &str = "x-request-id";
34
35#[derive(Clone)]
36struct HealthController {
37 session_service: Arc<dyn adk_session::SessionService>,
38 artifact_service: Option<Arc<dyn adk_artifact::ArtifactService>>,
39 memory_service: Option<Arc<dyn adk_core::Memory>>,
40}
41
42impl HealthController {
43 fn new(config: &ServerConfig) -> Self {
44 Self {
45 session_service: config.session_service.clone(),
46 artifact_service: config.artifact_service.clone(),
47 memory_service: config.memory_service.clone(),
48 }
49 }
50}
51
52#[derive(Clone, Debug)]
53struct RequestId(String);
54
55impl RequestId {
56 fn as_str(&self) -> &str {
57 &self.0
58 }
59}
60
61#[derive(Serialize)]
62#[serde(rename_all = "camelCase")]
63struct HealthResponse {
64 status: &'static str,
65 components: HealthComponents,
66}
67
68#[derive(Serialize)]
69#[serde(rename_all = "camelCase")]
70struct HealthComponents {
71 session: ComponentHealth,
72 memory: ComponentHealth,
73 artifact: ComponentHealth,
74}
75
76#[derive(Serialize)]
77#[serde(rename_all = "camelCase")]
78struct ComponentHealth {
79 status: &'static str,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 error: Option<String>,
82}
83
84impl ComponentHealth {
85 fn healthy() -> Self {
86 Self { status: "healthy", error: None }
87 }
88
89 fn unhealthy(error: impl Into<String>) -> Self {
90 Self { status: "unhealthy", error: Some(error.into()) }
91 }
92
93 fn not_configured() -> Self {
94 Self { status: "not_configured", error: None }
95 }
96}
97
98fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
100 let cors = CorsLayer::new()
101 .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
102 .allow_headers([
103 header::CONTENT_TYPE,
104 header::AUTHORIZATION,
105 HeaderName::from_static(REQUEST_ID_HEADER),
106 HeaderName::from_static("x-adk-ui-protocol"),
107 HeaderName::from_static("x-adk-ui-transport"),
108 ]);
109
110 if config.security.allowed_origins.is_empty() {
111 cors.allow_origin(AllowOrigin::any())
112 } else {
113 let origins: Vec<HeaderValue> = config
114 .security
115 .allowed_origins
116 .iter()
117 .filter_map(|origin| origin.parse().ok())
118 .collect();
119 cors.allow_origin(origins)
120 }
121}
122
123fn validate_request_id(headers: &HeaderMap) -> Option<String> {
124 let value = headers.get(REQUEST_ID_HEADER)?;
125 let raw = value.to_str().ok()?;
126 if raw.len() > 128 {
127 return None;
128 }
129 uuid::Uuid::parse_str(raw).ok()?;
130 Some(raw.to_string())
131}
132
133async fn request_id_middleware(mut request: Request<Body>, next: Next) -> Response {
134 let request_id =
135 validate_request_id(request.headers()).unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
136
137 request.extensions_mut().insert(RequestId(request_id.clone()));
138
139 let mut response = next.run(request).await;
140 if let Ok(value) = HeaderValue::from_str(&request_id) {
141 response.headers_mut().insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
142 }
143 response
144}
145
146async fn auth_middleware(
147 request: Request<Body>,
148 next: Next,
149 extractor: Option<Arc<dyn RequestContextExtractor>>,
150) -> Response {
151 let (mut parts, body) = request.into_parts();
152
153 let request_context = match extractor {
154 Some(extractor) => match extractor.extract(&parts).await {
155 Ok(context) => Some(context),
156 Err(RequestContextError::MissingAuth) => {
157 return (
158 StatusCode::UNAUTHORIZED,
159 Json(serde_json::json!({ "error": "missing authorization" })),
160 )
161 .into_response();
162 }
163 Err(RequestContextError::InvalidToken(message)) => {
164 return (
165 StatusCode::UNAUTHORIZED,
166 Json(serde_json::json!({ "error": format!("invalid token: {message}") })),
167 )
168 .into_response();
169 }
170 Err(RequestContextError::ExtractionFailed(message)) => {
171 return (
172 StatusCode::INTERNAL_SERVER_ERROR,
173 Json(serde_json::json!({
174 "error": format!("auth extraction failed: {message}")
175 })),
176 )
177 .into_response();
178 }
179 },
180 None => None,
181 };
182
183 parts.extensions.insert::<Option<RequestContext>>(request_context);
184 next.run(Request::from_parts(parts, body)).await
185}
186
187async fn health_check(State(controller): State<HealthController>) -> impl IntoResponse {
188 let session = match controller.session_service.health_check().await {
189 Ok(()) => ComponentHealth::healthy(),
190 Err(error) => ComponentHealth::unhealthy(error.to_string()),
191 };
192
193 let memory = match controller.memory_service.as_ref() {
194 Some(service) => match service.health_check().await {
195 Ok(()) => ComponentHealth::healthy(),
196 Err(error) => ComponentHealth::unhealthy(error.to_string()),
197 },
198 None => ComponentHealth::not_configured(),
199 };
200
201 let artifact = match controller.artifact_service.as_ref() {
202 Some(service) => match service.health_check().await {
203 Ok(()) => ComponentHealth::healthy(),
204 Err(error) => ComponentHealth::unhealthy(error.to_string()),
205 },
206 None => ComponentHealth::not_configured(),
207 };
208
209 let healthy = session.status == "healthy"
210 && memory.status != "unhealthy"
211 && artifact.status != "unhealthy";
212
213 (
214 if healthy { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE },
215 Json(HealthResponse {
216 status: if healthy { "healthy" } else { "unhealthy" },
217 components: HealthComponents { session, memory, artifact },
218 }),
219 )
220}
221
222pub fn create_app(config: ServerConfig) -> Router {
224 create_app_with_a2a(config, None)
225}
226
227pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
229 let session_controller = SessionController::new(config.session_service.clone());
230 let runtime_controller = RuntimeController::new(config.clone());
231 let apps_controller = AppsController::new(config.clone());
232 let artifacts_controller = ArtifactsController::new(config.clone());
233 let debug_controller = DebugController::new(config.clone());
234 let health_controller = HealthController::new(&config);
235
236 let auth_layer = middleware::from_fn({
237 let extractor = config.request_context_extractor.clone();
238 move |request: Request<Body>, next: Next| {
239 let extractor = extractor.clone();
240 async move { auth_middleware(request, next, extractor).await }
241 }
242 });
243
244 let health_router =
245 Router::new().route("/health", get(health_check)).with_state(health_controller);
246
247 let ui_api_router = Router::new()
248 .route("/apps", get(controllers::apps::list_apps))
249 .route("/list-apps", get(controllers::apps::list_apps_compat))
250 .with_state(apps_controller)
251 .route("/ui/capabilities", get(controllers::ui::ui_capabilities))
252 .route("/ui/initialize", post(controllers::ui::ui_initialize))
253 .route("/ui/message", post(controllers::ui::ui_message))
254 .route("/ui/update-model-context", post(controllers::ui::ui_update_model_context))
255 .route("/ui/notifications/poll", post(controllers::ui::ui_poll_notifications))
256 .route(
257 "/ui/notifications/resources-list-changed",
258 post(controllers::ui::ui_notify_resources_list_changed),
259 )
260 .route(
261 "/ui/notifications/tools-list-changed",
262 post(controllers::ui::ui_notify_tools_list_changed),
263 )
264 .route("/ui/resources", get(controllers::ui::list_ui_resources))
265 .route("/ui/resources/read", get(controllers::ui::read_ui_resource))
266 .route("/ui/resources/register", post(controllers::ui::register_ui_resource));
267
268 let session_router = Router::new()
269 .route("/sessions", post(controllers::session::create_session))
270 .route(
271 "/sessions/{app_name}/{user_id}/{session_id}",
272 get(controllers::session::get_session).delete(controllers::session::delete_session),
273 )
274 .route(
275 "/apps/{app_name}/users/{user_id}/sessions",
276 get(controllers::session::list_sessions)
277 .post(controllers::session::create_session_from_path),
278 )
279 .route(
280 "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
281 get(controllers::session::get_session_from_path)
282 .post(controllers::session::create_session_from_path)
283 .delete(controllers::session::delete_session_from_path),
284 )
285 .with_state(session_controller)
286 .layer(auth_layer.clone());
287
288 let runtime_router = Router::new()
289 .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
290 .route("/run_sse", post(controllers::runtime::run_sse_compat))
291 .with_state(runtime_controller);
292
293 let artifacts_router = Router::new()
294 .route(
295 "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
296 get(controllers::artifacts::list_artifacts),
297 )
298 .route(
299 "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
300 get(controllers::artifacts::get_artifact),
301 )
302 .with_state(artifacts_controller)
303 .layer(auth_layer.clone());
304
305 let mut debug_router = Router::new()
306 .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
307 .route(
308 "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
309 get(controllers::debug::get_graph),
310 )
311 .route(
312 "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
313 get(controllers::debug::get_graph),
314 )
315 .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
316 .route(
317 "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
318 get(controllers::debug::get_event),
319 );
320
321 if config.request_context_extractor.is_none() || config.security.expose_admin_debug {
322 debug_router = debug_router
323 .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id));
324 }
325
326 let debug_router = debug_router.with_state(debug_controller.clone()).layer(auth_layer.clone());
327
328 let api_router = Router::new()
329 .merge(health_router)
330 .merge(ui_api_router)
331 .merge(session_router)
332 .merge(runtime_router)
333 .merge(artifacts_router)
334 .merge(debug_router);
335
336 let ui_router = Router::new()
337 .route("/", get(web_ui::root_redirect))
338 .route("/ui/", get(web_ui::serve_ui_index))
339 .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
340 .with_state(config.clone())
341 .route("/ui/{*path}", get(web_ui::serve_ui_assets));
342
343 let mut app = Router::new().nest("/api", api_router).merge(ui_router);
344
345 if let Some(base_url) = a2a_base_url {
346 let a2a_controller = A2aController::new(config.clone(), base_url);
347 let a2a_router = Router::new()
348 .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
349 .route("/a2a", post(controllers::a2a::handle_jsonrpc))
350 .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
351 .with_state(a2a_controller);
352 app = app.merge(a2a_router);
353 }
354
355 let cors_layer = build_cors_layer(&config);
356 let trace_layer = TraceLayer::new_for_http().make_span_with(|request: &Request<Body>| {
357 let request_id =
358 request.extensions().get::<RequestId>().map(RequestId::as_str).unwrap_or("");
359 tracing::info_span!(
360 "http.request",
361 request.id = %request_id,
362 http.method = %request.method(),
363 http.path = %request.uri().path()
364 )
365 });
366
367 app.layer(
368 ServiceBuilder::new()
369 .layer(middleware::from_fn(request_id_middleware))
370 .layer(trace_layer)
371 .layer(TimeoutLayer::with_status_code(
372 StatusCode::REQUEST_TIMEOUT,
373 config.security.request_timeout,
374 ))
375 .layer(DefaultBodyLimit::max(config.security.max_body_size))
376 .layer(cors_layer)
377 .layer(SetResponseHeaderLayer::if_not_present(
378 header::X_CONTENT_TYPE_OPTIONS,
379 HeaderValue::from_static("nosniff"),
380 ))
381 .layer(SetResponseHeaderLayer::if_not_present(
382 header::X_FRAME_OPTIONS,
383 HeaderValue::from_static("DENY"),
384 ))
385 .layer(SetResponseHeaderLayer::if_not_present(
386 header::X_XSS_PROTECTION,
387 HeaderValue::from_static("1; mode=block"),
388 )),
389 )
390}
391
392pub async fn shutdown_signal() {
394 let ctrl_c = async {
395 let _ = tokio::signal::ctrl_c().await;
396 };
397
398 #[cfg(unix)]
399 let terminate = async {
400 if let Ok(mut signal) =
401 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
402 {
403 let _ = signal.recv().await;
404 }
405 };
406
407 #[cfg(not(unix))]
408 let terminate = std::future::pending::<()>();
409
410 tokio::select! {
411 _ = ctrl_c => {}
412 _ = terminate => {}
413 }
414}