1use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17 error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18 CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19 GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20 ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21 SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22 TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::routing::{get, post};
33use axum::{Json, Router};
34use futures::future::FutureExt;
35use futures::stream::Stream;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46 pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50 fn default() -> Self {
51 Self {
52 bind_address: "0.0.0.0:8080".to_string(),
53 }
54 }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62 config: ServerConfig,
63 handler: BoxedHandler,
64 task_store: TaskStore,
65 webhook_store: WebhookStore,
66 auth_extractor: Option<AuthExtractor>,
67 additional_routes: Option<Router<AppState>>,
68 event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72 pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74 Self {
75 config: ServerConfig::default(),
76 handler: Arc::new(handler),
77 task_store: TaskStore::new(),
78 webhook_store: WebhookStore::new(),
79 auth_extractor: None,
80 additional_routes: None,
81 event_tx,
82 }
83 }
84
85 pub fn echo() -> Self {
86 Self::new(EchoHandler::default())
87 }
88
89 pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90 let _: SocketAddr = address.parse()?;
91 self.config.bind_address = address.to_string();
92 Ok(self)
93 }
94
95 pub fn bind_unchecked(mut self, address: &str) -> Self {
96 self.config.bind_address = address.to_string();
97 self
98 }
99
100 pub fn task_store(mut self, store: TaskStore) -> Self {
101 self.task_store = store;
102 self
103 }
104
105 pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106 where
107 F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108 {
109 self.auth_extractor = Some(Arc::new(extractor));
110 self
111 }
112
113 pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114 self.additional_routes = Some(routes);
115 self
116 }
117
118 pub fn get_task_store(&self) -> TaskStore {
119 self.task_store.clone()
120 }
121
122 pub fn build_router(self) -> Router {
123 let bind: SocketAddr = self
124 .config
125 .bind_address
126 .parse()
127 .expect("Invalid bind address");
128 let base_url = format!("http://{}", bind);
129 let card = Arc::new(self.handler.agent_card(&base_url));
130
131 let state = AppState {
132 handler: self.handler,
133 task_store: self.task_store,
134 webhook_store: self.webhook_store,
135 card,
136 auth_extractor: self.auth_extractor,
137 event_tx: self.event_tx,
138 };
139
140 let timed_routes = Router::new()
141 .route("/health", get(health))
142 .route("/.well-known/agent-card.json", get(agent_card))
143 .route("/v1/rpc", post(handle_rpc))
144 .layer(
145 ServiceBuilder::new()
146 .layer(HandleErrorLayer::new(handle_timeout_error))
147 .timeout(Duration::from_secs(30)),
148 );
149
150 let sse_routes = Router::new()
151 .route(
152 "/v1/tasks/:task_id/subscribe",
153 get(handle_task_subscribe_sse),
154 )
155 .route("/v1/message/stream", post(handle_message_stream_sse));
156
157 let mut router = timed_routes.merge(sse_routes);
158
159 if let Some(additional) = self.additional_routes {
160 router = router.merge(additional);
161 }
162
163 router.with_state(state)
164 }
165
166 pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
167 self.event_tx.clone()
168 }
169
170 pub async fn run(self) -> anyhow::Result<()> {
171 let bind: SocketAddr = self.config.bind_address.parse()?;
172
173 let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
174 webhook_delivery.start(self.event_tx.subscribe());
175
176 let router = self.build_router();
177
178 info!(%bind, "Starting A2A server");
179 let listener = tokio::net::TcpListener::bind(bind).await?;
180 axum::serve(listener, router)
181 .with_graceful_shutdown(shutdown_signal())
182 .await?;
183
184 info!("Server shutdown complete");
185 Ok(())
186 }
187}
188
189async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
190 if err.is::<tower::timeout::error::Elapsed>() {
191 (
192 StatusCode::REQUEST_TIMEOUT,
193 Json(error(
194 serde_json::Value::Null,
195 errors::INTERNAL_ERROR,
196 "Request timed out",
197 None,
198 )),
199 )
200 } else {
201 (
202 StatusCode::INTERNAL_SERVER_ERROR,
203 Json(error(
204 serde_json::Value::Null,
205 errors::INTERNAL_ERROR,
206 &format!("Internal error: {}", err),
207 None,
208 )),
209 )
210 }
211}
212
213async fn shutdown_signal() {
214 let ctrl_c = async {
215 signal::ctrl_c()
216 .await
217 .expect("failed to install Ctrl+C handler");
218 };
219
220 #[cfg(unix)]
221 let terminate = async {
222 signal::unix::signal(signal::unix::SignalKind::terminate())
223 .expect("failed to install SIGTERM handler")
224 .recv()
225 .await;
226 };
227
228 #[cfg(not(unix))]
229 let terminate = std::future::pending::<()>();
230
231 tokio::select! {
232 _ = ctrl_c => {},
233 _ = terminate => {},
234 }
235
236 info!("Shutdown signal received, initiating graceful shutdown...");
237}
238
239#[derive(Clone)]
240pub struct AppState {
241 handler: BoxedHandler,
242 task_store: TaskStore,
243 webhook_store: WebhookStore,
244 card: Arc<AgentCard>,
245 auth_extractor: Option<AuthExtractor>,
246 event_tx: broadcast::Sender<StreamResponse>,
247}
248
249impl AppState {
250 pub fn task_store(&self) -> &TaskStore {
251 &self.task_store
252 }
253
254 pub fn agent_card(&self) -> &AgentCard {
255 &self.card
256 }
257
258 pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
259 &self.event_tx
260 }
261
262 pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
263 self.event_tx.subscribe()
264 }
265
266 pub fn broadcast_event(&self, event: StreamResponse) {
267 let _ = self.event_tx.send(event);
268 }
269
270 fn streaming_enabled(&self) -> bool {
272 self.card.capabilities.streaming.unwrap_or(false)
273 }
274
275 fn push_notifications_enabled(&self) -> bool {
277 self.card.capabilities.push_notifications.unwrap_or(false)
278 }
279
280 fn endpoint_url(&self) -> &str {
282 self.card.endpoint().unwrap_or("")
283 }
284}
285
286fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
289 match history_length {
290 Some(0) => {
291 task.history = None;
292 }
293 Some(n) => {
294 if let Some(ref mut history) = task.history {
295 let len = history.len();
296 if len > n as usize {
297 *history = history.split_off(len - n as usize);
298 }
299 }
300 }
301 None => {}
302 }
303}
304
305fn validate_a2a_version(
308 headers: &HeaderMap,
309 req_id: &serde_json::Value,
310) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
311 if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
312 let version_str = version_header.to_str().unwrap_or("");
313
314 let parts: Vec<&str> = version_str.split('.').collect();
315 if parts.len() >= 2 {
316 if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
317 if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
318 return Ok(());
319 }
320
321 return Err((
322 StatusCode::BAD_REQUEST,
323 Json(error(
324 req_id.clone(),
325 errors::VERSION_NOT_SUPPORTED,
326 &format!(
327 "Protocol version {}.{} not supported. Supported version: {}.{}",
328 major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
329 ),
330 Some(serde_json::json!({
331 "requestedVersion": version_str,
332 "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
333 })),
334 )),
335 ));
336 }
337 }
338
339 if version_str == "1" || version_str == "1.0" {
341 return Ok(());
342 }
343
344 return Err((
345 StatusCode::BAD_REQUEST,
346 Json(error(
347 req_id.clone(),
348 errors::VERSION_NOT_SUPPORTED,
349 &format!(
350 "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
351 version_str
352 ),
353 None,
354 )),
355 ));
356 }
357
358 Ok(())
359}
360
361#[allow(dead_code)]
364pub fn rpc_error(
365 id: serde_json::Value,
366 code: i32,
367 message: &str,
368 status: StatusCode,
369) -> (StatusCode, Json<JsonRpcResponse>) {
370 (status, Json(error(id, code, message, None)))
371}
372
373#[allow(dead_code)]
374pub fn rpc_error_with_data(
375 id: serde_json::Value,
376 code: i32,
377 message: &str,
378 data: serde_json::Value,
379 status: StatusCode,
380) -> (StatusCode, Json<JsonRpcResponse>) {
381 (status, Json(error(id, code, message, Some(data))))
382}
383
384#[allow(dead_code)]
385pub fn rpc_success(
386 id: serde_json::Value,
387 result: serde_json::Value,
388) -> (StatusCode, Json<JsonRpcResponse>) {
389 (StatusCode::OK, Json(success(id, result)))
390}
391
392async fn health() -> Json<serde_json::Value> {
395 Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
396}
397
398async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
399 Json((*state.card).clone())
400}
401
402async fn handle_task_subscribe_sse(
403 State(state): State<AppState>,
404 axum::extract::Path(task_id): axum::extract::Path<String>,
405) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
406 let mut rx = state.subscribe_events();
407 let task_store = state.task_store.clone();
408 let target_task_id = task_id;
409
410 let stream = async_stream::stream! {
411 if let Some(task) = task_store.get_flexible(&target_task_id).await {
412 let event = StreamResponse::Task(task);
413 if let Ok(json) = serde_json::to_string(&event) {
414 yield Ok(Event::default().data(json));
415 }
416 }
417
418 loop {
419 match rx.recv().await {
420 Ok(event) => {
421 let matches = match &event {
422 StreamResponse::Task(t) => t.id == target_task_id,
423 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
424 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
425 StreamResponse::Message(_) => false,
426 };
427 if matches {
428 if let Ok(json) = serde_json::to_string(&event) {
429 yield Ok(Event::default().data(json));
430 }
431
432 if let StreamResponse::Task(t) = &event {
433 if t.status.state.is_terminal() {
434 break;
435 }
436 }
437 if let StreamResponse::StatusUpdate(e) = &event {
438 if e.status.state.is_terminal() {
439 break;
440 }
441 }
442 }
443 }
444 Err(broadcast::error::RecvError::Lagged(_)) => {
445 continue;
446 }
447 Err(broadcast::error::RecvError::Closed) => {
448 break;
449 }
450 }
451 }
452 };
453
454 Sse::new(stream).keep_alive(KeepAlive::default())
455}
456
457async fn handle_rpc(
458 State(state): State<AppState>,
459 headers: HeaderMap,
460 Json(req): Json<JsonRpcRequest>,
461) -> (StatusCode, Json<JsonRpcResponse>) {
462 if req.jsonrpc != "2.0" {
463 let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
464 return (StatusCode::BAD_REQUEST, Json(resp));
465 }
466
467 if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
468 return err_response;
469 }
470
471 let auth_context = state
472 .auth_extractor
473 .as_ref()
474 .and_then(|extractor| extractor(&headers));
475
476 match req.method.as_str() {
477 "message/send" => handle_message_send(state, req, auth_context).await,
479 "message/sendStreaming" => handle_message_stream_rpc(state, req).await,
480 "tasks/get" => handle_tasks_get(state, req).await,
481 "tasks/list" => handle_tasks_list(state, req).await,
482 "tasks/cancel" => handle_tasks_cancel(state, req).await,
483 "tasks/subscribe" => handle_tasks_subscribe(state, req).await,
484 "tasks/pushNotificationConfig/create" => handle_push_config_create(state, req).await,
485 "tasks/pushNotificationConfig/get" => handle_push_config_get(state, req).await,
486 "tasks/pushNotificationConfig/list" => handle_push_config_list(state, req).await,
487 "tasks/pushNotificationConfig/delete" => handle_push_config_delete(state, req).await,
488 "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context).await,
489 _ => (
490 StatusCode::NOT_FOUND,
491 Json(error(
492 req.id,
493 errors::METHOD_NOT_FOUND,
494 "method not found",
495 None,
496 )),
497 ),
498 }
499}
500
501fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
502 match e {
503 HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
504 HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
505 HandlerError::BackendUnavailable { .. } => {
506 (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
507 }
508 HandlerError::ProcessingFailed { .. } => {
509 (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
510 }
511 HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
512 }
513}
514
515async fn handle_message_send(
516 state: AppState,
517 req: JsonRpcRequest,
518 auth_context: Option<AuthContext>,
519) -> (StatusCode, Json<JsonRpcResponse>) {
520 let req_id = req.id.clone();
521
522 let params: Result<SendMessageRequest, _> =
523 serde_json::from_value(req.params.clone().unwrap_or_default());
524
525 let params = match params {
526 Ok(p) => p,
527 Err(err) => {
528 return (
529 StatusCode::BAD_REQUEST,
530 Json(error(
531 req_id,
532 errors::INVALID_PARAMS,
533 "invalid params",
534 Some(serde_json::json!({"error": err.to_string()})),
535 )),
536 );
537 }
538 };
539
540 let blocking = params
541 .configuration
542 .as_ref()
543 .and_then(|c| c.blocking)
544 .unwrap_or(false);
545 let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
546
547 match state
548 .handler
549 .handle_message(params.message, auth_context)
550 .await
551 {
552 Ok(response) => {
553 match response {
554 SendMessageResponse::Task(mut task) => {
555 state.task_store.insert(task.clone()).await;
557 state.broadcast_event(StreamResponse::Task(task.clone()));
558
559 if blocking && !task.status.state.is_terminal() {
561 let task_id = task.id.clone();
562 let mut rx = state.subscribe_events();
563
564 let wait_result = timeout(BLOCKING_TIMEOUT, async {
565 loop {
566 tokio::select! {
567 result = rx.recv() => {
568 match result {
569 Ok(StreamResponse::Task(t)) if t.id == task_id => {
570 if t.status.state.is_terminal() {
571 return Some(t);
572 }
573 }
574 Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
575 if e.status.state.is_terminal() {
576 if let Some(t) = state.task_store.get(&task_id).await {
577 return Some(t);
578 }
579 }
580 }
581 Err(broadcast::error::RecvError::Closed) => {
582 return None;
583 }
584 _ => {}
585 }
586 }
587 _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
588 if let Some(t) = state.task_store.get(&task_id).await {
589 if t.status.state.is_terminal() {
590 return Some(t);
591 }
592 }
593 }
594 }
595 }
596 })
597 .await;
598
599 match wait_result {
600 Ok(Some(final_task)) => task = final_task,
601 Ok(None) => {
602 if let Some(t) = state.task_store.get(&task.id).await {
603 task = t;
604 }
605 }
606 Err(_) => {
607 tracing::warn!("Blocking request timed out for task {}", task.id);
608 if let Some(t) = state.task_store.get(&task.id).await {
609 task = t;
610 }
611 }
612 }
613 }
614
615 apply_history_length(&mut task, history_length);
616
617 let resp = SendMessageResponse::Task(task);
619 match serde_json::to_value(resp) {
620 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
621 Err(e) => (
622 StatusCode::INTERNAL_SERVER_ERROR,
623 Json(error(
624 req_id,
625 errors::INTERNAL_ERROR,
626 "serialization failed",
627 Some(serde_json::json!({"error": e.to_string()})),
628 )),
629 ),
630 }
631 }
632 SendMessageResponse::Message(msg) => {
633 let resp = SendMessageResponse::Message(msg);
634 match serde_json::to_value(resp) {
635 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
636 Err(e) => (
637 StatusCode::INTERNAL_SERVER_ERROR,
638 Json(error(
639 req_id,
640 errors::INTERNAL_ERROR,
641 "serialization failed",
642 Some(serde_json::json!({"error": e.to_string()})),
643 )),
644 ),
645 }
646 }
647 }
648 }
649 Err(e) => {
650 let (code, status) = handler_error_to_rpc(&e);
651 (status, Json(error(req_id, code, &e.to_string(), None)))
652 }
653 }
654}
655
656async fn handle_message_stream_rpc(
657 state: AppState,
658 req: JsonRpcRequest,
659) -> (StatusCode, Json<JsonRpcResponse>) {
660 if !state.streaming_enabled() {
661 return (
662 StatusCode::BAD_REQUEST,
663 Json(error(
664 req.id,
665 errors::UNSUPPORTED_OPERATION,
666 "streaming not supported by this agent",
667 None,
668 )),
669 );
670 }
671
672 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
673 let stream_url = format!("{}/v1/message/stream", base_url);
674
675 (
676 StatusCode::OK,
677 Json(success(
678 req.id,
679 serde_json::json!({
680 "streamUrl": stream_url,
681 "protocol": "sse",
682 "method": "POST"
683 }),
684 )),
685 )
686}
687
688async fn handle_message_stream_sse(
689 State(state): State<AppState>,
690 headers: HeaderMap,
691 Json(params): Json<SendMessageRequest>,
692) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
693{
694 if !state.streaming_enabled() {
695 return Err((
696 StatusCode::BAD_REQUEST,
697 Json(error(
698 serde_json::Value::Null,
699 errors::UNSUPPORTED_OPERATION,
700 "streaming not supported by this agent",
701 None,
702 )),
703 ));
704 }
705
706 let auth_context = state
707 .auth_extractor
708 .as_ref()
709 .and_then(|extractor| extractor(&headers));
710
711 let response = state
712 .handler
713 .handle_message(params.message, auth_context)
714 .await
715 .map_err(|e| {
716 let (code, status) = handler_error_to_rpc(&e);
717 (
718 status,
719 Json(error(serde_json::Value::Null, code, &e.to_string(), None)),
720 )
721 })?;
722
723 let task = match response {
725 SendMessageResponse::Task(t) => t,
726 SendMessageResponse::Message(_) => {
727 return Err((
728 StatusCode::BAD_REQUEST,
729 Json(error(
730 serde_json::Value::Null,
731 errors::UNSUPPORTED_OPERATION,
732 "handler returned a message, streaming requires a task",
733 None,
734 )),
735 ));
736 }
737 };
738
739 let task_id = task.id.clone();
740 state.task_store.insert(task.clone()).await;
741 state.broadcast_event(StreamResponse::Task(task.clone()));
742
743 let mut rx = state.subscribe_events();
744 let task_store = state.task_store.clone();
745 let target_task_id = task_id;
746
747 let stream = async_stream::stream! {
748 let event = StreamResponse::Task(task);
749 if let Ok(json) = serde_json::to_string(&event) {
750 yield Ok(Event::default().data(json));
751 }
752
753 loop {
754 match rx.recv().await {
755 Ok(event) => {
756 let matches = match &event {
757 StreamResponse::Task(t) => t.id == target_task_id,
758 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
759 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
760 StreamResponse::Message(m) => {
761 m.context_id.as_ref().is_some_and(|ctx| {
762 task_store.get(&target_task_id).now_or_never()
763 .flatten()
764 .is_some_and(|t| t.context_id == *ctx)
765 })
766 }
767 };
768
769 if matches {
770 if let Ok(json) = serde_json::to_string(&event) {
771 yield Ok(Event::default().data(json));
772 }
773
774 if let StreamResponse::Task(t) = &event {
775 if t.status.state.is_terminal() {
776 break;
777 }
778 }
779 if let StreamResponse::StatusUpdate(e) = &event {
780 if e.status.state.is_terminal() {
781 break;
782 }
783 }
784 }
785 }
786 Err(broadcast::error::RecvError::Lagged(_)) => continue,
787 Err(broadcast::error::RecvError::Closed) => break,
788 }
789 }
790 };
791
792 Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
793}
794
795async fn handle_tasks_get(
796 state: AppState,
797 req: JsonRpcRequest,
798) -> (StatusCode, Json<JsonRpcResponse>) {
799 let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
800
801 match params {
802 Ok(p) => match state.task_store.get_flexible(&p.id).await {
803 Some(mut task) => {
804 apply_history_length(&mut task, p.history_length);
805
806 match serde_json::to_value(task) {
807 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
808 Err(e) => (
809 StatusCode::INTERNAL_SERVER_ERROR,
810 Json(error(
811 req.id,
812 errors::INTERNAL_ERROR,
813 "serialization failed",
814 Some(serde_json::json!({"error": e.to_string()})),
815 )),
816 ),
817 }
818 }
819 None => (
820 StatusCode::NOT_FOUND,
821 Json(error(
822 req.id,
823 errors::TASK_NOT_FOUND,
824 "task not found",
825 None,
826 )),
827 ),
828 },
829 Err(err) => (
830 StatusCode::BAD_REQUEST,
831 Json(error(
832 req.id,
833 errors::INVALID_PARAMS,
834 "invalid params",
835 Some(serde_json::json!({"error": err.to_string()})),
836 )),
837 ),
838 }
839}
840
841async fn handle_tasks_cancel(
842 state: AppState,
843 req: JsonRpcRequest,
844) -> (StatusCode, Json<JsonRpcResponse>) {
845 let params: Result<CancelTaskRequest, _> =
846 serde_json::from_value(req.params.unwrap_or_default());
847
848 match params {
849 Ok(p) => {
850 let result = state
851 .task_store
852 .update_flexible(&p.id, |task| {
853 if task.status.state.is_terminal() {
854 return Err(errors::TASK_NOT_CANCELABLE);
855 }
856 task.status.state = TaskState::Canceled;
857 task.status.timestamp = Some(now_iso8601());
858 Ok(())
859 })
860 .await;
861
862 match result {
863 Some(Ok(task)) => {
864 if let Err(e) = state.handler.cancel_task(&task.id).await {
865 tracing::warn!("Handler cancel_task failed: {}", e);
866 }
867
868 state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
869 task_id: task.id.clone(),
870 context_id: task.context_id.clone(),
871 status: task.status.clone(),
872 metadata: None,
873 }));
874
875 match serde_json::to_value(task) {
876 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
877 Err(e) => (
878 StatusCode::INTERNAL_SERVER_ERROR,
879 Json(error(
880 req.id,
881 errors::INTERNAL_ERROR,
882 "serialization failed",
883 Some(serde_json::json!({"error": e.to_string()})),
884 )),
885 ),
886 }
887 }
888 Some(Err(error_code)) => (
889 StatusCode::BAD_REQUEST,
890 Json(error(req.id, error_code, "task not cancelable", None)),
891 ),
892 None => (
893 StatusCode::NOT_FOUND,
894 Json(error(
895 req.id,
896 errors::TASK_NOT_FOUND,
897 "task not found",
898 None,
899 )),
900 ),
901 }
902 }
903 Err(err) => (
904 StatusCode::BAD_REQUEST,
905 Json(error(
906 req.id,
907 errors::INVALID_PARAMS,
908 "invalid params",
909 Some(serde_json::json!({"error": err.to_string()})),
910 )),
911 ),
912 }
913}
914
915async fn handle_tasks_list(
916 state: AppState,
917 req: JsonRpcRequest,
918) -> (StatusCode, Json<JsonRpcResponse>) {
919 let params: Result<ListTasksRequest, _> =
920 serde_json::from_value(req.params.unwrap_or_default());
921
922 match params {
923 Ok(p) => {
924 let response = state.task_store.list_filtered(&p).await;
925 match serde_json::to_value(response) {
926 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
927 Err(e) => (
928 StatusCode::INTERNAL_SERVER_ERROR,
929 Json(error(
930 req.id,
931 errors::INTERNAL_ERROR,
932 "serialization failed",
933 Some(serde_json::json!({"error": e.to_string()})),
934 )),
935 ),
936 }
937 }
938 Err(err) => (
939 StatusCode::BAD_REQUEST,
940 Json(error(
941 req.id,
942 errors::INVALID_PARAMS,
943 "invalid params",
944 Some(serde_json::json!({"error": err.to_string()})),
945 )),
946 ),
947 }
948}
949
950async fn handle_tasks_subscribe(
951 state: AppState,
952 req: JsonRpcRequest,
953) -> (StatusCode, Json<JsonRpcResponse>) {
954 let params: Result<SubscribeToTaskRequest, _> =
955 serde_json::from_value(req.params.unwrap_or_default());
956
957 match params {
958 Ok(p) => {
959 if state.task_store.get_flexible(&p.id).await.is_none() {
960 return (
961 StatusCode::NOT_FOUND,
962 Json(error(
963 req.id,
964 errors::TASK_NOT_FOUND,
965 "task not found",
966 None,
967 )),
968 );
969 }
970
971 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
972 let subscribe_url = format!("{}/v1/tasks/{}/subscribe", base_url, p.id);
973
974 (
975 StatusCode::OK,
976 Json(success(
977 req.id,
978 serde_json::json!({
979 "subscribeUrl": subscribe_url,
980 "protocol": "sse"
981 }),
982 )),
983 )
984 }
985 Err(err) => (
986 StatusCode::BAD_REQUEST,
987 Json(error(
988 req.id,
989 errors::INVALID_PARAMS,
990 "invalid params",
991 Some(serde_json::json!({"error": err.to_string()})),
992 )),
993 ),
994 }
995}
996
997async fn handle_get_extended_agent_card(
998 state: AppState,
999 req: JsonRpcRequest,
1000 auth_context: Option<AuthContext>,
1001) -> (StatusCode, Json<JsonRpcResponse>) {
1002 let Some(auth) = auth_context else {
1003 return (
1004 StatusCode::UNAUTHORIZED,
1005 Json(error(
1006 req.id,
1007 errors::INVALID_REQUEST,
1008 "authentication required for extended agent card",
1009 None,
1010 )),
1011 );
1012 };
1013
1014 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1015
1016 match state.handler.extended_agent_card(base_url, &auth).await {
1017 Some(card) => match serde_json::to_value(card) {
1018 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1019 Err(e) => (
1020 StatusCode::INTERNAL_SERVER_ERROR,
1021 Json(error(
1022 req.id,
1023 errors::INTERNAL_ERROR,
1024 "serialization failed",
1025 Some(serde_json::json!({"error": e.to_string()})),
1026 )),
1027 ),
1028 },
1029 None => (
1030 StatusCode::NOT_FOUND,
1031 Json(error(
1032 req.id,
1033 errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1034 "extended agent card not configured",
1035 None,
1036 )),
1037 ),
1038 }
1039}
1040
1041async fn handle_push_config_create(
1044 state: AppState,
1045 req: JsonRpcRequest,
1046) -> (StatusCode, Json<JsonRpcResponse>) {
1047 if !state.push_notifications_enabled() {
1048 return (
1049 StatusCode::BAD_REQUEST,
1050 Json(error(
1051 req.id,
1052 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1053 "push notifications not supported",
1054 None,
1055 )),
1056 );
1057 }
1058
1059 let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1060 serde_json::from_value(req.params.unwrap_or_default());
1061
1062 match params {
1063 Ok(p) => {
1064 if state.task_store.get_flexible(&p.task_id).await.is_none() {
1065 return (
1066 StatusCode::NOT_FOUND,
1067 Json(error(
1068 req.id,
1069 errors::TASK_NOT_FOUND,
1070 "task not found",
1071 None,
1072 )),
1073 );
1074 }
1075
1076 if let Err(e) = state
1077 .webhook_store
1078 .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1079 .await
1080 {
1081 return (
1082 StatusCode::BAD_REQUEST,
1083 Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1084 );
1085 }
1086
1087 (
1088 StatusCode::OK,
1089 Json(success(
1090 req.id,
1091 serde_json::json!({
1092 "configId": p.config_id,
1093 "config": p.push_notification_config
1094 }),
1095 )),
1096 )
1097 }
1098 Err(err) => (
1099 StatusCode::BAD_REQUEST,
1100 Json(error(
1101 req.id,
1102 errors::INVALID_PARAMS,
1103 "invalid params",
1104 Some(serde_json::json!({"error": err.to_string()})),
1105 )),
1106 ),
1107 }
1108}
1109
1110async fn handle_push_config_get(
1111 state: AppState,
1112 req: JsonRpcRequest,
1113) -> (StatusCode, Json<JsonRpcResponse>) {
1114 if !state.push_notifications_enabled() {
1115 return (
1116 StatusCode::BAD_REQUEST,
1117 Json(error(
1118 req.id,
1119 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1120 "push notifications not supported",
1121 None,
1122 )),
1123 );
1124 }
1125
1126 let params: Result<GetTaskPushNotificationConfigRequest, _> =
1127 serde_json::from_value(req.params.unwrap_or_default());
1128
1129 match params {
1130 Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1131 Some(config) => (
1132 StatusCode::OK,
1133 Json(success(
1134 req.id,
1135 serde_json::json!({
1136 "configId": p.id,
1137 "config": config
1138 }),
1139 )),
1140 ),
1141 None => (
1142 StatusCode::NOT_FOUND,
1143 Json(error(
1144 req.id,
1145 errors::TASK_NOT_FOUND,
1146 "push notification config not found",
1147 None,
1148 )),
1149 ),
1150 },
1151 Err(err) => (
1152 StatusCode::BAD_REQUEST,
1153 Json(error(
1154 req.id,
1155 errors::INVALID_PARAMS,
1156 "invalid params",
1157 Some(serde_json::json!({"error": err.to_string()})),
1158 )),
1159 ),
1160 }
1161}
1162
1163async fn handle_push_config_list(
1164 state: AppState,
1165 req: JsonRpcRequest,
1166) -> (StatusCode, Json<JsonRpcResponse>) {
1167 if !state.push_notifications_enabled() {
1168 return (
1169 StatusCode::BAD_REQUEST,
1170 Json(error(
1171 req.id,
1172 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1173 "push notifications not supported",
1174 None,
1175 )),
1176 );
1177 }
1178
1179 let params: Result<ListTaskPushNotificationConfigRequest, _> =
1180 serde_json::from_value(req.params.unwrap_or_default());
1181
1182 match params {
1183 Ok(p) => {
1184 let configs = state.webhook_store.list(&p.task_id).await;
1185
1186 let configs_json: Vec<_> = configs
1187 .iter()
1188 .map(|c| {
1189 serde_json::json!({
1190 "configId": c.config_id,
1191 "config": c.config
1192 })
1193 })
1194 .collect();
1195
1196 (
1197 StatusCode::OK,
1198 Json(success(
1199 req.id,
1200 serde_json::json!({
1201 "configs": configs_json,
1202 "nextPageToken": ""
1203 }),
1204 )),
1205 )
1206 }
1207 Err(err) => (
1208 StatusCode::BAD_REQUEST,
1209 Json(error(
1210 req.id,
1211 errors::INVALID_PARAMS,
1212 "invalid params",
1213 Some(serde_json::json!({"error": err.to_string()})),
1214 )),
1215 ),
1216 }
1217}
1218
1219async fn handle_push_config_delete(
1220 state: AppState,
1221 req: JsonRpcRequest,
1222) -> (StatusCode, Json<JsonRpcResponse>) {
1223 if !state.push_notifications_enabled() {
1224 return (
1225 StatusCode::BAD_REQUEST,
1226 Json(error(
1227 req.id,
1228 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1229 "push notifications not supported",
1230 None,
1231 )),
1232 );
1233 }
1234
1235 let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1236 serde_json::from_value(req.params.unwrap_or_default());
1237
1238 match params {
1239 Ok(p) => {
1240 if state.webhook_store.delete(&p.task_id, &p.id).await {
1241 (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1242 } else {
1243 (
1244 StatusCode::NOT_FOUND,
1245 Json(error(
1246 req.id,
1247 errors::TASK_NOT_FOUND,
1248 "push notification config not found",
1249 None,
1250 )),
1251 )
1252 }
1253 }
1254 Err(err) => (
1255 StatusCode::BAD_REQUEST,
1256 Json(error(
1257 req.id,
1258 errors::INVALID_PARAMS,
1259 "invalid params",
1260 Some(serde_json::json!({"error": err.to_string()})),
1261 )),
1262 ),
1263 }
1264}
1265
1266pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1267 bind_addr: &str,
1268 handler: H,
1269) -> anyhow::Result<()> {
1270 A2aServer::new(handler).bind(bind_addr)?.run().await
1271}
1272
1273pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1274 A2aServer::echo().bind(bind_addr)?.run().await
1275}