1use std::collections::HashMap;
53use std::convert::Infallible;
54use std::sync::Arc;
55
56use axum::body::Body;
57use axum::extract::{Path, Query, State};
58use axum::response::IntoResponse;
59use axum::routing::{get, post};
60use axum::Router;
61use bytes::Bytes;
62
63use crate::handler::{RequestHandler, SendMessageResult};
64use crate::streaming::build_sse_response;
65
66pub struct A2aRouter {
91 handler: Arc<RequestHandler>,
92 config: super::DispatchConfig,
93}
94
95impl A2aRouter {
96 #[must_use]
98 pub fn new(handler: Arc<RequestHandler>) -> Self {
99 Self {
100 handler,
101 config: super::DispatchConfig::default(),
102 }
103 }
104
105 #[must_use]
107 pub const fn with_config(handler: Arc<RequestHandler>, config: super::DispatchConfig) -> Self {
108 Self { handler, config }
109 }
110
111 pub fn into_router(self) -> Router {
116 let state = A2aState {
117 handler: self.handler,
118 config: Arc::new(self.config),
119 };
120
121 Router::new()
122 .route("/message:send", post(handle_send_message))
124 .route("/message:stream", post(handle_stream_message))
125 .route("/tasks", get(handle_list_tasks))
127 .route("/tasks/{*rest}", axum::routing::any(handle_tasks_catchall))
132 .route("/extendedAgentCard", get(handle_extended_card))
134 .route("/.well-known/agent.json", get(handle_agent_card))
136 .route("/health", get(handle_health))
138 .with_state(state)
139 }
140}
141
142#[derive(Clone)]
145struct A2aState {
146 handler: Arc<RequestHandler>,
147 config: Arc<super::DispatchConfig>,
148}
149
150fn extract_headers(headers: &axum::http::HeaderMap) -> HashMap<String, String> {
153 headers
154 .iter()
155 .filter_map(|(k, v)| {
156 v.to_str()
157 .ok()
158 .map(|val| (k.as_str().to_lowercase(), val.to_owned()))
159 })
160 .collect()
161}
162
163fn a2a_error_to_response(err: &dyn std::fmt::Display, status: u16) -> axum::response::Response {
166 let body = serde_json::json!({ "error": err.to_string() });
167 (
168 axum::http::StatusCode::from_u16(status)
169 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
170 axum::Json(body),
171 )
172 .into_response()
173}
174
175const fn server_error_status(err: &crate::error::ServerError) -> u16 {
176 use crate::error::ServerError;
177
178 match err {
179 ServerError::TaskNotFound(_) | ServerError::MethodNotFound(_) => 404,
180 ServerError::InvalidParams(_) | ServerError::Serialization(_) => 400,
181 ServerError::InvalidStateTransition { .. } | ServerError::TaskNotCancelable(_) => 409,
182 ServerError::PushNotSupported => 501,
183 ServerError::PayloadTooLarge(_) => 413,
184 _ => 500,
185 }
186}
187
188fn handler_error_to_response(err: &crate::error::ServerError) -> axum::response::Response {
189 a2a_error_to_response(err, server_error_status(err))
190}
191
192fn hyper_sse_to_axum(
197 resp: hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>>,
198) -> axum::response::Response {
199 let (parts, body) = resp.into_parts();
200 let axum_body = Body::new(body);
201 axum::response::Response::from_parts(parts, axum_body)
202}
203
204async fn handle_tasks_catchall(
217 State(state): State<A2aState>,
218 method: axum::http::Method,
219 Path(rest): Path<String>,
220 headers: axum::http::HeaderMap,
221 body: Bytes,
222) -> axum::response::Response {
223 let hdrs = extract_headers(&headers);
224 let segments: Vec<&str> = rest.split('/').filter(|s| !s.is_empty()).collect();
225
226 match (method.as_str(), segments.as_slice()) {
227 ("GET", [id]) if !id.contains(':') => handle_get_task_inner(&state, id, &hdrs).await,
229
230 ("POST", [id_action]) if id_action.ends_with(":cancel") => {
232 let id = &id_action[..id_action.len() - ":cancel".len()];
233 handle_cancel_task_inner(&state, id, &hdrs).await
234 }
235
236 ("GET" | "POST", [id_action]) if id_action.ends_with(":subscribe") => {
238 let id = &id_action[..id_action.len() - ":subscribe".len()];
239 handle_subscribe_inner(&state, id, &hdrs).await
240 }
241
242 ("POST", [task_id, "pushNotificationConfigs"]) => {
244 handle_create_push_config_inner(&state, task_id, &hdrs, body).await
245 }
246
247 ("GET", [task_id, "pushNotificationConfigs"]) => {
249 handle_list_push_configs_inner(&state, task_id, &hdrs).await
250 }
251
252 ("GET", [task_id, "pushNotificationConfigs", config_id]) => {
254 handle_get_push_config_inner(&state, task_id, config_id, &hdrs).await
255 }
256
257 ("DELETE", [task_id, "pushNotificationConfigs", config_id]) => {
259 handle_delete_push_config_inner(&state, task_id, config_id, &hdrs).await
260 }
261
262 _ => a2a_error_to_response(&"not found", 404),
263 }
264}
265
266async fn handle_send_message(
269 State(state): State<A2aState>,
270 headers: axum::http::HeaderMap,
271 body: Bytes,
272) -> axum::response::Response {
273 handle_send_inner(&state, false, &headers, body).await
274}
275
276async fn handle_stream_message(
277 State(state): State<A2aState>,
278 headers: axum::http::HeaderMap,
279 body: Bytes,
280) -> axum::response::Response {
281 handle_send_inner(&state, true, &headers, body).await
282}
283
284async fn handle_list_tasks(
285 State(state): State<A2aState>,
286 Query(query): Query<HashMap<String, String>>,
287 headers: axum::http::HeaderMap,
288) -> axum::response::Response {
289 let hdrs = extract_headers(&headers);
290 let params = a2a_protocol_types::params::ListTasksParams {
291 tenant: None,
292 context_id: query.get("contextId").cloned(),
293 status: query
294 .get("status")
295 .and_then(|s| serde_json::from_value(serde_json::Value::String(s.clone())).ok()),
296 page_size: query.get("pageSize").and_then(|v| v.parse().ok()),
297 page_token: query.get("pageToken").cloned(),
298 status_timestamp_after: query.get("statusTimestampAfter").cloned(),
299 include_artifacts: query.get("includeArtifacts").and_then(|v| v.parse().ok()),
300 history_length: query.get("historyLength").and_then(|v| v.parse().ok()),
301 };
302 match state.handler.on_list_tasks(params, Some(&hdrs)).await {
303 Ok(result) => axum::Json(result).into_response(),
304 Err(e) => handler_error_to_response(&e),
305 }
306}
307
308async fn handle_extended_card(
309 State(state): State<A2aState>,
310 headers: axum::http::HeaderMap,
311) -> axum::response::Response {
312 let hdrs = extract_headers(&headers);
313 match state.handler.on_get_extended_agent_card(Some(&hdrs)).await {
314 Ok(card) => axum::Json(card).into_response(),
315 Err(e) => handler_error_to_response(&e),
316 }
317}
318
319async fn handle_agent_card(State(state): State<A2aState>) -> axum::response::Response {
320 state.handler.agent_card.as_ref().map_or_else(
321 || a2a_error_to_response(&"agent card not configured", 404),
322 |card| axum::Json(card).into_response(),
323 )
324}
325
326async fn handle_health() -> axum::response::Response {
327 axum::Json(serde_json::json!({"status": "ok"})).into_response()
328}
329
330async fn handle_send_inner(
333 state: &A2aState,
334 streaming: bool,
335 headers: &axum::http::HeaderMap,
336 body: Bytes,
337) -> axum::response::Response {
338 let hdrs = extract_headers(headers);
339 let params: a2a_protocol_types::params::MessageSendParams = match serde_json::from_slice(&body)
340 {
341 Ok(p) => p,
342 Err(e) => return a2a_error_to_response(&e, 400),
343 };
344 match state
345 .handler
346 .on_send_message(params, streaming, Some(&hdrs))
347 .await
348 {
349 Ok(SendMessageResult::Response(resp)) => axum::Json(resp).into_response(),
350 Ok(SendMessageResult::Stream(reader)) => hyper_sse_to_axum(build_sse_response(
351 reader,
352 Some(state.config.sse_keep_alive_interval),
353 Some(state.config.sse_channel_capacity),
354 )),
355 Err(e) => handler_error_to_response(&e),
356 }
357}
358
359async fn handle_get_task_inner(
360 state: &A2aState,
361 id: &str,
362 hdrs: &HashMap<String, String>,
363) -> axum::response::Response {
364 let params = a2a_protocol_types::params::TaskQueryParams {
365 tenant: None,
366 id: id.to_owned(),
367 history_length: None,
368 };
369 match state.handler.on_get_task(params, Some(hdrs)).await {
370 Ok(task) => axum::Json(task).into_response(),
371 Err(e) => handler_error_to_response(&e),
372 }
373}
374
375async fn handle_cancel_task_inner(
376 state: &A2aState,
377 id: &str,
378 hdrs: &HashMap<String, String>,
379) -> axum::response::Response {
380 let params = a2a_protocol_types::params::CancelTaskParams {
381 tenant: None,
382 id: id.to_owned(),
383 metadata: None,
384 };
385 match state.handler.on_cancel_task(params, Some(hdrs)).await {
386 Ok(task) => axum::Json(task).into_response(),
387 Err(e) => handler_error_to_response(&e),
388 }
389}
390
391async fn handle_subscribe_inner(
392 state: &A2aState,
393 id: &str,
394 hdrs: &HashMap<String, String>,
395) -> axum::response::Response {
396 let params = a2a_protocol_types::params::TaskIdParams {
397 tenant: None,
398 id: id.to_owned(),
399 };
400 match state.handler.on_resubscribe(params, Some(hdrs)).await {
401 Ok(reader) => hyper_sse_to_axum(build_sse_response(
402 reader,
403 Some(state.config.sse_keep_alive_interval),
404 Some(state.config.sse_channel_capacity),
405 )),
406 Err(e) => handler_error_to_response(&e),
407 }
408}
409
410async fn handle_create_push_config_inner(
411 state: &A2aState,
412 task_id: &str,
413 hdrs: &HashMap<String, String>,
414 body: Bytes,
415) -> axum::response::Response {
416 let mut value: serde_json::Value = match serde_json::from_slice(&body) {
417 Ok(v) => v,
418 Err(e) => return a2a_error_to_response(&e, 400),
419 };
420 if let Some(obj) = value.as_object_mut() {
421 obj.entry("taskId")
422 .or_insert_with(|| serde_json::Value::String(task_id.to_owned()));
423 }
424 let config: a2a_protocol_types::push::TaskPushNotificationConfig =
425 match serde_json::from_value(value) {
426 Ok(c) => c,
427 Err(e) => return a2a_error_to_response(&e, 400),
428 };
429 match state.handler.on_set_push_config(config, Some(hdrs)).await {
430 Ok(result) => axum::Json(result).into_response(),
431 Err(e) => handler_error_to_response(&e),
432 }
433}
434
435async fn handle_get_push_config_inner(
436 state: &A2aState,
437 task_id: &str,
438 config_id: &str,
439 hdrs: &HashMap<String, String>,
440) -> axum::response::Response {
441 let params = a2a_protocol_types::params::GetPushConfigParams {
442 tenant: None,
443 task_id: task_id.to_owned(),
444 id: config_id.to_owned(),
445 };
446 match state.handler.on_get_push_config(params, Some(hdrs)).await {
447 Ok(config) => axum::Json(config).into_response(),
448 Err(e) => handler_error_to_response(&e),
449 }
450}
451
452async fn handle_list_push_configs_inner(
453 state: &A2aState,
454 task_id: &str,
455 hdrs: &HashMap<String, String>,
456) -> axum::response::Response {
457 match state
458 .handler
459 .on_list_push_configs(task_id, None, Some(hdrs))
460 .await
461 {
462 Ok(configs) => {
463 let resp = a2a_protocol_types::responses::ListPushConfigsResponse {
464 configs,
465 next_page_token: None,
466 };
467 axum::Json(resp).into_response()
468 }
469 Err(e) => handler_error_to_response(&e),
470 }
471}
472
473async fn handle_delete_push_config_inner(
474 state: &A2aState,
475 task_id: &str,
476 config_id: &str,
477 hdrs: &HashMap<String, String>,
478) -> axum::response::Response {
479 let params = a2a_protocol_types::params::DeletePushConfigParams {
480 tenant: None,
481 task_id: task_id.to_owned(),
482 id: config_id.to_owned(),
483 };
484 match state
485 .handler
486 .on_delete_push_config(params, Some(hdrs))
487 .await
488 {
489 Ok(()) => axum::Json(serde_json::json!({})).into_response(),
490 Err(e) => handler_error_to_response(&e),
491 }
492}
493
494#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn extract_headers_lowercases_names() {
502 let mut map = axum::http::HeaderMap::new();
503 map.insert("X-Request-ID", "abc".parse().unwrap());
504 map.insert("content-type", "application/json".parse().unwrap());
505
506 let result = extract_headers(&map);
507 assert_eq!(result.get("x-request-id").unwrap(), "abc");
508 assert_eq!(result.get("content-type").unwrap(), "application/json");
509 }
510
511 #[test]
512 fn extract_headers_skips_non_utf8_values() {
513 let mut map = axum::http::HeaderMap::new();
514 map.insert("good", "valid".parse().unwrap());
515 let result = extract_headers(&map);
517 assert_eq!(result.len(), 1);
518 assert_eq!(result.get("good").unwrap(), "valid");
519 }
520
521 #[test]
522 fn extract_headers_empty_map() {
523 let map = axum::http::HeaderMap::new();
524 let result = extract_headers(&map);
525 assert!(result.is_empty());
526 }
527
528 #[test]
529 fn a2a_state_is_clone() {
530 fn assert_clone<T: Clone>() {}
531 assert_clone::<A2aState>();
532 }
533
534 #[test]
535 fn server_error_status_task_not_found() {
536 use crate::error::ServerError;
537 assert_eq!(
538 server_error_status(&ServerError::TaskNotFound("t".into())),
539 404
540 );
541 }
542
543 #[test]
544 fn server_error_status_method_not_found() {
545 use crate::error::ServerError;
546 assert_eq!(
547 server_error_status(&ServerError::MethodNotFound("m".into())),
548 404
549 );
550 }
551
552 #[test]
553 fn server_error_status_invalid_params() {
554 use crate::error::ServerError;
555 assert_eq!(
556 server_error_status(&ServerError::InvalidParams("p".into())),
557 400
558 );
559 }
560
561 #[test]
562 fn server_error_status_serialization() {
563 use crate::error::ServerError;
564 let err = ServerError::Serialization(serde_json::from_str::<String>("bad").unwrap_err());
565 assert_eq!(server_error_status(&err), 400);
566 }
567
568 #[test]
569 fn server_error_status_task_not_cancelable() {
570 use crate::error::ServerError;
571 assert_eq!(
572 server_error_status(&ServerError::TaskNotCancelable("t".into())),
573 409
574 );
575 }
576
577 #[test]
578 fn server_error_status_invalid_state_transition() {
579 use crate::error::ServerError;
580 let err = ServerError::InvalidStateTransition {
581 task_id: "t".into(),
582 from: a2a_protocol_types::task::TaskState::Working,
583 to: a2a_protocol_types::task::TaskState::Submitted,
584 };
585 assert_eq!(server_error_status(&err), 409);
586 }
587
588 #[test]
589 fn server_error_status_push_not_supported() {
590 use crate::error::ServerError;
591 assert_eq!(server_error_status(&ServerError::PushNotSupported), 501);
592 }
593
594 #[test]
595 fn server_error_status_payload_too_large() {
596 use crate::error::ServerError;
597 assert_eq!(
598 server_error_status(&ServerError::PayloadTooLarge("big".into())),
599 413
600 );
601 }
602
603 #[test]
604 fn server_error_status_internal() {
605 use crate::error::ServerError;
606 assert_eq!(
607 server_error_status(&ServerError::Internal("oops".into())),
608 500
609 );
610 }
611
612 #[test]
613 fn a2a_error_to_response_returns_correct_status() {
614 let resp = a2a_error_to_response(&"test error", 400);
615 assert_eq!(resp.status().as_u16(), 400);
616 }
617
618 #[test]
619 fn a2a_error_to_response_returns_json_body() {
620 let resp = a2a_error_to_response(&"not found", 404);
621 assert_eq!(resp.status().as_u16(), 404);
622 }
623
624 #[test]
625 fn a2a_error_to_response_invalid_status_falls_back_to_500() {
626 let resp = a2a_error_to_response(&"bad status", 1000);
628 assert_eq!(resp.status().as_u16(), 500);
629 }
630
631 #[test]
632 fn handler_error_to_response_maps_correctly() {
633 use crate::error::ServerError;
634 let resp = handler_error_to_response(&ServerError::TaskNotFound("t1".into()));
635 assert_eq!(resp.status().as_u16(), 404);
636
637 let resp = handler_error_to_response(&ServerError::InvalidParams("bad".into()));
638 assert_eq!(resp.status().as_u16(), 400);
639
640 let resp = handler_error_to_response(&ServerError::Internal("oops".into()));
641 assert_eq!(resp.status().as_u16(), 500);
642 }
643
644 #[test]
645 fn a2a_router_new_creates_with_defaults() {
646 use crate::builder::RequestHandlerBuilder;
648
649 struct NoopExecutor;
650 impl crate::executor::AgentExecutor for NoopExecutor {
651 fn execute<'a>(
652 &'a self,
653 _ctx: &'a crate::request_context::RequestContext,
654 _queue: &'a dyn crate::streaming::EventQueueWriter,
655 ) -> std::pin::Pin<
656 Box<
657 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
658 + Send
659 + 'a,
660 >,
661 > {
662 Box::pin(async { Ok(()) })
663 }
664 }
665
666 let handler = Arc::new(RequestHandlerBuilder::new(NoopExecutor).build().unwrap());
667 let router = A2aRouter::new(handler);
668 let _axum_router = router.into_router();
670 }
671
672 #[test]
673 fn a2a_router_with_config() {
674 use crate::builder::RequestHandlerBuilder;
675
676 struct NoopExecutor;
677 impl crate::executor::AgentExecutor for NoopExecutor {
678 fn execute<'a>(
679 &'a self,
680 _ctx: &'a crate::request_context::RequestContext,
681 _queue: &'a dyn crate::streaming::EventQueueWriter,
682 ) -> std::pin::Pin<
683 Box<
684 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
685 + Send
686 + 'a,
687 >,
688 > {
689 Box::pin(async { Ok(()) })
690 }
691 }
692
693 let handler = Arc::new(RequestHandlerBuilder::new(NoopExecutor).build().unwrap());
694 let config =
695 super::super::DispatchConfig::default().with_max_request_body_size(8 * 1024 * 1024);
696 let router = A2aRouter::with_config(handler, config);
697 let _axum_router = router.into_router();
698 }
699}