Skip to main content

a2a_rs_server/
server.rs

1//! Generic A2A JSON-RPC Server
2
3use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17    error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18    CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19    GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20    ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21    SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22    TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use futures::stream::Stream;
37use tokio::sync::broadcast;
38use tracing::info;
39
40use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
41use crate::task_store::TaskStore;
42use crate::webhook_delivery::WebhookDelivery;
43use crate::webhook_store::WebhookStore;
44
45#[derive(Debug, Clone)]
46pub struct ServerConfig {
47    pub bind_address: String,
48}
49
50impl Default for ServerConfig {
51    fn default() -> Self {
52        Self {
53            bind_address: "0.0.0.0:8080".to_string(),
54        }
55    }
56}
57
58pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
59
60const EVENT_CHANNEL_CAPACITY: usize = 1024;
61
62pub struct A2aServer {
63    config: ServerConfig,
64    handler: BoxedHandler,
65    task_store: TaskStore,
66    webhook_store: WebhookStore,
67    auth_extractor: Option<AuthExtractor>,
68    additional_routes: Option<Router<AppState>>,
69    event_tx: broadcast::Sender<StreamResponse>,
70}
71
72impl A2aServer {
73    pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
74        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
75        Self {
76            config: ServerConfig::default(),
77            handler: Arc::new(handler),
78            task_store: TaskStore::new(),
79            webhook_store: WebhookStore::new(),
80            auth_extractor: None,
81            additional_routes: None,
82            event_tx,
83        }
84    }
85
86    pub fn echo() -> Self {
87        Self::new(EchoHandler::default())
88    }
89
90    pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
91        let _: SocketAddr = address.parse()?;
92        self.config.bind_address = address.to_string();
93        Ok(self)
94    }
95
96    pub fn bind_unchecked(mut self, address: &str) -> Self {
97        self.config.bind_address = address.to_string();
98        self
99    }
100
101    pub fn task_store(mut self, store: TaskStore) -> Self {
102        self.task_store = store;
103        self
104    }
105
106    pub fn auth_extractor<F>(mut self, extractor: F) -> Self
107    where
108        F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
109    {
110        self.auth_extractor = Some(Arc::new(extractor));
111        self
112    }
113
114    pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
115        self.additional_routes = Some(routes);
116        self
117    }
118
119    pub fn get_task_store(&self) -> TaskStore {
120        self.task_store.clone()
121    }
122
123    pub fn build_router(self) -> Router {
124        let bind: SocketAddr = self
125            .config
126            .bind_address
127            .parse()
128            .expect("Invalid bind address");
129        let base_url = format!("http://{}", bind);
130        let card = Arc::new(self.handler.agent_card(&base_url));
131
132        let state = AppState {
133            handler: self.handler,
134            task_store: self.task_store,
135            webhook_store: self.webhook_store,
136            card,
137            auth_extractor: self.auth_extractor,
138            event_tx: self.event_tx,
139        };
140
141        let timed_routes = Router::new()
142            .route("/health", get(health))
143            .route("/.well-known/agent-card.json", get(agent_card))
144            .route("/v1/rpc", post(handle_rpc))
145            .layer(
146                ServiceBuilder::new()
147                    .layer(HandleErrorLayer::new(handle_timeout_error))
148                    .timeout(Duration::from_secs(30)),
149            );
150
151        let sse_routes = Router::new().route(
152            "/v1/tasks/:task_id/subscribe",
153            get(handle_task_subscribe_sse),
154        );
155
156        let mut router = timed_routes.merge(sse_routes);
157
158        if let Some(additional) = self.additional_routes {
159            router = router.merge(additional);
160        }
161
162        router.with_state(state)
163    }
164
165    pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
166        self.event_tx.clone()
167    }
168
169    pub async fn run(self) -> anyhow::Result<()> {
170        let bind: SocketAddr = self.config.bind_address.parse()?;
171
172        let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
173        webhook_delivery.start(self.event_tx.subscribe());
174
175        let router = self.build_router();
176
177        info!(%bind, "Starting A2A server");
178        let listener = tokio::net::TcpListener::bind(bind).await?;
179        axum::serve(listener, router)
180            .with_graceful_shutdown(shutdown_signal())
181            .await?;
182
183        info!("Server shutdown complete");
184        Ok(())
185    }
186}
187
188async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
189    if err.is::<tower::timeout::error::Elapsed>() {
190        (
191            StatusCode::REQUEST_TIMEOUT,
192            Json(error(
193                serde_json::Value::Null,
194                errors::INTERNAL_ERROR,
195                "Request timed out",
196                None,
197            )),
198        )
199    } else {
200        (
201            StatusCode::INTERNAL_SERVER_ERROR,
202            Json(error(
203                serde_json::Value::Null,
204                errors::INTERNAL_ERROR,
205                &format!("Internal error: {}", err),
206                None,
207            )),
208        )
209    }
210}
211
212async fn shutdown_signal() {
213    let ctrl_c = async {
214        signal::ctrl_c()
215            .await
216            .expect("failed to install Ctrl+C handler");
217    };
218
219    #[cfg(unix)]
220    let terminate = async {
221        signal::unix::signal(signal::unix::SignalKind::terminate())
222            .expect("failed to install SIGTERM handler")
223            .recv()
224            .await;
225    };
226
227    #[cfg(not(unix))]
228    let terminate = std::future::pending::<()>();
229
230    tokio::select! {
231        _ = ctrl_c => {},
232        _ = terminate => {},
233    }
234
235    info!("Shutdown signal received, initiating graceful shutdown...");
236}
237
238#[derive(Clone)]
239pub struct AppState {
240    handler: BoxedHandler,
241    task_store: TaskStore,
242    webhook_store: WebhookStore,
243    card: Arc<AgentCard>,
244    auth_extractor: Option<AuthExtractor>,
245    event_tx: broadcast::Sender<StreamResponse>,
246}
247
248impl AppState {
249    pub fn task_store(&self) -> &TaskStore {
250        &self.task_store
251    }
252
253    pub fn agent_card(&self) -> &AgentCard {
254        &self.card
255    }
256
257    pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
258        &self.event_tx
259    }
260
261    pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
262        self.event_tx.subscribe()
263    }
264
265    pub fn broadcast_event(&self, event: StreamResponse) {
266        let _ = self.event_tx.send(event);
267    }
268
269    /// Check if streaming is enabled in capabilities
270    fn streaming_enabled(&self) -> bool {
271        self.card.capabilities.streaming.unwrap_or(false)
272    }
273
274    /// Check if push notifications are enabled in capabilities
275    fn push_notifications_enabled(&self) -> bool {
276        self.card.capabilities.push_notifications.unwrap_or(false)
277    }
278
279    /// Get the endpoint URL from the agent card
280    fn endpoint_url(&self) -> &str {
281        self.card.endpoint().unwrap_or("")
282    }
283}
284
285// ============ History Trimming ============
286
287fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
288    match history_length {
289        Some(0) => {
290            task.history = None;
291        }
292        Some(n) => {
293            if let Some(ref mut history) = task.history {
294                let len = history.len();
295                if len > n as usize {
296                    *history = history.split_off(len - n as usize);
297                }
298            }
299        }
300        None => {}
301    }
302}
303
304// ============ Version Validation ============
305
306fn validate_a2a_version(
307    headers: &HeaderMap,
308    req_id: &serde_json::Value,
309) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
310    if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
311        let version_str = version_header.to_str().unwrap_or("");
312
313        let parts: Vec<&str> = version_str.split('.').collect();
314        if parts.len() >= 2 {
315            if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
316                if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
317                    return Ok(());
318                }
319
320                return Err((
321                    StatusCode::BAD_REQUEST,
322                    Json(error(
323                        req_id.clone(),
324                        errors::VERSION_NOT_SUPPORTED,
325                        &format!(
326                            "Protocol version {}.{} not supported. Supported version: {}.{}",
327                            major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
328                        ),
329                        Some(serde_json::json!({
330                            "requestedVersion": version_str,
331                            "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
332                        })),
333                    )),
334                ));
335            }
336        }
337
338        // Also accept bare "1.0" without minor
339        if version_str == "1" || version_str == "1.0" {
340            return Ok(());
341        }
342
343        return Err((
344            StatusCode::BAD_REQUEST,
345            Json(error(
346                req_id.clone(),
347                errors::VERSION_NOT_SUPPORTED,
348                &format!(
349                    "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
350                    version_str
351                ),
352                None,
353            )),
354        ));
355    }
356
357    Ok(())
358}
359
360// ============ Error Response Helpers ============
361
362#[allow(dead_code)]
363pub fn rpc_error(
364    id: serde_json::Value,
365    code: i32,
366    message: &str,
367    status: StatusCode,
368) -> (StatusCode, Json<JsonRpcResponse>) {
369    (status, Json(error(id, code, message, None)))
370}
371
372#[allow(dead_code)]
373pub fn rpc_error_with_data(
374    id: serde_json::Value,
375    code: i32,
376    message: &str,
377    data: serde_json::Value,
378    status: StatusCode,
379) -> (StatusCode, Json<JsonRpcResponse>) {
380    (status, Json(error(id, code, message, Some(data))))
381}
382
383#[allow(dead_code)]
384pub fn rpc_success(
385    id: serde_json::Value,
386    result: serde_json::Value,
387) -> (StatusCode, Json<JsonRpcResponse>) {
388    (StatusCode::OK, Json(success(id, result)))
389}
390
391// ============ Route Handlers ============
392
393async fn health() -> Json<serde_json::Value> {
394    Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
395}
396
397async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
398    Json((*state.card).clone())
399}
400
401async fn handle_task_subscribe_sse(
402    State(state): State<AppState>,
403    axum::extract::Path(task_id): axum::extract::Path<String>,
404) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
405    let mut rx = state.subscribe_events();
406    let task_store = state.task_store.clone();
407    let target_task_id = task_id;
408
409    let stream = async_stream::stream! {
410        if let Some(task) = task_store.get_flexible(&target_task_id).await {
411            if let Ok(json) = serde_json::to_string(&task) {
412                yield Ok(Event::default().data(json));
413            }
414        }
415
416        loop {
417            match rx.recv().await {
418                Ok(event) => {
419                    let matches = match &event {
420                        StreamResponse::Task(t) => t.id == target_task_id,
421                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
422                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
423                        StreamResponse::Message(_) => false,
424                    };
425                    if matches {
426                        // Serialize the inner value directly
427                        let json = match &event {
428                            StreamResponse::Task(t) => serde_json::to_string(t),
429                            StreamResponse::Message(m) => serde_json::to_string(m),
430                            StreamResponse::StatusUpdate(e) => serde_json::to_string(e),
431                            StreamResponse::ArtifactUpdate(e) => serde_json::to_string(e),
432                        };
433                        if let Ok(json) = json {
434                            yield Ok(Event::default().data(json));
435                        }
436
437                        let is_terminal = match &event {
438                            StreamResponse::Task(t) => t.status.state.is_terminal(),
439                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
440                            _ => false,
441                        };
442                        if is_terminal {
443                            break;
444                        }
445                    }
446                }
447                Err(broadcast::error::RecvError::Lagged(_)) => continue,
448                Err(broadcast::error::RecvError::Closed) => break,
449            }
450        }
451    };
452
453    Sse::new(stream).keep_alive(KeepAlive::default())
454}
455
456async fn handle_rpc(
457    State(state): State<AppState>,
458    headers: HeaderMap,
459    Json(req): Json<JsonRpcRequest>,
460) -> Response {
461    if req.jsonrpc != "2.0" {
462        let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
463        return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
464    }
465
466    if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
467        return err_response.into_response();
468    }
469
470    let auth_context = state
471        .auth_extractor
472        .as_ref()
473        .and_then(|extractor| extractor(&headers));
474
475    match req.method.as_str() {
476        // Spec method names per JSON-RPC binding
477        "message/send" => handle_message_send(state, req, auth_context)
478            .await
479            .into_response(),
480        "message/stream" => {
481            handle_message_stream(state, req, headers, auth_context)
482                .await
483                .into_response()
484        }
485        "tasks/get" => handle_tasks_get(state, req).await.into_response(),
486        "tasks/list" => handle_tasks_list(state, req).await.into_response(),
487        "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
488        "tasks/subscribe" => handle_tasks_subscribe(state, req).await.into_response(),
489        "tasks/pushNotificationConfig/create" => {
490            handle_push_config_create(state, req).await.into_response()
491        }
492        "tasks/pushNotificationConfig/get" => {
493            handle_push_config_get(state, req).await.into_response()
494        }
495        "tasks/pushNotificationConfig/list" => {
496            handle_push_config_list(state, req).await.into_response()
497        }
498        "tasks/pushNotificationConfig/delete" => {
499            handle_push_config_delete(state, req).await.into_response()
500        }
501        "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context)
502            .await
503            .into_response(),
504        _ => (
505            StatusCode::NOT_FOUND,
506            Json(error(
507                req.id,
508                errors::METHOD_NOT_FOUND,
509                "method not found",
510                None,
511            )),
512        )
513            .into_response(),
514    }
515}
516
517fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
518    match e {
519        HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
520        HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
521        HandlerError::BackendUnavailable { .. } => {
522            (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
523        }
524        HandlerError::ProcessingFailed { .. } => {
525            (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
526        }
527        HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
528    }
529}
530
531async fn handle_message_send(
532    state: AppState,
533    req: JsonRpcRequest,
534    auth_context: Option<AuthContext>,
535) -> (StatusCode, Json<JsonRpcResponse>) {
536    let req_id = req.id.clone();
537
538    let params: Result<SendMessageRequest, _> =
539        serde_json::from_value(req.params.clone().unwrap_or_default());
540
541    let params = match params {
542        Ok(p) => p,
543        Err(err) => {
544            return (
545                StatusCode::BAD_REQUEST,
546                Json(error(
547                    req_id,
548                    errors::INVALID_PARAMS,
549                    "invalid params",
550                    Some(serde_json::json!({"error": err.to_string()})),
551                )),
552            );
553        }
554    };
555
556    let blocking = params
557        .configuration
558        .as_ref()
559        .and_then(|c| c.blocking)
560        .unwrap_or(false);
561    let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
562
563    match state
564        .handler
565        .handle_message(params.message, auth_context)
566        .await
567    {
568        Ok(response) => {
569            match response {
570                SendMessageResponse::Task(mut task) => {
571                    // Store the task and broadcast event
572                    state.task_store.insert(task.clone()).await;
573                    state.broadcast_event(StreamResponse::Task(task.clone()));
574
575                    // If blocking mode, wait for task to reach terminal state
576                    if blocking && !task.status.state.is_terminal() {
577                        let task_id = task.id.clone();
578                        let mut rx = state.subscribe_events();
579
580                        let wait_result = timeout(BLOCKING_TIMEOUT, async {
581                            loop {
582                                tokio::select! {
583                                    result = rx.recv() => {
584                                        match result {
585                                            Ok(StreamResponse::Task(t)) if t.id == task_id => {
586                                                if t.status.state.is_terminal() {
587                                                    return Some(t);
588                                                }
589                                            }
590                                            Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
591                                                if e.status.state.is_terminal() {
592                                                    if let Some(t) = state.task_store.get(&task_id).await {
593                                                        return Some(t);
594                                                    }
595                                                }
596                                            }
597                                            Err(broadcast::error::RecvError::Closed) => {
598                                                return None;
599                                            }
600                                            _ => {}
601                                        }
602                                    }
603                                    _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
604                                        if let Some(t) = state.task_store.get(&task_id).await {
605                                            if t.status.state.is_terminal() {
606                                                return Some(t);
607                                            }
608                                        }
609                                    }
610                                }
611                            }
612                        })
613                        .await;
614
615                        match wait_result {
616                            Ok(Some(final_task)) => task = final_task,
617                            Ok(None) => {
618                                if let Some(t) = state.task_store.get(&task.id).await {
619                                    task = t;
620                                }
621                            }
622                            Err(_) => {
623                                tracing::warn!("Blocking request timed out for task {}", task.id);
624                                if let Some(t) = state.task_store.get(&task.id).await {
625                                    task = t;
626                                }
627                            }
628                        }
629                    }
630
631                    apply_history_length(&mut task, history_length);
632
633                    // Serialize the Task directly into the JSON-RPC result field
634                    // (no wrapper key — matches the A2A reference SDK wire format)
635                    match serde_json::to_value(&task) {
636                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
637                        Err(e) => (
638                            StatusCode::INTERNAL_SERVER_ERROR,
639                            Json(error(
640                                req_id,
641                                errors::INTERNAL_ERROR,
642                                "serialization failed",
643                                Some(serde_json::json!({"error": e.to_string()})),
644                            )),
645                        ),
646                    }
647                }
648                SendMessageResponse::Message(msg) => {
649                    // Serialize the Message directly into the JSON-RPC result field
650                    match serde_json::to_value(&msg) {
651                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
652                        Err(e) => (
653                            StatusCode::INTERNAL_SERVER_ERROR,
654                            Json(error(
655                                req_id,
656                                errors::INTERNAL_ERROR,
657                                "serialization failed",
658                                Some(serde_json::json!({"error": e.to_string()})),
659                            )),
660                        ),
661                    }
662                }
663            }
664        }
665        Err(e) => {
666            let (code, status) = handler_error_to_rpc(&e);
667            (status, Json(error(req_id, code, &e.to_string(), None)))
668        }
669    }
670}
671
672/// Handle `message/stream` — returns SSE directly from the JSON-RPC endpoint.
673///
674/// Each SSE event's `data:` is a full JSON-RPC response envelope wrapping the result
675/// (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent).
676async fn handle_message_stream(
677    state: AppState,
678    req: JsonRpcRequest,
679    _headers: HeaderMap,
680    auth_context: Option<AuthContext>,
681) -> Response {
682    let req_id = req.id.clone();
683
684    if !state.streaming_enabled() {
685        return (
686            StatusCode::BAD_REQUEST,
687            Json(error(
688                req_id,
689                errors::UNSUPPORTED_OPERATION,
690                "streaming not supported by this agent",
691                None,
692            )),
693        )
694            .into_response();
695    }
696
697    let params: Result<SendMessageRequest, _> =
698        serde_json::from_value(req.params.clone().unwrap_or_default());
699
700    let params = match params {
701        Ok(p) => p,
702        Err(err) => {
703            return (
704                StatusCode::BAD_REQUEST,
705                Json(error(
706                    req_id,
707                    errors::INVALID_PARAMS,
708                    "invalid params",
709                    Some(serde_json::json!({"error": err.to_string()})),
710                )),
711            )
712                .into_response();
713        }
714    };
715
716    let response = match state
717        .handler
718        .handle_message(params.message, auth_context)
719        .await
720    {
721        Ok(r) => r,
722        Err(e) => {
723            let (code, status) = handler_error_to_rpc(&e);
724            return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
725        }
726    };
727
728    // Extract task from response (streaming only works with tasks)
729    let task = match response {
730        SendMessageResponse::Task(t) => t,
731        SendMessageResponse::Message(_) => {
732            return (
733                StatusCode::BAD_REQUEST,
734                Json(error(
735                    req_id,
736                    errors::UNSUPPORTED_OPERATION,
737                    "handler returned a message, streaming requires a task",
738                    None,
739                )),
740            )
741                .into_response();
742        }
743    };
744
745    let task_id = task.id.clone();
746    state.task_store.insert(task.clone()).await;
747    state.broadcast_event(StreamResponse::Task(task.clone()));
748
749    let mut rx = state.subscribe_events();
750    let task_store = state.task_store.clone();
751    let target_task_id = task_id;
752
753    // Helper: wrap a value in a JSON-RPC success response envelope
754    let wrap = move |value: serde_json::Value| -> String {
755        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
756    };
757
758    let stream = async_stream::stream! {
759        // Yield initial task
760        if let Ok(val) = serde_json::to_value(&task) {
761            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
762        }
763
764        loop {
765            match rx.recv().await {
766                Ok(event) => {
767                    let matches = match &event {
768                        StreamResponse::Task(t) => t.id == target_task_id,
769                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
770                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
771                        StreamResponse::Message(m) => {
772                            m.context_id.as_ref().is_some_and(|ctx| {
773                                task_store.get(&target_task_id).now_or_never()
774                                    .flatten()
775                                    .is_some_and(|t| t.context_id == *ctx)
776                            })
777                        }
778                    };
779
780                    if matches {
781                        // Serialize the inner value directly (not the StreamResponse wrapper)
782                        let val = match &event {
783                            StreamResponse::Task(t) => serde_json::to_value(t),
784                            StreamResponse::Message(m) => serde_json::to_value(m),
785                            StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
786                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
787                        };
788                        if let Ok(val) = val {
789                            yield Ok(Event::default().data(wrap(val)));
790                        }
791
792                        // End stream on terminal state or final flag
793                        let is_terminal = match &event {
794                            StreamResponse::Task(t) => t.status.state.is_terminal(),
795                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
796                            _ => false,
797                        };
798                        if is_terminal {
799                            break;
800                        }
801                    }
802                }
803                Err(broadcast::error::RecvError::Lagged(_)) => continue,
804                Err(broadcast::error::RecvError::Closed) => break,
805            }
806        }
807    };
808
809    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
810}
811
812async fn handle_tasks_get(
813    state: AppState,
814    req: JsonRpcRequest,
815) -> (StatusCode, Json<JsonRpcResponse>) {
816    let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
817
818    match params {
819        Ok(p) => match state.task_store.get_flexible(&p.id).await {
820            Some(mut task) => {
821                apply_history_length(&mut task, p.history_length);
822
823                match serde_json::to_value(task) {
824                    Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
825                    Err(e) => (
826                        StatusCode::INTERNAL_SERVER_ERROR,
827                        Json(error(
828                            req.id,
829                            errors::INTERNAL_ERROR,
830                            "serialization failed",
831                            Some(serde_json::json!({"error": e.to_string()})),
832                        )),
833                    ),
834                }
835            }
836            None => (
837                StatusCode::NOT_FOUND,
838                Json(error(
839                    req.id,
840                    errors::TASK_NOT_FOUND,
841                    "task not found",
842                    None,
843                )),
844            ),
845        },
846        Err(err) => (
847            StatusCode::BAD_REQUEST,
848            Json(error(
849                req.id,
850                errors::INVALID_PARAMS,
851                "invalid params",
852                Some(serde_json::json!({"error": err.to_string()})),
853            )),
854        ),
855    }
856}
857
858async fn handle_tasks_cancel(
859    state: AppState,
860    req: JsonRpcRequest,
861) -> (StatusCode, Json<JsonRpcResponse>) {
862    let params: Result<CancelTaskRequest, _> =
863        serde_json::from_value(req.params.unwrap_or_default());
864
865    match params {
866        Ok(p) => {
867            let result = state
868                .task_store
869                .update_flexible(&p.id, |task| {
870                    if task.status.state.is_terminal() {
871                        return Err(errors::TASK_NOT_CANCELABLE);
872                    }
873                    task.status.state = TaskState::Canceled;
874                    task.status.timestamp = Some(now_iso8601());
875                    Ok(())
876                })
877                .await;
878
879            match result {
880                Some(Ok(task)) => {
881                    if let Err(e) = state.handler.cancel_task(&task.id).await {
882                        tracing::warn!("Handler cancel_task failed: {}", e);
883                    }
884
885                    state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
886                        kind: "status-update".to_string(),
887                        task_id: task.id.clone(),
888                        context_id: task.context_id.clone(),
889                        status: task.status.clone(),
890                        is_final: true,
891                        metadata: None,
892                    }));
893
894                    match serde_json::to_value(task) {
895                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
896                        Err(e) => (
897                            StatusCode::INTERNAL_SERVER_ERROR,
898                            Json(error(
899                                req.id,
900                                errors::INTERNAL_ERROR,
901                                "serialization failed",
902                                Some(serde_json::json!({"error": e.to_string()})),
903                            )),
904                        ),
905                    }
906                }
907                Some(Err(error_code)) => (
908                    StatusCode::BAD_REQUEST,
909                    Json(error(req.id, error_code, "task not cancelable", None)),
910                ),
911                None => (
912                    StatusCode::NOT_FOUND,
913                    Json(error(
914                        req.id,
915                        errors::TASK_NOT_FOUND,
916                        "task not found",
917                        None,
918                    )),
919                ),
920            }
921        }
922        Err(err) => (
923            StatusCode::BAD_REQUEST,
924            Json(error(
925                req.id,
926                errors::INVALID_PARAMS,
927                "invalid params",
928                Some(serde_json::json!({"error": err.to_string()})),
929            )),
930        ),
931    }
932}
933
934async fn handle_tasks_list(
935    state: AppState,
936    req: JsonRpcRequest,
937) -> (StatusCode, Json<JsonRpcResponse>) {
938    let params: Result<ListTasksRequest, _> =
939        serde_json::from_value(req.params.unwrap_or_default());
940
941    match params {
942        Ok(p) => {
943            let response = state.task_store.list_filtered(&p).await;
944            match serde_json::to_value(response) {
945                Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
946                Err(e) => (
947                    StatusCode::INTERNAL_SERVER_ERROR,
948                    Json(error(
949                        req.id,
950                        errors::INTERNAL_ERROR,
951                        "serialization failed",
952                        Some(serde_json::json!({"error": e.to_string()})),
953                    )),
954                ),
955            }
956        }
957        Err(err) => (
958            StatusCode::BAD_REQUEST,
959            Json(error(
960                req.id,
961                errors::INVALID_PARAMS,
962                "invalid params",
963                Some(serde_json::json!({"error": err.to_string()})),
964            )),
965        ),
966    }
967}
968
969async fn handle_tasks_subscribe(
970    state: AppState,
971    req: JsonRpcRequest,
972) -> (StatusCode, Json<JsonRpcResponse>) {
973    let params: Result<SubscribeToTaskRequest, _> =
974        serde_json::from_value(req.params.unwrap_or_default());
975
976    match params {
977        Ok(p) => {
978            if state.task_store.get_flexible(&p.id).await.is_none() {
979                return (
980                    StatusCode::NOT_FOUND,
981                    Json(error(
982                        req.id,
983                        errors::TASK_NOT_FOUND,
984                        "task not found",
985                        None,
986                    )),
987                );
988            }
989
990            let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
991            let subscribe_url = format!("{}/v1/tasks/{}/subscribe", base_url, p.id);
992
993            (
994                StatusCode::OK,
995                Json(success(
996                    req.id,
997                    serde_json::json!({
998                        "subscribeUrl": subscribe_url,
999                        "protocol": "sse"
1000                    }),
1001                )),
1002            )
1003        }
1004        Err(err) => (
1005            StatusCode::BAD_REQUEST,
1006            Json(error(
1007                req.id,
1008                errors::INVALID_PARAMS,
1009                "invalid params",
1010                Some(serde_json::json!({"error": err.to_string()})),
1011            )),
1012        ),
1013    }
1014}
1015
1016async fn handle_get_extended_agent_card(
1017    state: AppState,
1018    req: JsonRpcRequest,
1019    auth_context: Option<AuthContext>,
1020) -> (StatusCode, Json<JsonRpcResponse>) {
1021    let Some(auth) = auth_context else {
1022        return (
1023            StatusCode::UNAUTHORIZED,
1024            Json(error(
1025                req.id,
1026                errors::INVALID_REQUEST,
1027                "authentication required for extended agent card",
1028                None,
1029            )),
1030        );
1031    };
1032
1033    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1034
1035    match state.handler.extended_agent_card(base_url, &auth).await {
1036        Some(card) => match serde_json::to_value(card) {
1037            Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1038            Err(e) => (
1039                StatusCode::INTERNAL_SERVER_ERROR,
1040                Json(error(
1041                    req.id,
1042                    errors::INTERNAL_ERROR,
1043                    "serialization failed",
1044                    Some(serde_json::json!({"error": e.to_string()})),
1045                )),
1046            ),
1047        },
1048        None => (
1049            StatusCode::NOT_FOUND,
1050            Json(error(
1051                req.id,
1052                errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1053                "extended agent card not configured",
1054                None,
1055            )),
1056        ),
1057    }
1058}
1059
1060// ============ Push Notification Config Handlers ============
1061
1062async fn handle_push_config_create(
1063    state: AppState,
1064    req: JsonRpcRequest,
1065) -> (StatusCode, Json<JsonRpcResponse>) {
1066    if !state.push_notifications_enabled() {
1067        return (
1068            StatusCode::BAD_REQUEST,
1069            Json(error(
1070                req.id,
1071                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1072                "push notifications not supported",
1073                None,
1074            )),
1075        );
1076    }
1077
1078    let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1079        serde_json::from_value(req.params.unwrap_or_default());
1080
1081    match params {
1082        Ok(p) => {
1083            if state.task_store.get_flexible(&p.task_id).await.is_none() {
1084                return (
1085                    StatusCode::NOT_FOUND,
1086                    Json(error(
1087                        req.id,
1088                        errors::TASK_NOT_FOUND,
1089                        "task not found",
1090                        None,
1091                    )),
1092                );
1093            }
1094
1095            if let Err(e) = state
1096                .webhook_store
1097                .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1098                .await
1099            {
1100                return (
1101                    StatusCode::BAD_REQUEST,
1102                    Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1103                );
1104            }
1105
1106            (
1107                StatusCode::OK,
1108                Json(success(
1109                    req.id,
1110                    serde_json::json!({
1111                        "configId": p.config_id,
1112                        "config": p.push_notification_config
1113                    }),
1114                )),
1115            )
1116        }
1117        Err(err) => (
1118            StatusCode::BAD_REQUEST,
1119            Json(error(
1120                req.id,
1121                errors::INVALID_PARAMS,
1122                "invalid params",
1123                Some(serde_json::json!({"error": err.to_string()})),
1124            )),
1125        ),
1126    }
1127}
1128
1129async fn handle_push_config_get(
1130    state: AppState,
1131    req: JsonRpcRequest,
1132) -> (StatusCode, Json<JsonRpcResponse>) {
1133    if !state.push_notifications_enabled() {
1134        return (
1135            StatusCode::BAD_REQUEST,
1136            Json(error(
1137                req.id,
1138                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1139                "push notifications not supported",
1140                None,
1141            )),
1142        );
1143    }
1144
1145    let params: Result<GetTaskPushNotificationConfigRequest, _> =
1146        serde_json::from_value(req.params.unwrap_or_default());
1147
1148    match params {
1149        Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1150            Some(config) => (
1151                StatusCode::OK,
1152                Json(success(
1153                    req.id,
1154                    serde_json::json!({
1155                        "configId": p.id,
1156                        "config": config
1157                    }),
1158                )),
1159            ),
1160            None => (
1161                StatusCode::NOT_FOUND,
1162                Json(error(
1163                    req.id,
1164                    errors::TASK_NOT_FOUND,
1165                    "push notification config not found",
1166                    None,
1167                )),
1168            ),
1169        },
1170        Err(err) => (
1171            StatusCode::BAD_REQUEST,
1172            Json(error(
1173                req.id,
1174                errors::INVALID_PARAMS,
1175                "invalid params",
1176                Some(serde_json::json!({"error": err.to_string()})),
1177            )),
1178        ),
1179    }
1180}
1181
1182async fn handle_push_config_list(
1183    state: AppState,
1184    req: JsonRpcRequest,
1185) -> (StatusCode, Json<JsonRpcResponse>) {
1186    if !state.push_notifications_enabled() {
1187        return (
1188            StatusCode::BAD_REQUEST,
1189            Json(error(
1190                req.id,
1191                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1192                "push notifications not supported",
1193                None,
1194            )),
1195        );
1196    }
1197
1198    let params: Result<ListTaskPushNotificationConfigRequest, _> =
1199        serde_json::from_value(req.params.unwrap_or_default());
1200
1201    match params {
1202        Ok(p) => {
1203            let configs = state.webhook_store.list(&p.task_id).await;
1204
1205            let configs_json: Vec<_> = configs
1206                .iter()
1207                .map(|c| {
1208                    serde_json::json!({
1209                        "configId": c.config_id,
1210                        "config": c.config
1211                    })
1212                })
1213                .collect();
1214
1215            (
1216                StatusCode::OK,
1217                Json(success(
1218                    req.id,
1219                    serde_json::json!({
1220                        "configs": configs_json,
1221                        "nextPageToken": ""
1222                    }),
1223                )),
1224            )
1225        }
1226        Err(err) => (
1227            StatusCode::BAD_REQUEST,
1228            Json(error(
1229                req.id,
1230                errors::INVALID_PARAMS,
1231                "invalid params",
1232                Some(serde_json::json!({"error": err.to_string()})),
1233            )),
1234        ),
1235    }
1236}
1237
1238async fn handle_push_config_delete(
1239    state: AppState,
1240    req: JsonRpcRequest,
1241) -> (StatusCode, Json<JsonRpcResponse>) {
1242    if !state.push_notifications_enabled() {
1243        return (
1244            StatusCode::BAD_REQUEST,
1245            Json(error(
1246                req.id,
1247                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1248                "push notifications not supported",
1249                None,
1250            )),
1251        );
1252    }
1253
1254    let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1255        serde_json::from_value(req.params.unwrap_or_default());
1256
1257    match params {
1258        Ok(p) => {
1259            if state.webhook_store.delete(&p.task_id, &p.id).await {
1260                (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1261            } else {
1262                (
1263                    StatusCode::NOT_FOUND,
1264                    Json(error(
1265                        req.id,
1266                        errors::TASK_NOT_FOUND,
1267                        "push notification config not found",
1268                        None,
1269                    )),
1270                )
1271            }
1272        }
1273        Err(err) => (
1274            StatusCode::BAD_REQUEST,
1275            Json(error(
1276                req.id,
1277                errors::INVALID_PARAMS,
1278                "invalid params",
1279                Some(serde_json::json!({"error": err.to_string()})),
1280            )),
1281        ),
1282    }
1283}
1284
1285pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1286    bind_addr: &str,
1287    handler: H,
1288) -> anyhow::Result<()> {
1289    A2aServer::new(handler).bind(bind_addr)?.run().await
1290}
1291
1292pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1293    A2aServer::echo().bind(bind_addr)?.run().await
1294}