1use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17 error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18 CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19 GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20 ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21 SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22 TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use futures::stream::Stream;
37use tokio::sync::broadcast;
38use tracing::info;
39
40use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
41use crate::task_store::TaskStore;
42use crate::webhook_delivery::WebhookDelivery;
43use crate::webhook_store::WebhookStore;
44
45#[derive(Debug, Clone)]
46pub struct ServerConfig {
47 pub bind_address: String,
48}
49
50impl Default for ServerConfig {
51 fn default() -> Self {
52 Self {
53 bind_address: "0.0.0.0:8080".to_string(),
54 }
55 }
56}
57
58pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
59
60const EVENT_CHANNEL_CAPACITY: usize = 1024;
61
62pub struct A2aServer {
63 config: ServerConfig,
64 handler: BoxedHandler,
65 task_store: TaskStore,
66 webhook_store: WebhookStore,
67 auth_extractor: Option<AuthExtractor>,
68 additional_routes: Option<Router<AppState>>,
69 event_tx: broadcast::Sender<StreamResponse>,
70}
71
72impl A2aServer {
73 pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
74 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
75 Self {
76 config: ServerConfig::default(),
77 handler: Arc::new(handler),
78 task_store: TaskStore::new(),
79 webhook_store: WebhookStore::new(),
80 auth_extractor: None,
81 additional_routes: None,
82 event_tx,
83 }
84 }
85
86 pub fn echo() -> Self {
87 Self::new(EchoHandler::default())
88 }
89
90 pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
91 let _: SocketAddr = address.parse()?;
92 self.config.bind_address = address.to_string();
93 Ok(self)
94 }
95
96 pub fn bind_unchecked(mut self, address: &str) -> Self {
97 self.config.bind_address = address.to_string();
98 self
99 }
100
101 pub fn task_store(mut self, store: TaskStore) -> Self {
102 self.task_store = store;
103 self
104 }
105
106 pub fn auth_extractor<F>(mut self, extractor: F) -> Self
107 where
108 F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
109 {
110 self.auth_extractor = Some(Arc::new(extractor));
111 self
112 }
113
114 pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
115 self.additional_routes = Some(routes);
116 self
117 }
118
119 pub fn get_task_store(&self) -> TaskStore {
120 self.task_store.clone()
121 }
122
123 pub fn build_router(self) -> Router {
124 let bind: SocketAddr = self
125 .config
126 .bind_address
127 .parse()
128 .expect("Invalid bind address");
129 let base_url = format!("http://{}", bind);
130 let card = Arc::new(self.handler.agent_card(&base_url));
131
132 let state = AppState {
133 handler: self.handler,
134 task_store: self.task_store,
135 webhook_store: self.webhook_store,
136 card,
137 auth_extractor: self.auth_extractor,
138 event_tx: self.event_tx,
139 };
140
141 let timed_routes = Router::new()
142 .route("/health", get(health))
143 .route("/.well-known/agent-card.json", get(agent_card))
144 .route("/v1/rpc", post(handle_rpc))
145 .layer(
146 ServiceBuilder::new()
147 .layer(HandleErrorLayer::new(handle_timeout_error))
148 .timeout(Duration::from_secs(30)),
149 );
150
151 let sse_routes = Router::new().route(
152 "/v1/tasks/:task_id/subscribe",
153 get(handle_task_subscribe_sse),
154 );
155
156 let mut router = timed_routes.merge(sse_routes);
157
158 if let Some(additional) = self.additional_routes {
159 router = router.merge(additional);
160 }
161
162 router.with_state(state)
163 }
164
165 pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
166 self.event_tx.clone()
167 }
168
169 pub async fn run(self) -> anyhow::Result<()> {
170 let bind: SocketAddr = self.config.bind_address.parse()?;
171
172 let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
173 webhook_delivery.start(self.event_tx.subscribe());
174
175 let router = self.build_router();
176
177 info!(%bind, "Starting A2A server");
178 let listener = tokio::net::TcpListener::bind(bind).await?;
179 axum::serve(listener, router)
180 .with_graceful_shutdown(shutdown_signal())
181 .await?;
182
183 info!("Server shutdown complete");
184 Ok(())
185 }
186}
187
188async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
189 if err.is::<tower::timeout::error::Elapsed>() {
190 (
191 StatusCode::REQUEST_TIMEOUT,
192 Json(error(
193 serde_json::Value::Null,
194 errors::INTERNAL_ERROR,
195 "Request timed out",
196 None,
197 )),
198 )
199 } else {
200 (
201 StatusCode::INTERNAL_SERVER_ERROR,
202 Json(error(
203 serde_json::Value::Null,
204 errors::INTERNAL_ERROR,
205 &format!("Internal error: {}", err),
206 None,
207 )),
208 )
209 }
210}
211
212async fn shutdown_signal() {
213 let ctrl_c = async {
214 signal::ctrl_c()
215 .await
216 .expect("failed to install Ctrl+C handler");
217 };
218
219 #[cfg(unix)]
220 let terminate = async {
221 signal::unix::signal(signal::unix::SignalKind::terminate())
222 .expect("failed to install SIGTERM handler")
223 .recv()
224 .await;
225 };
226
227 #[cfg(not(unix))]
228 let terminate = std::future::pending::<()>();
229
230 tokio::select! {
231 _ = ctrl_c => {},
232 _ = terminate => {},
233 }
234
235 info!("Shutdown signal received, initiating graceful shutdown...");
236}
237
238#[derive(Clone)]
239pub struct AppState {
240 handler: BoxedHandler,
241 task_store: TaskStore,
242 webhook_store: WebhookStore,
243 card: Arc<AgentCard>,
244 auth_extractor: Option<AuthExtractor>,
245 event_tx: broadcast::Sender<StreamResponse>,
246}
247
248impl AppState {
249 pub fn task_store(&self) -> &TaskStore {
250 &self.task_store
251 }
252
253 pub fn agent_card(&self) -> &AgentCard {
254 &self.card
255 }
256
257 pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
258 &self.event_tx
259 }
260
261 pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
262 self.event_tx.subscribe()
263 }
264
265 pub fn broadcast_event(&self, event: StreamResponse) {
266 let _ = self.event_tx.send(event);
267 }
268
269 fn streaming_enabled(&self) -> bool {
271 self.card.capabilities.streaming.unwrap_or(false)
272 }
273
274 fn push_notifications_enabled(&self) -> bool {
276 self.card.capabilities.push_notifications.unwrap_or(false)
277 }
278
279 fn endpoint_url(&self) -> &str {
281 self.card.endpoint().unwrap_or("")
282 }
283}
284
285fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
288 match history_length {
289 Some(0) => {
290 task.history = None;
291 }
292 Some(n) => {
293 if let Some(ref mut history) = task.history {
294 let len = history.len();
295 if len > n as usize {
296 *history = history.split_off(len - n as usize);
297 }
298 }
299 }
300 None => {}
301 }
302}
303
304fn validate_a2a_version(
307 headers: &HeaderMap,
308 req_id: &serde_json::Value,
309) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
310 if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
311 let version_str = version_header.to_str().unwrap_or("");
312
313 let parts: Vec<&str> = version_str.split('.').collect();
314 if parts.len() >= 2 {
315 if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
316 if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
317 return Ok(());
318 }
319
320 return Err((
321 StatusCode::BAD_REQUEST,
322 Json(error(
323 req_id.clone(),
324 errors::VERSION_NOT_SUPPORTED,
325 &format!(
326 "Protocol version {}.{} not supported. Supported version: {}.{}",
327 major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
328 ),
329 Some(serde_json::json!({
330 "requestedVersion": version_str,
331 "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
332 })),
333 )),
334 ));
335 }
336 }
337
338 if version_str == "1" || version_str == "1.0" {
340 return Ok(());
341 }
342
343 return Err((
344 StatusCode::BAD_REQUEST,
345 Json(error(
346 req_id.clone(),
347 errors::VERSION_NOT_SUPPORTED,
348 &format!(
349 "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
350 version_str
351 ),
352 None,
353 )),
354 ));
355 }
356
357 Ok(())
358}
359
360#[allow(dead_code)]
363pub fn rpc_error(
364 id: serde_json::Value,
365 code: i32,
366 message: &str,
367 status: StatusCode,
368) -> (StatusCode, Json<JsonRpcResponse>) {
369 (status, Json(error(id, code, message, None)))
370}
371
372#[allow(dead_code)]
373pub fn rpc_error_with_data(
374 id: serde_json::Value,
375 code: i32,
376 message: &str,
377 data: serde_json::Value,
378 status: StatusCode,
379) -> (StatusCode, Json<JsonRpcResponse>) {
380 (status, Json(error(id, code, message, Some(data))))
381}
382
383#[allow(dead_code)]
384pub fn rpc_success(
385 id: serde_json::Value,
386 result: serde_json::Value,
387) -> (StatusCode, Json<JsonRpcResponse>) {
388 (StatusCode::OK, Json(success(id, result)))
389}
390
391async fn health() -> Json<serde_json::Value> {
394 Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
395}
396
397async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
398 Json((*state.card).clone())
399}
400
401async fn handle_task_subscribe_sse(
402 State(state): State<AppState>,
403 axum::extract::Path(task_id): axum::extract::Path<String>,
404) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
405 let mut rx = state.subscribe_events();
406 let task_store = state.task_store.clone();
407 let target_task_id = task_id;
408
409 let stream = async_stream::stream! {
410 if let Some(task) = task_store.get_flexible(&target_task_id).await {
411 if let Ok(json) = serde_json::to_string(&task) {
412 yield Ok(Event::default().data(json));
413 }
414 }
415
416 loop {
417 match rx.recv().await {
418 Ok(event) => {
419 let matches = match &event {
420 StreamResponse::Task(t) => t.id == target_task_id,
421 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
422 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
423 StreamResponse::Message(_) => false,
424 };
425 if matches {
426 let json = match &event {
428 StreamResponse::Task(t) => serde_json::to_string(t),
429 StreamResponse::Message(m) => serde_json::to_string(m),
430 StreamResponse::StatusUpdate(e) => serde_json::to_string(e),
431 StreamResponse::ArtifactUpdate(e) => serde_json::to_string(e),
432 };
433 if let Ok(json) = json {
434 yield Ok(Event::default().data(json));
435 }
436
437 let is_terminal = match &event {
438 StreamResponse::Task(t) => t.status.state.is_terminal(),
439 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
440 _ => false,
441 };
442 if is_terminal {
443 break;
444 }
445 }
446 }
447 Err(broadcast::error::RecvError::Lagged(_)) => continue,
448 Err(broadcast::error::RecvError::Closed) => break,
449 }
450 }
451 };
452
453 Sse::new(stream).keep_alive(KeepAlive::default())
454}
455
456async fn handle_rpc(
457 State(state): State<AppState>,
458 headers: HeaderMap,
459 Json(req): Json<JsonRpcRequest>,
460) -> Response {
461 if req.jsonrpc != "2.0" {
462 let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
463 return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
464 }
465
466 if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
467 return err_response.into_response();
468 }
469
470 let auth_context = state
471 .auth_extractor
472 .as_ref()
473 .and_then(|extractor| extractor(&headers));
474
475 match req.method.as_str() {
476 "message/send" => handle_message_send(state, req, auth_context)
478 .await
479 .into_response(),
480 "message/stream" => {
481 handle_message_stream(state, req, headers, auth_context)
482 .await
483 .into_response()
484 }
485 "tasks/get" => handle_tasks_get(state, req).await.into_response(),
486 "tasks/list" => handle_tasks_list(state, req).await.into_response(),
487 "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
488 "tasks/subscribe" => handle_tasks_subscribe(state, req).await.into_response(),
489 "tasks/pushNotificationConfig/create" => {
490 handle_push_config_create(state, req).await.into_response()
491 }
492 "tasks/pushNotificationConfig/get" => {
493 handle_push_config_get(state, req).await.into_response()
494 }
495 "tasks/pushNotificationConfig/list" => {
496 handle_push_config_list(state, req).await.into_response()
497 }
498 "tasks/pushNotificationConfig/delete" => {
499 handle_push_config_delete(state, req).await.into_response()
500 }
501 "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context)
502 .await
503 .into_response(),
504 _ => (
505 StatusCode::NOT_FOUND,
506 Json(error(
507 req.id,
508 errors::METHOD_NOT_FOUND,
509 "method not found",
510 None,
511 )),
512 )
513 .into_response(),
514 }
515}
516
517fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
518 match e {
519 HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
520 HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
521 HandlerError::BackendUnavailable { .. } => {
522 (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
523 }
524 HandlerError::ProcessingFailed { .. } => {
525 (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
526 }
527 HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
528 }
529}
530
531async fn handle_message_send(
532 state: AppState,
533 req: JsonRpcRequest,
534 auth_context: Option<AuthContext>,
535) -> (StatusCode, Json<JsonRpcResponse>) {
536 let req_id = req.id.clone();
537
538 let params: Result<SendMessageRequest, _> =
539 serde_json::from_value(req.params.clone().unwrap_or_default());
540
541 let params = match params {
542 Ok(p) => p,
543 Err(err) => {
544 return (
545 StatusCode::BAD_REQUEST,
546 Json(error(
547 req_id,
548 errors::INVALID_PARAMS,
549 "invalid params",
550 Some(serde_json::json!({"error": err.to_string()})),
551 )),
552 );
553 }
554 };
555
556 let blocking = params
557 .configuration
558 .as_ref()
559 .and_then(|c| c.blocking)
560 .unwrap_or(false);
561 let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
562
563 match state
564 .handler
565 .handle_message(params.message, auth_context)
566 .await
567 {
568 Ok(response) => {
569 match response {
570 SendMessageResponse::Task(mut task) => {
571 state.task_store.insert(task.clone()).await;
573 state.broadcast_event(StreamResponse::Task(task.clone()));
574
575 if blocking && !task.status.state.is_terminal() {
577 let task_id = task.id.clone();
578 let mut rx = state.subscribe_events();
579
580 let wait_result = timeout(BLOCKING_TIMEOUT, async {
581 loop {
582 tokio::select! {
583 result = rx.recv() => {
584 match result {
585 Ok(StreamResponse::Task(t)) if t.id == task_id => {
586 if t.status.state.is_terminal() {
587 return Some(t);
588 }
589 }
590 Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
591 if e.status.state.is_terminal() {
592 if let Some(t) = state.task_store.get(&task_id).await {
593 return Some(t);
594 }
595 }
596 }
597 Err(broadcast::error::RecvError::Closed) => {
598 return None;
599 }
600 _ => {}
601 }
602 }
603 _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
604 if let Some(t) = state.task_store.get(&task_id).await {
605 if t.status.state.is_terminal() {
606 return Some(t);
607 }
608 }
609 }
610 }
611 }
612 })
613 .await;
614
615 match wait_result {
616 Ok(Some(final_task)) => task = final_task,
617 Ok(None) => {
618 if let Some(t) = state.task_store.get(&task.id).await {
619 task = t;
620 }
621 }
622 Err(_) => {
623 tracing::warn!("Blocking request timed out for task {}", task.id);
624 if let Some(t) = state.task_store.get(&task.id).await {
625 task = t;
626 }
627 }
628 }
629 }
630
631 apply_history_length(&mut task, history_length);
632
633 match serde_json::to_value(&task) {
636 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
637 Err(e) => (
638 StatusCode::INTERNAL_SERVER_ERROR,
639 Json(error(
640 req_id,
641 errors::INTERNAL_ERROR,
642 "serialization failed",
643 Some(serde_json::json!({"error": e.to_string()})),
644 )),
645 ),
646 }
647 }
648 SendMessageResponse::Message(msg) => {
649 match serde_json::to_value(&msg) {
651 Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
652 Err(e) => (
653 StatusCode::INTERNAL_SERVER_ERROR,
654 Json(error(
655 req_id,
656 errors::INTERNAL_ERROR,
657 "serialization failed",
658 Some(serde_json::json!({"error": e.to_string()})),
659 )),
660 ),
661 }
662 }
663 }
664 }
665 Err(e) => {
666 let (code, status) = handler_error_to_rpc(&e);
667 (status, Json(error(req_id, code, &e.to_string(), None)))
668 }
669 }
670}
671
672async fn handle_message_stream(
677 state: AppState,
678 req: JsonRpcRequest,
679 _headers: HeaderMap,
680 auth_context: Option<AuthContext>,
681) -> Response {
682 let req_id = req.id.clone();
683
684 if !state.streaming_enabled() {
685 return (
686 StatusCode::BAD_REQUEST,
687 Json(error(
688 req_id,
689 errors::UNSUPPORTED_OPERATION,
690 "streaming not supported by this agent",
691 None,
692 )),
693 )
694 .into_response();
695 }
696
697 let params: Result<SendMessageRequest, _> =
698 serde_json::from_value(req.params.clone().unwrap_or_default());
699
700 let params = match params {
701 Ok(p) => p,
702 Err(err) => {
703 return (
704 StatusCode::BAD_REQUEST,
705 Json(error(
706 req_id,
707 errors::INVALID_PARAMS,
708 "invalid params",
709 Some(serde_json::json!({"error": err.to_string()})),
710 )),
711 )
712 .into_response();
713 }
714 };
715
716 let response = match state
717 .handler
718 .handle_message(params.message, auth_context)
719 .await
720 {
721 Ok(r) => r,
722 Err(e) => {
723 let (code, status) = handler_error_to_rpc(&e);
724 return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
725 }
726 };
727
728 let task = match response {
730 SendMessageResponse::Task(t) => t,
731 SendMessageResponse::Message(_) => {
732 return (
733 StatusCode::BAD_REQUEST,
734 Json(error(
735 req_id,
736 errors::UNSUPPORTED_OPERATION,
737 "handler returned a message, streaming requires a task",
738 None,
739 )),
740 )
741 .into_response();
742 }
743 };
744
745 let task_id = task.id.clone();
746 state.task_store.insert(task.clone()).await;
747 state.broadcast_event(StreamResponse::Task(task.clone()));
748
749 let mut rx = state.subscribe_events();
750 let task_store = state.task_store.clone();
751 let target_task_id = task_id;
752
753 let wrap = move |value: serde_json::Value| -> String {
755 serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
756 };
757
758 let stream = async_stream::stream! {
759 if let Ok(val) = serde_json::to_value(&task) {
761 yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
762 }
763
764 loop {
765 match rx.recv().await {
766 Ok(event) => {
767 let matches = match &event {
768 StreamResponse::Task(t) => t.id == target_task_id,
769 StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
770 StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
771 StreamResponse::Message(m) => {
772 m.context_id.as_ref().is_some_and(|ctx| {
773 task_store.get(&target_task_id).now_or_never()
774 .flatten()
775 .is_some_and(|t| t.context_id == *ctx)
776 })
777 }
778 };
779
780 if matches {
781 let val = match &event {
783 StreamResponse::Task(t) => serde_json::to_value(t),
784 StreamResponse::Message(m) => serde_json::to_value(m),
785 StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
786 StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
787 };
788 if let Ok(val) = val {
789 yield Ok(Event::default().data(wrap(val)));
790 }
791
792 let is_terminal = match &event {
794 StreamResponse::Task(t) => t.status.state.is_terminal(),
795 StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
796 _ => false,
797 };
798 if is_terminal {
799 break;
800 }
801 }
802 }
803 Err(broadcast::error::RecvError::Lagged(_)) => continue,
804 Err(broadcast::error::RecvError::Closed) => break,
805 }
806 }
807 };
808
809 Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
810}
811
812async fn handle_tasks_get(
813 state: AppState,
814 req: JsonRpcRequest,
815) -> (StatusCode, Json<JsonRpcResponse>) {
816 let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
817
818 match params {
819 Ok(p) => match state.task_store.get_flexible(&p.id).await {
820 Some(mut task) => {
821 apply_history_length(&mut task, p.history_length);
822
823 match serde_json::to_value(task) {
824 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
825 Err(e) => (
826 StatusCode::INTERNAL_SERVER_ERROR,
827 Json(error(
828 req.id,
829 errors::INTERNAL_ERROR,
830 "serialization failed",
831 Some(serde_json::json!({"error": e.to_string()})),
832 )),
833 ),
834 }
835 }
836 None => (
837 StatusCode::NOT_FOUND,
838 Json(error(
839 req.id,
840 errors::TASK_NOT_FOUND,
841 "task not found",
842 None,
843 )),
844 ),
845 },
846 Err(err) => (
847 StatusCode::BAD_REQUEST,
848 Json(error(
849 req.id,
850 errors::INVALID_PARAMS,
851 "invalid params",
852 Some(serde_json::json!({"error": err.to_string()})),
853 )),
854 ),
855 }
856}
857
858async fn handle_tasks_cancel(
859 state: AppState,
860 req: JsonRpcRequest,
861) -> (StatusCode, Json<JsonRpcResponse>) {
862 let params: Result<CancelTaskRequest, _> =
863 serde_json::from_value(req.params.unwrap_or_default());
864
865 match params {
866 Ok(p) => {
867 let result = state
868 .task_store
869 .update_flexible(&p.id, |task| {
870 if task.status.state.is_terminal() {
871 return Err(errors::TASK_NOT_CANCELABLE);
872 }
873 task.status.state = TaskState::Canceled;
874 task.status.timestamp = Some(now_iso8601());
875 Ok(())
876 })
877 .await;
878
879 match result {
880 Some(Ok(task)) => {
881 if let Err(e) = state.handler.cancel_task(&task.id).await {
882 tracing::warn!("Handler cancel_task failed: {}", e);
883 }
884
885 state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
886 kind: "status-update".to_string(),
887 task_id: task.id.clone(),
888 context_id: task.context_id.clone(),
889 status: task.status.clone(),
890 is_final: true,
891 metadata: None,
892 }));
893
894 match serde_json::to_value(task) {
895 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
896 Err(e) => (
897 StatusCode::INTERNAL_SERVER_ERROR,
898 Json(error(
899 req.id,
900 errors::INTERNAL_ERROR,
901 "serialization failed",
902 Some(serde_json::json!({"error": e.to_string()})),
903 )),
904 ),
905 }
906 }
907 Some(Err(error_code)) => (
908 StatusCode::BAD_REQUEST,
909 Json(error(req.id, error_code, "task not cancelable", None)),
910 ),
911 None => (
912 StatusCode::NOT_FOUND,
913 Json(error(
914 req.id,
915 errors::TASK_NOT_FOUND,
916 "task not found",
917 None,
918 )),
919 ),
920 }
921 }
922 Err(err) => (
923 StatusCode::BAD_REQUEST,
924 Json(error(
925 req.id,
926 errors::INVALID_PARAMS,
927 "invalid params",
928 Some(serde_json::json!({"error": err.to_string()})),
929 )),
930 ),
931 }
932}
933
934async fn handle_tasks_list(
935 state: AppState,
936 req: JsonRpcRequest,
937) -> (StatusCode, Json<JsonRpcResponse>) {
938 let params: Result<ListTasksRequest, _> =
939 serde_json::from_value(req.params.unwrap_or_default());
940
941 match params {
942 Ok(p) => {
943 let response = state.task_store.list_filtered(&p).await;
944 match serde_json::to_value(response) {
945 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
946 Err(e) => (
947 StatusCode::INTERNAL_SERVER_ERROR,
948 Json(error(
949 req.id,
950 errors::INTERNAL_ERROR,
951 "serialization failed",
952 Some(serde_json::json!({"error": e.to_string()})),
953 )),
954 ),
955 }
956 }
957 Err(err) => (
958 StatusCode::BAD_REQUEST,
959 Json(error(
960 req.id,
961 errors::INVALID_PARAMS,
962 "invalid params",
963 Some(serde_json::json!({"error": err.to_string()})),
964 )),
965 ),
966 }
967}
968
969async fn handle_tasks_subscribe(
970 state: AppState,
971 req: JsonRpcRequest,
972) -> (StatusCode, Json<JsonRpcResponse>) {
973 let params: Result<SubscribeToTaskRequest, _> =
974 serde_json::from_value(req.params.unwrap_or_default());
975
976 match params {
977 Ok(p) => {
978 if state.task_store.get_flexible(&p.id).await.is_none() {
979 return (
980 StatusCode::NOT_FOUND,
981 Json(error(
982 req.id,
983 errors::TASK_NOT_FOUND,
984 "task not found",
985 None,
986 )),
987 );
988 }
989
990 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
991 let subscribe_url = format!("{}/v1/tasks/{}/subscribe", base_url, p.id);
992
993 (
994 StatusCode::OK,
995 Json(success(
996 req.id,
997 serde_json::json!({
998 "subscribeUrl": subscribe_url,
999 "protocol": "sse"
1000 }),
1001 )),
1002 )
1003 }
1004 Err(err) => (
1005 StatusCode::BAD_REQUEST,
1006 Json(error(
1007 req.id,
1008 errors::INVALID_PARAMS,
1009 "invalid params",
1010 Some(serde_json::json!({"error": err.to_string()})),
1011 )),
1012 ),
1013 }
1014}
1015
1016async fn handle_get_extended_agent_card(
1017 state: AppState,
1018 req: JsonRpcRequest,
1019 auth_context: Option<AuthContext>,
1020) -> (StatusCode, Json<JsonRpcResponse>) {
1021 let Some(auth) = auth_context else {
1022 return (
1023 StatusCode::UNAUTHORIZED,
1024 Json(error(
1025 req.id,
1026 errors::INVALID_REQUEST,
1027 "authentication required for extended agent card",
1028 None,
1029 )),
1030 );
1031 };
1032
1033 let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1034
1035 match state.handler.extended_agent_card(base_url, &auth).await {
1036 Some(card) => match serde_json::to_value(card) {
1037 Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1038 Err(e) => (
1039 StatusCode::INTERNAL_SERVER_ERROR,
1040 Json(error(
1041 req.id,
1042 errors::INTERNAL_ERROR,
1043 "serialization failed",
1044 Some(serde_json::json!({"error": e.to_string()})),
1045 )),
1046 ),
1047 },
1048 None => (
1049 StatusCode::NOT_FOUND,
1050 Json(error(
1051 req.id,
1052 errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1053 "extended agent card not configured",
1054 None,
1055 )),
1056 ),
1057 }
1058}
1059
1060async fn handle_push_config_create(
1063 state: AppState,
1064 req: JsonRpcRequest,
1065) -> (StatusCode, Json<JsonRpcResponse>) {
1066 if !state.push_notifications_enabled() {
1067 return (
1068 StatusCode::BAD_REQUEST,
1069 Json(error(
1070 req.id,
1071 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1072 "push notifications not supported",
1073 None,
1074 )),
1075 );
1076 }
1077
1078 let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1079 serde_json::from_value(req.params.unwrap_or_default());
1080
1081 match params {
1082 Ok(p) => {
1083 if state.task_store.get_flexible(&p.task_id).await.is_none() {
1084 return (
1085 StatusCode::NOT_FOUND,
1086 Json(error(
1087 req.id,
1088 errors::TASK_NOT_FOUND,
1089 "task not found",
1090 None,
1091 )),
1092 );
1093 }
1094
1095 if let Err(e) = state
1096 .webhook_store
1097 .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1098 .await
1099 {
1100 return (
1101 StatusCode::BAD_REQUEST,
1102 Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1103 );
1104 }
1105
1106 (
1107 StatusCode::OK,
1108 Json(success(
1109 req.id,
1110 serde_json::json!({
1111 "configId": p.config_id,
1112 "config": p.push_notification_config
1113 }),
1114 )),
1115 )
1116 }
1117 Err(err) => (
1118 StatusCode::BAD_REQUEST,
1119 Json(error(
1120 req.id,
1121 errors::INVALID_PARAMS,
1122 "invalid params",
1123 Some(serde_json::json!({"error": err.to_string()})),
1124 )),
1125 ),
1126 }
1127}
1128
1129async fn handle_push_config_get(
1130 state: AppState,
1131 req: JsonRpcRequest,
1132) -> (StatusCode, Json<JsonRpcResponse>) {
1133 if !state.push_notifications_enabled() {
1134 return (
1135 StatusCode::BAD_REQUEST,
1136 Json(error(
1137 req.id,
1138 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1139 "push notifications not supported",
1140 None,
1141 )),
1142 );
1143 }
1144
1145 let params: Result<GetTaskPushNotificationConfigRequest, _> =
1146 serde_json::from_value(req.params.unwrap_or_default());
1147
1148 match params {
1149 Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1150 Some(config) => (
1151 StatusCode::OK,
1152 Json(success(
1153 req.id,
1154 serde_json::json!({
1155 "configId": p.id,
1156 "config": config
1157 }),
1158 )),
1159 ),
1160 None => (
1161 StatusCode::NOT_FOUND,
1162 Json(error(
1163 req.id,
1164 errors::TASK_NOT_FOUND,
1165 "push notification config not found",
1166 None,
1167 )),
1168 ),
1169 },
1170 Err(err) => (
1171 StatusCode::BAD_REQUEST,
1172 Json(error(
1173 req.id,
1174 errors::INVALID_PARAMS,
1175 "invalid params",
1176 Some(serde_json::json!({"error": err.to_string()})),
1177 )),
1178 ),
1179 }
1180}
1181
1182async fn handle_push_config_list(
1183 state: AppState,
1184 req: JsonRpcRequest,
1185) -> (StatusCode, Json<JsonRpcResponse>) {
1186 if !state.push_notifications_enabled() {
1187 return (
1188 StatusCode::BAD_REQUEST,
1189 Json(error(
1190 req.id,
1191 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1192 "push notifications not supported",
1193 None,
1194 )),
1195 );
1196 }
1197
1198 let params: Result<ListTaskPushNotificationConfigRequest, _> =
1199 serde_json::from_value(req.params.unwrap_or_default());
1200
1201 match params {
1202 Ok(p) => {
1203 let configs = state.webhook_store.list(&p.task_id).await;
1204
1205 let configs_json: Vec<_> = configs
1206 .iter()
1207 .map(|c| {
1208 serde_json::json!({
1209 "configId": c.config_id,
1210 "config": c.config
1211 })
1212 })
1213 .collect();
1214
1215 (
1216 StatusCode::OK,
1217 Json(success(
1218 req.id,
1219 serde_json::json!({
1220 "configs": configs_json,
1221 "nextPageToken": ""
1222 }),
1223 )),
1224 )
1225 }
1226 Err(err) => (
1227 StatusCode::BAD_REQUEST,
1228 Json(error(
1229 req.id,
1230 errors::INVALID_PARAMS,
1231 "invalid params",
1232 Some(serde_json::json!({"error": err.to_string()})),
1233 )),
1234 ),
1235 }
1236}
1237
1238async fn handle_push_config_delete(
1239 state: AppState,
1240 req: JsonRpcRequest,
1241) -> (StatusCode, Json<JsonRpcResponse>) {
1242 if !state.push_notifications_enabled() {
1243 return (
1244 StatusCode::BAD_REQUEST,
1245 Json(error(
1246 req.id,
1247 errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1248 "push notifications not supported",
1249 None,
1250 )),
1251 );
1252 }
1253
1254 let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1255 serde_json::from_value(req.params.unwrap_or_default());
1256
1257 match params {
1258 Ok(p) => {
1259 if state.webhook_store.delete(&p.task_id, &p.id).await {
1260 (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1261 } else {
1262 (
1263 StatusCode::NOT_FOUND,
1264 Json(error(
1265 req.id,
1266 errors::TASK_NOT_FOUND,
1267 "push notification config not found",
1268 None,
1269 )),
1270 )
1271 }
1272 }
1273 Err(err) => (
1274 StatusCode::BAD_REQUEST,
1275 Json(error(
1276 req.id,
1277 errors::INVALID_PARAMS,
1278 "invalid params",
1279 Some(serde_json::json!({"error": err.to_string()})),
1280 )),
1281 ),
1282 }
1283}
1284
1285pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1286 bind_addr: &str,
1287 handler: H,
1288) -> anyhow::Result<()> {
1289 A2aServer::new(handler).bind(bind_addr)?.run().await
1290}
1291
1292pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1293 A2aServer::echo().bind(bind_addr)?.run().await
1294}