1mod query;
13mod response;
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::sync::Arc;
18
19use bytes::Bytes;
20use http_body_util::combinators::BoxBody;
21use hyper::body::Incoming;
22
23use crate::agent_card::StaticAgentCardHandler;
24use crate::dispatch::cors::CorsConfig;
25use crate::handler::{RequestHandler, SendMessageResult};
26use crate::streaming::build_sse_response;
27
28use query::{
29 contains_path_traversal, parse_list_tasks_query, parse_query_param_u32, strip_tenant_prefix,
30};
31use response::{
32 error_json_response, extract_headers, health_response, inject_field_if_missing,
33 json_ok_response, not_found_response, read_body_limited, server_error_to_response,
34};
35
36pub struct RestDispatcher {
41 handler: Arc<RequestHandler>,
42 card_handler: Option<StaticAgentCardHandler>,
43 cors: Option<CorsConfig>,
44 config: super::DispatchConfig,
45}
46
47impl RestDispatcher {
48 #[must_use]
50 pub fn new(handler: Arc<RequestHandler>) -> Self {
51 Self::with_config(handler, super::DispatchConfig::default())
52 }
53
54 #[must_use]
56 pub fn with_config(handler: Arc<RequestHandler>, config: super::DispatchConfig) -> Self {
57 let card_handler = handler
58 .agent_card
59 .as_ref()
60 .and_then(|card| StaticAgentCardHandler::new(card).ok());
61 Self {
62 handler,
63 card_handler,
64 cors: None,
65 config,
66 }
67 }
68
69 #[must_use]
74 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
75 self.cors = Some(cors);
76 self
77 }
78
79 #[allow(clippy::too_many_lines)]
81 pub async fn dispatch(
82 &self,
83 req: hyper::Request<Incoming>,
84 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
85 let method = req.method().clone();
86 let path = req.uri().path().to_owned();
87 let query = req.uri().query().unwrap_or("").to_owned();
88 trace_info!(http_method = %method, %path, "dispatching REST request");
89
90 if method == "OPTIONS" {
92 if let Some(ref cors) = self.cors {
93 return cors.preflight_response();
94 }
95 return health_response();
96 }
97
98 if query.len() > self.config.max_query_string_length {
100 let mut resp = error_json_response(
101 414,
102 &format!(
103 "query string too long: {} bytes exceeds {} byte limit",
104 query.len(),
105 self.config.max_query_string_length
106 ),
107 );
108 if let Some(ref cors) = self.cors {
109 cors.apply_headers(&mut resp);
110 }
111 return resp;
112 }
113
114 if method == "GET" && (path == "/health" || path == "/ready") {
116 let mut resp = health_response();
117 if let Some(ref cors) = self.cors {
118 cors.apply_headers(&mut resp);
119 }
120 return resp;
121 }
122
123 if method == "POST" || method == "PUT" || method == "PATCH" {
125 if let Some(ct) = req.headers().get("content-type") {
126 let ct_str = ct.to_str().unwrap_or("");
127 if !ct_str.starts_with("application/json")
128 && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
129 {
130 return error_json_response(
131 415,
132 &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
133 );
134 }
135 }
136 }
137
138 if contains_path_traversal(&path) {
140 return error_json_response(400, "invalid path: path traversal not allowed");
141 }
142
143 if method == "GET" && path == "/.well-known/agent.json" {
145 return self
146 .card_handler
147 .as_ref()
148 .map_or_else(not_found_response, |h| {
149 h.handle(&req).map(http_body_util::BodyExt::boxed)
150 });
151 }
152
153 let (tenant, rest_path) = strip_tenant_prefix(&path);
155
156 let headers = extract_headers(req.headers());
158
159 let mut resp = self
160 .dispatch_rest(req, method.as_str(), rest_path, &query, tenant, &headers)
161 .await;
162 if let Some(ref cors) = self.cors {
163 cors.apply_headers(&mut resp);
164 }
165 resp
166 }
167
168 #[allow(clippy::too_many_lines)]
170 async fn dispatch_rest(
171 &self,
172 req: hyper::Request<Incoming>,
173 method: &str,
174 path: &str,
175 query: &str,
176 tenant: Option<&str>,
177 headers: &HashMap<String, String>,
178 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
179 match (method, path) {
182 ("POST", "/message:send" | "/message/send") => {
183 return self.handle_send(req, false, headers).await;
184 }
185 ("POST", "/message:stream" | "/message/stream") => {
186 return self.handle_send(req, true, headers).await;
187 }
188 _ => {}
189 }
190
191 if let Some(rest) = path.strip_prefix("/tasks/") {
193 if let Some((id, action)) = rest.split_once(':') {
194 if !id.is_empty() {
195 match (method, action) {
196 ("POST", "cancel") => {
197 return self.handle_cancel_task(id, tenant, headers).await;
198 }
199 ("POST" | "GET", "subscribe") => {
200 return self.handle_resubscribe(id, tenant, headers).await;
201 }
202 _ => {}
203 }
204 }
205 }
206 }
207
208 let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
209
210 match (method, segments.as_slice()) {
211 ("GET", ["tasks"]) => self.handle_list_tasks(query, tenant, headers).await,
213 ("GET", ["tasks", id]) => self.handle_get_task(id, query, tenant, headers).await,
214
215 ("POST", ["tasks", id, "cancel"]) => self.handle_cancel_task(id, tenant, headers).await,
217
218 ("POST", ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig"]) => {
220 self.handle_set_push_config(req, task_id, headers).await
221 }
222 (
223 "GET",
224 ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig", config_id],
225 ) => {
226 self.handle_get_push_config(task_id, config_id, tenant, headers)
227 .await
228 }
229 ("GET", ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig"]) => {
230 self.handle_list_push_configs(task_id, tenant, headers)
231 .await
232 }
233 (
234 "DELETE",
235 ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig", config_id],
236 )
237 | (
238 "POST",
239 ["tasks", task_id, "pushNotificationConfigs" | "pushNotificationConfig", config_id, "delete"],
240 ) => {
241 self.handle_delete_push_config(task_id, config_id, tenant, headers)
242 .await
243 }
244
245 ("GET", ["extendedAgentCard"]) => self.handle_extended_card(headers).await,
247
248 _ => not_found_response(),
249 }
250 }
251
252 async fn handle_send(
255 &self,
256 req: hyper::Request<Incoming>,
257 streaming: bool,
258 headers: &HashMap<String, String>,
259 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
260 let body_bytes = match read_body_limited(
261 req.into_body(),
262 self.config.max_request_body_size,
263 self.config.body_read_timeout,
264 )
265 .await
266 {
267 Ok(bytes) => bytes,
268 Err(msg) => return error_json_response(413, &msg),
269 };
270 let params: a2a_protocol_types::params::MessageSendParams =
271 match serde_json::from_slice(&body_bytes) {
272 Ok(p) => p,
273 Err(e) => return error_json_response(400, &e.to_string()),
274 };
275 match self
276 .handler
277 .on_send_message(params, streaming, Some(headers))
278 .await
279 {
280 Ok(SendMessageResult::Response(resp)) => json_ok_response(&resp),
281 Ok(SendMessageResult::Stream(reader)) => build_sse_response(
282 reader,
283 Some(self.config.sse_keep_alive_interval),
284 Some(self.config.sse_channel_capacity),
285 ),
286 Err(e) => server_error_to_response(&e),
287 }
288 }
289
290 async fn handle_get_task(
291 &self,
292 id: &str,
293 query: &str,
294 tenant: Option<&str>,
295 headers: &HashMap<String, String>,
296 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
297 let history_length = parse_query_param_u32(query, "historyLength");
298 let params = a2a_protocol_types::params::TaskQueryParams {
299 tenant: tenant.map(str::to_owned),
300 id: id.to_owned(),
301 history_length,
302 };
303 match self.handler.on_get_task(params, Some(headers)).await {
304 Ok(task) => json_ok_response(&task),
305 Err(e) => server_error_to_response(&e),
306 }
307 }
308
309 async fn handle_list_tasks(
310 &self,
311 query: &str,
312 tenant: Option<&str>,
313 headers: &HashMap<String, String>,
314 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
315 let params = parse_list_tasks_query(query, tenant);
316 match self.handler.on_list_tasks(params, Some(headers)).await {
317 Ok(result) => json_ok_response(&result),
318 Err(e) => server_error_to_response(&e),
319 }
320 }
321
322 async fn handle_cancel_task(
323 &self,
324 id: &str,
325 tenant: Option<&str>,
326 headers: &HashMap<String, String>,
327 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
328 let params = a2a_protocol_types::params::CancelTaskParams {
329 tenant: tenant.map(str::to_owned),
330 id: id.to_owned(),
331 metadata: None,
332 };
333 match self.handler.on_cancel_task(params, Some(headers)).await {
334 Ok(task) => json_ok_response(&task),
335 Err(e) => server_error_to_response(&e),
336 }
337 }
338
339 async fn handle_resubscribe(
340 &self,
341 id: &str,
342 tenant: Option<&str>,
343 headers: &HashMap<String, String>,
344 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
345 let params = a2a_protocol_types::params::TaskIdParams {
346 tenant: tenant.map(str::to_owned),
347 id: id.to_owned(),
348 };
349 match self.handler.on_resubscribe(params, Some(headers)).await {
350 Ok(reader) => build_sse_response(
351 reader,
352 Some(self.config.sse_keep_alive_interval),
353 Some(self.config.sse_channel_capacity),
354 ),
355 Err(e) => server_error_to_response(&e),
356 }
357 }
358
359 async fn handle_set_push_config(
360 &self,
361 req: hyper::Request<Incoming>,
362 task_id: &str,
363 headers: &HashMap<String, String>,
364 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
365 let body_bytes = match read_body_limited(
366 req.into_body(),
367 self.config.max_request_body_size,
368 self.config.body_read_timeout,
369 )
370 .await
371 {
372 Ok(bytes) => bytes,
373 Err(msg) => return error_json_response(413, &msg),
374 };
375 let body_value: serde_json::Value = match serde_json::from_slice(&body_bytes) {
379 Ok(v) => v,
380 Err(e) => return error_json_response(400, &e.to_string()),
381 };
382 let body_value = inject_field_if_missing(body_value, "taskId", task_id);
383 let config: a2a_protocol_types::push::TaskPushNotificationConfig =
384 match serde_json::from_value(body_value) {
385 Ok(c) => c,
386 Err(e) => return error_json_response(400, &e.to_string()),
387 };
388 match self.handler.on_set_push_config(config, Some(headers)).await {
389 Ok(result) => json_ok_response(&result),
390 Err(e) => server_error_to_response(&e),
391 }
392 }
393
394 async fn handle_get_push_config(
395 &self,
396 task_id: &str,
397 config_id: &str,
398 tenant: Option<&str>,
399 headers: &HashMap<String, String>,
400 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
401 let params = a2a_protocol_types::params::GetPushConfigParams {
402 tenant: tenant.map(str::to_owned),
403 task_id: task_id.to_owned(),
404 id: config_id.to_owned(),
405 };
406 match self.handler.on_get_push_config(params, Some(headers)).await {
407 Ok(config) => json_ok_response(&config),
408 Err(e) => server_error_to_response(&e),
409 }
410 }
411
412 async fn handle_list_push_configs(
413 &self,
414 task_id: &str,
415 tenant: Option<&str>,
416 headers: &HashMap<String, String>,
417 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
418 match self
419 .handler
420 .on_list_push_configs(task_id, tenant, Some(headers))
421 .await
422 {
423 Ok(configs) => {
424 let resp = a2a_protocol_types::responses::ListPushConfigsResponse {
425 configs,
426 next_page_token: None,
427 };
428 json_ok_response(&resp)
429 }
430 Err(e) => server_error_to_response(&e),
431 }
432 }
433
434 async fn handle_delete_push_config(
435 &self,
436 task_id: &str,
437 config_id: &str,
438 tenant: Option<&str>,
439 headers: &HashMap<String, String>,
440 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
441 let params = a2a_protocol_types::params::DeletePushConfigParams {
442 tenant: tenant.map(str::to_owned),
443 task_id: task_id.to_owned(),
444 id: config_id.to_owned(),
445 };
446 match self
447 .handler
448 .on_delete_push_config(params, Some(headers))
449 .await
450 {
451 Ok(()) => json_ok_response(&serde_json::json!({})),
452 Err(e) => server_error_to_response(&e),
453 }
454 }
455
456 async fn handle_extended_card(
457 &self,
458 headers: &HashMap<String, String>,
459 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
460 match self.handler.on_get_extended_agent_card(Some(headers)).await {
461 Ok(card) => json_ok_response(&card),
462 Err(e) => server_error_to_response(&e),
463 }
464 }
465}
466
467impl std::fmt::Debug for RestDispatcher {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 f.debug_struct("RestDispatcher").finish()
470 }
471}
472
473impl crate::serve::Dispatcher for RestDispatcher {
476 fn dispatch(
477 &self,
478 req: hyper::Request<Incoming>,
479 ) -> std::pin::Pin<
480 Box<dyn std::future::Future<Output = crate::serve::DispatchResponse> + Send + '_>,
481 > {
482 Box::pin(self.dispatch(req))
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 #[test]
491 fn rest_dispatcher_debug_format() {
492 let debug_output = "RestDispatcher";
495 assert!(!debug_output.is_empty());
496 }
497
498 #[test]
499 fn dispatch_config_default_query_limit() {
500 let config = super::super::DispatchConfig::default();
501 assert_eq!(config.max_query_string_length, 4096);
502 }
503}