1use std::collections::HashMap;
53use std::convert::Infallible;
54use std::sync::Arc;
55
56use axum::body::Body;
57use axum::extract::{Path, Query, State};
58use axum::response::IntoResponse;
59use axum::routing::{get, post};
60use axum::Router;
61use bytes::Bytes;
62
63use crate::handler::{RequestHandler, SendMessageResult};
64use crate::streaming::build_sse_response;
65
66pub struct A2aRouter {
91 handler: Arc<RequestHandler>,
92 config: super::DispatchConfig,
93}
94
95impl A2aRouter {
96 #[must_use]
98 pub fn new(handler: Arc<RequestHandler>) -> Self {
99 Self {
100 handler,
101 config: super::DispatchConfig::default(),
102 }
103 }
104
105 #[must_use]
107 pub const fn with_config(handler: Arc<RequestHandler>, config: super::DispatchConfig) -> Self {
108 Self { handler, config }
109 }
110
111 pub fn into_router(self) -> Router {
116 let state = A2aState {
117 handler: self.handler,
118 config: Arc::new(self.config),
119 };
120
121 Router::new()
122 .route("/message:send", post(handle_send_message))
124 .route("/message:stream", post(handle_stream_message))
125 .route("/tasks", get(handle_list_tasks))
127 .route("/tasks/{*rest}", axum::routing::any(handle_tasks_catchall))
132 .route("/extendedAgentCard", get(handle_extended_card))
134 .route("/.well-known/agent-card.json", get(handle_agent_card))
136 .route("/health", get(handle_health))
138 .with_state(state)
139 }
140}
141
142#[derive(Clone)]
145struct A2aState {
146 handler: Arc<RequestHandler>,
147 config: Arc<super::DispatchConfig>,
148}
149
150fn extract_headers(headers: &axum::http::HeaderMap) -> HashMap<String, String> {
153 headers
154 .iter()
155 .filter_map(|(k, v)| {
156 v.to_str()
157 .ok()
158 .map(|val| (k.as_str().to_lowercase(), val.to_owned()))
159 })
160 .collect()
161}
162
163fn a2a_error_to_response(err: &dyn std::fmt::Display, status: u16) -> axum::response::Response {
166 let body = serde_json::json!({ "error": err.to_string() });
167 (
168 axum::http::StatusCode::from_u16(status)
169 .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
170 axum::Json(body),
171 )
172 .into_response()
173}
174
175const fn server_error_status(err: &crate::error::ServerError) -> u16 {
176 use crate::error::ServerError;
177
178 match err {
179 ServerError::TaskNotFound(_) | ServerError::MethodNotFound(_) => 404,
180 ServerError::InvalidParams(_) | ServerError::Serialization(_) => 400,
181 ServerError::InvalidStateTransition { .. } | ServerError::TaskNotCancelable(_) => 409,
182 ServerError::PushNotSupported => 501,
183 ServerError::PayloadTooLarge(_) => 413,
184 _ => 500,
185 }
186}
187
188fn handler_error_to_response(err: &crate::error::ServerError) -> axum::response::Response {
189 a2a_error_to_response(err, server_error_status(err))
190}
191
192fn hyper_sse_to_axum(
197 resp: hyper::Response<http_body_util::combinators::BoxBody<Bytes, Infallible>>,
198) -> axum::response::Response {
199 let (parts, body) = resp.into_parts();
200 let axum_body = Body::new(body);
201 axum::response::Response::from_parts(parts, axum_body)
202}
203
204async fn handle_tasks_catchall(
217 State(state): State<A2aState>,
218 method: axum::http::Method,
219 Path(rest): Path<String>,
220 headers: axum::http::HeaderMap,
221 body: Bytes,
222) -> axum::response::Response {
223 let hdrs = extract_headers(&headers);
224 let segments: Vec<&str> = rest.split('/').filter(|s| !s.is_empty()).collect();
225
226 match (method.as_str(), segments.as_slice()) {
227 ("GET", [id]) if !id.contains(':') => handle_get_task_inner(&state, id, &hdrs).await,
229
230 ("POST", [id_action]) if id_action.ends_with(":cancel") => {
232 let id = &id_action[..id_action.len() - ":cancel".len()];
233 handle_cancel_task_inner(&state, id, &hdrs).await
234 }
235
236 ("GET" | "POST", [id_action]) if id_action.ends_with(":subscribe") => {
238 let id = &id_action[..id_action.len() - ":subscribe".len()];
239 handle_subscribe_inner(&state, id, &hdrs).await
240 }
241
242 ("POST", [task_id, "pushNotificationConfigs"]) => {
244 handle_create_push_config_inner(&state, task_id, &hdrs, body).await
245 }
246
247 ("GET", [task_id, "pushNotificationConfigs"]) => {
249 handle_list_push_configs_inner(&state, task_id, &hdrs).await
250 }
251
252 ("GET", [task_id, "pushNotificationConfigs", config_id]) => {
254 handle_get_push_config_inner(&state, task_id, config_id, &hdrs).await
255 }
256
257 ("DELETE", [task_id, "pushNotificationConfigs", config_id]) => {
259 handle_delete_push_config_inner(&state, task_id, config_id, &hdrs).await
260 }
261
262 _ => a2a_error_to_response(&"not found", 404),
263 }
264}
265
266async fn handle_send_message(
269 State(state): State<A2aState>,
270 headers: axum::http::HeaderMap,
271 body: Bytes,
272) -> axum::response::Response {
273 handle_send_inner(&state, false, &headers, body).await
274}
275
276async fn handle_stream_message(
277 State(state): State<A2aState>,
278 headers: axum::http::HeaderMap,
279 body: Bytes,
280) -> axum::response::Response {
281 handle_send_inner(&state, true, &headers, body).await
282}
283
284async fn handle_list_tasks(
285 State(state): State<A2aState>,
286 Query(query): Query<HashMap<String, String>>,
287 headers: axum::http::HeaderMap,
288) -> axum::response::Response {
289 let hdrs = extract_headers(&headers);
290 let params = a2a_protocol_types::params::ListTasksParams {
291 tenant: None,
292 context_id: query.get("contextId").cloned(),
293 status: query
294 .get("status")
295 .and_then(|s| serde_json::from_value(serde_json::Value::String(s.clone())).ok()),
296 page_size: query.get("pageSize").and_then(|v| v.parse().ok()),
297 page_token: query.get("pageToken").cloned(),
298 status_timestamp_after: query.get("statusTimestampAfter").cloned(),
299 include_artifacts: query.get("includeArtifacts").and_then(|v| v.parse().ok()),
300 history_length: query.get("historyLength").and_then(|v| v.parse().ok()),
301 };
302 match state.handler.on_list_tasks(params, Some(&hdrs)).await {
303 Ok(result) => axum::Json(result).into_response(),
304 Err(e) => handler_error_to_response(&e),
305 }
306}
307
308async fn handle_extended_card(
309 State(state): State<A2aState>,
310 headers: axum::http::HeaderMap,
311) -> axum::response::Response {
312 let hdrs = extract_headers(&headers);
313 match state.handler.on_get_extended_agent_card(Some(&hdrs)).await {
314 Ok(card) => axum::Json(card).into_response(),
315 Err(e) => handler_error_to_response(&e),
316 }
317}
318
319async fn handle_agent_card(State(state): State<A2aState>) -> axum::response::Response {
320 state.handler.agent_card.as_ref().map_or_else(
321 || a2a_error_to_response(&"agent card not configured", 404),
322 |card| axum::Json(card).into_response(),
323 )
324}
325
326async fn handle_health() -> axum::response::Response {
327 axum::Json(serde_json::json!({"status": "ok"})).into_response()
328}
329
330async fn handle_send_inner(
333 state: &A2aState,
334 streaming: bool,
335 headers: &axum::http::HeaderMap,
336 body: Bytes,
337) -> axum::response::Response {
338 let hdrs = extract_headers(headers);
339 let params: a2a_protocol_types::params::MessageSendParams = match serde_json::from_slice(&body)
340 {
341 Ok(p) => p,
342 Err(e) => return a2a_error_to_response(&e, 400),
343 };
344 match state
345 .handler
346 .on_send_message(params, streaming, Some(&hdrs))
347 .await
348 {
349 Ok(SendMessageResult::Response(resp)) => axum::Json(resp).into_response(),
350 Ok(SendMessageResult::Stream(reader)) => hyper_sse_to_axum(build_sse_response(
351 reader,
352 Some(state.config.sse_keep_alive_interval),
353 Some(state.config.sse_channel_capacity),
354 false, )),
356 Err(e) => handler_error_to_response(&e),
357 }
358}
359
360async fn handle_get_task_inner(
361 state: &A2aState,
362 id: &str,
363 hdrs: &HashMap<String, String>,
364) -> axum::response::Response {
365 let params = a2a_protocol_types::params::TaskQueryParams {
366 tenant: None,
367 id: id.to_owned(),
368 history_length: None,
369 };
370 match state.handler.on_get_task(params, Some(hdrs)).await {
371 Ok(task) => axum::Json(task).into_response(),
372 Err(e) => handler_error_to_response(&e),
373 }
374}
375
376async fn handle_cancel_task_inner(
377 state: &A2aState,
378 id: &str,
379 hdrs: &HashMap<String, String>,
380) -> axum::response::Response {
381 let params = a2a_protocol_types::params::CancelTaskParams {
382 tenant: None,
383 id: id.to_owned(),
384 metadata: None,
385 };
386 match state.handler.on_cancel_task(params, Some(hdrs)).await {
387 Ok(task) => axum::Json(task).into_response(),
388 Err(e) => handler_error_to_response(&e),
389 }
390}
391
392async fn handle_subscribe_inner(
393 state: &A2aState,
394 id: &str,
395 hdrs: &HashMap<String, String>,
396) -> axum::response::Response {
397 let params = a2a_protocol_types::params::TaskIdParams {
398 tenant: None,
399 id: id.to_owned(),
400 };
401 match state.handler.on_resubscribe(params, Some(hdrs)).await {
402 Ok(reader) => hyper_sse_to_axum(build_sse_response(
403 reader,
404 Some(state.config.sse_keep_alive_interval),
405 Some(state.config.sse_channel_capacity),
406 false, )),
408 Err(e) => handler_error_to_response(&e),
409 }
410}
411
412async fn handle_create_push_config_inner(
413 state: &A2aState,
414 task_id: &str,
415 hdrs: &HashMap<String, String>,
416 body: Bytes,
417) -> axum::response::Response {
418 let mut value: serde_json::Value = match serde_json::from_slice(&body) {
419 Ok(v) => v,
420 Err(e) => return a2a_error_to_response(&e, 400),
421 };
422 if let Some(obj) = value.as_object_mut() {
423 obj.entry("taskId")
424 .or_insert_with(|| serde_json::Value::String(task_id.to_owned()));
425 }
426 let config: a2a_protocol_types::push::TaskPushNotificationConfig =
427 match serde_json::from_value(value) {
428 Ok(c) => c,
429 Err(e) => return a2a_error_to_response(&e, 400),
430 };
431 match state.handler.on_set_push_config(config, Some(hdrs)).await {
432 Ok(result) => axum::Json(result).into_response(),
433 Err(e) => handler_error_to_response(&e),
434 }
435}
436
437async fn handle_get_push_config_inner(
438 state: &A2aState,
439 task_id: &str,
440 config_id: &str,
441 hdrs: &HashMap<String, String>,
442) -> axum::response::Response {
443 let params = a2a_protocol_types::params::GetPushConfigParams {
444 tenant: None,
445 task_id: task_id.to_owned(),
446 id: config_id.to_owned(),
447 };
448 match state.handler.on_get_push_config(params, Some(hdrs)).await {
449 Ok(config) => axum::Json(config).into_response(),
450 Err(e) => handler_error_to_response(&e),
451 }
452}
453
454async fn handle_list_push_configs_inner(
455 state: &A2aState,
456 task_id: &str,
457 hdrs: &HashMap<String, String>,
458) -> axum::response::Response {
459 match state
460 .handler
461 .on_list_push_configs(task_id, None, Some(hdrs))
462 .await
463 {
464 Ok(configs) => {
465 let resp = a2a_protocol_types::responses::ListPushConfigsResponse {
466 configs,
467 next_page_token: None,
468 };
469 axum::Json(resp).into_response()
470 }
471 Err(e) => handler_error_to_response(&e),
472 }
473}
474
475async fn handle_delete_push_config_inner(
476 state: &A2aState,
477 task_id: &str,
478 config_id: &str,
479 hdrs: &HashMap<String, String>,
480) -> axum::response::Response {
481 let params = a2a_protocol_types::params::DeletePushConfigParams {
482 tenant: None,
483 task_id: task_id.to_owned(),
484 id: config_id.to_owned(),
485 };
486 match state
487 .handler
488 .on_delete_push_config(params, Some(hdrs))
489 .await
490 {
491 Ok(()) => axum::Json(serde_json::json!({})).into_response(),
492 Err(e) => handler_error_to_response(&e),
493 }
494}
495
496#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn extract_headers_lowercases_names() {
504 let mut map = axum::http::HeaderMap::new();
505 map.insert("X-Request-ID", "abc".parse().unwrap());
506 map.insert("content-type", "application/json".parse().unwrap());
507
508 let result = extract_headers(&map);
509 assert_eq!(result.get("x-request-id").unwrap(), "abc");
510 assert_eq!(result.get("content-type").unwrap(), "application/json");
511 }
512
513 #[test]
514 fn extract_headers_skips_non_utf8_values() {
515 let mut map = axum::http::HeaderMap::new();
516 map.insert("good", "valid".parse().unwrap());
517 let result = extract_headers(&map);
519 assert_eq!(result.len(), 1);
520 assert_eq!(result.get("good").unwrap(), "valid");
521 }
522
523 #[test]
524 fn extract_headers_empty_map() {
525 let map = axum::http::HeaderMap::new();
526 let result = extract_headers(&map);
527 assert!(result.is_empty());
528 }
529
530 #[test]
531 fn a2a_state_is_clone() {
532 fn assert_clone<T: Clone>() {}
533 assert_clone::<A2aState>();
534 }
535
536 #[test]
537 fn server_error_status_task_not_found() {
538 use crate::error::ServerError;
539 assert_eq!(
540 server_error_status(&ServerError::TaskNotFound("t".into())),
541 404
542 );
543 }
544
545 #[test]
546 fn server_error_status_method_not_found() {
547 use crate::error::ServerError;
548 assert_eq!(
549 server_error_status(&ServerError::MethodNotFound("m".into())),
550 404
551 );
552 }
553
554 #[test]
555 fn server_error_status_invalid_params() {
556 use crate::error::ServerError;
557 assert_eq!(
558 server_error_status(&ServerError::InvalidParams("p".into())),
559 400
560 );
561 }
562
563 #[test]
564 fn server_error_status_serialization() {
565 use crate::error::ServerError;
566 let err = ServerError::Serialization(serde_json::from_str::<String>("bad").unwrap_err());
567 assert_eq!(server_error_status(&err), 400);
568 }
569
570 #[test]
571 fn server_error_status_task_not_cancelable() {
572 use crate::error::ServerError;
573 assert_eq!(
574 server_error_status(&ServerError::TaskNotCancelable("t".into())),
575 409
576 );
577 }
578
579 #[test]
580 fn server_error_status_invalid_state_transition() {
581 use crate::error::ServerError;
582 let err = ServerError::InvalidStateTransition {
583 task_id: "t".into(),
584 from: a2a_protocol_types::task::TaskState::Working,
585 to: a2a_protocol_types::task::TaskState::Submitted,
586 };
587 assert_eq!(server_error_status(&err), 409);
588 }
589
590 #[test]
591 fn server_error_status_push_not_supported() {
592 use crate::error::ServerError;
593 assert_eq!(server_error_status(&ServerError::PushNotSupported), 501);
594 }
595
596 #[test]
597 fn server_error_status_payload_too_large() {
598 use crate::error::ServerError;
599 assert_eq!(
600 server_error_status(&ServerError::PayloadTooLarge("big".into())),
601 413
602 );
603 }
604
605 #[test]
606 fn server_error_status_internal() {
607 use crate::error::ServerError;
608 assert_eq!(
609 server_error_status(&ServerError::Internal("oops".into())),
610 500
611 );
612 }
613
614 #[test]
615 fn a2a_error_to_response_returns_correct_status() {
616 let resp = a2a_error_to_response(&"test error", 400);
617 assert_eq!(resp.status().as_u16(), 400);
618 }
619
620 #[test]
621 fn a2a_error_to_response_returns_json_body() {
622 let resp = a2a_error_to_response(&"not found", 404);
623 assert_eq!(resp.status().as_u16(), 404);
624 }
625
626 #[test]
627 fn a2a_error_to_response_invalid_status_falls_back_to_500() {
628 let resp = a2a_error_to_response(&"bad status", 1000);
630 assert_eq!(resp.status().as_u16(), 500);
631 }
632
633 #[test]
634 fn handler_error_to_response_maps_correctly() {
635 use crate::error::ServerError;
636 let resp = handler_error_to_response(&ServerError::TaskNotFound("t1".into()));
637 assert_eq!(resp.status().as_u16(), 404);
638
639 let resp = handler_error_to_response(&ServerError::InvalidParams("bad".into()));
640 assert_eq!(resp.status().as_u16(), 400);
641
642 let resp = handler_error_to_response(&ServerError::Internal("oops".into()));
643 assert_eq!(resp.status().as_u16(), 500);
644 }
645
646 #[test]
647 fn a2a_router_new_creates_with_defaults() {
648 use crate::builder::RequestHandlerBuilder;
650
651 struct NoopExecutor;
652 impl crate::executor::AgentExecutor for NoopExecutor {
653 fn execute<'a>(
654 &'a self,
655 _ctx: &'a crate::request_context::RequestContext,
656 _queue: &'a dyn crate::streaming::EventQueueWriter,
657 ) -> std::pin::Pin<
658 Box<
659 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
660 + Send
661 + 'a,
662 >,
663 > {
664 Box::pin(async { Ok(()) })
665 }
666 }
667
668 let handler = Arc::new(RequestHandlerBuilder::new(NoopExecutor).build().unwrap());
669 let router = A2aRouter::new(handler);
670 let _axum_router = router.into_router();
672 }
673
674 #[test]
675 fn a2a_router_with_config() {
676 use crate::builder::RequestHandlerBuilder;
677
678 struct NoopExecutor;
679 impl crate::executor::AgentExecutor for NoopExecutor {
680 fn execute<'a>(
681 &'a self,
682 _ctx: &'a crate::request_context::RequestContext,
683 _queue: &'a dyn crate::streaming::EventQueueWriter,
684 ) -> std::pin::Pin<
685 Box<
686 dyn std::future::Future<Output = a2a_protocol_types::error::A2aResult<()>>
687 + Send
688 + 'a,
689 >,
690 > {
691 Box::pin(async { Ok(()) })
692 }
693 }
694
695 let handler = Arc::new(RequestHandlerBuilder::new(NoopExecutor).build().unwrap());
696 let config =
697 super::super::DispatchConfig::default().with_max_request_body_size(8 * 1024 * 1024);
698 let router = A2aRouter::with_config(handler, config);
699 let _axum_router = router.into_router();
700 }
701}