Skip to main content

a2a_rs_server/
server.rs

1//! Generic A2A JSON-RPC Server
2
3use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17    error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18    CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19    GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20    ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21    SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22    TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use futures::stream::Stream;
37use tokio::sync::broadcast;
38use tracing::info;
39
40use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
41use crate::task_store::TaskStore;
42use crate::webhook_delivery::WebhookDelivery;
43use crate::webhook_store::WebhookStore;
44
45#[derive(Debug, Clone)]
46pub struct ServerConfig {
47    pub bind_address: String,
48}
49
50impl Default for ServerConfig {
51    fn default() -> Self {
52        Self {
53            bind_address: "0.0.0.0:8080".to_string(),
54        }
55    }
56}
57
58pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
59
60const EVENT_CHANNEL_CAPACITY: usize = 1024;
61
62pub struct A2aServer {
63    config: ServerConfig,
64    handler: BoxedHandler,
65    task_store: TaskStore,
66    webhook_store: WebhookStore,
67    auth_extractor: Option<AuthExtractor>,
68    additional_routes: Option<Router<AppState>>,
69    event_tx: broadcast::Sender<StreamResponse>,
70}
71
72impl A2aServer {
73    pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
74        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
75        Self {
76            config: ServerConfig::default(),
77            handler: Arc::new(handler),
78            task_store: TaskStore::new(),
79            webhook_store: WebhookStore::new(),
80            auth_extractor: None,
81            additional_routes: None,
82            event_tx,
83        }
84    }
85
86    pub fn echo() -> Self {
87        Self::new(EchoHandler::default())
88    }
89
90    pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
91        let _: SocketAddr = address.parse()?;
92        self.config.bind_address = address.to_string();
93        Ok(self)
94    }
95
96    pub fn bind_unchecked(mut self, address: &str) -> Self {
97        self.config.bind_address = address.to_string();
98        self
99    }
100
101    pub fn task_store(mut self, store: TaskStore) -> Self {
102        self.task_store = store;
103        self
104    }
105
106    pub fn auth_extractor<F>(mut self, extractor: F) -> Self
107    where
108        F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
109    {
110        self.auth_extractor = Some(Arc::new(extractor));
111        self
112    }
113
114    pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
115        self.additional_routes = Some(routes);
116        self
117    }
118
119    pub fn get_task_store(&self) -> TaskStore {
120        self.task_store.clone()
121    }
122
123    pub fn build_router(self) -> Router {
124        let bind: SocketAddr = self
125            .config
126            .bind_address
127            .parse()
128            .expect("Invalid bind address");
129        let base_url = format!("http://{}", bind);
130        let card = Arc::new(self.handler.agent_card(&base_url));
131
132        let state = AppState {
133            handler: self.handler,
134            task_store: self.task_store,
135            webhook_store: self.webhook_store,
136            card,
137            auth_extractor: self.auth_extractor,
138            event_tx: self.event_tx,
139        };
140
141        let timed_routes = Router::new()
142            .route("/health", get(health))
143            .route("/.well-known/agent-card.json", get(agent_card))
144            .route("/v1/rpc", post(handle_rpc))
145            .layer(
146                ServiceBuilder::new()
147                    .layer(HandleErrorLayer::new(handle_timeout_error))
148                    .timeout(Duration::from_secs(30)),
149            );
150
151        let mut router = timed_routes;
152
153        if let Some(additional) = self.additional_routes {
154            router = router.merge(additional);
155        }
156
157        router.with_state(state)
158    }
159
160    pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
161        self.event_tx.clone()
162    }
163
164    pub async fn run(self) -> anyhow::Result<()> {
165        let bind: SocketAddr = self.config.bind_address.parse()?;
166
167        let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
168        webhook_delivery.start(self.event_tx.subscribe());
169
170        let router = self.build_router();
171
172        info!(%bind, "Starting A2A server");
173        let listener = tokio::net::TcpListener::bind(bind).await?;
174        axum::serve(listener, router)
175            .with_graceful_shutdown(shutdown_signal())
176            .await?;
177
178        info!("Server shutdown complete");
179        Ok(())
180    }
181}
182
183async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
184    if err.is::<tower::timeout::error::Elapsed>() {
185        (
186            StatusCode::REQUEST_TIMEOUT,
187            Json(error(
188                serde_json::Value::Null,
189                errors::INTERNAL_ERROR,
190                "Request timed out",
191                None,
192            )),
193        )
194    } else {
195        (
196            StatusCode::INTERNAL_SERVER_ERROR,
197            Json(error(
198                serde_json::Value::Null,
199                errors::INTERNAL_ERROR,
200                &format!("Internal error: {}", err),
201                None,
202            )),
203        )
204    }
205}
206
207async fn shutdown_signal() {
208    let ctrl_c = async {
209        signal::ctrl_c()
210            .await
211            .expect("failed to install Ctrl+C handler");
212    };
213
214    #[cfg(unix)]
215    let terminate = async {
216        signal::unix::signal(signal::unix::SignalKind::terminate())
217            .expect("failed to install SIGTERM handler")
218            .recv()
219            .await;
220    };
221
222    #[cfg(not(unix))]
223    let terminate = std::future::pending::<()>();
224
225    tokio::select! {
226        _ = ctrl_c => {},
227        _ = terminate => {},
228    }
229
230    info!("Shutdown signal received, initiating graceful shutdown...");
231}
232
233#[derive(Clone)]
234pub struct AppState {
235    handler: BoxedHandler,
236    task_store: TaskStore,
237    webhook_store: WebhookStore,
238    card: Arc<AgentCard>,
239    auth_extractor: Option<AuthExtractor>,
240    event_tx: broadcast::Sender<StreamResponse>,
241}
242
243impl AppState {
244    pub fn task_store(&self) -> &TaskStore {
245        &self.task_store
246    }
247
248    pub fn agent_card(&self) -> &AgentCard {
249        &self.card
250    }
251
252    pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
253        &self.event_tx
254    }
255
256    pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
257        self.event_tx.subscribe()
258    }
259
260    pub fn broadcast_event(&self, event: StreamResponse) {
261        let _ = self.event_tx.send(event);
262    }
263
264    /// Check if streaming is enabled in capabilities
265    fn streaming_enabled(&self) -> bool {
266        self.card.capabilities.streaming.unwrap_or(false)
267    }
268
269    /// Check if push notifications are enabled in capabilities
270    fn push_notifications_enabled(&self) -> bool {
271        self.card.capabilities.push_notifications.unwrap_or(false)
272    }
273
274    /// Get the endpoint URL from the agent card
275    fn endpoint_url(&self) -> &str {
276        self.card.endpoint().unwrap_or("")
277    }
278}
279
280// ============ History Trimming ============
281
282fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
283    match history_length {
284        Some(0) => {
285            task.history = None;
286        }
287        Some(n) => {
288            if let Some(ref mut history) = task.history {
289                let len = history.len();
290                if len > n as usize {
291                    *history = history.split_off(len - n as usize);
292                }
293            }
294        }
295        None => {}
296    }
297}
298
299// ============ Version Validation ============
300
301fn validate_a2a_version(
302    headers: &HeaderMap,
303    req_id: &serde_json::Value,
304) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
305    if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
306        let version_str = version_header.to_str().unwrap_or("");
307
308        let parts: Vec<&str> = version_str.split('.').collect();
309        if parts.len() >= 2 {
310            if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
311                if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
312                    return Ok(());
313                }
314
315                return Err((
316                    StatusCode::BAD_REQUEST,
317                    Json(error(
318                        req_id.clone(),
319                        errors::VERSION_NOT_SUPPORTED,
320                        &format!(
321                            "Protocol version {}.{} not supported. Supported version: {}.{}",
322                            major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
323                        ),
324                        Some(serde_json::json!({
325                            "requestedVersion": version_str,
326                            "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
327                        })),
328                    )),
329                ));
330            }
331        }
332
333        // Also accept bare "1.0" without minor
334        if version_str == "1" || version_str == "1.0" {
335            return Ok(());
336        }
337
338        return Err((
339            StatusCode::BAD_REQUEST,
340            Json(error(
341                req_id.clone(),
342                errors::VERSION_NOT_SUPPORTED,
343                &format!(
344                    "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
345                    version_str
346                ),
347                None,
348            )),
349        ));
350    }
351
352    Ok(())
353}
354
355// ============ Error Response Helpers ============
356
357#[allow(dead_code)]
358pub fn rpc_error(
359    id: serde_json::Value,
360    code: i32,
361    message: &str,
362    status: StatusCode,
363) -> (StatusCode, Json<JsonRpcResponse>) {
364    (status, Json(error(id, code, message, None)))
365}
366
367#[allow(dead_code)]
368pub fn rpc_error_with_data(
369    id: serde_json::Value,
370    code: i32,
371    message: &str,
372    data: serde_json::Value,
373    status: StatusCode,
374) -> (StatusCode, Json<JsonRpcResponse>) {
375    (status, Json(error(id, code, message, Some(data))))
376}
377
378#[allow(dead_code)]
379pub fn rpc_success(
380    id: serde_json::Value,
381    result: serde_json::Value,
382) -> (StatusCode, Json<JsonRpcResponse>) {
383    (StatusCode::OK, Json(success(id, result)))
384}
385
386// ============ Route Handlers ============
387
388async fn health() -> Json<serde_json::Value> {
389    Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
390}
391
392async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
393    Json((*state.card).clone())
394}
395
396// handle_task_subscribe_sse removed — tasks/resubscribe now returns SSE
397// directly from the /v1/rpc endpoint via handle_tasks_resubscribe.
398
399async fn handle_rpc(
400    State(state): State<AppState>,
401    headers: HeaderMap,
402    Json(req): Json<JsonRpcRequest>,
403) -> Response {
404    if req.jsonrpc != "2.0" {
405        let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
406        return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
407    }
408
409    if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
410        return err_response.into_response();
411    }
412
413    let auth_context = state
414        .auth_extractor
415        .as_ref()
416        .and_then(|extractor| extractor(&headers));
417
418    match req.method.as_str() {
419        // Spec method names per JSON-RPC binding
420        "message/send" => handle_message_send(state, req, auth_context)
421            .await
422            .into_response(),
423        "message/stream" => {
424            handle_message_stream(state, req, headers, auth_context)
425                .await
426                .into_response()
427        }
428        "tasks/get" => handle_tasks_get(state, req).await.into_response(),
429        "tasks/list" => handle_tasks_list(state, req).await.into_response(),
430        "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
431        "tasks/resubscribe" => handle_tasks_resubscribe(state, req).await.into_response(),
432        "tasks/pushNotificationConfig/create" => {
433            handle_push_config_create(state, req).await.into_response()
434        }
435        "tasks/pushNotificationConfig/get" => {
436            handle_push_config_get(state, req).await.into_response()
437        }
438        "tasks/pushNotificationConfig/list" => {
439            handle_push_config_list(state, req).await.into_response()
440        }
441        "tasks/pushNotificationConfig/delete" => {
442            handle_push_config_delete(state, req).await.into_response()
443        }
444        "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context)
445            .await
446            .into_response(),
447        _ => (
448            StatusCode::NOT_FOUND,
449            Json(error(
450                req.id,
451                errors::METHOD_NOT_FOUND,
452                "method not found",
453                None,
454            )),
455        )
456            .into_response(),
457    }
458}
459
460fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
461    match e {
462        HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
463        HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
464        HandlerError::BackendUnavailable { .. } => {
465            (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
466        }
467        HandlerError::ProcessingFailed { .. } => {
468            (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
469        }
470        HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
471    }
472}
473
474async fn handle_message_send(
475    state: AppState,
476    req: JsonRpcRequest,
477    auth_context: Option<AuthContext>,
478) -> (StatusCode, Json<JsonRpcResponse>) {
479    let req_id = req.id.clone();
480
481    let params: Result<SendMessageRequest, _> =
482        serde_json::from_value(req.params.clone().unwrap_or_default());
483
484    let params = match params {
485        Ok(p) => p,
486        Err(err) => {
487            return (
488                StatusCode::BAD_REQUEST,
489                Json(error(
490                    req_id,
491                    errors::INVALID_PARAMS,
492                    "invalid params",
493                    Some(serde_json::json!({"error": err.to_string()})),
494                )),
495            );
496        }
497    };
498
499    let blocking = params
500        .configuration
501        .as_ref()
502        .and_then(|c| c.blocking)
503        .unwrap_or(false);
504    let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
505
506    match state
507        .handler
508        .handle_message(params.message, auth_context)
509        .await
510    {
511        Ok(response) => {
512            match response {
513                SendMessageResponse::Task(mut task) => {
514                    // Store the task and broadcast event
515                    state.task_store.insert(task.clone()).await;
516                    state.broadcast_event(StreamResponse::Task(task.clone()));
517
518                    // If blocking mode, wait for task to reach terminal state
519                    if blocking && !task.status.state.is_terminal() {
520                        let task_id = task.id.clone();
521                        let mut rx = state.subscribe_events();
522
523                        let wait_result = timeout(BLOCKING_TIMEOUT, async {
524                            loop {
525                                tokio::select! {
526                                    result = rx.recv() => {
527                                        match result {
528                                            Ok(StreamResponse::Task(t)) if t.id == task_id => {
529                                                if t.status.state.is_terminal() {
530                                                    return Some(t);
531                                                }
532                                            }
533                                            Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
534                                                if e.status.state.is_terminal() {
535                                                    if let Some(t) = state.task_store.get(&task_id).await {
536                                                        return Some(t);
537                                                    }
538                                                }
539                                            }
540                                            Err(broadcast::error::RecvError::Closed) => {
541                                                return None;
542                                            }
543                                            _ => {}
544                                        }
545                                    }
546                                    _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
547                                        if let Some(t) = state.task_store.get(&task_id).await {
548                                            if t.status.state.is_terminal() {
549                                                return Some(t);
550                                            }
551                                        }
552                                    }
553                                }
554                            }
555                        })
556                        .await;
557
558                        match wait_result {
559                            Ok(Some(final_task)) => task = final_task,
560                            Ok(None) => {
561                                if let Some(t) = state.task_store.get(&task.id).await {
562                                    task = t;
563                                }
564                            }
565                            Err(_) => {
566                                tracing::warn!("Blocking request timed out for task {}", task.id);
567                                if let Some(t) = state.task_store.get(&task.id).await {
568                                    task = t;
569                                }
570                            }
571                        }
572                    }
573
574                    apply_history_length(&mut task, history_length);
575
576                    // Serialize the Task directly into the JSON-RPC result field
577                    // (no wrapper key — matches the A2A reference SDK wire format)
578                    match serde_json::to_value(&task) {
579                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
580                        Err(e) => (
581                            StatusCode::INTERNAL_SERVER_ERROR,
582                            Json(error(
583                                req_id,
584                                errors::INTERNAL_ERROR,
585                                "serialization failed",
586                                Some(serde_json::json!({"error": e.to_string()})),
587                            )),
588                        ),
589                    }
590                }
591                SendMessageResponse::Message(msg) => {
592                    // Serialize the Message directly into the JSON-RPC result field
593                    match serde_json::to_value(&msg) {
594                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
595                        Err(e) => (
596                            StatusCode::INTERNAL_SERVER_ERROR,
597                            Json(error(
598                                req_id,
599                                errors::INTERNAL_ERROR,
600                                "serialization failed",
601                                Some(serde_json::json!({"error": e.to_string()})),
602                            )),
603                        ),
604                    }
605                }
606            }
607        }
608        Err(e) => {
609            let (code, status) = handler_error_to_rpc(&e);
610            (status, Json(error(req_id, code, &e.to_string(), None)))
611        }
612    }
613}
614
615/// Handle `message/stream` — returns SSE directly from the JSON-RPC endpoint.
616///
617/// Each SSE event's `data:` is a full JSON-RPC response envelope wrapping the result
618/// (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent).
619async fn handle_message_stream(
620    state: AppState,
621    req: JsonRpcRequest,
622    _headers: HeaderMap,
623    auth_context: Option<AuthContext>,
624) -> Response {
625    let req_id = req.id.clone();
626
627    if !state.streaming_enabled() {
628        return (
629            StatusCode::BAD_REQUEST,
630            Json(error(
631                req_id,
632                errors::UNSUPPORTED_OPERATION,
633                "streaming not supported by this agent",
634                None,
635            )),
636        )
637            .into_response();
638    }
639
640    let params: Result<SendMessageRequest, _> =
641        serde_json::from_value(req.params.clone().unwrap_or_default());
642
643    let params = match params {
644        Ok(p) => p,
645        Err(err) => {
646            return (
647                StatusCode::BAD_REQUEST,
648                Json(error(
649                    req_id,
650                    errors::INVALID_PARAMS,
651                    "invalid params",
652                    Some(serde_json::json!({"error": err.to_string()})),
653                )),
654            )
655                .into_response();
656        }
657    };
658
659    let response = match state
660        .handler
661        .handle_message(params.message, auth_context)
662        .await
663    {
664        Ok(r) => r,
665        Err(e) => {
666            let (code, status) = handler_error_to_rpc(&e);
667            return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
668        }
669    };
670
671    // Extract task from response (streaming only works with tasks)
672    let task = match response {
673        SendMessageResponse::Task(t) => t,
674        SendMessageResponse::Message(_) => {
675            return (
676                StatusCode::BAD_REQUEST,
677                Json(error(
678                    req_id,
679                    errors::UNSUPPORTED_OPERATION,
680                    "handler returned a message, streaming requires a task",
681                    None,
682                )),
683            )
684                .into_response();
685        }
686    };
687
688    let task_id = task.id.clone();
689    state.task_store.insert(task.clone()).await;
690    state.broadcast_event(StreamResponse::Task(task.clone()));
691
692    let mut rx = state.subscribe_events();
693    let task_store = state.task_store.clone();
694    let target_task_id = task_id;
695
696    // Helper: wrap a value in a JSON-RPC success response envelope
697    let wrap = move |value: serde_json::Value| -> String {
698        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
699    };
700
701    let stream = async_stream::stream! {
702        // Yield initial task
703        if let Ok(val) = serde_json::to_value(&task) {
704            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
705        }
706
707        loop {
708            match rx.recv().await {
709                Ok(event) => {
710                    let matches = match &event {
711                        StreamResponse::Task(t) => t.id == target_task_id,
712                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
713                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
714                        StreamResponse::Message(m) => {
715                            m.context_id.as_ref().is_some_and(|ctx| {
716                                task_store.get(&target_task_id).now_or_never()
717                                    .flatten()
718                                    .is_some_and(|t| t.context_id == *ctx)
719                            })
720                        }
721                    };
722
723                    if matches {
724                        // Serialize the inner value directly (not the StreamResponse wrapper)
725                        let val = match &event {
726                            StreamResponse::Task(t) => serde_json::to_value(t),
727                            StreamResponse::Message(m) => serde_json::to_value(m),
728                            StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
729                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
730                        };
731                        if let Ok(val) = val {
732                            yield Ok(Event::default().data(wrap(val)));
733                        }
734
735                        // End stream on terminal state or final flag
736                        let is_terminal = match &event {
737                            StreamResponse::Task(t) => t.status.state.is_terminal(),
738                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
739                            _ => false,
740                        };
741                        if is_terminal {
742                            break;
743                        }
744                    }
745                }
746                Err(broadcast::error::RecvError::Lagged(_)) => continue,
747                Err(broadcast::error::RecvError::Closed) => break,
748            }
749        }
750    };
751
752    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
753}
754
755async fn handle_tasks_get(
756    state: AppState,
757    req: JsonRpcRequest,
758) -> (StatusCode, Json<JsonRpcResponse>) {
759    let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
760
761    match params {
762        Ok(p) => match state.task_store.get_flexible(&p.id).await {
763            Some(mut task) => {
764                apply_history_length(&mut task, p.history_length);
765
766                match serde_json::to_value(task) {
767                    Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
768                    Err(e) => (
769                        StatusCode::INTERNAL_SERVER_ERROR,
770                        Json(error(
771                            req.id,
772                            errors::INTERNAL_ERROR,
773                            "serialization failed",
774                            Some(serde_json::json!({"error": e.to_string()})),
775                        )),
776                    ),
777                }
778            }
779            None => (
780                StatusCode::NOT_FOUND,
781                Json(error(
782                    req.id,
783                    errors::TASK_NOT_FOUND,
784                    "task not found",
785                    None,
786                )),
787            ),
788        },
789        Err(err) => (
790            StatusCode::BAD_REQUEST,
791            Json(error(
792                req.id,
793                errors::INVALID_PARAMS,
794                "invalid params",
795                Some(serde_json::json!({"error": err.to_string()})),
796            )),
797        ),
798    }
799}
800
801async fn handle_tasks_cancel(
802    state: AppState,
803    req: JsonRpcRequest,
804) -> (StatusCode, Json<JsonRpcResponse>) {
805    let params: Result<CancelTaskRequest, _> =
806        serde_json::from_value(req.params.unwrap_or_default());
807
808    match params {
809        Ok(p) => {
810            let result = state
811                .task_store
812                .update_flexible(&p.id, |task| {
813                    if task.status.state.is_terminal() {
814                        return Err(errors::TASK_NOT_CANCELABLE);
815                    }
816                    task.status.state = TaskState::Canceled;
817                    task.status.timestamp = Some(now_iso8601());
818                    Ok(())
819                })
820                .await;
821
822            match result {
823                Some(Ok(task)) => {
824                    if let Err(e) = state.handler.cancel_task(&task.id).await {
825                        tracing::warn!("Handler cancel_task failed: {}", e);
826                    }
827
828                    state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
829                        kind: "status-update".to_string(),
830                        task_id: task.id.clone(),
831                        context_id: task.context_id.clone(),
832                        status: task.status.clone(),
833                        is_final: true,
834                        metadata: None,
835                    }));
836
837                    match serde_json::to_value(task) {
838                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
839                        Err(e) => (
840                            StatusCode::INTERNAL_SERVER_ERROR,
841                            Json(error(
842                                req.id,
843                                errors::INTERNAL_ERROR,
844                                "serialization failed",
845                                Some(serde_json::json!({"error": e.to_string()})),
846                            )),
847                        ),
848                    }
849                }
850                Some(Err(error_code)) => (
851                    StatusCode::BAD_REQUEST,
852                    Json(error(req.id, error_code, "task not cancelable", None)),
853                ),
854                None => (
855                    StatusCode::NOT_FOUND,
856                    Json(error(
857                        req.id,
858                        errors::TASK_NOT_FOUND,
859                        "task not found",
860                        None,
861                    )),
862                ),
863            }
864        }
865        Err(err) => (
866            StatusCode::BAD_REQUEST,
867            Json(error(
868                req.id,
869                errors::INVALID_PARAMS,
870                "invalid params",
871                Some(serde_json::json!({"error": err.to_string()})),
872            )),
873        ),
874    }
875}
876
877async fn handle_tasks_list(
878    state: AppState,
879    req: JsonRpcRequest,
880) -> (StatusCode, Json<JsonRpcResponse>) {
881    let params: Result<ListTasksRequest, _> =
882        serde_json::from_value(req.params.unwrap_or_default());
883
884    match params {
885        Ok(p) => {
886            let response = state.task_store.list_filtered(&p).await;
887            match serde_json::to_value(response) {
888                Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
889                Err(e) => (
890                    StatusCode::INTERNAL_SERVER_ERROR,
891                    Json(error(
892                        req.id,
893                        errors::INTERNAL_ERROR,
894                        "serialization failed",
895                        Some(serde_json::json!({"error": e.to_string()})),
896                    )),
897                ),
898            }
899        }
900        Err(err) => (
901            StatusCode::BAD_REQUEST,
902            Json(error(
903                req.id,
904                errors::INVALID_PARAMS,
905                "invalid params",
906                Some(serde_json::json!({"error": err.to_string()})),
907            )),
908        ),
909    }
910}
911
912/// Handle `tasks/resubscribe` — returns SSE directly from the JSON-RPC endpoint.
913///
914/// Reconnects to an existing task's event stream. Each SSE event's `data:` is
915/// a JSON-RPC response envelope wrapping the result, same as `message/stream`.
916async fn handle_tasks_resubscribe(state: AppState, req: JsonRpcRequest) -> Response {
917    let req_id = req.id.clone();
918
919    let params: Result<SubscribeToTaskRequest, _> =
920        serde_json::from_value(req.params.unwrap_or_default());
921
922    let params = match params {
923        Ok(p) => p,
924        Err(err) => {
925            return (
926                StatusCode::BAD_REQUEST,
927                Json(error(
928                    req_id,
929                    errors::INVALID_PARAMS,
930                    "invalid params",
931                    Some(serde_json::json!({"error": err.to_string()})),
932                )),
933            )
934                .into_response();
935        }
936    };
937
938    // Verify the task exists
939    let task = match state.task_store.get_flexible(&params.id).await {
940        Some(t) => t,
941        None => {
942            return (
943                StatusCode::NOT_FOUND,
944                Json(error(req_id, errors::TASK_NOT_FOUND, "task not found", None)),
945            )
946                .into_response();
947        }
948    };
949
950    let target_task_id = task.id.clone();
951    let mut rx = state.subscribe_events();
952    let task_store = state.task_store.clone();
953
954    let wrap = move |value: serde_json::Value| -> String {
955        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
956    };
957
958    let stream = async_stream::stream! {
959        // Yield current task snapshot
960        if let Ok(val) = serde_json::to_value(&task) {
961            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
962        }
963
964        // If already terminal, stop
965        if task.status.state.is_terminal() {
966            return;
967        }
968
969        loop {
970            match rx.recv().await {
971                Ok(event) => {
972                    let matches = match &event {
973                        StreamResponse::Task(t) => t.id == target_task_id,
974                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
975                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
976                        StreamResponse::Message(m) => {
977                            m.context_id.as_ref().is_some_and(|ctx| {
978                                task_store.get(&target_task_id).now_or_never()
979                                    .flatten()
980                                    .is_some_and(|t| t.context_id == *ctx)
981                            })
982                        }
983                    };
984
985                    if matches {
986                        let val = match &event {
987                            StreamResponse::Task(t) => serde_json::to_value(t),
988                            StreamResponse::Message(m) => serde_json::to_value(m),
989                            StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
990                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
991                        };
992                        if let Ok(val) = val {
993                            yield Ok(Event::default().data(wrap(val)));
994                        }
995
996                        let is_terminal = match &event {
997                            StreamResponse::Task(t) => t.status.state.is_terminal(),
998                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
999                            _ => false,
1000                        };
1001                        if is_terminal {
1002                            break;
1003                        }
1004                    }
1005                }
1006                Err(broadcast::error::RecvError::Lagged(_)) => continue,
1007                Err(broadcast::error::RecvError::Closed) => break,
1008            }
1009        }
1010    };
1011
1012    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
1013}
1014
1015async fn handle_get_extended_agent_card(
1016    state: AppState,
1017    req: JsonRpcRequest,
1018    auth_context: Option<AuthContext>,
1019) -> (StatusCode, Json<JsonRpcResponse>) {
1020    let Some(auth) = auth_context else {
1021        return (
1022            StatusCode::UNAUTHORIZED,
1023            Json(error(
1024                req.id,
1025                errors::INVALID_REQUEST,
1026                "authentication required for extended agent card",
1027                None,
1028            )),
1029        );
1030    };
1031
1032    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1033
1034    match state.handler.extended_agent_card(base_url, &auth).await {
1035        Some(card) => match serde_json::to_value(card) {
1036            Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1037            Err(e) => (
1038                StatusCode::INTERNAL_SERVER_ERROR,
1039                Json(error(
1040                    req.id,
1041                    errors::INTERNAL_ERROR,
1042                    "serialization failed",
1043                    Some(serde_json::json!({"error": e.to_string()})),
1044                )),
1045            ),
1046        },
1047        None => (
1048            StatusCode::NOT_FOUND,
1049            Json(error(
1050                req.id,
1051                errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1052                "extended agent card not configured",
1053                None,
1054            )),
1055        ),
1056    }
1057}
1058
1059// ============ Push Notification Config Handlers ============
1060
1061async fn handle_push_config_create(
1062    state: AppState,
1063    req: JsonRpcRequest,
1064) -> (StatusCode, Json<JsonRpcResponse>) {
1065    if !state.push_notifications_enabled() {
1066        return (
1067            StatusCode::BAD_REQUEST,
1068            Json(error(
1069                req.id,
1070                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1071                "push notifications not supported",
1072                None,
1073            )),
1074        );
1075    }
1076
1077    let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1078        serde_json::from_value(req.params.unwrap_or_default());
1079
1080    match params {
1081        Ok(p) => {
1082            if state.task_store.get_flexible(&p.task_id).await.is_none() {
1083                return (
1084                    StatusCode::NOT_FOUND,
1085                    Json(error(
1086                        req.id,
1087                        errors::TASK_NOT_FOUND,
1088                        "task not found",
1089                        None,
1090                    )),
1091                );
1092            }
1093
1094            if let Err(e) = state
1095                .webhook_store
1096                .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1097                .await
1098            {
1099                return (
1100                    StatusCode::BAD_REQUEST,
1101                    Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1102                );
1103            }
1104
1105            (
1106                StatusCode::OK,
1107                Json(success(
1108                    req.id,
1109                    serde_json::json!({
1110                        "configId": p.config_id,
1111                        "config": p.push_notification_config
1112                    }),
1113                )),
1114            )
1115        }
1116        Err(err) => (
1117            StatusCode::BAD_REQUEST,
1118            Json(error(
1119                req.id,
1120                errors::INVALID_PARAMS,
1121                "invalid params",
1122                Some(serde_json::json!({"error": err.to_string()})),
1123            )),
1124        ),
1125    }
1126}
1127
1128async fn handle_push_config_get(
1129    state: AppState,
1130    req: JsonRpcRequest,
1131) -> (StatusCode, Json<JsonRpcResponse>) {
1132    if !state.push_notifications_enabled() {
1133        return (
1134            StatusCode::BAD_REQUEST,
1135            Json(error(
1136                req.id,
1137                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1138                "push notifications not supported",
1139                None,
1140            )),
1141        );
1142    }
1143
1144    let params: Result<GetTaskPushNotificationConfigRequest, _> =
1145        serde_json::from_value(req.params.unwrap_or_default());
1146
1147    match params {
1148        Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1149            Some(config) => (
1150                StatusCode::OK,
1151                Json(success(
1152                    req.id,
1153                    serde_json::json!({
1154                        "configId": p.id,
1155                        "config": config
1156                    }),
1157                )),
1158            ),
1159            None => (
1160                StatusCode::NOT_FOUND,
1161                Json(error(
1162                    req.id,
1163                    errors::TASK_NOT_FOUND,
1164                    "push notification config not found",
1165                    None,
1166                )),
1167            ),
1168        },
1169        Err(err) => (
1170            StatusCode::BAD_REQUEST,
1171            Json(error(
1172                req.id,
1173                errors::INVALID_PARAMS,
1174                "invalid params",
1175                Some(serde_json::json!({"error": err.to_string()})),
1176            )),
1177        ),
1178    }
1179}
1180
1181async fn handle_push_config_list(
1182    state: AppState,
1183    req: JsonRpcRequest,
1184) -> (StatusCode, Json<JsonRpcResponse>) {
1185    if !state.push_notifications_enabled() {
1186        return (
1187            StatusCode::BAD_REQUEST,
1188            Json(error(
1189                req.id,
1190                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1191                "push notifications not supported",
1192                None,
1193            )),
1194        );
1195    }
1196
1197    let params: Result<ListTaskPushNotificationConfigRequest, _> =
1198        serde_json::from_value(req.params.unwrap_or_default());
1199
1200    match params {
1201        Ok(p) => {
1202            let configs = state.webhook_store.list(&p.task_id).await;
1203
1204            let configs_json: Vec<_> = configs
1205                .iter()
1206                .map(|c| {
1207                    serde_json::json!({
1208                        "configId": c.config_id,
1209                        "config": c.config
1210                    })
1211                })
1212                .collect();
1213
1214            (
1215                StatusCode::OK,
1216                Json(success(
1217                    req.id,
1218                    serde_json::json!({
1219                        "configs": configs_json,
1220                        "nextPageToken": ""
1221                    }),
1222                )),
1223            )
1224        }
1225        Err(err) => (
1226            StatusCode::BAD_REQUEST,
1227            Json(error(
1228                req.id,
1229                errors::INVALID_PARAMS,
1230                "invalid params",
1231                Some(serde_json::json!({"error": err.to_string()})),
1232            )),
1233        ),
1234    }
1235}
1236
1237async fn handle_push_config_delete(
1238    state: AppState,
1239    req: JsonRpcRequest,
1240) -> (StatusCode, Json<JsonRpcResponse>) {
1241    if !state.push_notifications_enabled() {
1242        return (
1243            StatusCode::BAD_REQUEST,
1244            Json(error(
1245                req.id,
1246                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1247                "push notifications not supported",
1248                None,
1249            )),
1250        );
1251    }
1252
1253    let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1254        serde_json::from_value(req.params.unwrap_or_default());
1255
1256    match params {
1257        Ok(p) => {
1258            if state.webhook_store.delete(&p.task_id, &p.id).await {
1259                (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1260            } else {
1261                (
1262                    StatusCode::NOT_FOUND,
1263                    Json(error(
1264                        req.id,
1265                        errors::TASK_NOT_FOUND,
1266                        "push notification config not found",
1267                        None,
1268                    )),
1269                )
1270            }
1271        }
1272        Err(err) => (
1273            StatusCode::BAD_REQUEST,
1274            Json(error(
1275                req.id,
1276                errors::INVALID_PARAMS,
1277                "invalid params",
1278                Some(serde_json::json!({"error": err.to_string()})),
1279            )),
1280        ),
1281    }
1282}
1283
1284pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1285    bind_addr: &str,
1286    handler: H,
1287) -> anyhow::Result<()> {
1288    A2aServer::new(handler).bind(bind_addr)?.run().await
1289}
1290
1291pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1292    A2aServer::echo().bind(bind_addr)?.run().await
1293}