adk_server/rest/
mod.rs

1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5    A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6    SessionController,
7};
8
9use crate::{web_ui, ServerConfig};
10use axum::{
11    extract::DefaultBodyLimit,
12    http::{header, HeaderValue, Method},
13    routing::{get, post},
14    Router,
15};
16use tower::ServiceBuilder;
17use tower_http::{
18    cors::{AllowOrigin, CorsLayer},
19    set_header::SetResponseHeaderLayer,
20    timeout::TimeoutLayer,
21    trace::TraceLayer,
22};
23
24/// Build CORS layer based on security configuration
25fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
26    let cors = CorsLayer::new()
27        .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
28        .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
29
30    if config.security.allowed_origins.is_empty() {
31        // Development mode: allow all origins (with warning logged at startup)
32        cors.allow_origin(AllowOrigin::any())
33    } else {
34        // Production mode: only allow specified origins
35        let origins: Vec<HeaderValue> =
36            config.security.allowed_origins.iter().filter_map(|o| o.parse().ok()).collect();
37        cors.allow_origin(origins)
38    }
39}
40
41/// Create the server application with optional A2A support
42pub fn create_app(config: ServerConfig) -> Router {
43    create_app_with_a2a(config, None)
44}
45
46/// Create the server application with A2A support at the specified base URL
47pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
48    let session_controller = SessionController::new(config.session_service.clone());
49    let runtime_controller = RuntimeController::new(config.clone());
50    let apps_controller = AppsController::new(config.clone());
51    let artifacts_controller = ArtifactsController::new(config.clone());
52    let debug_controller = DebugController::new(config.clone());
53
54    let api_router = Router::new()
55        .route("/health", get(health_check))
56        .route("/apps", get(controllers::apps::list_apps))
57        .route("/list-apps", get(controllers::apps::list_apps_compat))
58        .with_state(apps_controller)
59        .route("/sessions", post(controllers::session::create_session))
60        .route(
61            "/sessions/{app_name}/{user_id}/{session_id}",
62            get(controllers::session::get_session).delete(controllers::session::delete_session),
63        )
64        // adk-go compatible routes
65        .route(
66            "/apps/{app_name}/users/{user_id}/sessions",
67            get(controllers::session::list_sessions)
68                .post(controllers::session::create_session_from_path),
69        )
70        .route(
71            "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
72            get(controllers::session::get_session_from_path)
73                .post(controllers::session::create_session_from_path)
74                .delete(controllers::session::delete_session_from_path),
75        )
76        .with_state(session_controller)
77        .route("/run/{app_name}/{user_id}/{session_id}", post(controllers::runtime::run_sse))
78        .route("/run_sse", post(controllers::runtime::run_sse_compat))
79        .with_state(runtime_controller)
80        .route(
81            "/sessions/{app_name}/{user_id}/{session_id}/artifacts",
82            get(controllers::artifacts::list_artifacts),
83        )
84        .route(
85            "/sessions/{app_name}/{user_id}/{session_id}/artifacts/{artifact_name}",
86            get(controllers::artifacts::get_artifact),
87        )
88        .with_state(artifacts_controller)
89        .route("/debug/trace/{event_id}", get(controllers::debug::get_trace))
90        .route(
91            "/debug/graph/{app_name}/{user_id}/{session_id}/{event_id}",
92            get(controllers::debug::get_graph),
93        )
94        .with_state(debug_controller);
95
96    let ui_router = Router::new()
97        .route("/", get(web_ui::root_redirect))
98        .route("/ui/", get(web_ui::serve_ui_index))
99        .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
100        .with_state(config.clone())
101        .route("/ui/{*path}", get(web_ui::serve_ui_assets));
102
103    let mut app = Router::new().nest("/api", api_router).merge(ui_router);
104
105    // Add A2A routes if base URL is provided
106    if let Some(base_url) = a2a_base_url {
107        let a2a_controller = A2aController::new(config.clone(), base_url);
108        let a2a_router = Router::new()
109            .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
110            .route("/a2a", post(controllers::a2a::handle_jsonrpc))
111            .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
112            .with_state(a2a_controller);
113        app = app.merge(a2a_router);
114    }
115
116    // Build security layers
117    let cors_layer = build_cors_layer(&config);
118
119    // Apply all middleware layers
120    app.layer(
121        ServiceBuilder::new()
122            // Tracing for observability
123            .layer(TraceLayer::new_for_http())
124            // Request timeout
125            .layer(TimeoutLayer::with_status_code(
126                axum::http::StatusCode::REQUEST_TIMEOUT,
127                config.security.request_timeout,
128            ))
129            // Request body size limit
130            .layer(DefaultBodyLimit::max(config.security.max_body_size))
131            // CORS configuration
132            .layer(cors_layer)
133            // Security headers
134            .layer(SetResponseHeaderLayer::if_not_present(
135                header::X_CONTENT_TYPE_OPTIONS,
136                HeaderValue::from_static("nosniff"),
137            ))
138            .layer(SetResponseHeaderLayer::if_not_present(
139                header::X_FRAME_OPTIONS,
140                HeaderValue::from_static("DENY"),
141            ))
142            .layer(SetResponseHeaderLayer::if_not_present(
143                header::X_XSS_PROTECTION,
144                HeaderValue::from_static("1; mode=block"),
145            )),
146    )
147}
148
149async fn health_check() -> &'static str {
150    "OK"
151}