1use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17 error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18 CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19 GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20 ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21 SendMessageResponse, SendMessageResult, StreamResponse, StreamingMessageResult,
22 SubscribeToTaskRequest, Task, TaskState, TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46 pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50 fn default() -> Self {
51 Self {
52 bind_address: "0.0.0.0:8080".to_string(),
53 }
54 }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62 config: ServerConfig,
63 handler: BoxedHandler,
64 task_store: TaskStore,
65 webhook_store: WebhookStore,
66 auth_extractor: Option<AuthExtractor>,
67 additional_routes: Option<Router<AppState>>,
68 event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72 pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74 Self {
75 config: ServerConfig::default(),
76 handler: Arc::new(handler),
77 task_store: TaskStore::new(),
78 webhook_store: WebhookStore::new(),
79 auth_extractor: None,
80 additional_routes: None,
81 event_tx,
82 }
83 }
84
85 pub fn echo() -> Self {
86 Self::new(EchoHandler::default())
87 }
88
89 pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90 let _: SocketAddr = address.parse()?;
91 self.config.bind_address = address.to_string();
92 Ok(self)
93 }
94
95 pub fn bind_unchecked(mut self, address: &str) -> Self {
96 self.config.bind_address = address.to_string();
97 self
98 }
99
100 pub fn task_store(mut self, store: TaskStore) -> Self {
101 self.task_store = store;
102 self
103 }
104
105 pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106 where
107 F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108 {
109 self.auth_extractor = Some(Arc::new(extractor));
110 self
111 }
112
113 pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114 self.additional_routes = Some(routes);
115 self
116 }
117
118 pub fn get_task_store(&self) -> TaskStore {
119 self.task_store.clone()
120 }
121
122 pub fn build_router(self) -> Router {
123 let bind: SocketAddr = self
124 .config
125 .bind_address
126 .parse()
127 .expect("Invalid bind address");
128 let base_url = format!("http://{}", bind);
129 let card = Arc::new(self.handler.agent_card(&base_url));
130
131 let state = AppState {
132 handler: self.handler,
133 task_store: self.task_store,
134 webhook_store: self.webhook_store,
135 card,
136 auth_extractor: self.auth_extractor,
137 event_tx: self.event_tx,
138 };
139
140 let timed_routes = Router::new()
141 .route("/health", get(health))
142 .route("/.well-known/agent-card.json", get(agent_card))
143 .route("/v1/rpc", post(handle_rpc))
144 .layer(
145 ServiceBuilder::new()
146 .layer(HandleErrorLayer::new(handle_timeout_error))
147 .timeout(Duration::from_secs(30)),
148 );
149
150 let mut router = timed_routes;
151
152 if let Some(additional) = self.additional_routes {
153 router = router.merge(additional);
154 }
155
156 router.with_state(state)
157 }
158
159 pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
160 self.event_tx.clone()
161 }
162
163 pub async fn run(self) -> anyhow::Result<()> {
164 let bind: SocketAddr = self.config.bind_address.parse()?;
165
166 let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
167 webhook_delivery.start(self.event_tx.subscribe());
168
169 let router = self.build_router();
170
171 info!(%bind, "Starting A2A server");
172 let listener = tokio::net::TcpListener::bind(bind).await?;
173 axum::serve(listener, router)
174 .with_graceful_shutdown(shutdown_signal())
175 .await?;
176
177 info!("Server shutdown complete");
178 Ok(())
179 }
180}
181
182async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
183 if err.is::<tower::timeout::error::Elapsed>() {
184 (
185 StatusCode::REQUEST_TIMEOUT,
186 Json(error(
187 serde_json::Value::Null,
188 errors::INTERNAL_ERROR,
189 "Request timed out",
190 None,
191 )),
192 )
193 } else {
194 (
195 StatusCode::INTERNAL_SERVER_ERROR,
196 Json(error(
197 serde_json::Value::Null,
198 errors::INTERNAL_ERROR,
199 &format!("Internal error: {}", err),
200 None,
201 )),
202 )
203 }
204}
205
206async fn shutdown_signal() {
207 let ctrl_c = async {
208 signal::ctrl_c()
209 .await
210 .expect("failed to install Ctrl+C handler");
211 };
212
213 #[cfg(unix)]
214 let terminate = async {
215 signal::unix::signal(signal::unix::SignalKind::terminate())
216 .expect("failed to install SIGTERM handler")
217 .recv()
218 .await;
219 };
220
221 #[cfg(not(unix))]
222 let terminate = std::future::pending::<()>();
223
224 tokio::select! {
225 _ = ctrl_c => {},
226 _ = terminate => {},
227 }
228
229 info!("Shutdown signal received, initiating graceful shutdown...");
230}
231
232#[derive(Clone)]
233pub struct AppState {
234 handler: BoxedHandler,
235 task_store: TaskStore,
236 webhook_store: WebhookStore,
237 card: Arc<AgentCard>,
238 auth_extractor: Option<AuthExtractor>,
239 event_tx: broadcast::Sender<StreamResponse>,
240}
241
242impl AppState {
243 pub fn task_store(&self) -> &TaskStore {
244 &self.task_store
245 }
246
247 pub fn agent_card(&self) -> &AgentCard {
248 &self.card
249 }
250
251 pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
252 &self.event_tx
253 }
254
255 pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
256 self.event_tx.subscribe()
257 }
258
259 pub fn broadcast_event(&self, event: StreamResponse) {
260 let _ = self.event_tx.send(event);
261 }
262
263 fn streaming_enabled(&self) -> bool {
265 self.card.capabilities.streaming.unwrap_or(false)
266 }
267
268 fn push_notifications_enabled(&self) -> bool {
270 self.card.capabilities.push_notifications.unwrap_or(false)
271 }
272
273 fn endpoint_url(&self) -> &str {
275 self.card.endpoint().unwrap_or("")
276 }
277}
278
279fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
282 match history_length {
283 Some(0) => {
284 task.history = None;
285 }
286 Some(n) => {
287 if let Some(ref mut history) = task.history {
288 let len = history.len();
289 if len > n as usize {
290 *history = history.split_off(len - n as usize);
291 }
292 }
293 }
294 None => {}
295 }
296}
297
298fn validate_a2a_version(
301 headers: &HeaderMap,
302 req_id: &serde_json::Value,
303) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
304 if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
305 let version_str = version_header.to_str().unwrap_or("");
306
307 let parts: Vec<&str> = version_str.split('.').collect();
308 if parts.len() >= 2 {
309 if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
310 if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
311 return Ok(());
312 }
313
314 return Err((
315 StatusCode::BAD_REQUEST,
316 Json(error(
317 req_id.clone(),
318 errors::VERSION_NOT_SUPPORTED,
319 &format!(
320 "Protocol version {}.{} not supported. Supported version: {}.{}",
321 major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
322 ),
323 Some(serde_json::json!({
324 "requestedVersion": version_str,
325 "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
326 })),
327 )),
328 ));
329 }
330 }
331
332 if version_str == "1" || version_str == "1.0" {
334 return Ok(());
335 }
336
337 return Err((
338 StatusCode::BAD_REQUEST,
339 Json(error(
340 req_id.clone(),
341 errors::VERSION_NOT_SUPPORTED,
342 &format!(
343 "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
344 version_str
345 ),
346 None,
347 )),
348 ));
349 }
350
351 Ok(())
352}
353
354#[allow(dead_code)]
357pub fn rpc_error(
358 id: serde_json::Value,
359 code: i32,
360 message: &str,
361 status: StatusCode,
362) -> (StatusCode, Json<JsonRpcResponse>) {
363 (status, Json(error(id, code, message, None)))
364}
365
366#[allow(dead_code)]
367pub fn rpc_error_with_data(
368 id: serde_json::Value,
369 code: i32,
370 message: &str,
371 data: serde_json::Value,
372 status: StatusCode,
373) -> (StatusCode, Json<JsonRpcResponse>) {
374 (status, Json(error(id, code, message, Some(data))))
375}
376
377#[allow(dead_code)]
378pub fn rpc_success(
379 id: serde_json::Value,
380 result: serde_json::Value,
381) -> (StatusCode, Json<JsonRpcResponse>) {
382 (StatusCode::OK, Json(success(id, result)))
383}
384
385async fn health() -> Json<serde_json::Value> {
388 Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
389}
390
391async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
392 Json((*state.card).clone())
393}
394
395async fn handle_rpc(
399 State(state): State<AppState>,
400 headers: HeaderMap,
401 Json(req): Json<JsonRpcRequest>,
402) -> Response {
403 if req.jsonrpc != "2.0" {
404 let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
405 return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
406 }
407
408 if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
409 return err_response.into_response();
410 }
411
412 let auth_context = state
413 .auth_extractor
414 .as_ref()
415 .and_then(|extractor| extractor(&headers));
416
417 match req.method.as_str() {
418 "SendMessage" | "message/send" => handle_message_send(state, req, auth_context)
420 .await
421 .into_response(),
422 "SendStreamingMessage" | "message/stream" => {
423 handle_message_stream(state, req, headers, auth_context)
424 .await
425 .into_response()
426 }
427 "GetTask" | "tasks/get" => handle_tasks_get(state, req).await.into_response(),
428 "ListTasks" | "tasks/list" => handle_tasks_list(state, req).await.into_response(),
429 "CancelTask" | "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
430 "SubscribeToTask" | "tasks/resubscribe" => {
431 handle_tasks_resubscribe(state, req).await.into_response()
432 }
433 "CreateTaskPushNotificationConfig" | "tasks/pushNotificationConfig/create" => {
434 handle_push_config_create(state, req).await.into_response()
435 }
436 "GetTaskPushNotificationConfig" | "tasks/pushNotificationConfig/get" => {
437 handle_push_config_get(state, req).await.into_response()
438 }
439 "ListTaskPushNotificationConfigs" | "tasks/pushNotificationConfig/list" => {
440 handle_push_config_list(state, req).await.into_response()
441 }
442 "DeleteTaskPushNotificationConfig" | "tasks/pushNotificationConfig/delete" => {
443 handle_push_config_delete(state, req).await.into_response()
444 }
445 "GetExtendedAgentCard" | "agentCard/getExtended" => {
446 handle_get_extended_agent_card(state, req, auth_context)
447 .await
448 .into_response()
449 }
450 _ => (
451 StatusCode::NOT_FOUND,
452 Json(error(
453 req.id,
454 errors::METHOD_NOT_FOUND,
455 "method not found",
456 None,
457 )),
458 )
459 .into_response(),
460 }
461}
462
463fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
464 match e {
465 HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
466 HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
467 HandlerError::BackendUnavailable { .. } => {
468 (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
469 }
470 HandlerError::ProcessingFailed { .. } => {
471 (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
472 }
473 HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
474 }
475}
476
477async fn handle_message_send(
478 state: AppState,
479 req: JsonRpcRequest,
480 auth_context: Option<AuthContext>,
481) -> (StatusCode, Json<JsonRpcResponse>) {
482 let req_id = req.id.clone();
483
484 let params: Result<SendMessageRequest, _> =
485 serde_json::from_value(req.params.clone().unwrap_or_default());
486
487 let params = match params {
488 Ok(p) => p,
489 Err(err) => {
490 return (
491 StatusCode::BAD_REQUEST,
492 Json(error(
493 req_id,
494 errors::INVALID_PARAMS,
495 "invalid params",
496 Some(serde_json::json!({"error": err.to_string()})),
497 )),
498 );
499 }
500 };
501
502 let blocking = params
503 .configuration
504 .as_ref()
505 .and_then(|c| c.blocking)
506 .unwrap_or(false);
507 let return_immediately = params
508 .configuration
509 .as_ref()
510 .and_then(|c| c.return_immediately)
511 .unwrap_or(false);
512 let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
513
514 match state
515 .handler
516 .handle_message(params.message, auth_context)
517 .await
518 {
519 Ok(response) => {
520 match response {
521 SendMessageResponse::Task(mut task) => {
522 state.task_store.insert(task.clone()).await;
524 state.broadcast_event(StreamResponse::Task(task.clone()));
525
526 if blocking && !return_immediately && !task.status.state.is_terminal() {
528 let task_id = task.id.clone();
529 let mut rx = state.subscribe_events();
530
531 let wait_result = timeout(BLOCKING_TIMEOUT, async {
532 loop {
533 tokio::select! {
534 result = rx.recv() => {
535 match result {
536 Ok(StreamResponse::Task(t)) if t.id == task_id => {
537 if t.status.state.is_terminal() {
538 return Some(t);
539 }
540 }
541 Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
542 if e.status.state.is_terminal() {
543 if let Some(t) = state.task_store.get(&task_id).await {
544 return Some(t);
545 }
546 }
547 }
548 Err(broadcast::error::RecvError::Closed) => {
549 return None;
550 }
551 _ => {}
552 }
553 }
554 _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
555 if let Some(t) = state.task_store.get(&task_id).await {
556 if t.status.state.is_terminal() {
557 return Some(t);
558 }
559 }
560 }
561 }
562 }
563 })
564 .await;
565
566 match wait_result {
567 Ok(Some(final_task)) => task = final_task,
568 Ok(None) => {
569 if let Some(t) = state.task_store.get(&task.id).await {
570 task = t;
571 }
572 }
573 Err(_) => {
574 tracing::warn!("Blocking request timed out for task {}", task.id);
575 if let Some(t) = state.task_store.get(&task.id).await {
576 task = t;
577 }
578 }
579 }
580 }
581
582 apply_history_length(&mut task, history_length);
583
584 match serde_json::to_value(SendMessageResult::Task(task.clone())) {
586 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
587 Err(e) => (
588 StatusCode::INTERNAL_SERVER_ERROR,
589 Json(error(
590 req_id,
591 errors::INTERNAL_ERROR,
592 "serialization failed",
593 Some(serde_json::json!({"error": e.to_string()})),
594 )),
595 ),
596 }
597 }
598 SendMessageResponse::Message(msg) => {
599 match serde_json::to_value(SendMessageResult::Message(msg)) {
601 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
602 Err(e) => (
603 StatusCode::INTERNAL_SERVER_ERROR,
604 Json(error(
605 req_id,
606 errors::INTERNAL_ERROR,
607 "serialization failed",
608 Some(serde_json::json!({"error": e.to_string()})),
609 )),
610 ),
611 }
612 }
613 }
614 }
615 Err(e) => {
616 let (code, status) = handler_error_to_rpc(&e);
617 (status, Json(error(req_id, code, &e.to_string(), None)))
618 }
619 }
620}
621
622async fn handle_message_stream(
627 state: AppState,
628 req: JsonRpcRequest,
629 _headers: HeaderMap,
630 auth_context: Option<AuthContext>,
631) -> Response {
632 let req_id = req.id.clone();
633
634 if !state.streaming_enabled() {
635 return (
636 StatusCode::BAD_REQUEST,
637 Json(error(
638 req_id,
639 errors::UNSUPPORTED_OPERATION,
640 "streaming not supported by this agent",
641 None,
642 )),
643 )
644 .into_response();
645 }
646
647 let params: Result<SendMessageRequest, _> =
648 serde_json::from_value(req.params.clone().unwrap_or_default());
649
650 let params = match params {
651 Ok(p) => p,
652 Err(err) => {
653 return (
654 StatusCode::BAD_REQUEST,
655 Json(error(
656 req_id,
657 errors::INVALID_PARAMS,
658 "invalid params",
659 Some(serde_json::json!({"error": err.to_string()})),
660 )),
661 )
662 .into_response();
663 }
664 };
665
666 let response = match state
667 .handler
668 .handle_message(params.message, auth_context)
669 .await
670 {
671 Ok(r) => r,
672 Err(e) => {
673 let (code, status) = handler_error_to_rpc(&e);
674 return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
675 }
676 };
677
678 let task = match response {
680 SendMessageResponse::Task(t) => t,
681 SendMessageResponse::Message(_) => {
682 return (
683 StatusCode::BAD_REQUEST,
684 Json(error(
685 req_id,
686 errors::UNSUPPORTED_OPERATION,
687 "handler returned a message, streaming requires a task",
688 None,
689 )),
690 )
691 .into_response();
692 }
693 };
694
695 let task_id = task.id.clone();
696 state.task_store.insert(task.clone()).await;
697 state.broadcast_event(StreamResponse::Task(task.clone()));
698
699 let mut rx = state.subscribe_events();
700 let task_store = state.task_store.clone();
701 let target_task_id = task_id;
702
703 let wrap = move |value: serde_json::Value| -> String {
705 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
706 };
707
708 let stream = async_stream::stream! {
709 if let Ok(val) = serde_json::to_value(StreamingMessageResult::Task(task)) {
711 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
712 }
713
714 loop {
715 match rx.recv().await {
716 Ok(event) => {
717 let matches = match &event {
718 StreamResponse::Task(t) => t.id == target_task_id,
719 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
720 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
721 StreamResponse::Message(m) => {
722 m.context_id.as_ref().is_some_and(|ctx| {
723 task_store.get(&target_task_id).now_or_never()
724 .flatten()
725 .is_some_and(|t| t.context_id == *ctx)
726 })
727 }
728 };
729
730 if matches {
731 let val = match event.clone() {
733 StreamResponse::Task(t) => serde_json::to_value(StreamingMessageResult::Task(t)),
734 StreamResponse::Message(m) => serde_json::to_value(StreamingMessageResult::Message(m)),
735 StreamResponse::StatusUpdate(e) => serde_json::to_value(StreamingMessageResult::StatusUpdate(e)),
736 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(StreamingMessageResult::ArtifactUpdate(e)),
737 };
738 if let Ok(val) = val {
739 yield Ok(Event::default().data(wrap(val)));
740 }
741
742 let is_terminal = match &event {
744 StreamResponse::Task(t) => t.status.state.is_terminal(),
745 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
746 _ => false,
747 };
748 if is_terminal {
749 break;
750 }
751 }
752 }
753 Err(broadcast::error::RecvError::Lagged(_)) => continue,
754 Err(broadcast::error::RecvError::Closed) => break,
755 }
756 }
757 };
758
759 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
760}
761
762async fn handle_tasks_get(
763 state: AppState,
764 req: JsonRpcRequest,
765) -> (StatusCode, Json<JsonRpcResponse>) {
766 let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
767
768 match params {
769 Ok(p) => match state.task_store.get_flexible(&p.id).await {
770 Some(mut task) => {
771 apply_history_length(&mut task, p.history_length);
772
773 match serde_json::to_value(task) {
774 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
775 Err(e) => (
776 StatusCode::INTERNAL_SERVER_ERROR,
777 Json(error(
778 req.id,
779 errors::INTERNAL_ERROR,
780 "serialization failed",
781 Some(serde_json::json!({"error": e.to_string()})),
782 )),
783 ),
784 }
785 }
786 None => (
787 StatusCode::NOT_FOUND,
788 Json(error(
789 req.id,
790 errors::TASK_NOT_FOUND,
791 "task not found",
792 None,
793 )),
794 ),
795 },
796 Err(err) => (
797 StatusCode::BAD_REQUEST,
798 Json(error(
799 req.id,
800 errors::INVALID_PARAMS,
801 "invalid params",
802 Some(serde_json::json!({"error": err.to_string()})),
803 )),
804 ),
805 }
806}
807
808async fn handle_tasks_cancel(
809 state: AppState,
810 req: JsonRpcRequest,
811) -> (StatusCode, Json<JsonRpcResponse>) {
812 let params: Result<CancelTaskRequest, _> =
813 serde_json::from_value(req.params.unwrap_or_default());
814
815 match params {
816 Ok(p) => {
817 let result = state
818 .task_store
819 .update_flexible(&p.id, |task| {
820 if task.status.state.is_terminal() {
821 return Err(errors::TASK_NOT_CANCELABLE);
822 }
823 task.status.state = TaskState::Canceled;
824 task.status.timestamp = Some(now_iso8601());
825 Ok(())
826 })
827 .await;
828
829 match result {
830 Some(Ok(task)) => {
831 if let Err(e) = state.handler.cancel_task(&task.id).await {
832 tracing::warn!("Handler cancel_task failed: {}", e);
833 }
834
835 state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
836 kind: "status-update".to_string(),
837 task_id: task.id.clone(),
838 context_id: task.context_id.clone(),
839 status: task.status.clone(),
840 is_final: true,
841 metadata: None,
842 }));
843
844 match serde_json::to_value(task) {
845 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
846 Err(e) => (
847 StatusCode::INTERNAL_SERVER_ERROR,
848 Json(error(
849 req.id,
850 errors::INTERNAL_ERROR,
851 "serialization failed",
852 Some(serde_json::json!({"error": e.to_string()})),
853 )),
854 ),
855 }
856 }
857 Some(Err(error_code)) => (
858 StatusCode::BAD_REQUEST,
859 Json(error(req.id, error_code, "task not cancelable", None)),
860 ),
861 None => (
862 StatusCode::NOT_FOUND,
863 Json(error(
864 req.id,
865 errors::TASK_NOT_FOUND,
866 "task not found",
867 None,
868 )),
869 ),
870 }
871 }
872 Err(err) => (
873 StatusCode::BAD_REQUEST,
874 Json(error(
875 req.id,
876 errors::INVALID_PARAMS,
877 "invalid params",
878 Some(serde_json::json!({"error": err.to_string()})),
879 )),
880 ),
881 }
882}
883
884async fn handle_tasks_list(
885 state: AppState,
886 req: JsonRpcRequest,
887) -> (StatusCode, Json<JsonRpcResponse>) {
888 let params: Result<ListTasksRequest, _> =
889 serde_json::from_value(req.params.unwrap_or_default());
890
891 match params {
892 Ok(p) => {
893 let response = state.task_store.list_filtered(&p).await;
894 match serde_json::to_value(response) {
895 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
896 Err(e) => (
897 StatusCode::INTERNAL_SERVER_ERROR,
898 Json(error(
899 req.id,
900 errors::INTERNAL_ERROR,
901 "serialization failed",
902 Some(serde_json::json!({"error": e.to_string()})),
903 )),
904 ),
905 }
906 }
907 Err(err) => (
908 StatusCode::BAD_REQUEST,
909 Json(error(
910 req.id,
911 errors::INVALID_PARAMS,
912 "invalid params",
913 Some(serde_json::json!({"error": err.to_string()})),
914 )),
915 ),
916 }
917}
918
919async fn handle_tasks_resubscribe(state: AppState, req: JsonRpcRequest) -> Response {
924 let req_id = req.id.clone();
925
926 let params: Result<SubscribeToTaskRequest, _> =
927 serde_json::from_value(req.params.unwrap_or_default());
928
929 let params = match params {
930 Ok(p) => p,
931 Err(err) => {
932 return (
933 StatusCode::BAD_REQUEST,
934 Json(error(
935 req_id,
936 errors::INVALID_PARAMS,
937 "invalid params",
938 Some(serde_json::json!({"error": err.to_string()})),
939 )),
940 )
941 .into_response();
942 }
943 };
944
945 let task = match state.task_store.get_flexible(¶ms.id).await {
947 Some(t) => t,
948 None => {
949 return (
950 StatusCode::NOT_FOUND,
951 Json(error(req_id, errors::TASK_NOT_FOUND, "task not found", None)),
952 )
953 .into_response();
954 }
955 };
956
957 let target_task_id = task.id.clone();
958 let mut rx = state.subscribe_events();
959 let task_store = state.task_store.clone();
960
961 let wrap = move |value: serde_json::Value| -> String {
962 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
963 };
964
965 let stream = async_stream::stream! {
966 if let Ok(val) = serde_json::to_value(StreamingMessageResult::Task(task.clone())) {
968 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
969 }
970
971 if task.status.state.is_terminal() {
973 return;
974 }
975
976 loop {
977 match rx.recv().await {
978 Ok(event) => {
979 let matches = match &event {
980 StreamResponse::Task(t) => t.id == target_task_id,
981 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
982 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
983 StreamResponse::Message(m) => {
984 m.context_id.as_ref().is_some_and(|ctx| {
985 task_store.get(&target_task_id).now_or_never()
986 .flatten()
987 .is_some_and(|t| t.context_id == *ctx)
988 })
989 }
990 };
991
992 if matches {
993 let val = match event.clone() {
994 StreamResponse::Task(t) => serde_json::to_value(StreamingMessageResult::Task(t)),
995 StreamResponse::Message(m) => serde_json::to_value(StreamingMessageResult::Message(m)),
996 StreamResponse::StatusUpdate(e) => serde_json::to_value(StreamingMessageResult::StatusUpdate(e)),
997 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(StreamingMessageResult::ArtifactUpdate(e)),
998 };
999 if let Ok(val) = val {
1000 yield Ok(Event::default().data(wrap(val)));
1001 }
1002
1003 let is_terminal = match &event {
1004 StreamResponse::Task(t) => t.status.state.is_terminal(),
1005 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
1006 _ => false,
1007 };
1008 if is_terminal {
1009 break;
1010 }
1011 }
1012 }
1013 Err(broadcast::error::RecvError::Lagged(_)) => continue,
1014 Err(broadcast::error::RecvError::Closed) => break,
1015 }
1016 }
1017 };
1018
1019 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
1020}
1021
1022async fn handle_get_extended_agent_card(
1023 state: AppState,
1024 req: JsonRpcRequest,
1025 auth_context: Option<AuthContext>,
1026) -> (StatusCode, Json<JsonRpcResponse>) {
1027 let Some(auth) = auth_context else {
1028 return (
1029 StatusCode::UNAUTHORIZED,
1030 Json(error(
1031 req.id,
1032 errors::INVALID_REQUEST,
1033 "authentication required for extended agent card",
1034 None,
1035 )),
1036 );
1037 };
1038
1039 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1040
1041 match state.handler.extended_agent_card(base_url, &auth).await {
1042 Some(card) => match serde_json::to_value(card) {
1043 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1044 Err(e) => (
1045 StatusCode::INTERNAL_SERVER_ERROR,
1046 Json(error(
1047 req.id,
1048 errors::INTERNAL_ERROR,
1049 "serialization failed",
1050 Some(serde_json::json!({"error": e.to_string()})),
1051 )),
1052 ),
1053 },
1054 None => (
1055 StatusCode::NOT_FOUND,
1056 Json(error(
1057 req.id,
1058 errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1059 "extended agent card not configured",
1060 None,
1061 )),
1062 ),
1063 }
1064}
1065
1066async fn handle_push_config_create(
1069 state: AppState,
1070 req: JsonRpcRequest,
1071) -> (StatusCode, Json<JsonRpcResponse>) {
1072 if !state.push_notifications_enabled() {
1073 return (
1074 StatusCode::BAD_REQUEST,
1075 Json(error(
1076 req.id,
1077 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1078 "push notifications not supported",
1079 None,
1080 )),
1081 );
1082 }
1083
1084 let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1085 serde_json::from_value(req.params.unwrap_or_default());
1086
1087 match params {
1088 Ok(p) => {
1089 if state.task_store.get_flexible(&p.task_id).await.is_none() {
1090 return (
1091 StatusCode::NOT_FOUND,
1092 Json(error(
1093 req.id,
1094 errors::TASK_NOT_FOUND,
1095 "task not found",
1096 None,
1097 )),
1098 );
1099 }
1100
1101 if let Err(e) = state
1102 .webhook_store
1103 .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1104 .await
1105 {
1106 return (
1107 StatusCode::BAD_REQUEST,
1108 Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1109 );
1110 }
1111
1112 (
1113 StatusCode::OK,
1114 Json(success(
1115 req.id,
1116 serde_json::json!({
1117 "configId": p.config_id,
1118 "config": p.push_notification_config
1119 }),
1120 )),
1121 )
1122 }
1123 Err(err) => (
1124 StatusCode::BAD_REQUEST,
1125 Json(error(
1126 req.id,
1127 errors::INVALID_PARAMS,
1128 "invalid params",
1129 Some(serde_json::json!({"error": err.to_string()})),
1130 )),
1131 ),
1132 }
1133}
1134
1135async fn handle_push_config_get(
1136 state: AppState,
1137 req: JsonRpcRequest,
1138) -> (StatusCode, Json<JsonRpcResponse>) {
1139 if !state.push_notifications_enabled() {
1140 return (
1141 StatusCode::BAD_REQUEST,
1142 Json(error(
1143 req.id,
1144 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1145 "push notifications not supported",
1146 None,
1147 )),
1148 );
1149 }
1150
1151 let params: Result<GetTaskPushNotificationConfigRequest, _> =
1152 serde_json::from_value(req.params.unwrap_or_default());
1153
1154 match params {
1155 Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1156 Some(config) => (
1157 StatusCode::OK,
1158 Json(success(
1159 req.id,
1160 serde_json::json!({
1161 "configId": p.id,
1162 "config": config
1163 }),
1164 )),
1165 ),
1166 None => (
1167 StatusCode::NOT_FOUND,
1168 Json(error(
1169 req.id,
1170 errors::TASK_NOT_FOUND,
1171 "push notification config not found",
1172 None,
1173 )),
1174 ),
1175 },
1176 Err(err) => (
1177 StatusCode::BAD_REQUEST,
1178 Json(error(
1179 req.id,
1180 errors::INVALID_PARAMS,
1181 "invalid params",
1182 Some(serde_json::json!({"error": err.to_string()})),
1183 )),
1184 ),
1185 }
1186}
1187
1188async fn handle_push_config_list(
1189 state: AppState,
1190 req: JsonRpcRequest,
1191) -> (StatusCode, Json<JsonRpcResponse>) {
1192 if !state.push_notifications_enabled() {
1193 return (
1194 StatusCode::BAD_REQUEST,
1195 Json(error(
1196 req.id,
1197 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1198 "push notifications not supported",
1199 None,
1200 )),
1201 );
1202 }
1203
1204 let params: Result<ListTaskPushNotificationConfigRequest, _> =
1205 serde_json::from_value(req.params.unwrap_or_default());
1206
1207 match params {
1208 Ok(p) => {
1209 let configs = state.webhook_store.list(&p.task_id).await;
1210
1211 let configs_json: Vec<_> = configs
1212 .iter()
1213 .map(|c| {
1214 serde_json::json!({
1215 "configId": c.config_id,
1216 "config": c.config
1217 })
1218 })
1219 .collect();
1220
1221 (
1222 StatusCode::OK,
1223 Json(success(
1224 req.id,
1225 serde_json::json!({
1226 "configs": configs_json,
1227 "nextPageToken": ""
1228 }),
1229 )),
1230 )
1231 }
1232 Err(err) => (
1233 StatusCode::BAD_REQUEST,
1234 Json(error(
1235 req.id,
1236 errors::INVALID_PARAMS,
1237 "invalid params",
1238 Some(serde_json::json!({"error": err.to_string()})),
1239 )),
1240 ),
1241 }
1242}
1243
1244async fn handle_push_config_delete(
1245 state: AppState,
1246 req: JsonRpcRequest,
1247) -> (StatusCode, Json<JsonRpcResponse>) {
1248 if !state.push_notifications_enabled() {
1249 return (
1250 StatusCode::BAD_REQUEST,
1251 Json(error(
1252 req.id,
1253 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1254 "push notifications not supported",
1255 None,
1256 )),
1257 );
1258 }
1259
1260 let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1261 serde_json::from_value(req.params.unwrap_or_default());
1262
1263 match params {
1264 Ok(p) => {
1265 if state.webhook_store.delete(&p.task_id, &p.id).await {
1266 (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1267 } else {
1268 (
1269 StatusCode::NOT_FOUND,
1270 Json(error(
1271 req.id,
1272 errors::TASK_NOT_FOUND,
1273 "push notification config not found",
1274 None,
1275 )),
1276 )
1277 }
1278 }
1279 Err(err) => (
1280 StatusCode::BAD_REQUEST,
1281 Json(error(
1282 req.id,
1283 errors::INVALID_PARAMS,
1284 "invalid params",
1285 Some(serde_json::json!({"error": err.to_string()})),
1286 )),
1287 ),
1288 }
1289}
1290
1291pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1292 bind_addr: &str,
1293 handler: H,
1294) -> anyhow::Result<()> {
1295 A2aServer::new(handler).bind(bind_addr)?.run().await
1296}
1297
1298pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1299 A2aServer::echo().bind(bind_addr)?.run().await
1300}