1use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17 error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18 CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19 GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20 ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21 SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22 TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::routing::{get, post};
33use axum::{Json, Router};
34use futures::future::FutureExt;
35use futures::stream::Stream;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46 pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50 fn default() -> Self {
51 Self {
52 bind_address: "0.0.0.0:8080".to_string(),
53 }
54 }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62 config: ServerConfig,
63 handler: BoxedHandler,
64 task_store: TaskStore,
65 webhook_store: WebhookStore,
66 auth_extractor: Option<AuthExtractor>,
67 additional_routes: Option<Router<AppState>>,
68 event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72 pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74 Self {
75 config: ServerConfig::default(),
76 handler: Arc::new(handler),
77 task_store: TaskStore::new(),
78 webhook_store: WebhookStore::new(),
79 auth_extractor: None,
80 additional_routes: None,
81 event_tx,
82 }
83 }
84
85 pub fn echo() -> Self {
86 Self::new(EchoHandler::default())
87 }
88
89 pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90 let _: SocketAddr = address.parse()?;
91 self.config.bind_address = address.to_string();
92 Ok(self)
93 }
94
95 pub fn bind_unchecked(mut self, address: &str) -> Self {
96 self.config.bind_address = address.to_string();
97 self
98 }
99
100 pub fn task_store(mut self, store: TaskStore) -> Self {
101 self.task_store = store;
102 self
103 }
104
105 pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106 where
107 F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108 {
109 self.auth_extractor = Some(Arc::new(extractor));
110 self
111 }
112
113 pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114 self.additional_routes = Some(routes);
115 self
116 }
117
118 pub fn get_task_store(&self) -> TaskStore {
119 self.task_store.clone()
120 }
121
122 pub fn build_router(self) -> Router {
123 let bind: SocketAddr = self.config.bind_address.parse().expect("Invalid bind address");
124 let base_url = format!("http://{}", bind);
125 let card = Arc::new(self.handler.agent_card(&base_url));
126
127 let state = AppState {
128 handler: self.handler,
129 task_store: self.task_store,
130 webhook_store: self.webhook_store,
131 card,
132 auth_extractor: self.auth_extractor,
133 event_tx: self.event_tx,
134 };
135
136 let timed_routes = Router::new()
137 .route("/health", get(health))
138 .route("/.well-known/agent-card.json", get(agent_card))
139 .route("/v1/rpc", post(handle_rpc))
140 .layer(
141 ServiceBuilder::new()
142 .layer(HandleErrorLayer::new(handle_timeout_error))
143 .timeout(Duration::from_secs(30)),
144 );
145
146 let sse_routes = Router::new()
147 .route("/v1/tasks/:task_id/subscribe", get(handle_task_subscribe_sse))
148 .route("/v1/message/stream", post(handle_message_stream_sse));
149
150 let mut router = timed_routes.merge(sse_routes);
151
152 if let Some(additional) = self.additional_routes {
153 router = router.merge(additional);
154 }
155
156 router.with_state(state)
157 }
158
159 pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
160 self.event_tx.clone()
161 }
162
163 pub async fn run(self) -> anyhow::Result<()> {
164 let bind: SocketAddr = self.config.bind_address.parse()?;
165
166 let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
167 webhook_delivery.start(self.event_tx.subscribe());
168
169 let router = self.build_router();
170
171 info!(%bind, "Starting A2A server");
172 let listener = tokio::net::TcpListener::bind(bind).await?;
173 axum::serve(listener, router)
174 .with_graceful_shutdown(shutdown_signal())
175 .await?;
176
177 info!("Server shutdown complete");
178 Ok(())
179 }
180}
181
182async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
183 if err.is::<tower::timeout::error::Elapsed>() {
184 (
185 StatusCode::REQUEST_TIMEOUT,
186 Json(error(
187 serde_json::Value::Null,
188 errors::INTERNAL_ERROR,
189 "Request timed out",
190 None,
191 )),
192 )
193 } else {
194 (
195 StatusCode::INTERNAL_SERVER_ERROR,
196 Json(error(
197 serde_json::Value::Null,
198 errors::INTERNAL_ERROR,
199 &format!("Internal error: {}", err),
200 None,
201 )),
202 )
203 }
204}
205
206async fn shutdown_signal() {
207 let ctrl_c = async {
208 signal::ctrl_c()
209 .await
210 .expect("failed to install Ctrl+C handler");
211 };
212
213 #[cfg(unix)]
214 let terminate = async {
215 signal::unix::signal(signal::unix::SignalKind::terminate())
216 .expect("failed to install SIGTERM handler")
217 .recv()
218 .await;
219 };
220
221 #[cfg(not(unix))]
222 let terminate = std::future::pending::<()>();
223
224 tokio::select! {
225 _ = ctrl_c => {},
226 _ = terminate => {},
227 }
228
229 info!("Shutdown signal received, initiating graceful shutdown...");
230}
231
232#[derive(Clone)]
233pub struct AppState {
234 handler: BoxedHandler,
235 task_store: TaskStore,
236 webhook_store: WebhookStore,
237 card: Arc<AgentCard>,
238 auth_extractor: Option<AuthExtractor>,
239 event_tx: broadcast::Sender<StreamResponse>,
240}
241
242impl AppState {
243 pub fn task_store(&self) -> &TaskStore {
244 &self.task_store
245 }
246
247 pub fn agent_card(&self) -> &AgentCard {
248 &self.card
249 }
250
251 pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
252 &self.event_tx
253 }
254
255 pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
256 self.event_tx.subscribe()
257 }
258
259 pub fn broadcast_event(&self, event: StreamResponse) {
260 let _ = self.event_tx.send(event);
261 }
262
263 fn streaming_enabled(&self) -> bool {
265 self.card.capabilities.streaming.unwrap_or(false)
266 }
267
268 fn push_notifications_enabled(&self) -> bool {
270 self.card.capabilities.push_notifications.unwrap_or(false)
271 }
272
273 fn endpoint_url(&self) -> &str {
275 self.card.endpoint().unwrap_or("")
276 }
277}
278
279fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
282 match history_length {
283 Some(0) => {
284 task.history = None;
285 }
286 Some(n) => {
287 if let Some(ref mut history) = task.history {
288 let len = history.len();
289 if len > n as usize {
290 *history = history.split_off(len - n as usize);
291 }
292 }
293 }
294 None => {}
295 }
296}
297
298fn validate_a2a_version(headers: &HeaderMap, req_id: &serde_json::Value) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
301 if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
302 let version_str = version_header.to_str().unwrap_or("");
303
304 let parts: Vec<&str> = version_str.split('.').collect();
305 if parts.len() >= 2 {
306 if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
307 if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
308 return Ok(());
309 }
310
311 return Err((
312 StatusCode::BAD_REQUEST,
313 Json(error(
314 req_id.clone(),
315 errors::VERSION_NOT_SUPPORTED,
316 &format!(
317 "Protocol version {}.{} not supported. Supported version: {}.{}",
318 major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
319 ),
320 Some(serde_json::json!({
321 "requestedVersion": version_str,
322 "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
323 })),
324 )),
325 ));
326 }
327 }
328
329 if version_str == "1" || version_str == "1.0" {
331 return Ok(());
332 }
333
334 return Err((
335 StatusCode::BAD_REQUEST,
336 Json(error(
337 req_id.clone(),
338 errors::VERSION_NOT_SUPPORTED,
339 &format!("Invalid version format: {}. Expected major.minor (e.g., '1.0')", version_str),
340 None,
341 )),
342 ));
343 }
344
345 Ok(())
346}
347
348#[allow(dead_code)]
351pub fn rpc_error(
352 id: serde_json::Value,
353 code: i32,
354 message: &str,
355 status: StatusCode,
356) -> (StatusCode, Json<JsonRpcResponse>) {
357 (status, Json(error(id, code, message, None)))
358}
359
360#[allow(dead_code)]
361pub fn rpc_error_with_data(
362 id: serde_json::Value,
363 code: i32,
364 message: &str,
365 data: serde_json::Value,
366 status: StatusCode,
367) -> (StatusCode, Json<JsonRpcResponse>) {
368 (status, Json(error(id, code, message, Some(data))))
369}
370
371#[allow(dead_code)]
372pub fn rpc_success(id: serde_json::Value, result: serde_json::Value) -> (StatusCode, Json<JsonRpcResponse>) {
373 (StatusCode::OK, Json(success(id, result)))
374}
375
376async fn health() -> Json<serde_json::Value> {
379 Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
380}
381
382async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
383 Json((*state.card).clone())
384}
385
386async fn handle_task_subscribe_sse(
387 State(state): State<AppState>,
388 axum::extract::Path(task_id): axum::extract::Path<String>,
389) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
390 let mut rx = state.subscribe_events();
391 let task_store = state.task_store.clone();
392 let target_task_id = task_id;
393
394 let stream = async_stream::stream! {
395 if let Some(task) = task_store.get_flexible(&target_task_id).await {
396 let event = StreamResponse::Task(task);
397 if let Ok(json) = serde_json::to_string(&event) {
398 yield Ok(Event::default().data(json));
399 }
400 }
401
402 loop {
403 match rx.recv().await {
404 Ok(event) => {
405 let matches = match &event {
406 StreamResponse::Task(t) => t.id == target_task_id,
407 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
408 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
409 StreamResponse::Message(_) => false,
410 };
411 if matches {
412 if let Ok(json) = serde_json::to_string(&event) {
413 yield Ok(Event::default().data(json));
414 }
415
416 if let StreamResponse::Task(t) = &event {
417 if t.status.state.is_terminal() {
418 break;
419 }
420 }
421 if let StreamResponse::StatusUpdate(e) = &event {
422 if e.status.state.is_terminal() {
423 break;
424 }
425 }
426 }
427 }
428 Err(broadcast::error::RecvError::Lagged(_)) => {
429 continue;
430 }
431 Err(broadcast::error::RecvError::Closed) => {
432 break;
433 }
434 }
435 }
436 };
437
438 Sse::new(stream).keep_alive(KeepAlive::default())
439}
440
441async fn handle_rpc(
442 State(state): State<AppState>,
443 headers: HeaderMap,
444 Json(req): Json<JsonRpcRequest>,
445) -> (StatusCode, Json<JsonRpcResponse>) {
446 if req.jsonrpc != "2.0" {
447 let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
448 return (StatusCode::BAD_REQUEST, Json(resp));
449 }
450
451 if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
452 return err_response;
453 }
454
455 let auth_context = state
456 .auth_extractor
457 .as_ref()
458 .and_then(|extractor| extractor(&headers));
459
460 match req.method.as_str() {
461 "message/send" => handle_message_send(state, req, auth_context).await,
463 "message/sendStreaming" => handle_message_stream_rpc(state, req).await,
464 "tasks/get" => handle_tasks_get(state, req).await,
465 "tasks/list" => handle_tasks_list(state, req).await,
466 "tasks/cancel" => handle_tasks_cancel(state, req).await,
467 "tasks/subscribe" => handle_tasks_subscribe(state, req).await,
468 "tasks/pushNotificationConfig/create" => handle_push_config_create(state, req).await,
469 "tasks/pushNotificationConfig/get" => handle_push_config_get(state, req).await,
470 "tasks/pushNotificationConfig/list" => handle_push_config_list(state, req).await,
471 "tasks/pushNotificationConfig/delete" => handle_push_config_delete(state, req).await,
472 "agentCard/getExtended" => {
473 handle_get_extended_agent_card(state, req, auth_context).await
474 }
475 _ => (
476 StatusCode::NOT_FOUND,
477 Json(error(
478 req.id,
479 errors::METHOD_NOT_FOUND,
480 "method not found",
481 None,
482 )),
483 ),
484 }
485}
486
487fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
488 match e {
489 HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
490 HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
491 HandlerError::BackendUnavailable { .. } => (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE),
492 HandlerError::ProcessingFailed { .. } => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
493 HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
494 }
495}
496
497async fn handle_message_send(
498 state: AppState,
499 req: JsonRpcRequest,
500 auth_context: Option<AuthContext>,
501) -> (StatusCode, Json<JsonRpcResponse>) {
502 let req_id = req.id.clone();
503
504 let params: Result<SendMessageRequest, _> =
505 serde_json::from_value(req.params.clone().unwrap_or_default());
506
507 let params = match params {
508 Ok(p) => p,
509 Err(err) => {
510 return (
511 StatusCode::BAD_REQUEST,
512 Json(error(
513 req_id,
514 errors::INVALID_PARAMS,
515 "invalid params",
516 Some(serde_json::json!({"error": err.to_string()})),
517 )),
518 );
519 }
520 };
521
522 let blocking = params
523 .configuration
524 .as_ref()
525 .and_then(|c| c.blocking)
526 .unwrap_or(false);
527 let history_length = params
528 .configuration
529 .as_ref()
530 .and_then(|c| c.history_length);
531
532 match state.handler.handle_message(params.message, auth_context).await {
533 Ok(response) => {
534 match response {
535 SendMessageResponse::Task(mut task) => {
536 state.task_store.insert(task.clone()).await;
538 state.broadcast_event(StreamResponse::Task(task.clone()));
539
540 if blocking && !task.status.state.is_terminal() {
542 let task_id = task.id.clone();
543 let mut rx = state.subscribe_events();
544
545 let wait_result = timeout(BLOCKING_TIMEOUT, async {
546 loop {
547 tokio::select! {
548 result = rx.recv() => {
549 match result {
550 Ok(StreamResponse::Task(t)) if t.id == task_id => {
551 if t.status.state.is_terminal() {
552 return Some(t);
553 }
554 }
555 Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
556 if e.status.state.is_terminal() {
557 if let Some(t) = state.task_store.get(&task_id).await {
558 return Some(t);
559 }
560 }
561 }
562 Err(broadcast::error::RecvError::Closed) => {
563 return None;
564 }
565 _ => {}
566 }
567 }
568 _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
569 if let Some(t) = state.task_store.get(&task_id).await {
570 if t.status.state.is_terminal() {
571 return Some(t);
572 }
573 }
574 }
575 }
576 }
577 })
578 .await;
579
580 match wait_result {
581 Ok(Some(final_task)) => task = final_task,
582 Ok(None) => {
583 if let Some(t) = state.task_store.get(&task.id).await {
584 task = t;
585 }
586 }
587 Err(_) => {
588 tracing::warn!("Blocking request timed out for task {}", task.id);
589 if let Some(t) = state.task_store.get(&task.id).await {
590 task = t;
591 }
592 }
593 }
594 }
595
596 apply_history_length(&mut task, history_length);
597
598 let resp = SendMessageResponse::Task(task);
600 match serde_json::to_value(resp) {
601 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
602 Err(e) => (
603 StatusCode::INTERNAL_SERVER_ERROR,
604 Json(error(
605 req_id,
606 errors::INTERNAL_ERROR,
607 "serialization failed",
608 Some(serde_json::json!({"error": e.to_string()})),
609 )),
610 ),
611 }
612 }
613 SendMessageResponse::Message(msg) => {
614 let resp = SendMessageResponse::Message(msg);
615 match serde_json::to_value(resp) {
616 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
617 Err(e) => (
618 StatusCode::INTERNAL_SERVER_ERROR,
619 Json(error(
620 req_id,
621 errors::INTERNAL_ERROR,
622 "serialization failed",
623 Some(serde_json::json!({"error": e.to_string()})),
624 )),
625 ),
626 }
627 }
628 }
629 }
630 Err(e) => {
631 let (code, status) = handler_error_to_rpc(&e);
632 (status, Json(error(req_id, code, &e.to_string(), None)))
633 }
634 }
635}
636
637async fn handle_message_stream_rpc(
638 state: AppState,
639 req: JsonRpcRequest,
640) -> (StatusCode, Json<JsonRpcResponse>) {
641 if !state.streaming_enabled() {
642 return (
643 StatusCode::BAD_REQUEST,
644 Json(error(
645 req.id,
646 errors::UNSUPPORTED_OPERATION,
647 "streaming not supported by this agent",
648 None,
649 )),
650 );
651 }
652
653 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
654 let stream_url = format!("{}/v1/message/stream", base_url);
655
656 (
657 StatusCode::OK,
658 Json(success(
659 req.id,
660 serde_json::json!({
661 "streamUrl": stream_url,
662 "protocol": "sse",
663 "method": "POST"
664 }),
665 )),
666 )
667}
668
669async fn handle_message_stream_sse(
670 State(state): State<AppState>,
671 headers: HeaderMap,
672 Json(params): Json<SendMessageRequest>,
673) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)> {
674 if !state.streaming_enabled() {
675 return Err((
676 StatusCode::BAD_REQUEST,
677 Json(error(
678 serde_json::Value::Null,
679 errors::UNSUPPORTED_OPERATION,
680 "streaming not supported by this agent",
681 None,
682 )),
683 ));
684 }
685
686 let auth_context = state
687 .auth_extractor
688 .as_ref()
689 .and_then(|extractor| extractor(&headers));
690
691 let response = state
692 .handler
693 .handle_message(params.message, auth_context)
694 .await
695 .map_err(|e| {
696 let (code, status) = handler_error_to_rpc(&e);
697 (status, Json(error(serde_json::Value::Null, code, &e.to_string(), None)))
698 })?;
699
700 let task = match response {
702 SendMessageResponse::Task(t) => t,
703 SendMessageResponse::Message(_) => {
704 return Err((
705 StatusCode::BAD_REQUEST,
706 Json(error(
707 serde_json::Value::Null,
708 errors::UNSUPPORTED_OPERATION,
709 "handler returned a message, streaming requires a task",
710 None,
711 )),
712 ));
713 }
714 };
715
716 let task_id = task.id.clone();
717 state.task_store.insert(task.clone()).await;
718 state.broadcast_event(StreamResponse::Task(task.clone()));
719
720 let mut rx = state.subscribe_events();
721 let task_store = state.task_store.clone();
722 let target_task_id = task_id;
723
724 let stream = async_stream::stream! {
725 let event = StreamResponse::Task(task);
726 if let Ok(json) = serde_json::to_string(&event) {
727 yield Ok(Event::default().data(json));
728 }
729
730 loop {
731 match rx.recv().await {
732 Ok(event) => {
733 let matches = match &event {
734 StreamResponse::Task(t) => t.id == target_task_id,
735 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
736 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
737 StreamResponse::Message(m) => {
738 m.context_id.as_ref().is_some_and(|ctx| {
739 task_store.get(&target_task_id).now_or_never()
740 .flatten()
741 .is_some_and(|t| t.context_id == *ctx)
742 })
743 }
744 };
745
746 if matches {
747 if let Ok(json) = serde_json::to_string(&event) {
748 yield Ok(Event::default().data(json));
749 }
750
751 if let StreamResponse::Task(t) = &event {
752 if t.status.state.is_terminal() {
753 break;
754 }
755 }
756 if let StreamResponse::StatusUpdate(e) = &event {
757 if e.status.state.is_terminal() {
758 break;
759 }
760 }
761 }
762 }
763 Err(broadcast::error::RecvError::Lagged(_)) => continue,
764 Err(broadcast::error::RecvError::Closed) => break,
765 }
766 }
767 };
768
769 Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
770}
771
772async fn handle_tasks_get(
773 state: AppState,
774 req: JsonRpcRequest,
775) -> (StatusCode, Json<JsonRpcResponse>) {
776 let params: Result<GetTaskRequest, _> =
777 serde_json::from_value(req.params.unwrap_or_default());
778
779 match params {
780 Ok(p) => {
781 match state.task_store.get_flexible(&p.id).await {
782 Some(mut task) => {
783 apply_history_length(&mut task, p.history_length);
784
785 match serde_json::to_value(task) {
786 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
787 Err(e) => (
788 StatusCode::INTERNAL_SERVER_ERROR,
789 Json(error(
790 req.id,
791 errors::INTERNAL_ERROR,
792 "serialization failed",
793 Some(serde_json::json!({"error": e.to_string()})),
794 )),
795 ),
796 }
797 }
798 None => (
799 StatusCode::NOT_FOUND,
800 Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
801 ),
802 }
803 }
804 Err(err) => (
805 StatusCode::BAD_REQUEST,
806 Json(error(
807 req.id,
808 errors::INVALID_PARAMS,
809 "invalid params",
810 Some(serde_json::json!({"error": err.to_string()})),
811 )),
812 ),
813 }
814}
815
816async fn handle_tasks_cancel(
817 state: AppState,
818 req: JsonRpcRequest,
819) -> (StatusCode, Json<JsonRpcResponse>) {
820 let params: Result<CancelTaskRequest, _> =
821 serde_json::from_value(req.params.unwrap_or_default());
822
823 match params {
824 Ok(p) => {
825 let result = state
826 .task_store
827 .update_flexible(&p.id, |task| {
828 if task.status.state.is_terminal() {
829 return Err(errors::TASK_NOT_CANCELABLE);
830 }
831 task.status.state = TaskState::Canceled;
832 task.status.timestamp = Some(now_iso8601());
833 Ok(())
834 })
835 .await;
836
837 match result {
838 Some(Ok(task)) => {
839 if let Err(e) = state.handler.cancel_task(&task.id).await {
840 tracing::warn!("Handler cancel_task failed: {}", e);
841 }
842
843 state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
844 task_id: task.id.clone(),
845 context_id: task.context_id.clone(),
846 status: task.status.clone(),
847 metadata: None,
848 }));
849
850 match serde_json::to_value(task) {
851 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
852 Err(e) => (
853 StatusCode::INTERNAL_SERVER_ERROR,
854 Json(error(
855 req.id,
856 errors::INTERNAL_ERROR,
857 "serialization failed",
858 Some(serde_json::json!({"error": e.to_string()})),
859 )),
860 ),
861 }
862 }
863 Some(Err(error_code)) => (
864 StatusCode::BAD_REQUEST,
865 Json(error(req.id, error_code, "task not cancelable", None)),
866 ),
867 None => (
868 StatusCode::NOT_FOUND,
869 Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
870 ),
871 }
872 }
873 Err(err) => (
874 StatusCode::BAD_REQUEST,
875 Json(error(
876 req.id,
877 errors::INVALID_PARAMS,
878 "invalid params",
879 Some(serde_json::json!({"error": err.to_string()})),
880 )),
881 ),
882 }
883}
884
885async fn handle_tasks_list(
886 state: AppState,
887 req: JsonRpcRequest,
888) -> (StatusCode, Json<JsonRpcResponse>) {
889 let params: Result<ListTasksRequest, _> =
890 serde_json::from_value(req.params.unwrap_or_default());
891
892 match params {
893 Ok(p) => {
894 let response = state.task_store.list_filtered(&p).await;
895 match serde_json::to_value(response) {
896 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
897 Err(e) => (
898 StatusCode::INTERNAL_SERVER_ERROR,
899 Json(error(
900 req.id,
901 errors::INTERNAL_ERROR,
902 "serialization failed",
903 Some(serde_json::json!({"error": e.to_string()})),
904 )),
905 ),
906 }
907 }
908 Err(err) => (
909 StatusCode::BAD_REQUEST,
910 Json(error(
911 req.id,
912 errors::INVALID_PARAMS,
913 "invalid params",
914 Some(serde_json::json!({"error": err.to_string()})),
915 )),
916 ),
917 }
918}
919
920async fn handle_tasks_subscribe(
921 state: AppState,
922 req: JsonRpcRequest,
923) -> (StatusCode, Json<JsonRpcResponse>) {
924 let params: Result<SubscribeToTaskRequest, _> =
925 serde_json::from_value(req.params.unwrap_or_default());
926
927 match params {
928 Ok(p) => {
929 if state.task_store.get_flexible(&p.id).await.is_none() {
930 return (
931 StatusCode::NOT_FOUND,
932 Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
933 );
934 }
935
936 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
937 let subscribe_url = format!("{}/v1/tasks/{}/subscribe", base_url, p.id);
938
939 (
940 StatusCode::OK,
941 Json(success(
942 req.id,
943 serde_json::json!({
944 "subscribeUrl": subscribe_url,
945 "protocol": "sse"
946 }),
947 )),
948 )
949 }
950 Err(err) => (
951 StatusCode::BAD_REQUEST,
952 Json(error(
953 req.id,
954 errors::INVALID_PARAMS,
955 "invalid params",
956 Some(serde_json::json!({"error": err.to_string()})),
957 )),
958 ),
959 }
960}
961
962async fn handle_get_extended_agent_card(
963 state: AppState,
964 req: JsonRpcRequest,
965 auth_context: Option<AuthContext>,
966) -> (StatusCode, Json<JsonRpcResponse>) {
967 let Some(auth) = auth_context else {
968 return (
969 StatusCode::UNAUTHORIZED,
970 Json(error(
971 req.id,
972 errors::INVALID_REQUEST,
973 "authentication required for extended agent card",
974 None,
975 )),
976 );
977 };
978
979 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
980
981 match state.handler.extended_agent_card(base_url, &auth).await {
982 Some(card) => match serde_json::to_value(card) {
983 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
984 Err(e) => (
985 StatusCode::INTERNAL_SERVER_ERROR,
986 Json(error(
987 req.id,
988 errors::INTERNAL_ERROR,
989 "serialization failed",
990 Some(serde_json::json!({"error": e.to_string()})),
991 )),
992 ),
993 },
994 None => (
995 StatusCode::NOT_FOUND,
996 Json(error(
997 req.id,
998 errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
999 "extended agent card not configured",
1000 None,
1001 )),
1002 ),
1003 }
1004}
1005
1006async fn handle_push_config_create(
1009 state: AppState,
1010 req: JsonRpcRequest,
1011) -> (StatusCode, Json<JsonRpcResponse>) {
1012 if !state.push_notifications_enabled() {
1013 return (
1014 StatusCode::BAD_REQUEST,
1015 Json(error(
1016 req.id,
1017 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1018 "push notifications not supported",
1019 None,
1020 )),
1021 );
1022 }
1023
1024 let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1025 serde_json::from_value(req.params.unwrap_or_default());
1026
1027 match params {
1028 Ok(p) => {
1029 if state.task_store.get_flexible(&p.task_id).await.is_none() {
1030 return (
1031 StatusCode::NOT_FOUND,
1032 Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
1033 );
1034 }
1035
1036 if let Err(e) = state
1037 .webhook_store
1038 .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1039 .await
1040 {
1041 return (
1042 StatusCode::BAD_REQUEST,
1043 Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1044 );
1045 }
1046
1047 (
1048 StatusCode::OK,
1049 Json(success(
1050 req.id,
1051 serde_json::json!({
1052 "configId": p.config_id,
1053 "config": p.push_notification_config
1054 }),
1055 )),
1056 )
1057 }
1058 Err(err) => (
1059 StatusCode::BAD_REQUEST,
1060 Json(error(
1061 req.id,
1062 errors::INVALID_PARAMS,
1063 "invalid params",
1064 Some(serde_json::json!({"error": err.to_string()})),
1065 )),
1066 ),
1067 }
1068}
1069
1070async fn handle_push_config_get(
1071 state: AppState,
1072 req: JsonRpcRequest,
1073) -> (StatusCode, Json<JsonRpcResponse>) {
1074 if !state.push_notifications_enabled() {
1075 return (
1076 StatusCode::BAD_REQUEST,
1077 Json(error(
1078 req.id,
1079 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1080 "push notifications not supported",
1081 None,
1082 )),
1083 );
1084 }
1085
1086 let params: Result<GetTaskPushNotificationConfigRequest, _> =
1087 serde_json::from_value(req.params.unwrap_or_default());
1088
1089 match params {
1090 Ok(p) => {
1091 match state.webhook_store.get(&p.task_id, &p.id).await {
1092 Some(config) => (
1093 StatusCode::OK,
1094 Json(success(
1095 req.id,
1096 serde_json::json!({
1097 "configId": p.id,
1098 "config": config
1099 }),
1100 )),
1101 ),
1102 None => (
1103 StatusCode::NOT_FOUND,
1104 Json(error(req.id, errors::TASK_NOT_FOUND, "push notification config not found", None)),
1105 ),
1106 }
1107 }
1108 Err(err) => (
1109 StatusCode::BAD_REQUEST,
1110 Json(error(
1111 req.id,
1112 errors::INVALID_PARAMS,
1113 "invalid params",
1114 Some(serde_json::json!({"error": err.to_string()})),
1115 )),
1116 ),
1117 }
1118}
1119
1120async fn handle_push_config_list(
1121 state: AppState,
1122 req: JsonRpcRequest,
1123) -> (StatusCode, Json<JsonRpcResponse>) {
1124 if !state.push_notifications_enabled() {
1125 return (
1126 StatusCode::BAD_REQUEST,
1127 Json(error(
1128 req.id,
1129 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1130 "push notifications not supported",
1131 None,
1132 )),
1133 );
1134 }
1135
1136 let params: Result<ListTaskPushNotificationConfigRequest, _> =
1137 serde_json::from_value(req.params.unwrap_or_default());
1138
1139 match params {
1140 Ok(p) => {
1141 let configs = state.webhook_store.list(&p.task_id).await;
1142
1143 let configs_json: Vec<_> = configs
1144 .iter()
1145 .map(|c| {
1146 serde_json::json!({
1147 "configId": c.config_id,
1148 "config": c.config
1149 })
1150 })
1151 .collect();
1152
1153 (
1154 StatusCode::OK,
1155 Json(success(
1156 req.id,
1157 serde_json::json!({
1158 "configs": configs_json,
1159 "nextPageToken": ""
1160 }),
1161 )),
1162 )
1163 }
1164 Err(err) => (
1165 StatusCode::BAD_REQUEST,
1166 Json(error(
1167 req.id,
1168 errors::INVALID_PARAMS,
1169 "invalid params",
1170 Some(serde_json::json!({"error": err.to_string()})),
1171 )),
1172 ),
1173 }
1174}
1175
1176async fn handle_push_config_delete(
1177 state: AppState,
1178 req: JsonRpcRequest,
1179) -> (StatusCode, Json<JsonRpcResponse>) {
1180 if !state.push_notifications_enabled() {
1181 return (
1182 StatusCode::BAD_REQUEST,
1183 Json(error(
1184 req.id,
1185 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1186 "push notifications not supported",
1187 None,
1188 )),
1189 );
1190 }
1191
1192 let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1193 serde_json::from_value(req.params.unwrap_or_default());
1194
1195 match params {
1196 Ok(p) => {
1197 if state.webhook_store.delete(&p.task_id, &p.id).await {
1198 (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1199 } else {
1200 (
1201 StatusCode::NOT_FOUND,
1202 Json(error(req.id, errors::TASK_NOT_FOUND, "push notification config not found", None)),
1203 )
1204 }
1205 }
1206 Err(err) => (
1207 StatusCode::BAD_REQUEST,
1208 Json(error(
1209 req.id,
1210 errors::INVALID_PARAMS,
1211 "invalid params",
1212 Some(serde_json::json!({"error": err.to_string()})),
1213 )),
1214 ),
1215 }
1216}
1217
1218pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1219 bind_addr: &str,
1220 handler: H,
1221) -> anyhow::Result<()> {
1222 A2aServer::new(handler).bind(bind_addr)?.run().await
1223}
1224
1225pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1226 A2aServer::echo().bind(bind_addr)?.run().await
1227}