1use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17 error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18 CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19 GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20 ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21 SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22 TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46 pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50 fn default() -> Self {
51 Self {
52 bind_address: "0.0.0.0:8080".to_string(),
53 }
54 }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62 config: ServerConfig,
63 handler: BoxedHandler,
64 task_store: TaskStore,
65 webhook_store: WebhookStore,
66 auth_extractor: Option<AuthExtractor>,
67 additional_routes: Option<Router<AppState>>,
68 event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72 pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74 Self {
75 config: ServerConfig::default(),
76 handler: Arc::new(handler),
77 task_store: TaskStore::new(),
78 webhook_store: WebhookStore::new(),
79 auth_extractor: None,
80 additional_routes: None,
81 event_tx,
82 }
83 }
84
85 pub fn echo() -> Self {
86 Self::new(EchoHandler::default())
87 }
88
89 pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90 let _: SocketAddr = address.parse()?;
91 self.config.bind_address = address.to_string();
92 Ok(self)
93 }
94
95 pub fn bind_unchecked(mut self, address: &str) -> Self {
96 self.config.bind_address = address.to_string();
97 self
98 }
99
100 pub fn task_store(mut self, store: TaskStore) -> Self {
101 self.task_store = store;
102 self
103 }
104
105 pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106 where
107 F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108 {
109 self.auth_extractor = Some(Arc::new(extractor));
110 self
111 }
112
113 pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114 self.additional_routes = Some(routes);
115 self
116 }
117
118 pub fn get_task_store(&self) -> TaskStore {
119 self.task_store.clone()
120 }
121
122 pub fn build_router(self) -> Router {
123 let bind: SocketAddr = self
124 .config
125 .bind_address
126 .parse()
127 .expect("Invalid bind address");
128 let base_url = format!("http://{}", bind);
129 let card = Arc::new(self.handler.agent_card(&base_url));
130
131 let state = AppState {
132 handler: self.handler,
133 task_store: self.task_store,
134 webhook_store: self.webhook_store,
135 card,
136 auth_extractor: self.auth_extractor,
137 event_tx: self.event_tx,
138 };
139
140 let timed_routes = Router::new()
141 .route("/health", get(health))
142 .route("/.well-known/agent-card.json", get(agent_card))
143 .route("/v1/rpc", post(handle_rpc))
144 .layer(
145 ServiceBuilder::new()
146 .layer(HandleErrorLayer::new(handle_timeout_error))
147 .timeout(Duration::from_secs(30)),
148 );
149
150 let mut router = timed_routes;
151
152 if let Some(additional) = self.additional_routes {
153 router = router.merge(additional);
154 }
155
156 router.with_state(state)
157 }
158
159 pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
160 self.event_tx.clone()
161 }
162
163 pub async fn run(self) -> anyhow::Result<()> {
164 let bind: SocketAddr = self.config.bind_address.parse()?;
165
166 let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
167 webhook_delivery.start(self.event_tx.subscribe());
168
169 let router = self.build_router();
170
171 info!(%bind, "Starting A2A server");
172 let listener = tokio::net::TcpListener::bind(bind).await?;
173 axum::serve(listener, router)
174 .with_graceful_shutdown(shutdown_signal())
175 .await?;
176
177 info!("Server shutdown complete");
178 Ok(())
179 }
180}
181
182async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
183 if err.is::<tower::timeout::error::Elapsed>() {
184 (
185 StatusCode::REQUEST_TIMEOUT,
186 Json(error(
187 serde_json::Value::Null,
188 errors::INTERNAL_ERROR,
189 "Request timed out",
190 None,
191 )),
192 )
193 } else {
194 (
195 StatusCode::INTERNAL_SERVER_ERROR,
196 Json(error(
197 serde_json::Value::Null,
198 errors::INTERNAL_ERROR,
199 &format!("Internal error: {}", err),
200 None,
201 )),
202 )
203 }
204}
205
206async fn shutdown_signal() {
207 let ctrl_c = async {
208 signal::ctrl_c()
209 .await
210 .expect("failed to install Ctrl+C handler");
211 };
212
213 #[cfg(unix)]
214 let terminate = async {
215 signal::unix::signal(signal::unix::SignalKind::terminate())
216 .expect("failed to install SIGTERM handler")
217 .recv()
218 .await;
219 };
220
221 #[cfg(not(unix))]
222 let terminate = std::future::pending::<()>();
223
224 tokio::select! {
225 _ = ctrl_c => {},
226 _ = terminate => {},
227 }
228
229 info!("Shutdown signal received, initiating graceful shutdown...");
230}
231
232#[derive(Clone)]
233pub struct AppState {
234 handler: BoxedHandler,
235 task_store: TaskStore,
236 webhook_store: WebhookStore,
237 card: Arc<AgentCard>,
238 auth_extractor: Option<AuthExtractor>,
239 event_tx: broadcast::Sender<StreamResponse>,
240}
241
242impl AppState {
243 pub fn task_store(&self) -> &TaskStore {
244 &self.task_store
245 }
246
247 pub fn agent_card(&self) -> &AgentCard {
248 &self.card
249 }
250
251 pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
252 &self.event_tx
253 }
254
255 pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
256 self.event_tx.subscribe()
257 }
258
259 pub fn broadcast_event(&self, event: StreamResponse) {
260 let _ = self.event_tx.send(event);
261 }
262
263 fn streaming_enabled(&self) -> bool {
265 self.card.capabilities.streaming.unwrap_or(false)
266 }
267
268 fn push_notifications_enabled(&self) -> bool {
270 self.card.capabilities.push_notifications.unwrap_or(false)
271 }
272
273 fn endpoint_url(&self) -> &str {
275 self.card.endpoint().unwrap_or("")
276 }
277}
278
279fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
282 match history_length {
283 Some(0) => {
284 task.history = None;
285 }
286 Some(n) => {
287 if let Some(ref mut history) = task.history {
288 let len = history.len();
289 if len > n as usize {
290 *history = history.split_off(len - n as usize);
291 }
292 }
293 }
294 None => {}
295 }
296}
297
298fn validate_a2a_version(
301 headers: &HeaderMap,
302 req_id: &serde_json::Value,
303) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
304 if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
305 let version_str = version_header.to_str().unwrap_or("");
306
307 let parts: Vec<&str> = version_str.split('.').collect();
308 if parts.len() >= 2 {
309 if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
310 if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
311 return Ok(());
312 }
313
314 return Err((
315 StatusCode::BAD_REQUEST,
316 Json(error(
317 req_id.clone(),
318 errors::VERSION_NOT_SUPPORTED,
319 &format!(
320 "Protocol version {}.{} not supported. Supported version: {}.{}",
321 major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
322 ),
323 Some(serde_json::json!({
324 "requestedVersion": version_str,
325 "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
326 })),
327 )),
328 ));
329 }
330 }
331
332 if version_str == "1" || version_str == "1.0" {
334 return Ok(());
335 }
336
337 return Err((
338 StatusCode::BAD_REQUEST,
339 Json(error(
340 req_id.clone(),
341 errors::VERSION_NOT_SUPPORTED,
342 &format!(
343 "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
344 version_str
345 ),
346 None,
347 )),
348 ));
349 }
350
351 Ok(())
352}
353
354#[allow(dead_code)]
357pub fn rpc_error(
358 id: serde_json::Value,
359 code: i32,
360 message: &str,
361 status: StatusCode,
362) -> (StatusCode, Json<JsonRpcResponse>) {
363 (status, Json(error(id, code, message, None)))
364}
365
366#[allow(dead_code)]
367pub fn rpc_error_with_data(
368 id: serde_json::Value,
369 code: i32,
370 message: &str,
371 data: serde_json::Value,
372 status: StatusCode,
373) -> (StatusCode, Json<JsonRpcResponse>) {
374 (status, Json(error(id, code, message, Some(data))))
375}
376
377#[allow(dead_code)]
378pub fn rpc_success(
379 id: serde_json::Value,
380 result: serde_json::Value,
381) -> (StatusCode, Json<JsonRpcResponse>) {
382 (StatusCode::OK, Json(success(id, result)))
383}
384
385async fn health() -> Json<serde_json::Value> {
388 Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
389}
390
391async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
392 Json((*state.card).clone())
393}
394
395async fn handle_rpc(
399 State(state): State<AppState>,
400 headers: HeaderMap,
401 Json(req): Json<JsonRpcRequest>,
402) -> Response {
403 if req.jsonrpc != "2.0" {
404 let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
405 return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
406 }
407
408 if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
409 return err_response.into_response();
410 }
411
412 let auth_context = state
413 .auth_extractor
414 .as_ref()
415 .and_then(|extractor| extractor(&headers));
416
417 match req.method.as_str() {
418 "message/send" => handle_message_send(state, req, auth_context)
420 .await
421 .into_response(),
422 "message/stream" => {
423 handle_message_stream(state, req, headers, auth_context)
424 .await
425 .into_response()
426 }
427 "tasks/get" => handle_tasks_get(state, req).await.into_response(),
428 "tasks/list" => handle_tasks_list(state, req).await.into_response(),
429 "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
430 "tasks/resubscribe" => handle_tasks_resubscribe(state, req).await.into_response(),
431 "tasks/pushNotificationConfig/create" => {
432 handle_push_config_create(state, req).await.into_response()
433 }
434 "tasks/pushNotificationConfig/get" => {
435 handle_push_config_get(state, req).await.into_response()
436 }
437 "tasks/pushNotificationConfig/list" => {
438 handle_push_config_list(state, req).await.into_response()
439 }
440 "tasks/pushNotificationConfig/delete" => {
441 handle_push_config_delete(state, req).await.into_response()
442 }
443 "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context)
444 .await
445 .into_response(),
446 _ => (
447 StatusCode::NOT_FOUND,
448 Json(error(
449 req.id,
450 errors::METHOD_NOT_FOUND,
451 "method not found",
452 None,
453 )),
454 )
455 .into_response(),
456 }
457}
458
459fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
460 match e {
461 HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
462 HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
463 HandlerError::BackendUnavailable { .. } => {
464 (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
465 }
466 HandlerError::ProcessingFailed { .. } => {
467 (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
468 }
469 HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
470 }
471}
472
473async fn handle_message_send(
474 state: AppState,
475 req: JsonRpcRequest,
476 auth_context: Option<AuthContext>,
477) -> (StatusCode, Json<JsonRpcResponse>) {
478 let req_id = req.id.clone();
479
480 let params: Result<SendMessageRequest, _> =
481 serde_json::from_value(req.params.clone().unwrap_or_default());
482
483 let params = match params {
484 Ok(p) => p,
485 Err(err) => {
486 return (
487 StatusCode::BAD_REQUEST,
488 Json(error(
489 req_id,
490 errors::INVALID_PARAMS,
491 "invalid params",
492 Some(serde_json::json!({"error": err.to_string()})),
493 )),
494 );
495 }
496 };
497
498 let blocking = params
499 .configuration
500 .as_ref()
501 .and_then(|c| c.blocking)
502 .unwrap_or(false);
503 let return_immediately = params
504 .configuration
505 .as_ref()
506 .and_then(|c| c.return_immediately)
507 .unwrap_or(false);
508 let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
509
510 match state
511 .handler
512 .handle_message(params.message, auth_context)
513 .await
514 {
515 Ok(response) => {
516 match response {
517 SendMessageResponse::Task(mut task) => {
518 state.task_store.insert(task.clone()).await;
520 state.broadcast_event(StreamResponse::Task(task.clone()));
521
522 if blocking && !return_immediately && !task.status.state.is_terminal() {
524 let task_id = task.id.clone();
525 let mut rx = state.subscribe_events();
526
527 let wait_result = timeout(BLOCKING_TIMEOUT, async {
528 loop {
529 tokio::select! {
530 result = rx.recv() => {
531 match result {
532 Ok(StreamResponse::Task(t)) if t.id == task_id => {
533 if t.status.state.is_terminal() {
534 return Some(t);
535 }
536 }
537 Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
538 if e.status.state.is_terminal() {
539 if let Some(t) = state.task_store.get(&task_id).await {
540 return Some(t);
541 }
542 }
543 }
544 Err(broadcast::error::RecvError::Closed) => {
545 return None;
546 }
547 _ => {}
548 }
549 }
550 _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
551 if let Some(t) = state.task_store.get(&task_id).await {
552 if t.status.state.is_terminal() {
553 return Some(t);
554 }
555 }
556 }
557 }
558 }
559 })
560 .await;
561
562 match wait_result {
563 Ok(Some(final_task)) => task = final_task,
564 Ok(None) => {
565 if let Some(t) = state.task_store.get(&task.id).await {
566 task = t;
567 }
568 }
569 Err(_) => {
570 tracing::warn!("Blocking request timed out for task {}", task.id);
571 if let Some(t) = state.task_store.get(&task.id).await {
572 task = t;
573 }
574 }
575 }
576 }
577
578 apply_history_length(&mut task, history_length);
579
580 match serde_json::to_value(&task) {
583 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
584 Err(e) => (
585 StatusCode::INTERNAL_SERVER_ERROR,
586 Json(error(
587 req_id,
588 errors::INTERNAL_ERROR,
589 "serialization failed",
590 Some(serde_json::json!({"error": e.to_string()})),
591 )),
592 ),
593 }
594 }
595 SendMessageResponse::Message(msg) => {
596 match serde_json::to_value(&msg) {
598 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
599 Err(e) => (
600 StatusCode::INTERNAL_SERVER_ERROR,
601 Json(error(
602 req_id,
603 errors::INTERNAL_ERROR,
604 "serialization failed",
605 Some(serde_json::json!({"error": e.to_string()})),
606 )),
607 ),
608 }
609 }
610 }
611 }
612 Err(e) => {
613 let (code, status) = handler_error_to_rpc(&e);
614 (status, Json(error(req_id, code, &e.to_string(), None)))
615 }
616 }
617}
618
619async fn handle_message_stream(
624 state: AppState,
625 req: JsonRpcRequest,
626 _headers: HeaderMap,
627 auth_context: Option<AuthContext>,
628) -> Response {
629 let req_id = req.id.clone();
630
631 if !state.streaming_enabled() {
632 return (
633 StatusCode::BAD_REQUEST,
634 Json(error(
635 req_id,
636 errors::UNSUPPORTED_OPERATION,
637 "streaming not supported by this agent",
638 None,
639 )),
640 )
641 .into_response();
642 }
643
644 let params: Result<SendMessageRequest, _> =
645 serde_json::from_value(req.params.clone().unwrap_or_default());
646
647 let params = match params {
648 Ok(p) => p,
649 Err(err) => {
650 return (
651 StatusCode::BAD_REQUEST,
652 Json(error(
653 req_id,
654 errors::INVALID_PARAMS,
655 "invalid params",
656 Some(serde_json::json!({"error": err.to_string()})),
657 )),
658 )
659 .into_response();
660 }
661 };
662
663 let response = match state
664 .handler
665 .handle_message(params.message, auth_context)
666 .await
667 {
668 Ok(r) => r,
669 Err(e) => {
670 let (code, status) = handler_error_to_rpc(&e);
671 return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
672 }
673 };
674
675 let task = match response {
677 SendMessageResponse::Task(t) => t,
678 SendMessageResponse::Message(_) => {
679 return (
680 StatusCode::BAD_REQUEST,
681 Json(error(
682 req_id,
683 errors::UNSUPPORTED_OPERATION,
684 "handler returned a message, streaming requires a task",
685 None,
686 )),
687 )
688 .into_response();
689 }
690 };
691
692 let task_id = task.id.clone();
693 state.task_store.insert(task.clone()).await;
694 state.broadcast_event(StreamResponse::Task(task.clone()));
695
696 let mut rx = state.subscribe_events();
697 let task_store = state.task_store.clone();
698 let target_task_id = task_id;
699
700 let wrap = move |value: serde_json::Value| -> String {
702 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
703 };
704
705 let stream = async_stream::stream! {
706 if let Ok(val) = serde_json::to_value(&task) {
708 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
709 }
710
711 loop {
712 match rx.recv().await {
713 Ok(event) => {
714 let matches = match &event {
715 StreamResponse::Task(t) => t.id == target_task_id,
716 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
717 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
718 StreamResponse::Message(m) => {
719 m.context_id.as_ref().is_some_and(|ctx| {
720 task_store.get(&target_task_id).now_or_never()
721 .flatten()
722 .is_some_and(|t| t.context_id == *ctx)
723 })
724 }
725 };
726
727 if matches {
728 let val = match &event {
730 StreamResponse::Task(t) => serde_json::to_value(t),
731 StreamResponse::Message(m) => serde_json::to_value(m),
732 StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
733 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
734 };
735 if let Ok(val) = val {
736 yield Ok(Event::default().data(wrap(val)));
737 }
738
739 let is_terminal = match &event {
741 StreamResponse::Task(t) => t.status.state.is_terminal(),
742 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
743 _ => false,
744 };
745 if is_terminal {
746 break;
747 }
748 }
749 }
750 Err(broadcast::error::RecvError::Lagged(_)) => continue,
751 Err(broadcast::error::RecvError::Closed) => break,
752 }
753 }
754 };
755
756 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
757}
758
759async fn handle_tasks_get(
760 state: AppState,
761 req: JsonRpcRequest,
762) -> (StatusCode, Json<JsonRpcResponse>) {
763 let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
764
765 match params {
766 Ok(p) => match state.task_store.get_flexible(&p.id).await {
767 Some(mut task) => {
768 apply_history_length(&mut task, p.history_length);
769
770 match serde_json::to_value(task) {
771 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
772 Err(e) => (
773 StatusCode::INTERNAL_SERVER_ERROR,
774 Json(error(
775 req.id,
776 errors::INTERNAL_ERROR,
777 "serialization failed",
778 Some(serde_json::json!({"error": e.to_string()})),
779 )),
780 ),
781 }
782 }
783 None => (
784 StatusCode::NOT_FOUND,
785 Json(error(
786 req.id,
787 errors::TASK_NOT_FOUND,
788 "task not found",
789 None,
790 )),
791 ),
792 },
793 Err(err) => (
794 StatusCode::BAD_REQUEST,
795 Json(error(
796 req.id,
797 errors::INVALID_PARAMS,
798 "invalid params",
799 Some(serde_json::json!({"error": err.to_string()})),
800 )),
801 ),
802 }
803}
804
805async fn handle_tasks_cancel(
806 state: AppState,
807 req: JsonRpcRequest,
808) -> (StatusCode, Json<JsonRpcResponse>) {
809 let params: Result<CancelTaskRequest, _> =
810 serde_json::from_value(req.params.unwrap_or_default());
811
812 match params {
813 Ok(p) => {
814 let result = state
815 .task_store
816 .update_flexible(&p.id, |task| {
817 if task.status.state.is_terminal() {
818 return Err(errors::TASK_NOT_CANCELABLE);
819 }
820 task.status.state = TaskState::Canceled;
821 task.status.timestamp = Some(now_iso8601());
822 Ok(())
823 })
824 .await;
825
826 match result {
827 Some(Ok(task)) => {
828 if let Err(e) = state.handler.cancel_task(&task.id).await {
829 tracing::warn!("Handler cancel_task failed: {}", e);
830 }
831
832 state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
833 kind: "status-update".to_string(),
834 task_id: task.id.clone(),
835 context_id: task.context_id.clone(),
836 status: task.status.clone(),
837 is_final: true,
838 metadata: None,
839 }));
840
841 match serde_json::to_value(task) {
842 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
843 Err(e) => (
844 StatusCode::INTERNAL_SERVER_ERROR,
845 Json(error(
846 req.id,
847 errors::INTERNAL_ERROR,
848 "serialization failed",
849 Some(serde_json::json!({"error": e.to_string()})),
850 )),
851 ),
852 }
853 }
854 Some(Err(error_code)) => (
855 StatusCode::BAD_REQUEST,
856 Json(error(req.id, error_code, "task not cancelable", None)),
857 ),
858 None => (
859 StatusCode::NOT_FOUND,
860 Json(error(
861 req.id,
862 errors::TASK_NOT_FOUND,
863 "task not found",
864 None,
865 )),
866 ),
867 }
868 }
869 Err(err) => (
870 StatusCode::BAD_REQUEST,
871 Json(error(
872 req.id,
873 errors::INVALID_PARAMS,
874 "invalid params",
875 Some(serde_json::json!({"error": err.to_string()})),
876 )),
877 ),
878 }
879}
880
881async fn handle_tasks_list(
882 state: AppState,
883 req: JsonRpcRequest,
884) -> (StatusCode, Json<JsonRpcResponse>) {
885 let params: Result<ListTasksRequest, _> =
886 serde_json::from_value(req.params.unwrap_or_default());
887
888 match params {
889 Ok(p) => {
890 let response = state.task_store.list_filtered(&p).await;
891 match serde_json::to_value(response) {
892 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
893 Err(e) => (
894 StatusCode::INTERNAL_SERVER_ERROR,
895 Json(error(
896 req.id,
897 errors::INTERNAL_ERROR,
898 "serialization failed",
899 Some(serde_json::json!({"error": e.to_string()})),
900 )),
901 ),
902 }
903 }
904 Err(err) => (
905 StatusCode::BAD_REQUEST,
906 Json(error(
907 req.id,
908 errors::INVALID_PARAMS,
909 "invalid params",
910 Some(serde_json::json!({"error": err.to_string()})),
911 )),
912 ),
913 }
914}
915
916async fn handle_tasks_resubscribe(state: AppState, req: JsonRpcRequest) -> Response {
921 let req_id = req.id.clone();
922
923 let params: Result<SubscribeToTaskRequest, _> =
924 serde_json::from_value(req.params.unwrap_or_default());
925
926 let params = match params {
927 Ok(p) => p,
928 Err(err) => {
929 return (
930 StatusCode::BAD_REQUEST,
931 Json(error(
932 req_id,
933 errors::INVALID_PARAMS,
934 "invalid params",
935 Some(serde_json::json!({"error": err.to_string()})),
936 )),
937 )
938 .into_response();
939 }
940 };
941
942 let task = match state.task_store.get_flexible(¶ms.id).await {
944 Some(t) => t,
945 None => {
946 return (
947 StatusCode::NOT_FOUND,
948 Json(error(req_id, errors::TASK_NOT_FOUND, "task not found", None)),
949 )
950 .into_response();
951 }
952 };
953
954 let target_task_id = task.id.clone();
955 let mut rx = state.subscribe_events();
956 let task_store = state.task_store.clone();
957
958 let wrap = move |value: serde_json::Value| -> String {
959 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
960 };
961
962 let stream = async_stream::stream! {
963 if let Ok(val) = serde_json::to_value(&task) {
965 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
966 }
967
968 if task.status.state.is_terminal() {
970 return;
971 }
972
973 loop {
974 match rx.recv().await {
975 Ok(event) => {
976 let matches = match &event {
977 StreamResponse::Task(t) => t.id == target_task_id,
978 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
979 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
980 StreamResponse::Message(m) => {
981 m.context_id.as_ref().is_some_and(|ctx| {
982 task_store.get(&target_task_id).now_or_never()
983 .flatten()
984 .is_some_and(|t| t.context_id == *ctx)
985 })
986 }
987 };
988
989 if matches {
990 let val = match &event {
991 StreamResponse::Task(t) => serde_json::to_value(t),
992 StreamResponse::Message(m) => serde_json::to_value(m),
993 StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
994 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
995 };
996 if let Ok(val) = val {
997 yield Ok(Event::default().data(wrap(val)));
998 }
999
1000 let is_terminal = match &event {
1001 StreamResponse::Task(t) => t.status.state.is_terminal(),
1002 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
1003 _ => false,
1004 };
1005 if is_terminal {
1006 break;
1007 }
1008 }
1009 }
1010 Err(broadcast::error::RecvError::Lagged(_)) => continue,
1011 Err(broadcast::error::RecvError::Closed) => break,
1012 }
1013 }
1014 };
1015
1016 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
1017}
1018
1019async fn handle_get_extended_agent_card(
1020 state: AppState,
1021 req: JsonRpcRequest,
1022 auth_context: Option<AuthContext>,
1023) -> (StatusCode, Json<JsonRpcResponse>) {
1024 let Some(auth) = auth_context else {
1025 return (
1026 StatusCode::UNAUTHORIZED,
1027 Json(error(
1028 req.id,
1029 errors::INVALID_REQUEST,
1030 "authentication required for extended agent card",
1031 None,
1032 )),
1033 );
1034 };
1035
1036 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1037
1038 match state.handler.extended_agent_card(base_url, &auth).await {
1039 Some(card) => match serde_json::to_value(card) {
1040 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1041 Err(e) => (
1042 StatusCode::INTERNAL_SERVER_ERROR,
1043 Json(error(
1044 req.id,
1045 errors::INTERNAL_ERROR,
1046 "serialization failed",
1047 Some(serde_json::json!({"error": e.to_string()})),
1048 )),
1049 ),
1050 },
1051 None => (
1052 StatusCode::NOT_FOUND,
1053 Json(error(
1054 req.id,
1055 errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1056 "extended agent card not configured",
1057 None,
1058 )),
1059 ),
1060 }
1061}
1062
1063async fn handle_push_config_create(
1066 state: AppState,
1067 req: JsonRpcRequest,
1068) -> (StatusCode, Json<JsonRpcResponse>) {
1069 if !state.push_notifications_enabled() {
1070 return (
1071 StatusCode::BAD_REQUEST,
1072 Json(error(
1073 req.id,
1074 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1075 "push notifications not supported",
1076 None,
1077 )),
1078 );
1079 }
1080
1081 let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1082 serde_json::from_value(req.params.unwrap_or_default());
1083
1084 match params {
1085 Ok(p) => {
1086 if state.task_store.get_flexible(&p.task_id).await.is_none() {
1087 return (
1088 StatusCode::NOT_FOUND,
1089 Json(error(
1090 req.id,
1091 errors::TASK_NOT_FOUND,
1092 "task not found",
1093 None,
1094 )),
1095 );
1096 }
1097
1098 if let Err(e) = state
1099 .webhook_store
1100 .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1101 .await
1102 {
1103 return (
1104 StatusCode::BAD_REQUEST,
1105 Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1106 );
1107 }
1108
1109 (
1110 StatusCode::OK,
1111 Json(success(
1112 req.id,
1113 serde_json::json!({
1114 "configId": p.config_id,
1115 "config": p.push_notification_config
1116 }),
1117 )),
1118 )
1119 }
1120 Err(err) => (
1121 StatusCode::BAD_REQUEST,
1122 Json(error(
1123 req.id,
1124 errors::INVALID_PARAMS,
1125 "invalid params",
1126 Some(serde_json::json!({"error": err.to_string()})),
1127 )),
1128 ),
1129 }
1130}
1131
1132async fn handle_push_config_get(
1133 state: AppState,
1134 req: JsonRpcRequest,
1135) -> (StatusCode, Json<JsonRpcResponse>) {
1136 if !state.push_notifications_enabled() {
1137 return (
1138 StatusCode::BAD_REQUEST,
1139 Json(error(
1140 req.id,
1141 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1142 "push notifications not supported",
1143 None,
1144 )),
1145 );
1146 }
1147
1148 let params: Result<GetTaskPushNotificationConfigRequest, _> =
1149 serde_json::from_value(req.params.unwrap_or_default());
1150
1151 match params {
1152 Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1153 Some(config) => (
1154 StatusCode::OK,
1155 Json(success(
1156 req.id,
1157 serde_json::json!({
1158 "configId": p.id,
1159 "config": config
1160 }),
1161 )),
1162 ),
1163 None => (
1164 StatusCode::NOT_FOUND,
1165 Json(error(
1166 req.id,
1167 errors::TASK_NOT_FOUND,
1168 "push notification config not found",
1169 None,
1170 )),
1171 ),
1172 },
1173 Err(err) => (
1174 StatusCode::BAD_REQUEST,
1175 Json(error(
1176 req.id,
1177 errors::INVALID_PARAMS,
1178 "invalid params",
1179 Some(serde_json::json!({"error": err.to_string()})),
1180 )),
1181 ),
1182 }
1183}
1184
1185async fn handle_push_config_list(
1186 state: AppState,
1187 req: JsonRpcRequest,
1188) -> (StatusCode, Json<JsonRpcResponse>) {
1189 if !state.push_notifications_enabled() {
1190 return (
1191 StatusCode::BAD_REQUEST,
1192 Json(error(
1193 req.id,
1194 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1195 "push notifications not supported",
1196 None,
1197 )),
1198 );
1199 }
1200
1201 let params: Result<ListTaskPushNotificationConfigRequest, _> =
1202 serde_json::from_value(req.params.unwrap_or_default());
1203
1204 match params {
1205 Ok(p) => {
1206 let configs = state.webhook_store.list(&p.task_id).await;
1207
1208 let configs_json: Vec<_> = configs
1209 .iter()
1210 .map(|c| {
1211 serde_json::json!({
1212 "configId": c.config_id,
1213 "config": c.config
1214 })
1215 })
1216 .collect();
1217
1218 (
1219 StatusCode::OK,
1220 Json(success(
1221 req.id,
1222 serde_json::json!({
1223 "configs": configs_json,
1224 "nextPageToken": ""
1225 }),
1226 )),
1227 )
1228 }
1229 Err(err) => (
1230 StatusCode::BAD_REQUEST,
1231 Json(error(
1232 req.id,
1233 errors::INVALID_PARAMS,
1234 "invalid params",
1235 Some(serde_json::json!({"error": err.to_string()})),
1236 )),
1237 ),
1238 }
1239}
1240
1241async fn handle_push_config_delete(
1242 state: AppState,
1243 req: JsonRpcRequest,
1244) -> (StatusCode, Json<JsonRpcResponse>) {
1245 if !state.push_notifications_enabled() {
1246 return (
1247 StatusCode::BAD_REQUEST,
1248 Json(error(
1249 req.id,
1250 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1251 "push notifications not supported",
1252 None,
1253 )),
1254 );
1255 }
1256
1257 let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1258 serde_json::from_value(req.params.unwrap_or_default());
1259
1260 match params {
1261 Ok(p) => {
1262 if state.webhook_store.delete(&p.task_id, &p.id).await {
1263 (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1264 } else {
1265 (
1266 StatusCode::NOT_FOUND,
1267 Json(error(
1268 req.id,
1269 errors::TASK_NOT_FOUND,
1270 "push notification config not found",
1271 None,
1272 )),
1273 )
1274 }
1275 }
1276 Err(err) => (
1277 StatusCode::BAD_REQUEST,
1278 Json(error(
1279 req.id,
1280 errors::INVALID_PARAMS,
1281 "invalid params",
1282 Some(serde_json::json!({"error": err.to_string()})),
1283 )),
1284 ),
1285 }
1286}
1287
1288pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1289 bind_addr: &str,
1290 handler: H,
1291) -> anyhow::Result<()> {
1292 A2aServer::new(handler).bind(bind_addr)?.run().await
1293}
1294
1295pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1296 A2aServer::echo().bind(bind_addr)?.run().await
1297}