adk_server/rest/
mod.rs

1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5    A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6    SessionController,
7};
8
9use crate::{ServerConfig, web_ui};
10use axum::{
11    Router,
12    extract::DefaultBodyLimit,
13    http::{HeaderValue, Method, header},
14    routing::{get, post},
15};
16use tower::ServiceBuilder;
17use tower_http::{
18    cors::{AllowOrigin, CorsLayer},
19    set_header::SetResponseHeaderLayer,
20    timeout::TimeoutLayer,
21    trace::TraceLayer,
22};
23
24/// Build CORS layer based on security configuration
25fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
26    let cors = CorsLayer::new()
27        .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
28        .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
29
30    if config.security.allowed_origins.is_empty() {
31        // Development mode: allow all origins (with warning logged at startup)
32        cors.allow_origin(AllowOrigin::any())
33    } else {
34        // Production mode: only allow specified origins
35        let origins: Vec<HeaderValue> =
36            config.security.allowed_origins.iter().filter_map(|o| o.parse().ok()).collect();
37        cors.allow_origin(origins)
38    }
39}
40
41/// Create the server application with optional A2A support
42pub fn create_app(config: ServerConfig) -> Router {
43    create_app_with_a2a(config, None)
44}
45
46/// Create the server application with A2A support at the specified base URL
47pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
48    let session_controller = SessionController::new(config.session_service.clone());
49    let runtime_controller = RuntimeController::new(config.clone());
50    let apps_controller = AppsController::new(config.clone());
51    let artifacts_controller = ArtifactsController::new(config.clone());
52    let debug_controller = DebugController::new(config.clone());
53
54    let api_router = Router::new()
55        .route("/health", get(health_check))
56        .route("/apps", get(controllers::apps::list_apps))
57        .route("/list-apps", get(controllers::apps::list_apps_compat))
58        .with_state(apps_controller)
59        .route("/sessions", post(controllers::session::create_session))
60        .route(
61            "/sessions/{app_name}/{user_id}/{session_id}",
62            get(controllers::session::get_session).delete(controllers::session::delete_session),
63        )
64        // adk-go compatible routes
65        .route(
66            "/apps/{app_name}/users/{user_id}/sessions",
67            get(controllers::session::list_sessions)
68                .post(controllers::session::create_session_from_path),
69        )
70        .route(
71            "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
72            get(controllers::session::get_session_from_path)
73                .post(controllers::session::create_session_from_path)
74                .delete(controllers::session::delete_session_from_path),
75        )
76        .with_state(session_controller)
77        .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
78        .route("/run_sse", post(controllers::runtime::run_sse_compat))
79        .with_state(runtime_controller)
80        .route(
81            "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
82            get(controllers::artifacts::list_artifacts),
83        )
84        .route(
85            "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
86            get(controllers::artifacts::get_artifact),
87        )
88        .with_state(artifacts_controller)
89        .route("/debug/trace/{event_id}", get(controllers::debug::get_trace_by_event_id))
90        .route("/debug/trace/session/{session_id}", get(controllers::debug::get_session_traces))
91        .route(
92            "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
93            get(controllers::debug::get_graph),
94        )
95        // UI-compatible graph route
96        .route(
97            "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
98            get(controllers::debug::get_graph),
99        )
100        // UI-compatible eval_sets route (stub)
101        .route("/apps/{app_name}/eval_sets", get(controllers::debug::get_eval_sets))
102        // UI-compatible event route - for trace-event linking
103        .route(
104            "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}",
105            get(controllers::debug::get_event),
106        )
107        .with_state(debug_controller);
108
109    let ui_router = Router::new()
110        .route("/", get(web_ui::root_redirect))
111        .route("/ui/", get(web_ui::serve_ui_index))
112        .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
113        .with_state(config.clone())
114        .route("/ui/{*path}", get(web_ui::serve_ui_assets));
115
116    let mut app = Router::new().nest("/api", api_router).merge(ui_router);
117
118    // Add A2A routes if base URL is provided
119    if let Some(base_url) = a2a_base_url {
120        let a2a_controller = A2aController::new(config.clone(), base_url);
121        let a2a_router = Router::new()
122            .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
123            .route("/a2a", post(controllers::a2a::handle_jsonrpc))
124            .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
125            .with_state(a2a_controller);
126        app = app.merge(a2a_router);
127    }
128
129    // Build security layers
130    let cors_layer = build_cors_layer(&config);
131
132    // Apply all middleware layers
133    app.layer(
134        ServiceBuilder::new()
135            // Tracing for observability
136            .layer(TraceLayer::new_for_http())
137            // Request timeout
138            .layer(TimeoutLayer::with_status_code(
139                axum::http::StatusCode::REQUEST_TIMEOUT,
140                config.security.request_timeout,
141            ))
142            // Request body size limit
143            .layer(DefaultBodyLimit::max(config.security.max_body_size))
144            // CORS configuration
145            .layer(cors_layer)
146            // Security headers
147            .layer(SetResponseHeaderLayer::if_not_present(
148                header::X_CONTENT_TYPE_OPTIONS,
149                HeaderValue::from_static("nosniff"),
150            ))
151            .layer(SetResponseHeaderLayer::if_not_present(
152                header::X_FRAME_OPTIONS,
153                HeaderValue::from_static("DENY"),
154            ))
155            .layer(SetResponseHeaderLayer::if_not_present(
156                header::X_XSS_PROTECTION,
157                HeaderValue::from_static("1; mode=block"),
158            )),
159    )
160}
161
162async fn health_check() -> &'static str {
163    "OK"
164}