Skip to main content

a2a_rs_server/
server.rs

1//! Generic A2A JSON-RPC Server
2
3use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17    error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18    CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19    GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20    ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21    SendMessageResponse, SendMessageResult, StreamResponse, StreamingMessageResult,
22    SubscribeToTaskRequest, Task, TaskState, TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46    pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50    fn default() -> Self {
51        Self {
52            bind_address: "0.0.0.0:8080".to_string(),
53        }
54    }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62    config: ServerConfig,
63    handler: BoxedHandler,
64    task_store: TaskStore,
65    webhook_store: WebhookStore,
66    auth_extractor: Option<AuthExtractor>,
67    additional_routes: Option<Router<AppState>>,
68    event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72    pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74        Self {
75            config: ServerConfig::default(),
76            handler: Arc::new(handler),
77            task_store: TaskStore::new(),
78            webhook_store: WebhookStore::new(),
79            auth_extractor: None,
80            additional_routes: None,
81            event_tx,
82        }
83    }
84
85    pub fn echo() -> Self {
86        Self::new(EchoHandler::default())
87    }
88
89    pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90        let _: SocketAddr = address.parse()?;
91        self.config.bind_address = address.to_string();
92        Ok(self)
93    }
94
95    pub fn bind_unchecked(mut self, address: &str) -> Self {
96        self.config.bind_address = address.to_string();
97        self
98    }
99
100    pub fn task_store(mut self, store: TaskStore) -> Self {
101        self.task_store = store;
102        self
103    }
104
105    pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106    where
107        F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108    {
109        self.auth_extractor = Some(Arc::new(extractor));
110        self
111    }
112
113    pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114        self.additional_routes = Some(routes);
115        self
116    }
117
118    pub fn get_task_store(&self) -> TaskStore {
119        self.task_store.clone()
120    }
121
122    pub fn build_router(self) -> Router {
123        let bind: SocketAddr = self
124            .config
125            .bind_address
126            .parse()
127            .expect("Invalid bind address");
128        let base_url = format!("http://{}", bind);
129        let card = Arc::new(self.handler.agent_card(&base_url));
130
131        let state = AppState {
132            handler: self.handler,
133            task_store: self.task_store,
134            webhook_store: self.webhook_store,
135            card,
136            auth_extractor: self.auth_extractor,
137            event_tx: self.event_tx,
138        };
139
140        let timed_routes = Router::new()
141            .route("/health", get(health))
142            .route("/.well-known/agent-card.json", get(agent_card))
143            .route("/v1/rpc", post(handle_rpc))
144            .layer(
145                ServiceBuilder::new()
146                    .layer(HandleErrorLayer::new(handle_timeout_error))
147                    .timeout(Duration::from_secs(30)),
148            );
149
150        let mut router = timed_routes;
151
152        if let Some(additional) = self.additional_routes {
153            router = router.merge(additional);
154        }
155
156        router.with_state(state)
157    }
158
159    pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
160        self.event_tx.clone()
161    }
162
163    pub async fn run(self) -> anyhow::Result<()> {
164        let bind: SocketAddr = self.config.bind_address.parse()?;
165
166        let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
167        webhook_delivery.start(self.event_tx.subscribe());
168
169        let router = self.build_router();
170
171        info!(%bind, "Starting A2A server");
172        let listener = tokio::net::TcpListener::bind(bind).await?;
173        axum::serve(listener, router)
174            .with_graceful_shutdown(shutdown_signal())
175            .await?;
176
177        info!("Server shutdown complete");
178        Ok(())
179    }
180}
181
182async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
183    if err.is::<tower::timeout::error::Elapsed>() {
184        (
185            StatusCode::REQUEST_TIMEOUT,
186            Json(error(
187                serde_json::Value::Null,
188                errors::INTERNAL_ERROR,
189                "Request timed out",
190                None,
191            )),
192        )
193    } else {
194        (
195            StatusCode::INTERNAL_SERVER_ERROR,
196            Json(error(
197                serde_json::Value::Null,
198                errors::INTERNAL_ERROR,
199                &format!("Internal error: {}", err),
200                None,
201            )),
202        )
203    }
204}
205
206async fn shutdown_signal() {
207    let ctrl_c = async {
208        signal::ctrl_c()
209            .await
210            .expect("failed to install Ctrl+C handler");
211    };
212
213    #[cfg(unix)]
214    let terminate = async {
215        signal::unix::signal(signal::unix::SignalKind::terminate())
216            .expect("failed to install SIGTERM handler")
217            .recv()
218            .await;
219    };
220
221    #[cfg(not(unix))]
222    let terminate = std::future::pending::<()>();
223
224    tokio::select! {
225        _ = ctrl_c => {},
226        _ = terminate => {},
227    }
228
229    info!("Shutdown signal received, initiating graceful shutdown...");
230}
231
232#[derive(Clone)]
233pub struct AppState {
234    handler: BoxedHandler,
235    task_store: TaskStore,
236    webhook_store: WebhookStore,
237    card: Arc<AgentCard>,
238    auth_extractor: Option<AuthExtractor>,
239    event_tx: broadcast::Sender<StreamResponse>,
240}
241
242impl AppState {
243    pub fn task_store(&self) -> &TaskStore {
244        &self.task_store
245    }
246
247    pub fn agent_card(&self) -> &AgentCard {
248        &self.card
249    }
250
251    pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
252        &self.event_tx
253    }
254
255    pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
256        self.event_tx.subscribe()
257    }
258
259    pub fn broadcast_event(&self, event: StreamResponse) {
260        let _ = self.event_tx.send(event);
261    }
262
263    /// Check if streaming is enabled in capabilities
264    fn streaming_enabled(&self) -> bool {
265        self.card.capabilities.streaming.unwrap_or(false)
266    }
267
268    /// Check if push notifications are enabled in capabilities
269    fn push_notifications_enabled(&self) -> bool {
270        self.card.capabilities.push_notifications.unwrap_or(false)
271    }
272
273    /// Get the endpoint URL from the agent card
274    fn endpoint_url(&self) -> &str {
275        self.card.endpoint().unwrap_or("")
276    }
277}
278
279// ============ History Trimming ============
280
281fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
282    match history_length {
283        Some(0) => {
284            task.history = None;
285        }
286        Some(n) => {
287            if let Some(ref mut history) = task.history {
288                let len = history.len();
289                if len > n as usize {
290                    *history = history.split_off(len - n as usize);
291                }
292            }
293        }
294        None => {}
295    }
296}
297
298// ============ Version Validation ============
299
300fn validate_a2a_version(
301    headers: &HeaderMap,
302    req_id: &serde_json::Value,
303) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
304    if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
305        let version_str = version_header.to_str().unwrap_or("");
306
307        let parts: Vec<&str> = version_str.split('.').collect();
308        if parts.len() >= 2 {
309            if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
310                if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
311                    return Ok(());
312                }
313
314                return Err((
315                    StatusCode::BAD_REQUEST,
316                    Json(error(
317                        req_id.clone(),
318                        errors::VERSION_NOT_SUPPORTED,
319                        &format!(
320                            "Protocol version {}.{} not supported. Supported version: {}.{}",
321                            major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
322                        ),
323                        Some(serde_json::json!({
324                            "requestedVersion": version_str,
325                            "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
326                        })),
327                    )),
328                ));
329            }
330        }
331
332        // Also accept bare "1.0" without minor
333        if version_str == "1" || version_str == "1.0" {
334            return Ok(());
335        }
336
337        return Err((
338            StatusCode::BAD_REQUEST,
339            Json(error(
340                req_id.clone(),
341                errors::VERSION_NOT_SUPPORTED,
342                &format!(
343                    "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
344                    version_str
345                ),
346                None,
347            )),
348        ));
349    }
350
351    Ok(())
352}
353
354// ============ Error Response Helpers ============
355
356#[allow(dead_code)]
357pub fn rpc_error(
358    id: serde_json::Value,
359    code: i32,
360    message: &str,
361    status: StatusCode,
362) -> (StatusCode, Json<JsonRpcResponse>) {
363    (status, Json(error(id, code, message, None)))
364}
365
366#[allow(dead_code)]
367pub fn rpc_error_with_data(
368    id: serde_json::Value,
369    code: i32,
370    message: &str,
371    data: serde_json::Value,
372    status: StatusCode,
373) -> (StatusCode, Json<JsonRpcResponse>) {
374    (status, Json(error(id, code, message, Some(data))))
375}
376
377#[allow(dead_code)]
378pub fn rpc_success(
379    id: serde_json::Value,
380    result: serde_json::Value,
381) -> (StatusCode, Json<JsonRpcResponse>) {
382    (StatusCode::OK, Json(success(id, result)))
383}
384
385// ============ Route Handlers ============
386
387async fn health() -> Json<serde_json::Value> {
388    Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
389}
390
391async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
392    Json((*state.card).clone())
393}
394
395// handle_task_subscribe_sse removed — tasks/resubscribe now returns SSE
396// directly from the /v1/rpc endpoint via handle_tasks_resubscribe.
397
398async fn handle_rpc(
399    State(state): State<AppState>,
400    headers: HeaderMap,
401    Json(req): Json<JsonRpcRequest>,
402) -> Response {
403    if req.jsonrpc != "2.0" {
404        let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
405        return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
406    }
407
408    if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
409        return err_response.into_response();
410    }
411
412    let auth_context = state
413        .auth_extractor
414        .as_ref()
415        .and_then(|extractor| extractor(&headers));
416
417    match req.method.as_str() {
418        // Accept both v1.0 PascalCase and v0.3 kebab-case method names
419        "SendMessage" | "message/send" => handle_message_send(state, req, auth_context)
420            .await
421            .into_response(),
422        "SendStreamingMessage" | "message/stream" => {
423            handle_message_stream(state, req, headers, auth_context)
424                .await
425                .into_response()
426        }
427        "GetTask" | "tasks/get" => handle_tasks_get(state, req).await.into_response(),
428        "ListTasks" | "tasks/list" => handle_tasks_list(state, req).await.into_response(),
429        "CancelTask" | "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
430        "SubscribeToTask" | "tasks/resubscribe" => {
431            handle_tasks_resubscribe(state, req).await.into_response()
432        }
433        "CreateTaskPushNotificationConfig" | "tasks/pushNotificationConfig/create" => {
434            handle_push_config_create(state, req).await.into_response()
435        }
436        "GetTaskPushNotificationConfig" | "tasks/pushNotificationConfig/get" => {
437            handle_push_config_get(state, req).await.into_response()
438        }
439        "ListTaskPushNotificationConfigs" | "tasks/pushNotificationConfig/list" => {
440            handle_push_config_list(state, req).await.into_response()
441        }
442        "DeleteTaskPushNotificationConfig" | "tasks/pushNotificationConfig/delete" => {
443            handle_push_config_delete(state, req).await.into_response()
444        }
445        "GetExtendedAgentCard" | "agentCard/getExtended" => {
446            handle_get_extended_agent_card(state, req, auth_context)
447                .await
448                .into_response()
449        }
450        _ => (
451            StatusCode::NOT_FOUND,
452            Json(error(
453                req.id,
454                errors::METHOD_NOT_FOUND,
455                "method not found",
456                None,
457            )),
458        )
459            .into_response(),
460    }
461}
462
463fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
464    match e {
465        HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
466        HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
467        HandlerError::BackendUnavailable { .. } => {
468            (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
469        }
470        HandlerError::ProcessingFailed { .. } => {
471            (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
472        }
473        HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
474    }
475}
476
477async fn handle_message_send(
478    state: AppState,
479    req: JsonRpcRequest,
480    auth_context: Option<AuthContext>,
481) -> (StatusCode, Json<JsonRpcResponse>) {
482    let req_id = req.id.clone();
483
484    let params: Result<SendMessageRequest, _> =
485        serde_json::from_value(req.params.clone().unwrap_or_default());
486
487    let params = match params {
488        Ok(p) => p,
489        Err(err) => {
490            return (
491                StatusCode::BAD_REQUEST,
492                Json(error(
493                    req_id,
494                    errors::INVALID_PARAMS,
495                    "invalid params",
496                    Some(serde_json::json!({"error": err.to_string()})),
497                )),
498            );
499        }
500    };
501
502    let blocking = params
503        .configuration
504        .as_ref()
505        .and_then(|c| c.blocking)
506        .unwrap_or(false);
507    let return_immediately = params
508        .configuration
509        .as_ref()
510        .and_then(|c| c.return_immediately)
511        .unwrap_or(false);
512    let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
513
514    match state
515        .handler
516        .handle_message(params.message, auth_context)
517        .await
518    {
519        Ok(response) => {
520            match response {
521                SendMessageResponse::Task(mut task) => {
522                    // Store the task and broadcast event
523                    state.task_store.insert(task.clone()).await;
524                    state.broadcast_event(StreamResponse::Task(task.clone()));
525
526                    // If blocking mode and not returnImmediately, wait for terminal state
527                    if blocking && !return_immediately && !task.status.state.is_terminal() {
528                        let task_id = task.id.clone();
529                        let mut rx = state.subscribe_events();
530
531                        let wait_result = timeout(BLOCKING_TIMEOUT, async {
532                            loop {
533                                tokio::select! {
534                                    result = rx.recv() => {
535                                        match result {
536                                            Ok(StreamResponse::Task(t)) if t.id == task_id => {
537                                                if t.status.state.is_terminal() {
538                                                    return Some(t);
539                                                }
540                                            }
541                                            Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
542                                                if e.status.state.is_terminal() {
543                                                    if let Some(t) = state.task_store.get(&task_id).await {
544                                                        return Some(t);
545                                                    }
546                                                }
547                                            }
548                                            Err(broadcast::error::RecvError::Closed) => {
549                                                return None;
550                                            }
551                                            _ => {}
552                                        }
553                                    }
554                                    _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
555                                        if let Some(t) = state.task_store.get(&task_id).await {
556                                            if t.status.state.is_terminal() {
557                                                return Some(t);
558                                            }
559                                        }
560                                    }
561                                }
562                            }
563                        })
564                        .await;
565
566                        match wait_result {
567                            Ok(Some(final_task)) => task = final_task,
568                            Ok(None) => {
569                                if let Some(t) = state.task_store.get(&task.id).await {
570                                    task = t;
571                                }
572                            }
573                            Err(_) => {
574                                tracing::warn!("Blocking request timed out for task {}", task.id);
575                                if let Some(t) = state.task_store.get(&task.id).await {
576                                    task = t;
577                                }
578                            }
579                        }
580                    }
581
582                    apply_history_length(&mut task, history_length);
583
584                    // Serialize via SendMessageResult for externally tagged wrapping
585                    match serde_json::to_value(SendMessageResult::Task(task.clone())) {
586                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
587                        Err(e) => (
588                            StatusCode::INTERNAL_SERVER_ERROR,
589                            Json(error(
590                                req_id,
591                                errors::INTERNAL_ERROR,
592                                "serialization failed",
593                                Some(serde_json::json!({"error": e.to_string()})),
594                            )),
595                        ),
596                    }
597                }
598                SendMessageResponse::Message(msg) => {
599                    // Serialize via SendMessageResult for externally tagged wrapping
600                    match serde_json::to_value(SendMessageResult::Message(msg)) {
601                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
602                        Err(e) => (
603                            StatusCode::INTERNAL_SERVER_ERROR,
604                            Json(error(
605                                req_id,
606                                errors::INTERNAL_ERROR,
607                                "serialization failed",
608                                Some(serde_json::json!({"error": e.to_string()})),
609                            )),
610                        ),
611                    }
612                }
613            }
614        }
615        Err(e) => {
616            let (code, status) = handler_error_to_rpc(&e);
617            (status, Json(error(req_id, code, &e.to_string(), None)))
618        }
619    }
620}
621
622/// Handle `message/stream` — returns SSE directly from the JSON-RPC endpoint.
623///
624/// Each SSE event's `data:` is a full JSON-RPC response envelope wrapping the result
625/// (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent).
626async fn handle_message_stream(
627    state: AppState,
628    req: JsonRpcRequest,
629    _headers: HeaderMap,
630    auth_context: Option<AuthContext>,
631) -> Response {
632    let req_id = req.id.clone();
633
634    if !state.streaming_enabled() {
635        return (
636            StatusCode::BAD_REQUEST,
637            Json(error(
638                req_id,
639                errors::UNSUPPORTED_OPERATION,
640                "streaming not supported by this agent",
641                None,
642            )),
643        )
644            .into_response();
645    }
646
647    let params: Result<SendMessageRequest, _> =
648        serde_json::from_value(req.params.clone().unwrap_or_default());
649
650    let params = match params {
651        Ok(p) => p,
652        Err(err) => {
653            return (
654                StatusCode::BAD_REQUEST,
655                Json(error(
656                    req_id,
657                    errors::INVALID_PARAMS,
658                    "invalid params",
659                    Some(serde_json::json!({"error": err.to_string()})),
660                )),
661            )
662                .into_response();
663        }
664    };
665
666    let response = match state
667        .handler
668        .handle_message(params.message, auth_context)
669        .await
670    {
671        Ok(r) => r,
672        Err(e) => {
673            let (code, status) = handler_error_to_rpc(&e);
674            return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
675        }
676    };
677
678    // Extract task from response (streaming only works with tasks)
679    let task = match response {
680        SendMessageResponse::Task(t) => t,
681        SendMessageResponse::Message(_) => {
682            return (
683                StatusCode::BAD_REQUEST,
684                Json(error(
685                    req_id,
686                    errors::UNSUPPORTED_OPERATION,
687                    "handler returned a message, streaming requires a task",
688                    None,
689                )),
690            )
691                .into_response();
692        }
693    };
694
695    let task_id = task.id.clone();
696    state.task_store.insert(task.clone()).await;
697    state.broadcast_event(StreamResponse::Task(task.clone()));
698
699    let mut rx = state.subscribe_events();
700    let task_store = state.task_store.clone();
701    let target_task_id = task_id;
702
703    // Helper: wrap a value in a JSON-RPC success response envelope
704    let wrap = move |value: serde_json::Value| -> String {
705        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
706    };
707
708    let stream = async_stream::stream! {
709        // Yield initial task via StreamingMessageResult for proper tagging
710        if let Ok(val) = serde_json::to_value(StreamingMessageResult::Task(task)) {
711            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
712        }
713
714        loop {
715            match rx.recv().await {
716                Ok(event) => {
717                    let matches = match &event {
718                        StreamResponse::Task(t) => t.id == target_task_id,
719                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
720                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
721                        StreamResponse::Message(m) => {
722                            m.context_id.as_ref().is_some_and(|ctx| {
723                                task_store.get(&target_task_id).now_or_never()
724                                    .flatten()
725                                    .is_some_and(|t| t.context_id == *ctx)
726                            })
727                        }
728                    };
729
730                    if matches {
731                        // Serialize via StreamingMessageResult for proper external tagging
732                        let val = match event.clone() {
733                            StreamResponse::Task(t) => serde_json::to_value(StreamingMessageResult::Task(t)),
734                            StreamResponse::Message(m) => serde_json::to_value(StreamingMessageResult::Message(m)),
735                            StreamResponse::StatusUpdate(e) => serde_json::to_value(StreamingMessageResult::StatusUpdate(e)),
736                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(StreamingMessageResult::ArtifactUpdate(e)),
737                        };
738                        if let Ok(val) = val {
739                            yield Ok(Event::default().data(wrap(val)));
740                        }
741
742                        // End stream on terminal state or final flag
743                        let is_terminal = match &event {
744                            StreamResponse::Task(t) => t.status.state.is_terminal(),
745                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
746                            _ => false,
747                        };
748                        if is_terminal {
749                            break;
750                        }
751                    }
752                }
753                Err(broadcast::error::RecvError::Lagged(_)) => continue,
754                Err(broadcast::error::RecvError::Closed) => break,
755            }
756        }
757    };
758
759    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
760}
761
762async fn handle_tasks_get(
763    state: AppState,
764    req: JsonRpcRequest,
765) -> (StatusCode, Json<JsonRpcResponse>) {
766    let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
767
768    match params {
769        Ok(p) => match state.task_store.get_flexible(&p.id).await {
770            Some(mut task) => {
771                apply_history_length(&mut task, p.history_length);
772
773                match serde_json::to_value(task) {
774                    Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
775                    Err(e) => (
776                        StatusCode::INTERNAL_SERVER_ERROR,
777                        Json(error(
778                            req.id,
779                            errors::INTERNAL_ERROR,
780                            "serialization failed",
781                            Some(serde_json::json!({"error": e.to_string()})),
782                        )),
783                    ),
784                }
785            }
786            None => (
787                StatusCode::NOT_FOUND,
788                Json(error(
789                    req.id,
790                    errors::TASK_NOT_FOUND,
791                    "task not found",
792                    None,
793                )),
794            ),
795        },
796        Err(err) => (
797            StatusCode::BAD_REQUEST,
798            Json(error(
799                req.id,
800                errors::INVALID_PARAMS,
801                "invalid params",
802                Some(serde_json::json!({"error": err.to_string()})),
803            )),
804        ),
805    }
806}
807
808async fn handle_tasks_cancel(
809    state: AppState,
810    req: JsonRpcRequest,
811) -> (StatusCode, Json<JsonRpcResponse>) {
812    let params: Result<CancelTaskRequest, _> =
813        serde_json::from_value(req.params.unwrap_or_default());
814
815    match params {
816        Ok(p) => {
817            let result = state
818                .task_store
819                .update_flexible(&p.id, |task| {
820                    if task.status.state.is_terminal() {
821                        return Err(errors::TASK_NOT_CANCELABLE);
822                    }
823                    task.status.state = TaskState::Canceled;
824                    task.status.timestamp = Some(now_iso8601());
825                    Ok(())
826                })
827                .await;
828
829            match result {
830                Some(Ok(task)) => {
831                    if let Err(e) = state.handler.cancel_task(&task.id).await {
832                        tracing::warn!("Handler cancel_task failed: {}", e);
833                    }
834
835                    state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
836                        kind: "status-update".to_string(),
837                        task_id: task.id.clone(),
838                        context_id: task.context_id.clone(),
839                        status: task.status.clone(),
840                        is_final: true,
841                        metadata: None,
842                    }));
843
844                    match serde_json::to_value(task) {
845                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
846                        Err(e) => (
847                            StatusCode::INTERNAL_SERVER_ERROR,
848                            Json(error(
849                                req.id,
850                                errors::INTERNAL_ERROR,
851                                "serialization failed",
852                                Some(serde_json::json!({"error": e.to_string()})),
853                            )),
854                        ),
855                    }
856                }
857                Some(Err(error_code)) => (
858                    StatusCode::BAD_REQUEST,
859                    Json(error(req.id, error_code, "task not cancelable", None)),
860                ),
861                None => (
862                    StatusCode::NOT_FOUND,
863                    Json(error(
864                        req.id,
865                        errors::TASK_NOT_FOUND,
866                        "task not found",
867                        None,
868                    )),
869                ),
870            }
871        }
872        Err(err) => (
873            StatusCode::BAD_REQUEST,
874            Json(error(
875                req.id,
876                errors::INVALID_PARAMS,
877                "invalid params",
878                Some(serde_json::json!({"error": err.to_string()})),
879            )),
880        ),
881    }
882}
883
884async fn handle_tasks_list(
885    state: AppState,
886    req: JsonRpcRequest,
887) -> (StatusCode, Json<JsonRpcResponse>) {
888    let params: Result<ListTasksRequest, _> =
889        serde_json::from_value(req.params.unwrap_or_default());
890
891    match params {
892        Ok(p) => {
893            let response = state.task_store.list_filtered(&p).await;
894            match serde_json::to_value(response) {
895                Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
896                Err(e) => (
897                    StatusCode::INTERNAL_SERVER_ERROR,
898                    Json(error(
899                        req.id,
900                        errors::INTERNAL_ERROR,
901                        "serialization failed",
902                        Some(serde_json::json!({"error": e.to_string()})),
903                    )),
904                ),
905            }
906        }
907        Err(err) => (
908            StatusCode::BAD_REQUEST,
909            Json(error(
910                req.id,
911                errors::INVALID_PARAMS,
912                "invalid params",
913                Some(serde_json::json!({"error": err.to_string()})),
914            )),
915        ),
916    }
917}
918
919/// Handle `tasks/resubscribe` — returns SSE directly from the JSON-RPC endpoint.
920///
921/// Reconnects to an existing task's event stream. Each SSE event's `data:` is
922/// a JSON-RPC response envelope wrapping the result, same as `message/stream`.
923async fn handle_tasks_resubscribe(state: AppState, req: JsonRpcRequest) -> Response {
924    let req_id = req.id.clone();
925
926    let params: Result<SubscribeToTaskRequest, _> =
927        serde_json::from_value(req.params.unwrap_or_default());
928
929    let params = match params {
930        Ok(p) => p,
931        Err(err) => {
932            return (
933                StatusCode::BAD_REQUEST,
934                Json(error(
935                    req_id,
936                    errors::INVALID_PARAMS,
937                    "invalid params",
938                    Some(serde_json::json!({"error": err.to_string()})),
939                )),
940            )
941                .into_response();
942        }
943    };
944
945    // Verify the task exists
946    let task = match state.task_store.get_flexible(&params.id).await {
947        Some(t) => t,
948        None => {
949            return (
950                StatusCode::NOT_FOUND,
951                Json(error(req_id, errors::TASK_NOT_FOUND, "task not found", None)),
952            )
953                .into_response();
954        }
955    };
956
957    let target_task_id = task.id.clone();
958    let mut rx = state.subscribe_events();
959    let task_store = state.task_store.clone();
960
961    let wrap = move |value: serde_json::Value| -> String {
962        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
963    };
964
965    let stream = async_stream::stream! {
966        // Yield current task snapshot via StreamingMessageResult
967        if let Ok(val) = serde_json::to_value(StreamingMessageResult::Task(task.clone())) {
968            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
969        }
970
971        // If already terminal, stop
972        if task.status.state.is_terminal() {
973            return;
974        }
975
976        loop {
977            match rx.recv().await {
978                Ok(event) => {
979                    let matches = match &event {
980                        StreamResponse::Task(t) => t.id == target_task_id,
981                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
982                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
983                        StreamResponse::Message(m) => {
984                            m.context_id.as_ref().is_some_and(|ctx| {
985                                task_store.get(&target_task_id).now_or_never()
986                                    .flatten()
987                                    .is_some_and(|t| t.context_id == *ctx)
988                            })
989                        }
990                    };
991
992                    if matches {
993                        let val = match event.clone() {
994                            StreamResponse::Task(t) => serde_json::to_value(StreamingMessageResult::Task(t)),
995                            StreamResponse::Message(m) => serde_json::to_value(StreamingMessageResult::Message(m)),
996                            StreamResponse::StatusUpdate(e) => serde_json::to_value(StreamingMessageResult::StatusUpdate(e)),
997                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(StreamingMessageResult::ArtifactUpdate(e)),
998                        };
999                        if let Ok(val) = val {
1000                            yield Ok(Event::default().data(wrap(val)));
1001                        }
1002
1003                        let is_terminal = match &event {
1004                            StreamResponse::Task(t) => t.status.state.is_terminal(),
1005                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
1006                            _ => false,
1007                        };
1008                        if is_terminal {
1009                            break;
1010                        }
1011                    }
1012                }
1013                Err(broadcast::error::RecvError::Lagged(_)) => continue,
1014                Err(broadcast::error::RecvError::Closed) => break,
1015            }
1016        }
1017    };
1018
1019    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
1020}
1021
1022async fn handle_get_extended_agent_card(
1023    state: AppState,
1024    req: JsonRpcRequest,
1025    auth_context: Option<AuthContext>,
1026) -> (StatusCode, Json<JsonRpcResponse>) {
1027    let Some(auth) = auth_context else {
1028        return (
1029            StatusCode::UNAUTHORIZED,
1030            Json(error(
1031                req.id,
1032                errors::INVALID_REQUEST,
1033                "authentication required for extended agent card",
1034                None,
1035            )),
1036        );
1037    };
1038
1039    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1040
1041    match state.handler.extended_agent_card(base_url, &auth).await {
1042        Some(card) => match serde_json::to_value(card) {
1043            Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1044            Err(e) => (
1045                StatusCode::INTERNAL_SERVER_ERROR,
1046                Json(error(
1047                    req.id,
1048                    errors::INTERNAL_ERROR,
1049                    "serialization failed",
1050                    Some(serde_json::json!({"error": e.to_string()})),
1051                )),
1052            ),
1053        },
1054        None => (
1055            StatusCode::NOT_FOUND,
1056            Json(error(
1057                req.id,
1058                errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1059                "extended agent card not configured",
1060                None,
1061            )),
1062        ),
1063    }
1064}
1065
1066// ============ Push Notification Config Handlers ============
1067
1068async fn handle_push_config_create(
1069    state: AppState,
1070    req: JsonRpcRequest,
1071) -> (StatusCode, Json<JsonRpcResponse>) {
1072    if !state.push_notifications_enabled() {
1073        return (
1074            StatusCode::BAD_REQUEST,
1075            Json(error(
1076                req.id,
1077                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1078                "push notifications not supported",
1079                None,
1080            )),
1081        );
1082    }
1083
1084    let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1085        serde_json::from_value(req.params.unwrap_or_default());
1086
1087    match params {
1088        Ok(p) => {
1089            if state.task_store.get_flexible(&p.task_id).await.is_none() {
1090                return (
1091                    StatusCode::NOT_FOUND,
1092                    Json(error(
1093                        req.id,
1094                        errors::TASK_NOT_FOUND,
1095                        "task not found",
1096                        None,
1097                    )),
1098                );
1099            }
1100
1101            if let Err(e) = state
1102                .webhook_store
1103                .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1104                .await
1105            {
1106                return (
1107                    StatusCode::BAD_REQUEST,
1108                    Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1109                );
1110            }
1111
1112            (
1113                StatusCode::OK,
1114                Json(success(
1115                    req.id,
1116                    serde_json::json!({
1117                        "configId": p.config_id,
1118                        "config": p.push_notification_config
1119                    }),
1120                )),
1121            )
1122        }
1123        Err(err) => (
1124            StatusCode::BAD_REQUEST,
1125            Json(error(
1126                req.id,
1127                errors::INVALID_PARAMS,
1128                "invalid params",
1129                Some(serde_json::json!({"error": err.to_string()})),
1130            )),
1131        ),
1132    }
1133}
1134
1135async fn handle_push_config_get(
1136    state: AppState,
1137    req: JsonRpcRequest,
1138) -> (StatusCode, Json<JsonRpcResponse>) {
1139    if !state.push_notifications_enabled() {
1140        return (
1141            StatusCode::BAD_REQUEST,
1142            Json(error(
1143                req.id,
1144                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1145                "push notifications not supported",
1146                None,
1147            )),
1148        );
1149    }
1150
1151    let params: Result<GetTaskPushNotificationConfigRequest, _> =
1152        serde_json::from_value(req.params.unwrap_or_default());
1153
1154    match params {
1155        Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1156            Some(config) => (
1157                StatusCode::OK,
1158                Json(success(
1159                    req.id,
1160                    serde_json::json!({
1161                        "configId": p.id,
1162                        "config": config
1163                    }),
1164                )),
1165            ),
1166            None => (
1167                StatusCode::NOT_FOUND,
1168                Json(error(
1169                    req.id,
1170                    errors::TASK_NOT_FOUND,
1171                    "push notification config not found",
1172                    None,
1173                )),
1174            ),
1175        },
1176        Err(err) => (
1177            StatusCode::BAD_REQUEST,
1178            Json(error(
1179                req.id,
1180                errors::INVALID_PARAMS,
1181                "invalid params",
1182                Some(serde_json::json!({"error": err.to_string()})),
1183            )),
1184        ),
1185    }
1186}
1187
1188async fn handle_push_config_list(
1189    state: AppState,
1190    req: JsonRpcRequest,
1191) -> (StatusCode, Json<JsonRpcResponse>) {
1192    if !state.push_notifications_enabled() {
1193        return (
1194            StatusCode::BAD_REQUEST,
1195            Json(error(
1196                req.id,
1197                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1198                "push notifications not supported",
1199                None,
1200            )),
1201        );
1202    }
1203
1204    let params: Result<ListTaskPushNotificationConfigRequest, _> =
1205        serde_json::from_value(req.params.unwrap_or_default());
1206
1207    match params {
1208        Ok(p) => {
1209            let configs = state.webhook_store.list(&p.task_id).await;
1210
1211            let configs_json: Vec<_> = configs
1212                .iter()
1213                .map(|c| {
1214                    serde_json::json!({
1215                        "configId": c.config_id,
1216                        "config": c.config
1217                    })
1218                })
1219                .collect();
1220
1221            (
1222                StatusCode::OK,
1223                Json(success(
1224                    req.id,
1225                    serde_json::json!({
1226                        "configs": configs_json,
1227                        "nextPageToken": ""
1228                    }),
1229                )),
1230            )
1231        }
1232        Err(err) => (
1233            StatusCode::BAD_REQUEST,
1234            Json(error(
1235                req.id,
1236                errors::INVALID_PARAMS,
1237                "invalid params",
1238                Some(serde_json::json!({"error": err.to_string()})),
1239            )),
1240        ),
1241    }
1242}
1243
1244async fn handle_push_config_delete(
1245    state: AppState,
1246    req: JsonRpcRequest,
1247) -> (StatusCode, Json<JsonRpcResponse>) {
1248    if !state.push_notifications_enabled() {
1249        return (
1250            StatusCode::BAD_REQUEST,
1251            Json(error(
1252                req.id,
1253                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1254                "push notifications not supported",
1255                None,
1256            )),
1257        );
1258    }
1259
1260    let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1261        serde_json::from_value(req.params.unwrap_or_default());
1262
1263    match params {
1264        Ok(p) => {
1265            if state.webhook_store.delete(&p.task_id, &p.id).await {
1266                (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1267            } else {
1268                (
1269                    StatusCode::NOT_FOUND,
1270                    Json(error(
1271                        req.id,
1272                        errors::TASK_NOT_FOUND,
1273                        "push notification config not found",
1274                        None,
1275                    )),
1276                )
1277            }
1278        }
1279        Err(err) => (
1280            StatusCode::BAD_REQUEST,
1281            Json(error(
1282                req.id,
1283                errors::INVALID_PARAMS,
1284                "invalid params",
1285                Some(serde_json::json!({"error": err.to_string()})),
1286            )),
1287        ),
1288    }
1289}
1290
1291pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1292    bind_addr: &str,
1293    handler: H,
1294) -> anyhow::Result<()> {
1295    A2aServer::new(handler).bind(bind_addr)?.run().await
1296}
1297
1298pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1299    A2aServer::echo().bind(bind_addr)?.run().await
1300}