Skip to main content

a2a_protocol_server/dispatch/
rest.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! REST dispatcher.
5//!
6//! [`RestDispatcher`] routes HTTP requests by method and path to the
7//! appropriate [`RequestHandler`] method, following the REST transport
8//! convention defined in the A2A protocol.
9
10use std::convert::Infallible;
11use std::sync::Arc;
12
13use bytes::Bytes;
14use http_body_util::combinators::BoxBody;
15use http_body_util::{BodyExt, Full};
16use hyper::body::Incoming;
17
18use crate::agent_card::StaticAgentCardHandler;
19use crate::dispatch::cors::CorsConfig;
20use crate::error::ServerError;
21use crate::handler::{RequestHandler, SendMessageResult};
22use crate::streaming::build_sse_response;
23
24/// REST HTTP request dispatcher.
25///
26/// Routes requests by HTTP method and path to the underlying [`RequestHandler`].
27/// Optionally applies CORS headers to all responses.
28pub struct RestDispatcher {
29    handler: Arc<RequestHandler>,
30    card_handler: Option<StaticAgentCardHandler>,
31    cors: Option<CorsConfig>,
32}
33
34impl RestDispatcher {
35    /// Creates a new REST dispatcher.
36    #[must_use]
37    pub fn new(handler: Arc<RequestHandler>) -> Self {
38        let card_handler = handler
39            .agent_card
40            .as_ref()
41            .and_then(|card| StaticAgentCardHandler::new(card).ok());
42        Self {
43            handler,
44            card_handler,
45            cors: None,
46        }
47    }
48
49    /// Sets CORS configuration for this dispatcher.
50    ///
51    /// When set, all responses will include CORS headers, and `OPTIONS` preflight
52    /// requests will be handled automatically.
53    #[must_use]
54    pub fn with_cors(mut self, cors: CorsConfig) -> Self {
55        self.cors = Some(cors);
56        self
57    }
58
59    /// Dispatches an HTTP request to the appropriate handler method.
60    #[allow(clippy::too_many_lines)]
61    pub async fn dispatch(
62        &self,
63        req: hyper::Request<Incoming>,
64    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
65        let method = req.method().clone();
66        let path = req.uri().path().to_owned();
67        let query = req.uri().query().unwrap_or("").to_owned();
68        trace_info!(http_method = %method, %path, "dispatching REST request");
69
70        // Handle CORS preflight requests.
71        if method == "OPTIONS" {
72            if let Some(ref cors) = self.cors {
73                return cors.preflight_response();
74            }
75            return health_response();
76        }
77
78        // Reject oversized query strings (DoS protection).
79        if query.len() > MAX_QUERY_STRING_LENGTH {
80            let mut resp = error_json_response(
81                414,
82                &format!(
83                    "query string too long: {} bytes exceeds {} byte limit",
84                    query.len(),
85                    MAX_QUERY_STRING_LENGTH
86                ),
87            );
88            if let Some(ref cors) = self.cors {
89                cors.apply_headers(&mut resp);
90            }
91            return resp;
92        }
93
94        // Health check endpoint.
95        if method == "GET" && (path == "/health" || path == "/ready") {
96            let mut resp = health_response();
97            if let Some(ref cors) = self.cors {
98                cors.apply_headers(&mut resp);
99            }
100            return resp;
101        }
102
103        // Validate Content-Type for POST/PUT/PATCH requests.
104        if method == "POST" || method == "PUT" || method == "PATCH" {
105            if let Some(ct) = req.headers().get("content-type") {
106                let ct_str = ct.to_str().unwrap_or("");
107                if !ct_str.starts_with("application/json")
108                    && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
109                {
110                    return error_json_response(
111                        415,
112                        &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
113                    );
114                }
115            }
116        }
117
118        // Reject path traversal attempts (check both raw and percent-decoded forms).
119        if contains_path_traversal(&path) {
120            return error_json_response(400, "invalid path: path traversal not allowed");
121        }
122
123        // Agent card is always at the well-known path (no tenant prefix).
124        if method == "GET" && path == "/.well-known/agent.json" {
125            return self
126                .card_handler
127                .as_ref()
128                .map_or_else(not_found_response, |h| {
129                    h.handle(&req).map(http_body_util::BodyExt::boxed)
130                });
131        }
132
133        // Strip optional /tenants/{tenant}/ prefix.
134        let (tenant, rest_path) = strip_tenant_prefix(&path);
135
136        let mut resp = self
137            .dispatch_rest(req, method.as_str(), rest_path, &query, tenant)
138            .await;
139        if let Some(ref cors) = self.cors {
140            cors.apply_headers(&mut resp);
141        }
142        resp
143    }
144
145    /// Dispatch on the tenant-stripped path.
146    #[allow(clippy::too_many_lines)]
147    async fn dispatch_rest(
148        &self,
149        req: hyper::Request<Incoming>,
150        method: &str,
151        path: &str,
152        query: &str,
153        tenant: Option<&str>,
154    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
155        // Colon-suffixed routes: /message:send, /message:stream.
156        match (method, path) {
157            ("POST", "/message:send") => return self.handle_send(req, false).await,
158            ("POST", "/message:stream") => return self.handle_send(req, true).await,
159            _ => {}
160        }
161
162        // Colon-action routes on tasks: /tasks/{id}:cancel, /tasks/{id}:subscribe.
163        if let Some(rest) = path.strip_prefix("/tasks/") {
164            if let Some((id, action)) = rest.split_once(':') {
165                if !id.is_empty() {
166                    match (method, action) {
167                        ("POST", "cancel") => return self.handle_cancel_task(id).await,
168                        ("POST" | "GET", "subscribe") => {
169                            return self.handle_resubscribe(id).await;
170                        }
171                        _ => {}
172                    }
173                }
174            }
175        }
176
177        let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
178
179        match (method, segments.as_slice()) {
180            // Tasks.
181            ("GET", ["tasks"]) => self.handle_list_tasks(query, tenant).await,
182            ("GET", ["tasks", id]) => self.handle_get_task(id, query).await,
183
184            // Push notification configs.
185            ("POST", ["tasks", task_id, "pushNotificationConfigs"]) => {
186                self.handle_set_push_config(req, task_id).await
187            }
188            ("GET", ["tasks", task_id, "pushNotificationConfigs", config_id]) => {
189                self.handle_get_push_config(task_id, config_id).await
190            }
191            ("GET", ["tasks", task_id, "pushNotificationConfigs"]) => {
192                self.handle_list_push_configs(task_id).await
193            }
194            ("DELETE", ["tasks", task_id, "pushNotificationConfigs", config_id]) => {
195                self.handle_delete_push_config(task_id, config_id).await
196            }
197
198            // Extended card.
199            ("GET", ["extendedAgentCard"]) => self.handle_extended_card().await,
200
201            _ => not_found_response(),
202        }
203    }
204
205    // ── Route handlers ───────────────────────────────────────────────────
206
207    async fn handle_send(
208        &self,
209        req: hyper::Request<Incoming>,
210        streaming: bool,
211    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
212        let body_bytes = match read_body_limited(req.into_body(), MAX_REQUEST_BODY_SIZE).await {
213            Ok(bytes) => bytes,
214            Err(msg) => return error_json_response(413, &msg),
215        };
216        let params: a2a_protocol_types::params::MessageSendParams =
217            match serde_json::from_slice(&body_bytes) {
218                Ok(p) => p,
219                Err(e) => return error_json_response(400, &e.to_string()),
220            };
221        match self.handler.on_send_message(params, streaming).await {
222            Ok(SendMessageResult::Response(resp)) => json_ok_response(&resp),
223            Ok(SendMessageResult::Stream(reader)) => build_sse_response(reader, None),
224            Err(e) => server_error_to_response(&e),
225        }
226    }
227
228    async fn handle_get_task(
229        &self,
230        id: &str,
231        query: &str,
232    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
233        let history_length = parse_query_param_u32(query, "historyLength");
234        let params = a2a_protocol_types::params::TaskQueryParams {
235            tenant: None,
236            id: id.to_owned(),
237            history_length,
238        };
239        match self.handler.on_get_task(params).await {
240            Ok(task) => json_ok_response(&task),
241            Err(e) => server_error_to_response(&e),
242        }
243    }
244
245    async fn handle_list_tasks(
246        &self,
247        query: &str,
248        tenant: Option<&str>,
249    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
250        let params = parse_list_tasks_query(query, tenant);
251        match self.handler.on_list_tasks(params).await {
252            Ok(result) => json_ok_response(&result),
253            Err(e) => server_error_to_response(&e),
254        }
255    }
256
257    async fn handle_cancel_task(&self, id: &str) -> hyper::Response<BoxBody<Bytes, Infallible>> {
258        let params = a2a_protocol_types::params::CancelTaskParams {
259            tenant: None,
260            id: id.to_owned(),
261            metadata: None,
262        };
263        match self.handler.on_cancel_task(params).await {
264            Ok(task) => json_ok_response(&task),
265            Err(e) => server_error_to_response(&e),
266        }
267    }
268
269    async fn handle_resubscribe(&self, id: &str) -> hyper::Response<BoxBody<Bytes, Infallible>> {
270        let params = a2a_protocol_types::params::TaskIdParams {
271            tenant: None,
272            id: id.to_owned(),
273        };
274        match self.handler.on_resubscribe(params).await {
275            Ok(reader) => build_sse_response(reader, None),
276            Err(e) => server_error_to_response(&e),
277        }
278    }
279
280    async fn handle_set_push_config(
281        &self,
282        req: hyper::Request<Incoming>,
283        _task_id: &str,
284    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
285        let body_bytes = match read_body_limited(req.into_body(), MAX_REQUEST_BODY_SIZE).await {
286            Ok(bytes) => bytes,
287            Err(msg) => return error_json_response(413, &msg),
288        };
289        let config: a2a_protocol_types::push::TaskPushNotificationConfig =
290            match serde_json::from_slice(&body_bytes) {
291                Ok(c) => c,
292                Err(e) => return error_json_response(400, &e.to_string()),
293            };
294        match self.handler.on_set_push_config(config).await {
295            Ok(result) => json_ok_response(&result),
296            Err(e) => server_error_to_response(&e),
297        }
298    }
299
300    async fn handle_get_push_config(
301        &self,
302        task_id: &str,
303        config_id: &str,
304    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
305        let params = a2a_protocol_types::params::GetPushConfigParams {
306            tenant: None,
307            task_id: task_id.to_owned(),
308            id: config_id.to_owned(),
309        };
310        match self.handler.on_get_push_config(params).await {
311            Ok(config) => json_ok_response(&config),
312            Err(e) => server_error_to_response(&e),
313        }
314    }
315
316    async fn handle_list_push_configs(
317        &self,
318        task_id: &str,
319    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
320        match self.handler.on_list_push_configs(task_id).await {
321            Ok(configs) => json_ok_response(&configs),
322            Err(e) => server_error_to_response(&e),
323        }
324    }
325
326    async fn handle_delete_push_config(
327        &self,
328        task_id: &str,
329        config_id: &str,
330    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
331        let params = a2a_protocol_types::params::DeletePushConfigParams {
332            tenant: None,
333            task_id: task_id.to_owned(),
334            id: config_id.to_owned(),
335        };
336        match self.handler.on_delete_push_config(params).await {
337            Ok(()) => json_ok_response(&serde_json::json!({})),
338            Err(e) => server_error_to_response(&e),
339        }
340    }
341
342    async fn handle_extended_card(&self) -> hyper::Response<BoxBody<Bytes, Infallible>> {
343        match self.handler.on_get_extended_agent_card().await {
344            Ok(card) => json_ok_response(&card),
345            Err(e) => server_error_to_response(&e),
346        }
347    }
348}
349
350impl std::fmt::Debug for RestDispatcher {
351    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        f.debug_struct("RestDispatcher").finish()
353    }
354}
355
356// ── Response helpers ─────────────────────────────────────────────────────────
357
358fn json_ok_response<T: serde::Serialize>(value: &T) -> hyper::Response<BoxBody<Bytes, Infallible>> {
359    match serde_json::to_vec(value) {
360        Ok(body) => build_json_response(200, body),
361        Err(_err) => {
362            trace_error!(error = %_err, "REST response serialization failed");
363            internal_error_response()
364        }
365    }
366}
367
368fn error_json_response(status: u16, message: &str) -> hyper::Response<BoxBody<Bytes, Infallible>> {
369    let body = serde_json::json!({ "error": message });
370    serde_json::to_vec(&body).map_or_else(
371        |_| internal_error_response(),
372        |bytes| build_json_response(status, bytes),
373    )
374}
375
376/// Fallback when serialization itself fails.
377fn internal_error_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
378    let body = br#"{"error":"internal serialization error"}"#;
379    build_json_response(500, body.to_vec())
380}
381
382fn not_found_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
383    error_json_response(404, "not found")
384}
385
386fn server_error_to_response(err: &ServerError) -> hyper::Response<BoxBody<Bytes, Infallible>> {
387    let status = match err {
388        ServerError::TaskNotFound(_) | ServerError::MethodNotFound(_) => 404,
389        ServerError::TaskNotCancelable(_) => 409,
390        ServerError::InvalidParams(_)
391        | ServerError::Serialization(_)
392        | ServerError::PushNotSupported => 400,
393        _ => 500,
394    };
395    let a2a_err = err.to_a2a_error();
396    serde_json::to_vec(&a2a_err).map_or_else(
397        |_| internal_error_response(),
398        |body| build_json_response(status, body),
399    )
400}
401
402// ── Query parsing helpers ───────────────────────────────────────────────────
403
404/// Strips an optional `/tenants/{tenant}/` prefix, returning the tenant and
405/// remaining path.
406fn strip_tenant_prefix(path: &str) -> (Option<&str>, &str) {
407    if let Some(rest) = path.strip_prefix("/tenants/") {
408        if let Some(slash_pos) = rest.find('/') {
409            let tenant = &rest[..slash_pos];
410            let remaining = &rest[slash_pos..];
411            return (Some(tenant), remaining);
412        }
413    }
414    (None, path)
415}
416
417/// Parses a single query parameter value as `u32`.
418fn parse_query_param_u32(query: &str, key: &str) -> Option<u32> {
419    parse_query_param(query, key).and_then(|v| v.parse::<u32>().ok())
420}
421
422/// Parses a single query parameter value as a string, with percent-decoding.
423fn parse_query_param(query: &str, key: &str) -> Option<String> {
424    query.split('&').find_map(|pair| {
425        let (k, v) = pair.split_once('=')?;
426        if k == key {
427            Some(percent_decode(v))
428        } else {
429            None
430        }
431    })
432}
433
434/// Decodes percent-encoded characters in a query parameter value.
435///
436/// Handles `%XX` hex sequences and `+` as space (application/x-www-form-urlencoded).
437fn percent_decode(input: &str) -> String {
438    let mut output = String::with_capacity(input.len());
439    let mut bytes = input.as_bytes().iter();
440    while let Some(&b) = bytes.next() {
441        match b {
442            b'%' => {
443                let hi = bytes.next().copied();
444                let lo = bytes.next().copied();
445                if let (Some(h), Some(l)) = (hi, lo) {
446                    if let (Some(h), Some(l)) = (hex_val(h), hex_val(l)) {
447                        output.push(char::from(h << 4 | l));
448                        continue;
449                    }
450                }
451                // Invalid percent sequence — pass through as-is.
452                output.push('%');
453            }
454            b'+' => output.push(' '),
455            _ => output.push(char::from(b)),
456        }
457    }
458    output
459}
460
461/// Checks if a path contains traversal sequences (`..`) in either raw or
462/// percent-encoded form (`%2E%2E`, `%2e%2e`).
463fn contains_path_traversal(path: &str) -> bool {
464    if path.contains("..") {
465        return true;
466    }
467    // Also check percent-encoded variants.
468    let decoded = percent_decode(path);
469    decoded.contains("..")
470}
471
472/// Returns the numeric value of a hex digit, or `None` if invalid.
473const fn hex_val(b: u8) -> Option<u8> {
474    match b {
475        b'0'..=b'9' => Some(b - b'0'),
476        b'a'..=b'f' => Some(b - b'a' + 10),
477        b'A'..=b'F' => Some(b - b'A' + 10),
478        _ => None,
479    }
480}
481
482/// Parses a single query parameter value as `bool`.
483fn parse_query_param_bool(query: &str, key: &str) -> Option<bool> {
484    parse_query_param(query, key).map(|v| v == "true" || v == "1")
485}
486
487/// Maximum query string length in bytes (prevents denial-of-service via oversized query params).
488const MAX_QUERY_STRING_LENGTH: usize = 4096;
489
490/// Maximum request body size in bytes (4 MiB).
491const MAX_REQUEST_BODY_SIZE: usize = 4 * 1024 * 1024;
492
493/// Maximum duration to read a complete request body (slow loris protection).
494const BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
495
496/// Reads a request body with a size limit and timeout.
497async fn read_body_limited(body: Incoming, max_size: usize) -> Result<Bytes, String> {
498    let size_hint = <Incoming as hyper::body::Body>::size_hint(&body);
499    if let Some(upper) = size_hint.upper() {
500        if upper > max_size as u64 {
501            return Err(format!(
502                "request body too large: {upper} bytes exceeds {max_size} byte limit"
503            ));
504        }
505    }
506    let collected = tokio::time::timeout(BODY_READ_TIMEOUT, body.collect())
507        .await
508        .map_err(|_| "request body read timed out".to_owned())?
509        .map_err(|e| e.to_string())?;
510    let bytes = collected.to_bytes();
511    if bytes.len() > max_size {
512        return Err(format!(
513            "request body too large: {} bytes exceeds {max_size} byte limit",
514            bytes.len()
515        ));
516    }
517    Ok(bytes)
518}
519
520/// Returns a health check response.
521fn health_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
522    let body = br#"{"status":"ok"}"#;
523    build_json_response(200, body.to_vec())
524}
525
526/// Builds a JSON HTTP response with the given status and body.
527fn build_json_response(status: u16, body: Vec<u8>) -> hyper::Response<BoxBody<Bytes, Infallible>> {
528    hyper::Response::builder()
529        .status(status)
530        .header("content-type", a2a_protocol_types::A2A_CONTENT_TYPE)
531        .header(
532            a2a_protocol_types::A2A_VERSION_HEADER,
533            a2a_protocol_types::A2A_VERSION,
534        )
535        .body(Full::new(Bytes::from(body)).boxed())
536        .unwrap_or_else(|_| {
537            // Fallback: plain 500 response if builder fails (should never happen
538            // with valid static header names).
539            hyper::Response::new(
540                Full::new(Bytes::from_static(br#"{"error":"response build error"}"#)).boxed(),
541            )
542        })
543}
544
545/// Parses `ListTasksParams` from URL query parameters.
546fn parse_list_tasks_query(
547    query: &str,
548    tenant: Option<&str>,
549) -> a2a_protocol_types::params::ListTasksParams {
550    let status = parse_query_param(query, "status")
551        .and_then(|s| serde_json::from_value(serde_json::Value::String(s)).ok());
552    a2a_protocol_types::params::ListTasksParams {
553        tenant: tenant.map(str::to_owned),
554        context_id: parse_query_param(query, "contextId"),
555        status,
556        page_size: parse_query_param_u32(query, "pageSize"),
557        page_token: parse_query_param(query, "pageToken"),
558        status_timestamp_after: parse_query_param(query, "statusTimestampAfter"),
559        include_artifacts: parse_query_param_bool(query, "includeArtifacts"),
560        history_length: parse_query_param_u32(query, "historyLength"),
561    }
562}