1mod query;
13mod response;
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::sync::Arc;
18
19use bytes::Bytes;
20use http_body_util::combinators::BoxBody;
21use hyper::body::Incoming;
22
23use crate::agent_card::StaticAgentCardHandler;
24use crate::dispatch::cors::CorsConfig;
25use crate::handler::{RequestHandler, SendMessageResult};
26use crate::streaming::build_sse_response;
27
28use query::{
29 contains_path_traversal, parse_list_tasks_query, parse_query_param_u32, strip_tenant_prefix,
30};
31use response::{
32 error_json_response, extract_headers, health_response, inject_field_if_missing,
33 json_ok_response, not_found_response, read_body_limited, server_error_to_response,
34};
35
36pub struct RestDispatcher {
41 handler: Arc<RequestHandler>,
42 card_handler: Option<StaticAgentCardHandler>,
43 cors: Option<CorsConfig>,
44 config: super::DispatchConfig,
45}
46
47impl RestDispatcher {
48 #[must_use]
50 pub fn new(handler: Arc<RequestHandler>) -> Self {
51 Self::with_config(handler, super::DispatchConfig::default())
52 }
53
54 #[must_use]
56 pub fn with_config(handler: Arc<RequestHandler>, config: super::DispatchConfig) -> Self {
57 let card_handler = handler
58 .agent_card
59 .as_ref()
60 .and_then(|card| StaticAgentCardHandler::new(card).ok());
61 Self {
62 handler,
63 card_handler,
64 cors: None,
65 config,
66 }
67 }
68
69 #[must_use]
74 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
75 self.cors = Some(cors);
76 self
77 }
78
79 #[allow(clippy::too_many_lines)]
81 pub async fn dispatch(
82 &self,
83 req: hyper::Request<Incoming>,
84 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
85 let method = req.method().clone();
86 let path = req.uri().path().to_owned();
87 let query = req.uri().query().unwrap_or("").to_owned();
88 trace_info!(http_method = %method, %path, "dispatching REST request");
89
90 if method == "OPTIONS" {
92 if let Some(ref cors) = self.cors {
93 return cors.preflight_response();
94 }
95 return health_response();
96 }
97
98 if query.len() > self.config.max_query_string_length {
100 let mut resp = error_json_response(
101 414,
102 &format!(
103 "query string too long: {} bytes exceeds {} byte limit",
104 query.len(),
105 self.config.max_query_string_length
106 ),
107 );
108 if let Some(ref cors) = self.cors {
109 cors.apply_headers(&mut resp);
110 }
111 return resp;
112 }
113
114 if method == "GET" && (path == "/health" || path == "/ready") {
116 let mut resp = health_response();
117 if let Some(ref cors) = self.cors {
118 cors.apply_headers(&mut resp);
119 }
120 return resp;
121 }
122
123 if method == "POST" || method == "PUT" || method == "PATCH" {
125 if let Some(ct) = req.headers().get("content-type") {
126 let ct_str = ct.to_str().unwrap_or("");
127 if !ct_str.starts_with("application/json")
128 && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
129 {
130 return error_json_response(
131 415,
132 &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
133 );
134 }
135 }
136 }
137
138 if let Some(version) = req.headers().get(a2a_protocol_types::A2A_VERSION_HEADER) {
141 if let Ok(v) = version.to_str() {
142 let v = v.trim();
143 if !v.is_empty() {
145 let major = v.split('.').next().and_then(|s| s.parse::<u32>().ok());
146 if major != Some(1) {
147 return error_json_response(
148 400,
149 &format!("unsupported A2A version: {v}; this server supports 1.x"),
150 );
151 }
152 }
153 }
154 }
155
156 if contains_path_traversal(&path) {
158 return error_json_response(400, "invalid path: path traversal not allowed");
159 }
160
161 if method == "GET" && path == "/.well-known/agent-card.json" {
163 return self
164 .card_handler
165 .as_ref()
166 .map_or_else(not_found_response, |h| {
167 h.handle(&req).map(http_body_util::BodyExt::boxed)
168 });
169 }
170
171 let (tenant, rest_path) = strip_tenant_prefix(&path);
173
174 let headers = extract_headers(req.headers());
176
177 let mut resp = self
178 .dispatch_rest(req, method.as_str(), rest_path, &query, tenant, &headers)
179 .await;
180 if let Some(ref cors) = self.cors {
181 cors.apply_headers(&mut resp);
182 }
183 resp
184 }
185
186 #[allow(clippy::too_many_lines)]
188 async fn dispatch_rest(
189 &self,
190 req: hyper::Request<Incoming>,
191 method: &str,
192 path: &str,
193 query: &str,
194 tenant: Option<&str>,
195 headers: &HashMap<String, String>,
196 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
197 match (method, path) {
200 ("POST", "/message:send" | "/message/send") => {
201 return self.handle_send(req, false, headers).await;
202 }
203 ("POST", "/message:stream" | "/message/stream") => {
204 return self.handle_send(req, true, headers).await;
205 }
206 _ => {}
207 }
208
209 if let Some(rest) = path.strip_prefix("/tasks/") {
211 if let Some((id, action)) = rest.split_once(':') {
212 if !id.is_empty() {
213 match (method, action) {
214 ("POST", "cancel") => {
215 return self.handle_cancel_task(id, tenant, headers).await;
216 }
217 ("POST" | "GET", "subscribe") => {
218 return self.handle_resubscribe(id, tenant, headers).await;
219 }
220 _ => {}
221 }
222 }
223 }
224 }
225
226 let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
227
228 match (method, segments.as_slice()) {
229 ("GET", ["tasks"]) => self.handle_list_tasks(query, tenant, headers).await,
231 ("GET", ["tasks", id]) => self.handle_get_task(id, query, tenant, headers).await,
232
233 ("POST", ["tasks", id, "cancel"]) => self.handle_cancel_task(id, tenant, headers).await,
235
236 ("POST", ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig"]) => {
238 self.handle_set_push_config(req, task_id, headers).await
239 }
240 (
241 "GET",
242 ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig", config_id],
243 ) => {
244 self.handle_get_push_config(task_id, config_id, tenant, headers)
245 .await
246 }
247 ("GET", ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig"]) => {
248 self.handle_list_push_configs(task_id, tenant, headers)
249 .await
250 }
251 (
252 "DELETE",
253 ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig", config_id],
254 )
255 | (
256 "POST",
257 ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig", config_id, "delete"],
258 ) => {
259 self.handle_delete_push_config(task_id, config_id, tenant, headers)
260 .await
261 }
262
263 ("GET", ["extendedAgentCard"]) => self.handle_extended_card(headers).await,
265
266 _ => not_found_response(),
267 }
268 }
269
270 async fn handle_send(
273 &self,
274 req: hyper::Request<Incoming>,
275 streaming: bool,
276 headers: &HashMap<String, String>,
277 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
278 let body_bytes = match read_body_limited(
279 req.into_body(),
280 self.config.max_request_body_size,
281 self.config.body_read_timeout,
282 )
283 .await
284 {
285 Ok(bytes) => bytes,
286 Err(msg) => return error_json_response(413, &msg),
287 };
288 let params: a2a_protocol_types::params::MessageSendParams =
289 match serde_json::from_slice(&body_bytes) {
290 Ok(p) => p,
291 Err(e) => return error_json_response(400, &e.to_string()),
292 };
293 match self
294 .handler
295 .on_send_message(params, streaming, Some(headers))
296 .await
297 {
298 Ok(SendMessageResult::Response(resp)) => json_ok_response(&resp),
299 Ok(SendMessageResult::Stream(reader)) => build_sse_response(
300 reader,
301 Some(self.config.sse_keep_alive_interval),
302 Some(self.config.sse_channel_capacity),
303 false, ),
305 Err(e) => server_error_to_response(&e),
306 }
307 }
308
309 async fn handle_get_task(
310 &self,
311 id: &str,
312 query: &str,
313 tenant: Option<&str>,
314 headers: &HashMap<String, String>,
315 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
316 let history_length = parse_query_param_u32(query, "historyLength");
317 let params = a2a_protocol_types::params::TaskQueryParams {
318 tenant: tenant.map(str::to_owned),
319 id: id.to_owned(),
320 history_length,
321 };
322 match self.handler.on_get_task(params, Some(headers)).await {
323 Ok(task) => json_ok_response(&task),
324 Err(e) => server_error_to_response(&e),
325 }
326 }
327
328 async fn handle_list_tasks(
329 &self,
330 query: &str,
331 tenant: Option<&str>,
332 headers: &HashMap<String, String>,
333 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
334 let params = parse_list_tasks_query(query, tenant);
335 match self.handler.on_list_tasks(params, Some(headers)).await {
336 Ok(result) => json_ok_response(&result),
337 Err(e) => server_error_to_response(&e),
338 }
339 }
340
341 async fn handle_cancel_task(
342 &self,
343 id: &str,
344 tenant: Option<&str>,
345 headers: &HashMap<String, String>,
346 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
347 let params = a2a_protocol_types::params::CancelTaskParams {
348 tenant: tenant.map(str::to_owned),
349 id: id.to_owned(),
350 metadata: None,
351 };
352 match self.handler.on_cancel_task(params, Some(headers)).await {
353 Ok(task) => json_ok_response(&task),
354 Err(e) => server_error_to_response(&e),
355 }
356 }
357
358 async fn handle_resubscribe(
359 &self,
360 id: &str,
361 tenant: Option<&str>,
362 headers: &HashMap<String, String>,
363 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
364 let params = a2a_protocol_types::params::TaskIdParams {
365 tenant: tenant.map(str::to_owned),
366 id: id.to_owned(),
367 };
368 match self.handler.on_resubscribe(params, Some(headers)).await {
369 Ok(reader) => build_sse_response(
370 reader,
371 Some(self.config.sse_keep_alive_interval),
372 Some(self.config.sse_channel_capacity),
373 false, ),
375 Err(e) => server_error_to_response(&e),
376 }
377 }
378
379 async fn handle_set_push_config(
380 &self,
381 req: hyper::Request<Incoming>,
382 task_id: &str,
383 headers: &HashMap<String, String>,
384 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
385 let body_bytes = match read_body_limited(
386 req.into_body(),
387 self.config.max_request_body_size,
388 self.config.body_read_timeout,
389 )
390 .await
391 {
392 Ok(bytes) => bytes,
393 Err(msg) => return error_json_response(413, &msg),
394 };
395 let body_value: serde_json::Value = match serde_json::from_slice(&body_bytes) {
399 Ok(v) => v,
400 Err(e) => return error_json_response(400, &e.to_string()),
401 };
402 let body_value = inject_field_if_missing(body_value, "taskId", task_id);
403 let config: a2a_protocol_types::push::TaskPushNotificationConfig =
404 match serde_json::from_value(body_value) {
405 Ok(c) => c,
406 Err(e) => return error_json_response(400, &e.to_string()),
407 };
408 match self.handler.on_set_push_config(config, Some(headers)).await {
409 Ok(result) => json_ok_response(&result),
410 Err(e) => server_error_to_response(&e),
411 }
412 }
413
414 async fn handle_get_push_config(
415 &self,
416 task_id: &str,
417 config_id: &str,
418 tenant: Option<&str>,
419 headers: &HashMap<String, String>,
420 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
421 let params = a2a_protocol_types::params::GetPushConfigParams {
422 tenant: tenant.map(str::to_owned),
423 task_id: task_id.to_owned(),
424 id: config_id.to_owned(),
425 };
426 match self.handler.on_get_push_config(params, Some(headers)).await {
427 Ok(config) => json_ok_response(&config),
428 Err(e) => server_error_to_response(&e),
429 }
430 }
431
432 async fn handle_list_push_configs(
433 &self,
434 task_id: &str,
435 tenant: Option<&str>,
436 headers: &HashMap<String, String>,
437 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
438 match self
439 .handler
440 .on_list_push_configs(task_id, tenant, Some(headers))
441 .await
442 {
443 Ok(configs) => {
444 let resp = a2a_protocol_types::responses::ListPushConfigsResponse {
445 configs,
446 next_page_token: None,
447 };
448 json_ok_response(&resp)
449 }
450 Err(e) => server_error_to_response(&e),
451 }
452 }
453
454 async fn handle_delete_push_config(
455 &self,
456 task_id: &str,
457 config_id: &str,
458 tenant: Option<&str>,
459 headers: &HashMap<String, String>,
460 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
461 let params = a2a_protocol_types::params::DeletePushConfigParams {
462 tenant: tenant.map(str::to_owned),
463 task_id: task_id.to_owned(),
464 id: config_id.to_owned(),
465 };
466 match self
467 .handler
468 .on_delete_push_config(params, Some(headers))
469 .await
470 {
471 Ok(()) => json_ok_response(&serde_json::json!({})),
472 Err(e) => server_error_to_response(&e),
473 }
474 }
475
476 async fn handle_extended_card(
477 &self,
478 headers: &HashMap<String, String>,
479 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
480 match self.handler.on_get_extended_agent_card(Some(headers)).await {
481 Ok(card) => json_ok_response(&card),
482 Err(e) => server_error_to_response(&e),
483 }
484 }
485}
486
487impl std::fmt::Debug for RestDispatcher {
488 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489 f.debug_struct("RestDispatcher").finish()
490 }
491}
492
493impl crate::serve::Dispatcher for RestDispatcher {
496 fn dispatch(
497 &self,
498 req: hyper::Request<Incoming>,
499 ) -> std::pin::Pin<
500 Box<dyn std::future::Future<Output = crate::serve::DispatchResponse> + Send + '_>,
501 > {
502 Box::pin(self.dispatch(req))
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 #[test]
511 fn rest_dispatcher_debug_format() {
512 let debug_output = "RestDispatcher";
515 assert!(!debug_output.is_empty());
516 }
517
518 #[test]
519 fn dispatch_config_default_query_limit() {
520 let config = super::super::DispatchConfig::default();
521 assert_eq!(config.max_query_string_length, 4096);
522 }
523}