adk_server/rest/
mod.rs

1pub mod controllers;
2mod routes;
3
4pub use controllers::{
5    A2aController, AppsController, ArtifactsController, DebugController, RuntimeController,
6    SessionController,
7};
8
9use crate::{web_ui, ServerConfig};
10use axum::{
11    extract::DefaultBodyLimit,
12    http::{header, HeaderValue, Method},
13    routing::{get, post},
14    Router,
15};
16use tower::ServiceBuilder;
17use tower_http::{
18    cors::{AllowOrigin, CorsLayer},
19    set_header::SetResponseHeaderLayer,
20    timeout::TimeoutLayer,
21    trace::TraceLayer,
22};
23
24/// Build CORS layer based on security configuration
25fn build_cors_layer(config: &ServerConfig) -> CorsLayer {
26    let cors = CorsLayer::new()
27        .allow_methods([
28            Method::GET,
29            Method::POST,
30            Method::PUT,
31            Method::DELETE,
32            Method::OPTIONS,
33        ])
34        .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
35
36    if config.security.allowed_origins.is_empty() {
37        // Development mode: allow all origins (with warning logged at startup)
38        cors.allow_origin(AllowOrigin::any())
39    } else {
40        // Production mode: only allow specified origins
41        let origins: Vec<HeaderValue> = config
42            .security
43            .allowed_origins
44            .iter()
45            .filter_map(|o| o.parse().ok())
46            .collect();
47        cors.allow_origin(origins)
48    }
49}
50
51/// Create the server application with optional A2A support
52pub fn create_app(config: ServerConfig) -> Router {
53    create_app_with_a2a(config, None)
54}
55
56/// Create the server application with A2A support at the specified base URL
57pub fn create_app_with_a2a(config: ServerConfig, a2a_base_url: Option<&str>) -> Router {
58    let session_controller = SessionController::new(config.session_service.clone());
59    let runtime_controller = RuntimeController::new(config.clone());
60    let apps_controller = AppsController::new(config.clone());
61    let artifacts_controller = ArtifactsController::new(config.clone());
62    let debug_controller = DebugController::new(config.clone());
63
64    let api_router = Router::new()
65        .route("/health", get(health_check))
66        .route("/apps", get(controllers::apps::list_apps))
67        .route("/list-apps", get(controllers::apps::list_apps_compat))
68        .with_state(apps_controller)
69        .route("/sessions", post(controllers::session::create_session))
70        .route(
71            "/sessions/:app_name/:user_id/:session_id",
72            get(controllers::session::get_session).delete(controllers::session::delete_session),
73        )
74        // adk-go compatible routes
75        .route(
76            "/apps/:app_name/users/:user_id/sessions",
77            get(controllers::session::list_sessions)
78                .post(controllers::session::create_session_from_path),
79        )
80        .route(
81            "/apps/:app_name/users/:user_id/sessions/:session_id",
82            get(controllers::session::get_session_from_path)
83                .post(controllers::session::create_session_from_path)
84                .delete(controllers::session::delete_session_from_path),
85        )
86        .with_state(session_controller)
87        .route("/run/:app_name/:user_id/:session_id", post(controllers::runtime::run_sse))
88        .route("/run_sse", post(controllers::runtime::run_sse_compat))
89        .with_state(runtime_controller)
90        .route(
91            "/sessions/:app_name/:user_id/:session_id/artifacts",
92            get(controllers::artifacts::list_artifacts),
93        )
94        .route(
95            "/sessions/:app_name/:user_id/:session_id/artifacts/:artifact_name",
96            get(controllers::artifacts::get_artifact),
97        )
98        .with_state(artifacts_controller)
99        .route("/debug/trace/:event_id", get(controllers::debug::get_trace))
100        .route(
101            "/debug/graph/:app_name/:user_id/:session_id/:event_id",
102            get(controllers::debug::get_graph),
103        )
104        .with_state(debug_controller);
105
106    let ui_router = Router::new()
107        .route("/", get(web_ui::root_redirect))
108        .route("/ui/", get(web_ui::serve_ui_index))
109        .route("/ui/assets/config/runtime-config.json", get(web_ui::serve_runtime_config))
110        .with_state(config.clone())
111        .route("/ui/*path", get(web_ui::serve_ui_assets));
112
113    let mut app = Router::new().nest("/api", api_router).merge(ui_router);
114
115    // Add A2A routes if base URL is provided
116    if let Some(base_url) = a2a_base_url {
117        let a2a_controller = A2aController::new(config.clone(), base_url);
118        let a2a_router = Router::new()
119            .route("/.well-known/agent.json", get(controllers::a2a::get_agent_card))
120            .route("/a2a", post(controllers::a2a::handle_jsonrpc))
121            .route("/a2a/stream", post(controllers::a2a::handle_jsonrpc_stream))
122            .with_state(a2a_controller);
123        app = app.merge(a2a_router);
124    }
125
126    // Build security layers
127    let cors_layer = build_cors_layer(&config);
128
129    // Apply all middleware layers
130    app.layer(
131        ServiceBuilder::new()
132            // Tracing for observability
133            .layer(TraceLayer::new_for_http())
134            // Request timeout
135            .layer(TimeoutLayer::new(config.security.request_timeout))
136            // Request body size limit
137            .layer(DefaultBodyLimit::max(config.security.max_body_size))
138            // CORS configuration
139            .layer(cors_layer)
140            // Security headers
141            .layer(SetResponseHeaderLayer::if_not_present(
142                header::X_CONTENT_TYPE_OPTIONS,
143                HeaderValue::from_static("nosniff"),
144            ))
145            .layer(SetResponseHeaderLayer::if_not_present(
146                header::X_FRAME_OPTIONS,
147                HeaderValue::from_static("DENY"),
148            ))
149            .layer(SetResponseHeaderLayer::if_not_present(
150                header::X_XSS_PROTECTION,
151                HeaderValue::from_static("1; mode=block"),
152            )),
153    )
154}
155
156async fn health_check() -> &'static str {
157    "OK"
158}