1use std::convert::Infallible;
11use std::sync::Arc;
12
13use bytes::Bytes;
14use http_body_util::combinators::BoxBody;
15use http_body_util::{BodyExt, Full};
16use hyper::body::Incoming;
17
18use crate::agent_card::StaticAgentCardHandler;
19use crate::dispatch::cors::CorsConfig;
20use crate::error::ServerError;
21use crate::handler::{RequestHandler, SendMessageResult};
22use crate::streaming::build_sse_response;
23
24pub struct RestDispatcher {
29 handler: Arc<RequestHandler>,
30 card_handler: Option<StaticAgentCardHandler>,
31 cors: Option<CorsConfig>,
32}
33
34impl RestDispatcher {
35 #[must_use]
37 pub fn new(handler: Arc<RequestHandler>) -> Self {
38 let card_handler = handler
39 .agent_card
40 .as_ref()
41 .and_then(|card| StaticAgentCardHandler::new(card).ok());
42 Self {
43 handler,
44 card_handler,
45 cors: None,
46 }
47 }
48
49 #[must_use]
54 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
55 self.cors = Some(cors);
56 self
57 }
58
59 #[allow(clippy::too_many_lines)]
61 pub async fn dispatch(
62 &self,
63 req: hyper::Request<Incoming>,
64 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
65 let method = req.method().clone();
66 let path = req.uri().path().to_owned();
67 let query = req.uri().query().unwrap_or("").to_owned();
68 trace_info!(http_method = %method, %path, "dispatching REST request");
69
70 if method == "OPTIONS" {
72 if let Some(ref cors) = self.cors {
73 return cors.preflight_response();
74 }
75 return health_response();
76 }
77
78 if query.len() > MAX_QUERY_STRING_LENGTH {
80 let mut resp = error_json_response(
81 414,
82 &format!(
83 "query string too long: {} bytes exceeds {} byte limit",
84 query.len(),
85 MAX_QUERY_STRING_LENGTH
86 ),
87 );
88 if let Some(ref cors) = self.cors {
89 cors.apply_headers(&mut resp);
90 }
91 return resp;
92 }
93
94 if method == "GET" && (path == "/health" || path == "/ready") {
96 let mut resp = health_response();
97 if let Some(ref cors) = self.cors {
98 cors.apply_headers(&mut resp);
99 }
100 return resp;
101 }
102
103 if method == "POST" || method == "PUT" || method == "PATCH" {
105 if let Some(ct) = req.headers().get("content-type") {
106 let ct_str = ct.to_str().unwrap_or("");
107 if !ct_str.starts_with("application/json")
108 && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
109 {
110 return error_json_response(
111 415,
112 &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
113 );
114 }
115 }
116 }
117
118 if contains_path_traversal(&path) {
120 return error_json_response(400, "invalid path: path traversal not allowed");
121 }
122
123 if method == "GET" && path == "/.well-known/agent.json" {
125 return self
126 .card_handler
127 .as_ref()
128 .map_or_else(not_found_response, |h| {
129 h.handle(&req).map(http_body_util::BodyExt::boxed)
130 });
131 }
132
133 let (tenant, rest_path) = strip_tenant_prefix(&path);
135
136 let mut resp = self
137 .dispatch_rest(req, method.as_str(), rest_path, &query, tenant)
138 .await;
139 if let Some(ref cors) = self.cors {
140 cors.apply_headers(&mut resp);
141 }
142 resp
143 }
144
145 #[allow(clippy::too_many_lines)]
147 async fn dispatch_rest(
148 &self,
149 req: hyper::Request<Incoming>,
150 method: &str,
151 path: &str,
152 query: &str,
153 tenant: Option<&str>,
154 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
155 match (method, path) {
157 ("POST", "/message:send") => return self.handle_send(req, false).await,
158 ("POST", "/message:stream") => return self.handle_send(req, true).await,
159 _ => {}
160 }
161
162 if let Some(rest) = path.strip_prefix("/tasks/") {
164 if let Some((id, action)) = rest.split_once(':') {
165 if !id.is_empty() {
166 match (method, action) {
167 ("POST", "cancel") => return self.handle_cancel_task(id).await,
168 ("POST" | "GET", "subscribe") => {
169 return self.handle_resubscribe(id).await;
170 }
171 _ => {}
172 }
173 }
174 }
175 }
176
177 let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
178
179 match (method, segments.as_slice()) {
180 ("GET", ["tasks"]) => self.handle_list_tasks(query, tenant).await,
182 ("GET", ["tasks", id]) => self.handle_get_task(id, query).await,
183
184 ("POST", ["tasks", task_id, "pushNotificationConfigs"]) => {
186 self.handle_set_push_config(req, task_id).await
187 }
188 ("GET", ["tasks", task_id, "pushNotificationConfigs", config_id]) => {
189 self.handle_get_push_config(task_id, config_id).await
190 }
191 ("GET", ["tasks", task_id, "pushNotificationConfigs"]) => {
192 self.handle_list_push_configs(task_id).await
193 }
194 ("DELETE", ["tasks", task_id, "pushNotificationConfigs", config_id]) => {
195 self.handle_delete_push_config(task_id, config_id).await
196 }
197
198 ("GET", ["extendedAgentCard"]) => self.handle_extended_card().await,
200
201 _ => not_found_response(),
202 }
203 }
204
205 async fn handle_send(
208 &self,
209 req: hyper::Request<Incoming>,
210 streaming: bool,
211 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
212 let body_bytes = match read_body_limited(req.into_body(), MAX_REQUEST_BODY_SIZE).await {
213 Ok(bytes) => bytes,
214 Err(msg) => return error_json_response(413, &msg),
215 };
216 let params: a2a_protocol_types::params::MessageSendParams =
217 match serde_json::from_slice(&body_bytes) {
218 Ok(p) => p,
219 Err(e) => return error_json_response(400, &e.to_string()),
220 };
221 match self.handler.on_send_message(params, streaming).await {
222 Ok(SendMessageResult::Response(resp)) => json_ok_response(&resp),
223 Ok(SendMessageResult::Stream(reader)) => build_sse_response(reader, None),
224 Err(e) => server_error_to_response(&e),
225 }
226 }
227
228 async fn handle_get_task(
229 &self,
230 id: &str,
231 query: &str,
232 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
233 let history_length = parse_query_param_u32(query, "historyLength");
234 let params = a2a_protocol_types::params::TaskQueryParams {
235 tenant: None,
236 id: id.to_owned(),
237 history_length,
238 };
239 match self.handler.on_get_task(params).await {
240 Ok(task) => json_ok_response(&task),
241 Err(e) => server_error_to_response(&e),
242 }
243 }
244
245 async fn handle_list_tasks(
246 &self,
247 query: &str,
248 tenant: Option<&str>,
249 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
250 let params = parse_list_tasks_query(query, tenant);
251 match self.handler.on_list_tasks(params).await {
252 Ok(result) => json_ok_response(&result),
253 Err(e) => server_error_to_response(&e),
254 }
255 }
256
257 async fn handle_cancel_task(&self, id: &str) -> hyper::Response<BoxBody<Bytes, Infallible>> {
258 let params = a2a_protocol_types::params::CancelTaskParams {
259 tenant: None,
260 id: id.to_owned(),
261 metadata: None,
262 };
263 match self.handler.on_cancel_task(params).await {
264 Ok(task) => json_ok_response(&task),
265 Err(e) => server_error_to_response(&e),
266 }
267 }
268
269 async fn handle_resubscribe(&self, id: &str) -> hyper::Response<BoxBody<Bytes, Infallible>> {
270 let params = a2a_protocol_types::params::TaskIdParams {
271 tenant: None,
272 id: id.to_owned(),
273 };
274 match self.handler.on_resubscribe(params).await {
275 Ok(reader) => build_sse_response(reader, None),
276 Err(e) => server_error_to_response(&e),
277 }
278 }
279
280 async fn handle_set_push_config(
281 &self,
282 req: hyper::Request<Incoming>,
283 _task_id: &str,
284 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
285 let body_bytes = match read_body_limited(req.into_body(), MAX_REQUEST_BODY_SIZE).await {
286 Ok(bytes) => bytes,
287 Err(msg) => return error_json_response(413, &msg),
288 };
289 let config: a2a_protocol_types::push::TaskPushNotificationConfig =
290 match serde_json::from_slice(&body_bytes) {
291 Ok(c) => c,
292 Err(e) => return error_json_response(400, &e.to_string()),
293 };
294 match self.handler.on_set_push_config(config).await {
295 Ok(result) => json_ok_response(&result),
296 Err(e) => server_error_to_response(&e),
297 }
298 }
299
300 async fn handle_get_push_config(
301 &self,
302 task_id: &str,
303 config_id: &str,
304 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
305 let params = a2a_protocol_types::params::GetPushConfigParams {
306 tenant: None,
307 task_id: task_id.to_owned(),
308 id: config_id.to_owned(),
309 };
310 match self.handler.on_get_push_config(params).await {
311 Ok(config) => json_ok_response(&config),
312 Err(e) => server_error_to_response(&e),
313 }
314 }
315
316 async fn handle_list_push_configs(
317 &self,
318 task_id: &str,
319 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
320 match self.handler.on_list_push_configs(task_id).await {
321 Ok(configs) => json_ok_response(&configs),
322 Err(e) => server_error_to_response(&e),
323 }
324 }
325
326 async fn handle_delete_push_config(
327 &self,
328 task_id: &str,
329 config_id: &str,
330 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
331 let params = a2a_protocol_types::params::DeletePushConfigParams {
332 tenant: None,
333 task_id: task_id.to_owned(),
334 id: config_id.to_owned(),
335 };
336 match self.handler.on_delete_push_config(params).await {
337 Ok(()) => json_ok_response(&serde_json::json!({})),
338 Err(e) => server_error_to_response(&e),
339 }
340 }
341
342 async fn handle_extended_card(&self) -> hyper::Response<BoxBody<Bytes, Infallible>> {
343 match self.handler.on_get_extended_agent_card().await {
344 Ok(card) => json_ok_response(&card),
345 Err(e) => server_error_to_response(&e),
346 }
347 }
348}
349
350impl std::fmt::Debug for RestDispatcher {
351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 f.debug_struct("RestDispatcher").finish()
353 }
354}
355
356fn json_ok_response<T: serde::Serialize>(value: &T) -> hyper::Response<BoxBody<Bytes, Infallible>> {
359 match serde_json::to_vec(value) {
360 Ok(body) => build_json_response(200, body),
361 Err(_err) => {
362 trace_error!(error = %_err, "REST response serialization failed");
363 internal_error_response()
364 }
365 }
366}
367
368fn error_json_response(status: u16, message: &str) -> hyper::Response<BoxBody<Bytes, Infallible>> {
369 let body = serde_json::json!({ "error": message });
370 serde_json::to_vec(&body).map_or_else(
371 |_| internal_error_response(),
372 |bytes| build_json_response(status, bytes),
373 )
374}
375
376fn internal_error_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
378 let body = br#"{"error":"internal serialization error"}"#;
379 build_json_response(500, body.to_vec())
380}
381
382fn not_found_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
383 error_json_response(404, "not found")
384}
385
386fn server_error_to_response(err: &ServerError) -> hyper::Response<BoxBody<Bytes, Infallible>> {
387 let status = match err {
388 ServerError::TaskNotFound(_) | ServerError::MethodNotFound(_) => 404,
389 ServerError::TaskNotCancelable(_) => 409,
390 ServerError::InvalidParams(_)
391 | ServerError::Serialization(_)
392 | ServerError::PushNotSupported => 400,
393 _ => 500,
394 };
395 let a2a_err = err.to_a2a_error();
396 serde_json::to_vec(&a2a_err).map_or_else(
397 |_| internal_error_response(),
398 |body| build_json_response(status, body),
399 )
400}
401
402fn strip_tenant_prefix(path: &str) -> (Option<&str>, &str) {
407 if let Some(rest) = path.strip_prefix("/tenants/") {
408 if let Some(slash_pos) = rest.find('/') {
409 let tenant = &rest[..slash_pos];
410 let remaining = &rest[slash_pos..];
411 return (Some(tenant), remaining);
412 }
413 }
414 (None, path)
415}
416
417fn parse_query_param_u32(query: &str, key: &str) -> Option<u32> {
419 parse_query_param(query, key).and_then(|v| v.parse::<u32>().ok())
420}
421
422fn parse_query_param(query: &str, key: &str) -> Option<String> {
424 query.split('&').find_map(|pair| {
425 let (k, v) = pair.split_once('=')?;
426 if k == key {
427 Some(percent_decode(v))
428 } else {
429 None
430 }
431 })
432}
433
434fn percent_decode(input: &str) -> String {
438 let mut output = String::with_capacity(input.len());
439 let mut bytes = input.as_bytes().iter();
440 while let Some(&b) = bytes.next() {
441 match b {
442 b'%' => {
443 let hi = bytes.next().copied();
444 let lo = bytes.next().copied();
445 if let (Some(h), Some(l)) = (hi, lo) {
446 if let (Some(h), Some(l)) = (hex_val(h), hex_val(l)) {
447 output.push(char::from(h << 4 | l));
448 continue;
449 }
450 }
451 output.push('%');
453 }
454 b'+' => output.push(' '),
455 _ => output.push(char::from(b)),
456 }
457 }
458 output
459}
460
461fn contains_path_traversal(path: &str) -> bool {
464 if path.contains("..") {
465 return true;
466 }
467 let decoded = percent_decode(path);
469 decoded.contains("..")
470}
471
472const fn hex_val(b: u8) -> Option<u8> {
474 match b {
475 b'0'..=b'9' => Some(b - b'0'),
476 b'a'..=b'f' => Some(b - b'a' + 10),
477 b'A'..=b'F' => Some(b - b'A' + 10),
478 _ => None,
479 }
480}
481
482fn parse_query_param_bool(query: &str, key: &str) -> Option<bool> {
484 parse_query_param(query, key).map(|v| v == "true" || v == "1")
485}
486
487const MAX_QUERY_STRING_LENGTH: usize = 4096;
489
490const MAX_REQUEST_BODY_SIZE: usize = 4 * 1024 * 1024;
492
493const BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
495
496async fn read_body_limited(body: Incoming, max_size: usize) -> Result<Bytes, String> {
498 let size_hint = <Incoming as hyper::body::Body>::size_hint(&body);
499 if let Some(upper) = size_hint.upper() {
500 if upper > max_size as u64 {
501 return Err(format!(
502 "request body too large: {upper} bytes exceeds {max_size} byte limit"
503 ));
504 }
505 }
506 let collected = tokio::time::timeout(BODY_READ_TIMEOUT, body.collect())
507 .await
508 .map_err(|_| "request body read timed out".to_owned())?
509 .map_err(|e| e.to_string())?;
510 let bytes = collected.to_bytes();
511 if bytes.len() > max_size {
512 return Err(format!(
513 "request body too large: {} bytes exceeds {max_size} byte limit",
514 bytes.len()
515 ));
516 }
517 Ok(bytes)
518}
519
520fn health_response() -> hyper::Response<BoxBody<Bytes, Infallible>> {
522 let body = br#"{"status":"ok"}"#;
523 build_json_response(200, body.to_vec())
524}
525
526fn build_json_response(status: u16, body: Vec<u8>) -> hyper::Response<BoxBody<Bytes, Infallible>> {
528 hyper::Response::builder()
529 .status(status)
530 .header("content-type", a2a_protocol_types::A2A_CONTENT_TYPE)
531 .header(
532 a2a_protocol_types::A2A_VERSION_HEADER,
533 a2a_protocol_types::A2A_VERSION,
534 )
535 .body(Full::new(Bytes::from(body)).boxed())
536 .unwrap_or_else(|_| {
537 hyper::Response::new(
540 Full::new(Bytes::from_static(br#"{"error":"response build error"}"#)).boxed(),
541 )
542 })
543}
544
545fn parse_list_tasks_query(
547 query: &str,
548 tenant: Option<&str>,
549) -> a2a_protocol_types::params::ListTasksParams {
550 let status = parse_query_param(query, "status")
551 .and_then(|s| serde_json::from_value(serde_json::Value::String(s)).ok());
552 a2a_protocol_types::params::ListTasksParams {
553 tenant: tenant.map(str::to_owned),
554 context_id: parse_query_param(query, "contextId"),
555 status,
556 page_size: parse_query_param_u32(query, "pageSize"),
557 page_token: parse_query_param(query, "pageToken"),
558 status_timestamp_after: parse_query_param(query, "statusTimestampAfter"),
559 include_artifacts: parse_query_param_bool(query, "includeArtifacts"),
560 history_length: parse_query_param_u32(query, "historyLength"),
561 }
562}