1use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17 error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18 CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19 GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20 ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21 SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22 TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use futures::stream::Stream;
37use tokio::sync::broadcast;
38use tracing::info;
39
40use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
41use crate::task_store::TaskStore;
42use crate::webhook_delivery::WebhookDelivery;
43use crate::webhook_store::WebhookStore;
44
45#[derive(Debug, Clone)]
46pub struct ServerConfig {
47 pub bind_address: String,
48}
49
50impl Default for ServerConfig {
51 fn default() -> Self {
52 Self {
53 bind_address: "0.0.0.0:8080".to_string(),
54 }
55 }
56}
57
58pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
59
60const EVENT_CHANNEL_CAPACITY: usize = 1024;
61
62pub struct A2aServer {
63 config: ServerConfig,
64 handler: BoxedHandler,
65 task_store: TaskStore,
66 webhook_store: WebhookStore,
67 auth_extractor: Option<AuthExtractor>,
68 additional_routes: Option<Router<AppState>>,
69 event_tx: broadcast::Sender<StreamResponse>,
70}
71
72impl A2aServer {
73 pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
74 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
75 Self {
76 config: ServerConfig::default(),
77 handler: Arc::new(handler),
78 task_store: TaskStore::new(),
79 webhook_store: WebhookStore::new(),
80 auth_extractor: None,
81 additional_routes: None,
82 event_tx,
83 }
84 }
85
86 pub fn echo() -> Self {
87 Self::new(EchoHandler::default())
88 }
89
90 pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
91 let _: SocketAddr = address.parse()?;
92 self.config.bind_address = address.to_string();
93 Ok(self)
94 }
95
96 pub fn bind_unchecked(mut self, address: &str) -> Self {
97 self.config.bind_address = address.to_string();
98 self
99 }
100
101 pub fn task_store(mut self, store: TaskStore) -> Self {
102 self.task_store = store;
103 self
104 }
105
106 pub fn auth_extractor<F>(mut self, extractor: F) -> Self
107 where
108 F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
109 {
110 self.auth_extractor = Some(Arc::new(extractor));
111 self
112 }
113
114 pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
115 self.additional_routes = Some(routes);
116 self
117 }
118
119 pub fn get_task_store(&self) -> TaskStore {
120 self.task_store.clone()
121 }
122
123 pub fn build_router(self) -> Router {
124 let bind: SocketAddr = self
125 .config
126 .bind_address
127 .parse()
128 .expect("Invalid bind address");
129 let base_url = format!("http://{}", bind);
130 let card = Arc::new(self.handler.agent_card(&base_url));
131
132 let state = AppState {
133 handler: self.handler,
134 task_store: self.task_store,
135 webhook_store: self.webhook_store,
136 card,
137 auth_extractor: self.auth_extractor,
138 event_tx: self.event_tx,
139 };
140
141 let timed_routes = Router::new()
142 .route("/health", get(health))
143 .route("/.well-known/agent-card.json", get(agent_card))
144 .route("/v1/rpc", post(handle_rpc))
145 .layer(
146 ServiceBuilder::new()
147 .layer(HandleErrorLayer::new(handle_timeout_error))
148 .timeout(Duration::from_secs(30)),
149 );
150
151 let mut router = timed_routes;
152
153 if let Some(additional) = self.additional_routes {
154 router = router.merge(additional);
155 }
156
157 router.with_state(state)
158 }
159
160 pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
161 self.event_tx.clone()
162 }
163
164 pub async fn run(self) -> anyhow::Result<()> {
165 let bind: SocketAddr = self.config.bind_address.parse()?;
166
167 let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
168 webhook_delivery.start(self.event_tx.subscribe());
169
170 let router = self.build_router();
171
172 info!(%bind, "Starting A2A server");
173 let listener = tokio::net::TcpListener::bind(bind).await?;
174 axum::serve(listener, router)
175 .with_graceful_shutdown(shutdown_signal())
176 .await?;
177
178 info!("Server shutdown complete");
179 Ok(())
180 }
181}
182
183async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
184 if err.is::<tower::timeout::error::Elapsed>() {
185 (
186 StatusCode::REQUEST_TIMEOUT,
187 Json(error(
188 serde_json::Value::Null,
189 errors::INTERNAL_ERROR,
190 "Request timed out",
191 None,
192 )),
193 )
194 } else {
195 (
196 StatusCode::INTERNAL_SERVER_ERROR,
197 Json(error(
198 serde_json::Value::Null,
199 errors::INTERNAL_ERROR,
200 &format!("Internal error: {}", err),
201 None,
202 )),
203 )
204 }
205}
206
207async fn shutdown_signal() {
208 let ctrl_c = async {
209 signal::ctrl_c()
210 .await
211 .expect("failed to install Ctrl+C handler");
212 };
213
214 #[cfg(unix)]
215 let terminate = async {
216 signal::unix::signal(signal::unix::SignalKind::terminate())
217 .expect("failed to install SIGTERM handler")
218 .recv()
219 .await;
220 };
221
222 #[cfg(not(unix))]
223 let terminate = std::future::pending::<()>();
224
225 tokio::select! {
226 _ = ctrl_c => {},
227 _ = terminate => {},
228 }
229
230 info!("Shutdown signal received, initiating graceful shutdown...");
231}
232
233#[derive(Clone)]
234pub struct AppState {
235 handler: BoxedHandler,
236 task_store: TaskStore,
237 webhook_store: WebhookStore,
238 card: Arc<AgentCard>,
239 auth_extractor: Option<AuthExtractor>,
240 event_tx: broadcast::Sender<StreamResponse>,
241}
242
243impl AppState {
244 pub fn task_store(&self) -> &TaskStore {
245 &self.task_store
246 }
247
248 pub fn agent_card(&self) -> &AgentCard {
249 &self.card
250 }
251
252 pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
253 &self.event_tx
254 }
255
256 pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
257 self.event_tx.subscribe()
258 }
259
260 pub fn broadcast_event(&self, event: StreamResponse) {
261 let _ = self.event_tx.send(event);
262 }
263
264 fn streaming_enabled(&self) -> bool {
266 self.card.capabilities.streaming.unwrap_or(false)
267 }
268
269 fn push_notifications_enabled(&self) -> bool {
271 self.card.capabilities.push_notifications.unwrap_or(false)
272 }
273
274 fn endpoint_url(&self) -> &str {
276 self.card.endpoint().unwrap_or("")
277 }
278}
279
280fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
283 match history_length {
284 Some(0) => {
285 task.history = None;
286 }
287 Some(n) => {
288 if let Some(ref mut history) = task.history {
289 let len = history.len();
290 if len > n as usize {
291 *history = history.split_off(len - n as usize);
292 }
293 }
294 }
295 None => {}
296 }
297}
298
299fn validate_a2a_version(
302 headers: &HeaderMap,
303 req_id: &serde_json::Value,
304) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
305 if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
306 let version_str = version_header.to_str().unwrap_or("");
307
308 let parts: Vec<&str> = version_str.split('.').collect();
309 if parts.len() >= 2 {
310 if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
311 if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
312 return Ok(());
313 }
314
315 return Err((
316 StatusCode::BAD_REQUEST,
317 Json(error(
318 req_id.clone(),
319 errors::VERSION_NOT_SUPPORTED,
320 &format!(
321 "Protocol version {}.{} not supported. Supported version: {}.{}",
322 major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
323 ),
324 Some(serde_json::json!({
325 "requestedVersion": version_str,
326 "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
327 })),
328 )),
329 ));
330 }
331 }
332
333 if version_str == "1" || version_str == "1.0" {
335 return Ok(());
336 }
337
338 return Err((
339 StatusCode::BAD_REQUEST,
340 Json(error(
341 req_id.clone(),
342 errors::VERSION_NOT_SUPPORTED,
343 &format!(
344 "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
345 version_str
346 ),
347 None,
348 )),
349 ));
350 }
351
352 Ok(())
353}
354
355#[allow(dead_code)]
358pub fn rpc_error(
359 id: serde_json::Value,
360 code: i32,
361 message: &str,
362 status: StatusCode,
363) -> (StatusCode, Json<JsonRpcResponse>) {
364 (status, Json(error(id, code, message, None)))
365}
366
367#[allow(dead_code)]
368pub fn rpc_error_with_data(
369 id: serde_json::Value,
370 code: i32,
371 message: &str,
372 data: serde_json::Value,
373 status: StatusCode,
374) -> (StatusCode, Json<JsonRpcResponse>) {
375 (status, Json(error(id, code, message, Some(data))))
376}
377
378#[allow(dead_code)]
379pub fn rpc_success(
380 id: serde_json::Value,
381 result: serde_json::Value,
382) -> (StatusCode, Json<JsonRpcResponse>) {
383 (StatusCode::OK, Json(success(id, result)))
384}
385
386async fn health() -> Json<serde_json::Value> {
389 Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
390}
391
392async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
393 Json((*state.card).clone())
394}
395
396async fn handle_rpc(
400 State(state): State<AppState>,
401 headers: HeaderMap,
402 Json(req): Json<JsonRpcRequest>,
403) -> Response {
404 if req.jsonrpc != "2.0" {
405 let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
406 return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
407 }
408
409 if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
410 return err_response.into_response();
411 }
412
413 let auth_context = state
414 .auth_extractor
415 .as_ref()
416 .and_then(|extractor| extractor(&headers));
417
418 match req.method.as_str() {
419 "message/send" => handle_message_send(state, req, auth_context)
421 .await
422 .into_response(),
423 "message/stream" => {
424 handle_message_stream(state, req, headers, auth_context)
425 .await
426 .into_response()
427 }
428 "tasks/get" => handle_tasks_get(state, req).await.into_response(),
429 "tasks/list" => handle_tasks_list(state, req).await.into_response(),
430 "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
431 "tasks/resubscribe" => handle_tasks_resubscribe(state, req).await.into_response(),
432 "tasks/pushNotificationConfig/create" => {
433 handle_push_config_create(state, req).await.into_response()
434 }
435 "tasks/pushNotificationConfig/get" => {
436 handle_push_config_get(state, req).await.into_response()
437 }
438 "tasks/pushNotificationConfig/list" => {
439 handle_push_config_list(state, req).await.into_response()
440 }
441 "tasks/pushNotificationConfig/delete" => {
442 handle_push_config_delete(state, req).await.into_response()
443 }
444 "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context)
445 .await
446 .into_response(),
447 _ => (
448 StatusCode::NOT_FOUND,
449 Json(error(
450 req.id,
451 errors::METHOD_NOT_FOUND,
452 "method not found",
453 None,
454 )),
455 )
456 .into_response(),
457 }
458}
459
460fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
461 match e {
462 HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
463 HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
464 HandlerError::BackendUnavailable { .. } => {
465 (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
466 }
467 HandlerError::ProcessingFailed { .. } => {
468 (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
469 }
470 HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
471 }
472}
473
474async fn handle_message_send(
475 state: AppState,
476 req: JsonRpcRequest,
477 auth_context: Option<AuthContext>,
478) -> (StatusCode, Json<JsonRpcResponse>) {
479 let req_id = req.id.clone();
480
481 let params: Result<SendMessageRequest, _> =
482 serde_json::from_value(req.params.clone().unwrap_or_default());
483
484 let params = match params {
485 Ok(p) => p,
486 Err(err) => {
487 return (
488 StatusCode::BAD_REQUEST,
489 Json(error(
490 req_id,
491 errors::INVALID_PARAMS,
492 "invalid params",
493 Some(serde_json::json!({"error": err.to_string()})),
494 )),
495 );
496 }
497 };
498
499 let blocking = params
500 .configuration
501 .as_ref()
502 .and_then(|c| c.blocking)
503 .unwrap_or(false);
504 let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
505
506 match state
507 .handler
508 .handle_message(params.message, auth_context)
509 .await
510 {
511 Ok(response) => {
512 match response {
513 SendMessageResponse::Task(mut task) => {
514 state.task_store.insert(task.clone()).await;
516 state.broadcast_event(StreamResponse::Task(task.clone()));
517
518 if blocking && !task.status.state.is_terminal() {
520 let task_id = task.id.clone();
521 let mut rx = state.subscribe_events();
522
523 let wait_result = timeout(BLOCKING_TIMEOUT, async {
524 loop {
525 tokio::select! {
526 result = rx.recv() => {
527 match result {
528 Ok(StreamResponse::Task(t)) if t.id == task_id => {
529 if t.status.state.is_terminal() {
530 return Some(t);
531 }
532 }
533 Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
534 if e.status.state.is_terminal() {
535 if let Some(t) = state.task_store.get(&task_id).await {
536 return Some(t);
537 }
538 }
539 }
540 Err(broadcast::error::RecvError::Closed) => {
541 return None;
542 }
543 _ => {}
544 }
545 }
546 _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
547 if let Some(t) = state.task_store.get(&task_id).await {
548 if t.status.state.is_terminal() {
549 return Some(t);
550 }
551 }
552 }
553 }
554 }
555 })
556 .await;
557
558 match wait_result {
559 Ok(Some(final_task)) => task = final_task,
560 Ok(None) => {
561 if let Some(t) = state.task_store.get(&task.id).await {
562 task = t;
563 }
564 }
565 Err(_) => {
566 tracing::warn!("Blocking request timed out for task {}", task.id);
567 if let Some(t) = state.task_store.get(&task.id).await {
568 task = t;
569 }
570 }
571 }
572 }
573
574 apply_history_length(&mut task, history_length);
575
576 match serde_json::to_value(&task) {
579 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
580 Err(e) => (
581 StatusCode::INTERNAL_SERVER_ERROR,
582 Json(error(
583 req_id,
584 errors::INTERNAL_ERROR,
585 "serialization failed",
586 Some(serde_json::json!({"error": e.to_string()})),
587 )),
588 ),
589 }
590 }
591 SendMessageResponse::Message(msg) => {
592 match serde_json::to_value(&msg) {
594 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
595 Err(e) => (
596 StatusCode::INTERNAL_SERVER_ERROR,
597 Json(error(
598 req_id,
599 errors::INTERNAL_ERROR,
600 "serialization failed",
601 Some(serde_json::json!({"error": e.to_string()})),
602 )),
603 ),
604 }
605 }
606 }
607 }
608 Err(e) => {
609 let (code, status) = handler_error_to_rpc(&e);
610 (status, Json(error(req_id, code, &e.to_string(), None)))
611 }
612 }
613}
614
615async fn handle_message_stream(
620 state: AppState,
621 req: JsonRpcRequest,
622 _headers: HeaderMap,
623 auth_context: Option<AuthContext>,
624) -> Response {
625 let req_id = req.id.clone();
626
627 if !state.streaming_enabled() {
628 return (
629 StatusCode::BAD_REQUEST,
630 Json(error(
631 req_id,
632 errors::UNSUPPORTED_OPERATION,
633 "streaming not supported by this agent",
634 None,
635 )),
636 )
637 .into_response();
638 }
639
640 let params: Result<SendMessageRequest, _> =
641 serde_json::from_value(req.params.clone().unwrap_or_default());
642
643 let params = match params {
644 Ok(p) => p,
645 Err(err) => {
646 return (
647 StatusCode::BAD_REQUEST,
648 Json(error(
649 req_id,
650 errors::INVALID_PARAMS,
651 "invalid params",
652 Some(serde_json::json!({"error": err.to_string()})),
653 )),
654 )
655 .into_response();
656 }
657 };
658
659 let response = match state
660 .handler
661 .handle_message(params.message, auth_context)
662 .await
663 {
664 Ok(r) => r,
665 Err(e) => {
666 let (code, status) = handler_error_to_rpc(&e);
667 return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
668 }
669 };
670
671 let task = match response {
673 SendMessageResponse::Task(t) => t,
674 SendMessageResponse::Message(_) => {
675 return (
676 StatusCode::BAD_REQUEST,
677 Json(error(
678 req_id,
679 errors::UNSUPPORTED_OPERATION,
680 "handler returned a message, streaming requires a task",
681 None,
682 )),
683 )
684 .into_response();
685 }
686 };
687
688 let task_id = task.id.clone();
689 state.task_store.insert(task.clone()).await;
690 state.broadcast_event(StreamResponse::Task(task.clone()));
691
692 let mut rx = state.subscribe_events();
693 let task_store = state.task_store.clone();
694 let target_task_id = task_id;
695
696 let wrap = move |value: serde_json::Value| -> String {
698 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
699 };
700
701 let stream = async_stream::stream! {
702 if let Ok(val) = serde_json::to_value(&task) {
704 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
705 }
706
707 loop {
708 match rx.recv().await {
709 Ok(event) => {
710 let matches = match &event {
711 StreamResponse::Task(t) => t.id == target_task_id,
712 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
713 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
714 StreamResponse::Message(m) => {
715 m.context_id.as_ref().is_some_and(|ctx| {
716 task_store.get(&target_task_id).now_or_never()
717 .flatten()
718 .is_some_and(|t| t.context_id == *ctx)
719 })
720 }
721 };
722
723 if matches {
724 let val = match &event {
726 StreamResponse::Task(t) => serde_json::to_value(t),
727 StreamResponse::Message(m) => serde_json::to_value(m),
728 StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
729 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
730 };
731 if let Ok(val) = val {
732 yield Ok(Event::default().data(wrap(val)));
733 }
734
735 let is_terminal = match &event {
737 StreamResponse::Task(t) => t.status.state.is_terminal(),
738 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
739 _ => false,
740 };
741 if is_terminal {
742 break;
743 }
744 }
745 }
746 Err(broadcast::error::RecvError::Lagged(_)) => continue,
747 Err(broadcast::error::RecvError::Closed) => break,
748 }
749 }
750 };
751
752 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
753}
754
755async fn handle_tasks_get(
756 state: AppState,
757 req: JsonRpcRequest,
758) -> (StatusCode, Json<JsonRpcResponse>) {
759 let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
760
761 match params {
762 Ok(p) => match state.task_store.get_flexible(&p.id).await {
763 Some(mut task) => {
764 apply_history_length(&mut task, p.history_length);
765
766 match serde_json::to_value(task) {
767 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
768 Err(e) => (
769 StatusCode::INTERNAL_SERVER_ERROR,
770 Json(error(
771 req.id,
772 errors::INTERNAL_ERROR,
773 "serialization failed",
774 Some(serde_json::json!({"error": e.to_string()})),
775 )),
776 ),
777 }
778 }
779 None => (
780 StatusCode::NOT_FOUND,
781 Json(error(
782 req.id,
783 errors::TASK_NOT_FOUND,
784 "task not found",
785 None,
786 )),
787 ),
788 },
789 Err(err) => (
790 StatusCode::BAD_REQUEST,
791 Json(error(
792 req.id,
793 errors::INVALID_PARAMS,
794 "invalid params",
795 Some(serde_json::json!({"error": err.to_string()})),
796 )),
797 ),
798 }
799}
800
801async fn handle_tasks_cancel(
802 state: AppState,
803 req: JsonRpcRequest,
804) -> (StatusCode, Json<JsonRpcResponse>) {
805 let params: Result<CancelTaskRequest, _> =
806 serde_json::from_value(req.params.unwrap_or_default());
807
808 match params {
809 Ok(p) => {
810 let result = state
811 .task_store
812 .update_flexible(&p.id, |task| {
813 if task.status.state.is_terminal() {
814 return Err(errors::TASK_NOT_CANCELABLE);
815 }
816 task.status.state = TaskState::Canceled;
817 task.status.timestamp = Some(now_iso8601());
818 Ok(())
819 })
820 .await;
821
822 match result {
823 Some(Ok(task)) => {
824 if let Err(e) = state.handler.cancel_task(&task.id).await {
825 tracing::warn!("Handler cancel_task failed: {}", e);
826 }
827
828 state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
829 kind: "status-update".to_string(),
830 task_id: task.id.clone(),
831 context_id: task.context_id.clone(),
832 status: task.status.clone(),
833 is_final: true,
834 metadata: None,
835 }));
836
837 match serde_json::to_value(task) {
838 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
839 Err(e) => (
840 StatusCode::INTERNAL_SERVER_ERROR,
841 Json(error(
842 req.id,
843 errors::INTERNAL_ERROR,
844 "serialization failed",
845 Some(serde_json::json!({"error": e.to_string()})),
846 )),
847 ),
848 }
849 }
850 Some(Err(error_code)) => (
851 StatusCode::BAD_REQUEST,
852 Json(error(req.id, error_code, "task not cancelable", None)),
853 ),
854 None => (
855 StatusCode::NOT_FOUND,
856 Json(error(
857 req.id,
858 errors::TASK_NOT_FOUND,
859 "task not found",
860 None,
861 )),
862 ),
863 }
864 }
865 Err(err) => (
866 StatusCode::BAD_REQUEST,
867 Json(error(
868 req.id,
869 errors::INVALID_PARAMS,
870 "invalid params",
871 Some(serde_json::json!({"error": err.to_string()})),
872 )),
873 ),
874 }
875}
876
877async fn handle_tasks_list(
878 state: AppState,
879 req: JsonRpcRequest,
880) -> (StatusCode, Json<JsonRpcResponse>) {
881 let params: Result<ListTasksRequest, _> =
882 serde_json::from_value(req.params.unwrap_or_default());
883
884 match params {
885 Ok(p) => {
886 let response = state.task_store.list_filtered(&p).await;
887 match serde_json::to_value(response) {
888 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
889 Err(e) => (
890 StatusCode::INTERNAL_SERVER_ERROR,
891 Json(error(
892 req.id,
893 errors::INTERNAL_ERROR,
894 "serialization failed",
895 Some(serde_json::json!({"error": e.to_string()})),
896 )),
897 ),
898 }
899 }
900 Err(err) => (
901 StatusCode::BAD_REQUEST,
902 Json(error(
903 req.id,
904 errors::INVALID_PARAMS,
905 "invalid params",
906 Some(serde_json::json!({"error": err.to_string()})),
907 )),
908 ),
909 }
910}
911
912async fn handle_tasks_resubscribe(state: AppState, req: JsonRpcRequest) -> Response {
917 let req_id = req.id.clone();
918
919 let params: Result<SubscribeToTaskRequest, _> =
920 serde_json::from_value(req.params.unwrap_or_default());
921
922 let params = match params {
923 Ok(p) => p,
924 Err(err) => {
925 return (
926 StatusCode::BAD_REQUEST,
927 Json(error(
928 req_id,
929 errors::INVALID_PARAMS,
930 "invalid params",
931 Some(serde_json::json!({"error": err.to_string()})),
932 )),
933 )
934 .into_response();
935 }
936 };
937
938 let task = match state.task_store.get_flexible(¶ms.id).await {
940 Some(t) => t,
941 None => {
942 return (
943 StatusCode::NOT_FOUND,
944 Json(error(req_id, errors::TASK_NOT_FOUND, "task not found", None)),
945 )
946 .into_response();
947 }
948 };
949
950 let target_task_id = task.id.clone();
951 let mut rx = state.subscribe_events();
952 let task_store = state.task_store.clone();
953
954 let wrap = move |value: serde_json::Value| -> String {
955 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
956 };
957
958 let stream = async_stream::stream! {
959 if let Ok(val) = serde_json::to_value(&task) {
961 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
962 }
963
964 if task.status.state.is_terminal() {
966 return;
967 }
968
969 loop {
970 match rx.recv().await {
971 Ok(event) => {
972 let matches = match &event {
973 StreamResponse::Task(t) => t.id == target_task_id,
974 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
975 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
976 StreamResponse::Message(m) => {
977 m.context_id.as_ref().is_some_and(|ctx| {
978 task_store.get(&target_task_id).now_or_never()
979 .flatten()
980 .is_some_and(|t| t.context_id == *ctx)
981 })
982 }
983 };
984
985 if matches {
986 let val = match &event {
987 StreamResponse::Task(t) => serde_json::to_value(t),
988 StreamResponse::Message(m) => serde_json::to_value(m),
989 StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
990 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
991 };
992 if let Ok(val) = val {
993 yield Ok(Event::default().data(wrap(val)));
994 }
995
996 let is_terminal = match &event {
997 StreamResponse::Task(t) => t.status.state.is_terminal(),
998 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
999 _ => false,
1000 };
1001 if is_terminal {
1002 break;
1003 }
1004 }
1005 }
1006 Err(broadcast::error::RecvError::Lagged(_)) => continue,
1007 Err(broadcast::error::RecvError::Closed) => break,
1008 }
1009 }
1010 };
1011
1012 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
1013}
1014
1015async fn handle_get_extended_agent_card(
1016 state: AppState,
1017 req: JsonRpcRequest,
1018 auth_context: Option<AuthContext>,
1019) -> (StatusCode, Json<JsonRpcResponse>) {
1020 let Some(auth) = auth_context else {
1021 return (
1022 StatusCode::UNAUTHORIZED,
1023 Json(error(
1024 req.id,
1025 errors::INVALID_REQUEST,
1026 "authentication required for extended agent card",
1027 None,
1028 )),
1029 );
1030 };
1031
1032 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1033
1034 match state.handler.extended_agent_card(base_url, &auth).await {
1035 Some(card) => match serde_json::to_value(card) {
1036 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1037 Err(e) => (
1038 StatusCode::INTERNAL_SERVER_ERROR,
1039 Json(error(
1040 req.id,
1041 errors::INTERNAL_ERROR,
1042 "serialization failed",
1043 Some(serde_json::json!({"error": e.to_string()})),
1044 )),
1045 ),
1046 },
1047 None => (
1048 StatusCode::NOT_FOUND,
1049 Json(error(
1050 req.id,
1051 errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1052 "extended agent card not configured",
1053 None,
1054 )),
1055 ),
1056 }
1057}
1058
1059async fn handle_push_config_create(
1062 state: AppState,
1063 req: JsonRpcRequest,
1064) -> (StatusCode, Json<JsonRpcResponse>) {
1065 if !state.push_notifications_enabled() {
1066 return (
1067 StatusCode::BAD_REQUEST,
1068 Json(error(
1069 req.id,
1070 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1071 "push notifications not supported",
1072 None,
1073 )),
1074 );
1075 }
1076
1077 let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1078 serde_json::from_value(req.params.unwrap_or_default());
1079
1080 match params {
1081 Ok(p) => {
1082 if state.task_store.get_flexible(&p.task_id).await.is_none() {
1083 return (
1084 StatusCode::NOT_FOUND,
1085 Json(error(
1086 req.id,
1087 errors::TASK_NOT_FOUND,
1088 "task not found",
1089 None,
1090 )),
1091 );
1092 }
1093
1094 if let Err(e) = state
1095 .webhook_store
1096 .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1097 .await
1098 {
1099 return (
1100 StatusCode::BAD_REQUEST,
1101 Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1102 );
1103 }
1104
1105 (
1106 StatusCode::OK,
1107 Json(success(
1108 req.id,
1109 serde_json::json!({
1110 "configId": p.config_id,
1111 "config": p.push_notification_config
1112 }),
1113 )),
1114 )
1115 }
1116 Err(err) => (
1117 StatusCode::BAD_REQUEST,
1118 Json(error(
1119 req.id,
1120 errors::INVALID_PARAMS,
1121 "invalid params",
1122 Some(serde_json::json!({"error": err.to_string()})),
1123 )),
1124 ),
1125 }
1126}
1127
1128async fn handle_push_config_get(
1129 state: AppState,
1130 req: JsonRpcRequest,
1131) -> (StatusCode, Json<JsonRpcResponse>) {
1132 if !state.push_notifications_enabled() {
1133 return (
1134 StatusCode::BAD_REQUEST,
1135 Json(error(
1136 req.id,
1137 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1138 "push notifications not supported",
1139 None,
1140 )),
1141 );
1142 }
1143
1144 let params: Result<GetTaskPushNotificationConfigRequest, _> =
1145 serde_json::from_value(req.params.unwrap_or_default());
1146
1147 match params {
1148 Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1149 Some(config) => (
1150 StatusCode::OK,
1151 Json(success(
1152 req.id,
1153 serde_json::json!({
1154 "configId": p.id,
1155 "config": config
1156 }),
1157 )),
1158 ),
1159 None => (
1160 StatusCode::NOT_FOUND,
1161 Json(error(
1162 req.id,
1163 errors::TASK_NOT_FOUND,
1164 "push notification config not found",
1165 None,
1166 )),
1167 ),
1168 },
1169 Err(err) => (
1170 StatusCode::BAD_REQUEST,
1171 Json(error(
1172 req.id,
1173 errors::INVALID_PARAMS,
1174 "invalid params",
1175 Some(serde_json::json!({"error": err.to_string()})),
1176 )),
1177 ),
1178 }
1179}
1180
1181async fn handle_push_config_list(
1182 state: AppState,
1183 req: JsonRpcRequest,
1184) -> (StatusCode, Json<JsonRpcResponse>) {
1185 if !state.push_notifications_enabled() {
1186 return (
1187 StatusCode::BAD_REQUEST,
1188 Json(error(
1189 req.id,
1190 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1191 "push notifications not supported",
1192 None,
1193 )),
1194 );
1195 }
1196
1197 let params: Result<ListTaskPushNotificationConfigRequest, _> =
1198 serde_json::from_value(req.params.unwrap_or_default());
1199
1200 match params {
1201 Ok(p) => {
1202 let configs = state.webhook_store.list(&p.task_id).await;
1203
1204 let configs_json: Vec<_> = configs
1205 .iter()
1206 .map(|c| {
1207 serde_json::json!({
1208 "configId": c.config_id,
1209 "config": c.config
1210 })
1211 })
1212 .collect();
1213
1214 (
1215 StatusCode::OK,
1216 Json(success(
1217 req.id,
1218 serde_json::json!({
1219 "configs": configs_json,
1220 "nextPageToken": ""
1221 }),
1222 )),
1223 )
1224 }
1225 Err(err) => (
1226 StatusCode::BAD_REQUEST,
1227 Json(error(
1228 req.id,
1229 errors::INVALID_PARAMS,
1230 "invalid params",
1231 Some(serde_json::json!({"error": err.to_string()})),
1232 )),
1233 ),
1234 }
1235}
1236
1237async fn handle_push_config_delete(
1238 state: AppState,
1239 req: JsonRpcRequest,
1240) -> (StatusCode, Json<JsonRpcResponse>) {
1241 if !state.push_notifications_enabled() {
1242 return (
1243 StatusCode::BAD_REQUEST,
1244 Json(error(
1245 req.id,
1246 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1247 "push notifications not supported",
1248 None,
1249 )),
1250 );
1251 }
1252
1253 let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1254 serde_json::from_value(req.params.unwrap_or_default());
1255
1256 match params {
1257 Ok(p) => {
1258 if state.webhook_store.delete(&p.task_id, &p.id).await {
1259 (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1260 } else {
1261 (
1262 StatusCode::NOT_FOUND,
1263 Json(error(
1264 req.id,
1265 errors::TASK_NOT_FOUND,
1266 "push notification config not found",
1267 None,
1268 )),
1269 )
1270 }
1271 }
1272 Err(err) => (
1273 StatusCode::BAD_REQUEST,
1274 Json(error(
1275 req.id,
1276 errors::INVALID_PARAMS,
1277 "invalid params",
1278 Some(serde_json::json!({"error": err.to_string()})),
1279 )),
1280 ),
1281 }
1282}
1283
1284pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1285 bind_addr: &str,
1286 handler: H,
1287) -> anyhow::Result<()> {
1288 A2aServer::new(handler).bind(bind_addr)?.run().await
1289}
1290
1291pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1292 A2aServer::echo().bind(bind_addr)?.run().await
1293}