Skip to main content

a2a_rs_server/
server.rs

1//! Generic A2A JSON-RPC Server
2
3use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17    error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18    CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19    GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20    ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21    SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22    TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::response::{IntoResponse, Response};
33use axum::routing::{get, post};
34use axum::{Json, Router};
35use futures::future::FutureExt;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46    pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50    fn default() -> Self {
51        Self {
52            bind_address: "0.0.0.0:8080".to_string(),
53        }
54    }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62    config: ServerConfig,
63    handler: BoxedHandler,
64    task_store: TaskStore,
65    webhook_store: WebhookStore,
66    auth_extractor: Option<AuthExtractor>,
67    additional_routes: Option<Router<AppState>>,
68    event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72    pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74        Self {
75            config: ServerConfig::default(),
76            handler: Arc::new(handler),
77            task_store: TaskStore::new(),
78            webhook_store: WebhookStore::new(),
79            auth_extractor: None,
80            additional_routes: None,
81            event_tx,
82        }
83    }
84
85    pub fn echo() -> Self {
86        Self::new(EchoHandler::default())
87    }
88
89    pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90        let _: SocketAddr = address.parse()?;
91        self.config.bind_address = address.to_string();
92        Ok(self)
93    }
94
95    pub fn bind_unchecked(mut self, address: &str) -> Self {
96        self.config.bind_address = address.to_string();
97        self
98    }
99
100    pub fn task_store(mut self, store: TaskStore) -> Self {
101        self.task_store = store;
102        self
103    }
104
105    pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106    where
107        F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108    {
109        self.auth_extractor = Some(Arc::new(extractor));
110        self
111    }
112
113    pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114        self.additional_routes = Some(routes);
115        self
116    }
117
118    pub fn get_task_store(&self) -> TaskStore {
119        self.task_store.clone()
120    }
121
122    pub fn build_router(self) -> Router {
123        let bind: SocketAddr = self
124            .config
125            .bind_address
126            .parse()
127            .expect("Invalid bind address");
128        let base_url = format!("http://{}", bind);
129        let card = Arc::new(self.handler.agent_card(&base_url));
130
131        let state = AppState {
132            handler: self.handler,
133            task_store: self.task_store,
134            webhook_store: self.webhook_store,
135            card,
136            auth_extractor: self.auth_extractor,
137            event_tx: self.event_tx,
138        };
139
140        let timed_routes = Router::new()
141            .route("/health", get(health))
142            .route("/.well-known/agent-card.json", get(agent_card))
143            .route("/v1/rpc", post(handle_rpc))
144            .layer(
145                ServiceBuilder::new()
146                    .layer(HandleErrorLayer::new(handle_timeout_error))
147                    .timeout(Duration::from_secs(30)),
148            );
149
150        let mut router = timed_routes;
151
152        if let Some(additional) = self.additional_routes {
153            router = router.merge(additional);
154        }
155
156        router.with_state(state)
157    }
158
159    pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
160        self.event_tx.clone()
161    }
162
163    pub async fn run(self) -> anyhow::Result<()> {
164        let bind: SocketAddr = self.config.bind_address.parse()?;
165
166        let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
167        webhook_delivery.start(self.event_tx.subscribe());
168
169        let router = self.build_router();
170
171        info!(%bind, "Starting A2A server");
172        let listener = tokio::net::TcpListener::bind(bind).await?;
173        axum::serve(listener, router)
174            .with_graceful_shutdown(shutdown_signal())
175            .await?;
176
177        info!("Server shutdown complete");
178        Ok(())
179    }
180}
181
182async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
183    if err.is::<tower::timeout::error::Elapsed>() {
184        (
185            StatusCode::REQUEST_TIMEOUT,
186            Json(error(
187                serde_json::Value::Null,
188                errors::INTERNAL_ERROR,
189                "Request timed out",
190                None,
191            )),
192        )
193    } else {
194        (
195            StatusCode::INTERNAL_SERVER_ERROR,
196            Json(error(
197                serde_json::Value::Null,
198                errors::INTERNAL_ERROR,
199                &format!("Internal error: {}", err),
200                None,
201            )),
202        )
203    }
204}
205
206async fn shutdown_signal() {
207    let ctrl_c = async {
208        signal::ctrl_c()
209            .await
210            .expect("failed to install Ctrl+C handler");
211    };
212
213    #[cfg(unix)]
214    let terminate = async {
215        signal::unix::signal(signal::unix::SignalKind::terminate())
216            .expect("failed to install SIGTERM handler")
217            .recv()
218            .await;
219    };
220
221    #[cfg(not(unix))]
222    let terminate = std::future::pending::<()>();
223
224    tokio::select! {
225        _ = ctrl_c => {},
226        _ = terminate => {},
227    }
228
229    info!("Shutdown signal received, initiating graceful shutdown...");
230}
231
232#[derive(Clone)]
233pub struct AppState {
234    handler: BoxedHandler,
235    task_store: TaskStore,
236    webhook_store: WebhookStore,
237    card: Arc<AgentCard>,
238    auth_extractor: Option<AuthExtractor>,
239    event_tx: broadcast::Sender<StreamResponse>,
240}
241
242impl AppState {
243    pub fn task_store(&self) -> &TaskStore {
244        &self.task_store
245    }
246
247    pub fn agent_card(&self) -> &AgentCard {
248        &self.card
249    }
250
251    pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
252        &self.event_tx
253    }
254
255    pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
256        self.event_tx.subscribe()
257    }
258
259    pub fn broadcast_event(&self, event: StreamResponse) {
260        let _ = self.event_tx.send(event);
261    }
262
263    /// Check if streaming is enabled in capabilities
264    fn streaming_enabled(&self) -> bool {
265        self.card.capabilities.streaming.unwrap_or(false)
266    }
267
268    /// Check if push notifications are enabled in capabilities
269    fn push_notifications_enabled(&self) -> bool {
270        self.card.capabilities.push_notifications.unwrap_or(false)
271    }
272
273    /// Get the endpoint URL from the agent card
274    fn endpoint_url(&self) -> &str {
275        self.card.endpoint().unwrap_or("")
276    }
277}
278
279// ============ History Trimming ============
280
281fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
282    match history_length {
283        Some(0) => {
284            task.history = None;
285        }
286        Some(n) => {
287            if let Some(ref mut history) = task.history {
288                let len = history.len();
289                if len > n as usize {
290                    *history = history.split_off(len - n as usize);
291                }
292            }
293        }
294        None => {}
295    }
296}
297
298// ============ Version Validation ============
299
300fn validate_a2a_version(
301    headers: &HeaderMap,
302    req_id: &serde_json::Value,
303) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
304    if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
305        let version_str = version_header.to_str().unwrap_or("");
306
307        let parts: Vec<&str> = version_str.split('.').collect();
308        if parts.len() >= 2 {
309            if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
310                if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
311                    return Ok(());
312                }
313
314                return Err((
315                    StatusCode::BAD_REQUEST,
316                    Json(error(
317                        req_id.clone(),
318                        errors::VERSION_NOT_SUPPORTED,
319                        &format!(
320                            "Protocol version {}.{} not supported. Supported version: {}.{}",
321                            major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
322                        ),
323                        Some(serde_json::json!({
324                            "requestedVersion": version_str,
325                            "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
326                        })),
327                    )),
328                ));
329            }
330        }
331
332        // Also accept bare "1.0" without minor
333        if version_str == "1" || version_str == "1.0" {
334            return Ok(());
335        }
336
337        return Err((
338            StatusCode::BAD_REQUEST,
339            Json(error(
340                req_id.clone(),
341                errors::VERSION_NOT_SUPPORTED,
342                &format!(
343                    "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
344                    version_str
345                ),
346                None,
347            )),
348        ));
349    }
350
351    Ok(())
352}
353
354// ============ Error Response Helpers ============
355
356#[allow(dead_code)]
357pub fn rpc_error(
358    id: serde_json::Value,
359    code: i32,
360    message: &str,
361    status: StatusCode,
362) -> (StatusCode, Json<JsonRpcResponse>) {
363    (status, Json(error(id, code, message, None)))
364}
365
366#[allow(dead_code)]
367pub fn rpc_error_with_data(
368    id: serde_json::Value,
369    code: i32,
370    message: &str,
371    data: serde_json::Value,
372    status: StatusCode,
373) -> (StatusCode, Json<JsonRpcResponse>) {
374    (status, Json(error(id, code, message, Some(data))))
375}
376
377#[allow(dead_code)]
378pub fn rpc_success(
379    id: serde_json::Value,
380    result: serde_json::Value,
381) -> (StatusCode, Json<JsonRpcResponse>) {
382    (StatusCode::OK, Json(success(id, result)))
383}
384
385// ============ Route Handlers ============
386
387async fn health() -> Json<serde_json::Value> {
388    Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
389}
390
391async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
392    Json((*state.card).clone())
393}
394
395// handle_task_subscribe_sse removed — tasks/resubscribe now returns SSE
396// directly from the /v1/rpc endpoint via handle_tasks_resubscribe.
397
398async fn handle_rpc(
399    State(state): State<AppState>,
400    headers: HeaderMap,
401    Json(req): Json<JsonRpcRequest>,
402) -> Response {
403    if req.jsonrpc != "2.0" {
404        let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
405        return (StatusCode::BAD_REQUEST, Json(resp)).into_response();
406    }
407
408    if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
409        return err_response.into_response();
410    }
411
412    let auth_context = state
413        .auth_extractor
414        .as_ref()
415        .and_then(|extractor| extractor(&headers));
416
417    match req.method.as_str() {
418        // Spec method names per JSON-RPC binding
419        "message/send" => handle_message_send(state, req, auth_context)
420            .await
421            .into_response(),
422        "message/stream" => {
423            handle_message_stream(state, req, headers, auth_context)
424                .await
425                .into_response()
426        }
427        "tasks/get" => handle_tasks_get(state, req).await.into_response(),
428        "tasks/list" => handle_tasks_list(state, req).await.into_response(),
429        "tasks/cancel" => handle_tasks_cancel(state, req).await.into_response(),
430        "tasks/resubscribe" => handle_tasks_resubscribe(state, req).await.into_response(),
431        "tasks/pushNotificationConfig/create" => {
432            handle_push_config_create(state, req).await.into_response()
433        }
434        "tasks/pushNotificationConfig/get" => {
435            handle_push_config_get(state, req).await.into_response()
436        }
437        "tasks/pushNotificationConfig/list" => {
438            handle_push_config_list(state, req).await.into_response()
439        }
440        "tasks/pushNotificationConfig/delete" => {
441            handle_push_config_delete(state, req).await.into_response()
442        }
443        "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context)
444            .await
445            .into_response(),
446        _ => (
447            StatusCode::NOT_FOUND,
448            Json(error(
449                req.id,
450                errors::METHOD_NOT_FOUND,
451                "method not found",
452                None,
453            )),
454        )
455            .into_response(),
456    }
457}
458
459fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
460    match e {
461        HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
462        HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
463        HandlerError::BackendUnavailable { .. } => {
464            (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
465        }
466        HandlerError::ProcessingFailed { .. } => {
467            (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
468        }
469        HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
470    }
471}
472
473async fn handle_message_send(
474    state: AppState,
475    req: JsonRpcRequest,
476    auth_context: Option<AuthContext>,
477) -> (StatusCode, Json<JsonRpcResponse>) {
478    let req_id = req.id.clone();
479
480    let params: Result<SendMessageRequest, _> =
481        serde_json::from_value(req.params.clone().unwrap_or_default());
482
483    let params = match params {
484        Ok(p) => p,
485        Err(err) => {
486            return (
487                StatusCode::BAD_REQUEST,
488                Json(error(
489                    req_id,
490                    errors::INVALID_PARAMS,
491                    "invalid params",
492                    Some(serde_json::json!({"error": err.to_string()})),
493                )),
494            );
495        }
496    };
497
498    let blocking = params
499        .configuration
500        .as_ref()
501        .and_then(|c| c.blocking)
502        .unwrap_or(false);
503    let return_immediately = params
504        .configuration
505        .as_ref()
506        .and_then(|c| c.return_immediately)
507        .unwrap_or(false);
508    let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
509
510    match state
511        .handler
512        .handle_message(params.message, auth_context)
513        .await
514    {
515        Ok(response) => {
516            match response {
517                SendMessageResponse::Task(mut task) => {
518                    // Store the task and broadcast event
519                    state.task_store.insert(task.clone()).await;
520                    state.broadcast_event(StreamResponse::Task(task.clone()));
521
522                    // If blocking mode and not returnImmediately, wait for terminal state
523                    if blocking && !return_immediately && !task.status.state.is_terminal() {
524                        let task_id = task.id.clone();
525                        let mut rx = state.subscribe_events();
526
527                        let wait_result = timeout(BLOCKING_TIMEOUT, async {
528                            loop {
529                                tokio::select! {
530                                    result = rx.recv() => {
531                                        match result {
532                                            Ok(StreamResponse::Task(t)) if t.id == task_id => {
533                                                if t.status.state.is_terminal() {
534                                                    return Some(t);
535                                                }
536                                            }
537                                            Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
538                                                if e.status.state.is_terminal() {
539                                                    if let Some(t) = state.task_store.get(&task_id).await {
540                                                        return Some(t);
541                                                    }
542                                                }
543                                            }
544                                            Err(broadcast::error::RecvError::Closed) => {
545                                                return None;
546                                            }
547                                            _ => {}
548                                        }
549                                    }
550                                    _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
551                                        if let Some(t) = state.task_store.get(&task_id).await {
552                                            if t.status.state.is_terminal() {
553                                                return Some(t);
554                                            }
555                                        }
556                                    }
557                                }
558                            }
559                        })
560                        .await;
561
562                        match wait_result {
563                            Ok(Some(final_task)) => task = final_task,
564                            Ok(None) => {
565                                if let Some(t) = state.task_store.get(&task.id).await {
566                                    task = t;
567                                }
568                            }
569                            Err(_) => {
570                                tracing::warn!("Blocking request timed out for task {}", task.id);
571                                if let Some(t) = state.task_store.get(&task.id).await {
572                                    task = t;
573                                }
574                            }
575                        }
576                    }
577
578                    apply_history_length(&mut task, history_length);
579
580                    // Serialize the Task directly into the JSON-RPC result field
581                    // (no wrapper key — matches the A2A reference SDK wire format)
582                    match serde_json::to_value(&task) {
583                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
584                        Err(e) => (
585                            StatusCode::INTERNAL_SERVER_ERROR,
586                            Json(error(
587                                req_id,
588                                errors::INTERNAL_ERROR,
589                                "serialization failed",
590                                Some(serde_json::json!({"error": e.to_string()})),
591                            )),
592                        ),
593                    }
594                }
595                SendMessageResponse::Message(msg) => {
596                    // Serialize the Message directly into the JSON-RPC result field
597                    match serde_json::to_value(&msg) {
598                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
599                        Err(e) => (
600                            StatusCode::INTERNAL_SERVER_ERROR,
601                            Json(error(
602                                req_id,
603                                errors::INTERNAL_ERROR,
604                                "serialization failed",
605                                Some(serde_json::json!({"error": e.to_string()})),
606                            )),
607                        ),
608                    }
609                }
610            }
611        }
612        Err(e) => {
613            let (code, status) = handler_error_to_rpc(&e);
614            (status, Json(error(req_id, code, &e.to_string(), None)))
615        }
616    }
617}
618
619/// Handle `message/stream` — returns SSE directly from the JSON-RPC endpoint.
620///
621/// Each SSE event's `data:` is a full JSON-RPC response envelope wrapping the result
622/// (Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent).
623async fn handle_message_stream(
624    state: AppState,
625    req: JsonRpcRequest,
626    _headers: HeaderMap,
627    auth_context: Option<AuthContext>,
628) -> Response {
629    let req_id = req.id.clone();
630
631    if !state.streaming_enabled() {
632        return (
633            StatusCode::BAD_REQUEST,
634            Json(error(
635                req_id,
636                errors::UNSUPPORTED_OPERATION,
637                "streaming not supported by this agent",
638                None,
639            )),
640        )
641            .into_response();
642    }
643
644    let params: Result<SendMessageRequest, _> =
645        serde_json::from_value(req.params.clone().unwrap_or_default());
646
647    let params = match params {
648        Ok(p) => p,
649        Err(err) => {
650            return (
651                StatusCode::BAD_REQUEST,
652                Json(error(
653                    req_id,
654                    errors::INVALID_PARAMS,
655                    "invalid params",
656                    Some(serde_json::json!({"error": err.to_string()})),
657                )),
658            )
659                .into_response();
660        }
661    };
662
663    let response = match state
664        .handler
665        .handle_message(params.message, auth_context)
666        .await
667    {
668        Ok(r) => r,
669        Err(e) => {
670            let (code, status) = handler_error_to_rpc(&e);
671            return (status, Json(error(req_id, code, &e.to_string(), None))).into_response();
672        }
673    };
674
675    // Extract task from response (streaming only works with tasks)
676    let task = match response {
677        SendMessageResponse::Task(t) => t,
678        SendMessageResponse::Message(_) => {
679            return (
680                StatusCode::BAD_REQUEST,
681                Json(error(
682                    req_id,
683                    errors::UNSUPPORTED_OPERATION,
684                    "handler returned a message, streaming requires a task",
685                    None,
686                )),
687            )
688                .into_response();
689        }
690    };
691
692    let task_id = task.id.clone();
693    state.task_store.insert(task.clone()).await;
694    state.broadcast_event(StreamResponse::Task(task.clone()));
695
696    let mut rx = state.subscribe_events();
697    let task_store = state.task_store.clone();
698    let target_task_id = task_id;
699
700    // Helper: wrap a value in a JSON-RPC success response envelope
701    let wrap = move |value: serde_json::Value| -> String {
702        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
703    };
704
705    let stream = async_stream::stream! {
706        // Yield initial task
707        if let Ok(val) = serde_json::to_value(&task) {
708            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
709        }
710
711        loop {
712            match rx.recv().await {
713                Ok(event) => {
714                    let matches = match &event {
715                        StreamResponse::Task(t) => t.id == target_task_id,
716                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
717                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
718                        StreamResponse::Message(m) => {
719                            m.context_id.as_ref().is_some_and(|ctx| {
720                                task_store.get(&target_task_id).now_or_never()
721                                    .flatten()
722                                    .is_some_and(|t| t.context_id == *ctx)
723                            })
724                        }
725                    };
726
727                    if matches {
728                        // Serialize the inner value directly (not the StreamResponse wrapper)
729                        let val = match &event {
730                            StreamResponse::Task(t) => serde_json::to_value(t),
731                            StreamResponse::Message(m) => serde_json::to_value(m),
732                            StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
733                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
734                        };
735                        if let Ok(val) = val {
736                            yield Ok(Event::default().data(wrap(val)));
737                        }
738
739                        // End stream on terminal state or final flag
740                        let is_terminal = match &event {
741                            StreamResponse::Task(t) => t.status.state.is_terminal(),
742                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
743                            _ => false,
744                        };
745                        if is_terminal {
746                            break;
747                        }
748                    }
749                }
750                Err(broadcast::error::RecvError::Lagged(_)) => continue,
751                Err(broadcast::error::RecvError::Closed) => break,
752            }
753        }
754    };
755
756    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
757}
758
759async fn handle_tasks_get(
760    state: AppState,
761    req: JsonRpcRequest,
762) -> (StatusCode, Json<JsonRpcResponse>) {
763    let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
764
765    match params {
766        Ok(p) => match state.task_store.get_flexible(&p.id).await {
767            Some(mut task) => {
768                apply_history_length(&mut task, p.history_length);
769
770                match serde_json::to_value(task) {
771                    Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
772                    Err(e) => (
773                        StatusCode::INTERNAL_SERVER_ERROR,
774                        Json(error(
775                            req.id,
776                            errors::INTERNAL_ERROR,
777                            "serialization failed",
778                            Some(serde_json::json!({"error": e.to_string()})),
779                        )),
780                    ),
781                }
782            }
783            None => (
784                StatusCode::NOT_FOUND,
785                Json(error(
786                    req.id,
787                    errors::TASK_NOT_FOUND,
788                    "task not found",
789                    None,
790                )),
791            ),
792        },
793        Err(err) => (
794            StatusCode::BAD_REQUEST,
795            Json(error(
796                req.id,
797                errors::INVALID_PARAMS,
798                "invalid params",
799                Some(serde_json::json!({"error": err.to_string()})),
800            )),
801        ),
802    }
803}
804
805async fn handle_tasks_cancel(
806    state: AppState,
807    req: JsonRpcRequest,
808) -> (StatusCode, Json<JsonRpcResponse>) {
809    let params: Result<CancelTaskRequest, _> =
810        serde_json::from_value(req.params.unwrap_or_default());
811
812    match params {
813        Ok(p) => {
814            let result = state
815                .task_store
816                .update_flexible(&p.id, |task| {
817                    if task.status.state.is_terminal() {
818                        return Err(errors::TASK_NOT_CANCELABLE);
819                    }
820                    task.status.state = TaskState::Canceled;
821                    task.status.timestamp = Some(now_iso8601());
822                    Ok(())
823                })
824                .await;
825
826            match result {
827                Some(Ok(task)) => {
828                    if let Err(e) = state.handler.cancel_task(&task.id).await {
829                        tracing::warn!("Handler cancel_task failed: {}", e);
830                    }
831
832                    state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
833                        kind: "status-update".to_string(),
834                        task_id: task.id.clone(),
835                        context_id: task.context_id.clone(),
836                        status: task.status.clone(),
837                        is_final: true,
838                        metadata: None,
839                    }));
840
841                    match serde_json::to_value(task) {
842                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
843                        Err(e) => (
844                            StatusCode::INTERNAL_SERVER_ERROR,
845                            Json(error(
846                                req.id,
847                                errors::INTERNAL_ERROR,
848                                "serialization failed",
849                                Some(serde_json::json!({"error": e.to_string()})),
850                            )),
851                        ),
852                    }
853                }
854                Some(Err(error_code)) => (
855                    StatusCode::BAD_REQUEST,
856                    Json(error(req.id, error_code, "task not cancelable", None)),
857                ),
858                None => (
859                    StatusCode::NOT_FOUND,
860                    Json(error(
861                        req.id,
862                        errors::TASK_NOT_FOUND,
863                        "task not found",
864                        None,
865                    )),
866                ),
867            }
868        }
869        Err(err) => (
870            StatusCode::BAD_REQUEST,
871            Json(error(
872                req.id,
873                errors::INVALID_PARAMS,
874                "invalid params",
875                Some(serde_json::json!({"error": err.to_string()})),
876            )),
877        ),
878    }
879}
880
881async fn handle_tasks_list(
882    state: AppState,
883    req: JsonRpcRequest,
884) -> (StatusCode, Json<JsonRpcResponse>) {
885    let params: Result<ListTasksRequest, _> =
886        serde_json::from_value(req.params.unwrap_or_default());
887
888    match params {
889        Ok(p) => {
890            let response = state.task_store.list_filtered(&p).await;
891            match serde_json::to_value(response) {
892                Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
893                Err(e) => (
894                    StatusCode::INTERNAL_SERVER_ERROR,
895                    Json(error(
896                        req.id,
897                        errors::INTERNAL_ERROR,
898                        "serialization failed",
899                        Some(serde_json::json!({"error": e.to_string()})),
900                    )),
901                ),
902            }
903        }
904        Err(err) => (
905            StatusCode::BAD_REQUEST,
906            Json(error(
907                req.id,
908                errors::INVALID_PARAMS,
909                "invalid params",
910                Some(serde_json::json!({"error": err.to_string()})),
911            )),
912        ),
913    }
914}
915
916/// Handle `tasks/resubscribe` — returns SSE directly from the JSON-RPC endpoint.
917///
918/// Reconnects to an existing task's event stream. Each SSE event's `data:` is
919/// a JSON-RPC response envelope wrapping the result, same as `message/stream`.
920async fn handle_tasks_resubscribe(state: AppState, req: JsonRpcRequest) -> Response {
921    let req_id = req.id.clone();
922
923    let params: Result<SubscribeToTaskRequest, _> =
924        serde_json::from_value(req.params.unwrap_or_default());
925
926    let params = match params {
927        Ok(p) => p,
928        Err(err) => {
929            return (
930                StatusCode::BAD_REQUEST,
931                Json(error(
932                    req_id,
933                    errors::INVALID_PARAMS,
934                    "invalid params",
935                    Some(serde_json::json!({"error": err.to_string()})),
936                )),
937            )
938                .into_response();
939        }
940    };
941
942    // Verify the task exists
943    let task = match state.task_store.get_flexible(&params.id).await {
944        Some(t) => t,
945        None => {
946            return (
947                StatusCode::NOT_FOUND,
948                Json(error(req_id, errors::TASK_NOT_FOUND, "task not found", None)),
949            )
950                .into_response();
951        }
952    };
953
954    let target_task_id = task.id.clone();
955    let mut rx = state.subscribe_events();
956    let task_store = state.task_store.clone();
957
958    let wrap = move |value: serde_json::Value| -> String {
959        serde_json::to_string(&success(req_id.clone(), value)).unwrap_or_default()
960    };
961
962    let stream = async_stream::stream! {
963        // Yield current task snapshot
964        if let Ok(val) = serde_json::to_value(&task) {
965            yield Ok::<_, Infallible>(Event::default().data(wrap(val)));
966        }
967
968        // If already terminal, stop
969        if task.status.state.is_terminal() {
970            return;
971        }
972
973        loop {
974            match rx.recv().await {
975                Ok(event) => {
976                    let matches = match &event {
977                        StreamResponse::Task(t) => t.id == target_task_id,
978                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
979                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
980                        StreamResponse::Message(m) => {
981                            m.context_id.as_ref().is_some_and(|ctx| {
982                                task_store.get(&target_task_id).now_or_never()
983                                    .flatten()
984                                    .is_some_and(|t| t.context_id == *ctx)
985                            })
986                        }
987                    };
988
989                    if matches {
990                        let val = match &event {
991                            StreamResponse::Task(t) => serde_json::to_value(t),
992                            StreamResponse::Message(m) => serde_json::to_value(m),
993                            StreamResponse::StatusUpdate(e) => serde_json::to_value(e),
994                            StreamResponse::ArtifactUpdate(e) => serde_json::to_value(e),
995                        };
996                        if let Ok(val) = val {
997                            yield Ok(Event::default().data(wrap(val)));
998                        }
999
1000                        let is_terminal = match &event {
1001                            StreamResponse::Task(t) => t.status.state.is_terminal(),
1002                            StreamResponse::StatusUpdate(e) => e.is_final || e.status.state.is_terminal(),
1003                            _ => false,
1004                        };
1005                        if is_terminal {
1006                            break;
1007                        }
1008                    }
1009                }
1010                Err(broadcast::error::RecvError::Lagged(_)) => continue,
1011                Err(broadcast::error::RecvError::Closed) => break,
1012            }
1013        }
1014    };
1015
1016    Sse::new(stream).keep_alive(KeepAlive::default()).into_response()
1017}
1018
1019async fn handle_get_extended_agent_card(
1020    state: AppState,
1021    req: JsonRpcRequest,
1022    auth_context: Option<AuthContext>,
1023) -> (StatusCode, Json<JsonRpcResponse>) {
1024    let Some(auth) = auth_context else {
1025        return (
1026            StatusCode::UNAUTHORIZED,
1027            Json(error(
1028                req.id,
1029                errors::INVALID_REQUEST,
1030                "authentication required for extended agent card",
1031                None,
1032            )),
1033        );
1034    };
1035
1036    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1037
1038    match state.handler.extended_agent_card(base_url, &auth).await {
1039        Some(card) => match serde_json::to_value(card) {
1040            Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1041            Err(e) => (
1042                StatusCode::INTERNAL_SERVER_ERROR,
1043                Json(error(
1044                    req.id,
1045                    errors::INTERNAL_ERROR,
1046                    "serialization failed",
1047                    Some(serde_json::json!({"error": e.to_string()})),
1048                )),
1049            ),
1050        },
1051        None => (
1052            StatusCode::NOT_FOUND,
1053            Json(error(
1054                req.id,
1055                errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1056                "extended agent card not configured",
1057                None,
1058            )),
1059        ),
1060    }
1061}
1062
1063// ============ Push Notification Config Handlers ============
1064
1065async fn handle_push_config_create(
1066    state: AppState,
1067    req: JsonRpcRequest,
1068) -> (StatusCode, Json<JsonRpcResponse>) {
1069    if !state.push_notifications_enabled() {
1070        return (
1071            StatusCode::BAD_REQUEST,
1072            Json(error(
1073                req.id,
1074                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1075                "push notifications not supported",
1076                None,
1077            )),
1078        );
1079    }
1080
1081    let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1082        serde_json::from_value(req.params.unwrap_or_default());
1083
1084    match params {
1085        Ok(p) => {
1086            if state.task_store.get_flexible(&p.task_id).await.is_none() {
1087                return (
1088                    StatusCode::NOT_FOUND,
1089                    Json(error(
1090                        req.id,
1091                        errors::TASK_NOT_FOUND,
1092                        "task not found",
1093                        None,
1094                    )),
1095                );
1096            }
1097
1098            if let Err(e) = state
1099                .webhook_store
1100                .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1101                .await
1102            {
1103                return (
1104                    StatusCode::BAD_REQUEST,
1105                    Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1106                );
1107            }
1108
1109            (
1110                StatusCode::OK,
1111                Json(success(
1112                    req.id,
1113                    serde_json::json!({
1114                        "configId": p.config_id,
1115                        "config": p.push_notification_config
1116                    }),
1117                )),
1118            )
1119        }
1120        Err(err) => (
1121            StatusCode::BAD_REQUEST,
1122            Json(error(
1123                req.id,
1124                errors::INVALID_PARAMS,
1125                "invalid params",
1126                Some(serde_json::json!({"error": err.to_string()})),
1127            )),
1128        ),
1129    }
1130}
1131
1132async fn handle_push_config_get(
1133    state: AppState,
1134    req: JsonRpcRequest,
1135) -> (StatusCode, Json<JsonRpcResponse>) {
1136    if !state.push_notifications_enabled() {
1137        return (
1138            StatusCode::BAD_REQUEST,
1139            Json(error(
1140                req.id,
1141                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1142                "push notifications not supported",
1143                None,
1144            )),
1145        );
1146    }
1147
1148    let params: Result<GetTaskPushNotificationConfigRequest, _> =
1149        serde_json::from_value(req.params.unwrap_or_default());
1150
1151    match params {
1152        Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1153            Some(config) => (
1154                StatusCode::OK,
1155                Json(success(
1156                    req.id,
1157                    serde_json::json!({
1158                        "configId": p.id,
1159                        "config": config
1160                    }),
1161                )),
1162            ),
1163            None => (
1164                StatusCode::NOT_FOUND,
1165                Json(error(
1166                    req.id,
1167                    errors::TASK_NOT_FOUND,
1168                    "push notification config not found",
1169                    None,
1170                )),
1171            ),
1172        },
1173        Err(err) => (
1174            StatusCode::BAD_REQUEST,
1175            Json(error(
1176                req.id,
1177                errors::INVALID_PARAMS,
1178                "invalid params",
1179                Some(serde_json::json!({"error": err.to_string()})),
1180            )),
1181        ),
1182    }
1183}
1184
1185async fn handle_push_config_list(
1186    state: AppState,
1187    req: JsonRpcRequest,
1188) -> (StatusCode, Json<JsonRpcResponse>) {
1189    if !state.push_notifications_enabled() {
1190        return (
1191            StatusCode::BAD_REQUEST,
1192            Json(error(
1193                req.id,
1194                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1195                "push notifications not supported",
1196                None,
1197            )),
1198        );
1199    }
1200
1201    let params: Result<ListTaskPushNotificationConfigRequest, _> =
1202        serde_json::from_value(req.params.unwrap_or_default());
1203
1204    match params {
1205        Ok(p) => {
1206            let configs = state.webhook_store.list(&p.task_id).await;
1207
1208            let configs_json: Vec<_> = configs
1209                .iter()
1210                .map(|c| {
1211                    serde_json::json!({
1212                        "configId": c.config_id,
1213                        "config": c.config
1214                    })
1215                })
1216                .collect();
1217
1218            (
1219                StatusCode::OK,
1220                Json(success(
1221                    req.id,
1222                    serde_json::json!({
1223                        "configs": configs_json,
1224                        "nextPageToken": ""
1225                    }),
1226                )),
1227            )
1228        }
1229        Err(err) => (
1230            StatusCode::BAD_REQUEST,
1231            Json(error(
1232                req.id,
1233                errors::INVALID_PARAMS,
1234                "invalid params",
1235                Some(serde_json::json!({"error": err.to_string()})),
1236            )),
1237        ),
1238    }
1239}
1240
1241async fn handle_push_config_delete(
1242    state: AppState,
1243    req: JsonRpcRequest,
1244) -> (StatusCode, Json<JsonRpcResponse>) {
1245    if !state.push_notifications_enabled() {
1246        return (
1247            StatusCode::BAD_REQUEST,
1248            Json(error(
1249                req.id,
1250                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1251                "push notifications not supported",
1252                None,
1253            )),
1254        );
1255    }
1256
1257    let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1258        serde_json::from_value(req.params.unwrap_or_default());
1259
1260    match params {
1261        Ok(p) => {
1262            if state.webhook_store.delete(&p.task_id, &p.id).await {
1263                (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1264            } else {
1265                (
1266                    StatusCode::NOT_FOUND,
1267                    Json(error(
1268                        req.id,
1269                        errors::TASK_NOT_FOUND,
1270                        "push notification config not found",
1271                        None,
1272                    )),
1273                )
1274            }
1275        }
1276        Err(err) => (
1277            StatusCode::BAD_REQUEST,
1278            Json(error(
1279                req.id,
1280                errors::INVALID_PARAMS,
1281                "invalid params",
1282                Some(serde_json::json!({"error": err.to_string()})),
1283            )),
1284        ),
1285    }
1286}
1287
1288pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1289    bind_addr: &str,
1290    handler: H,
1291) -> anyhow::Result<()> {
1292    A2aServer::new(handler).bind(bind_addr)?.run().await
1293}
1294
1295pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1296    A2aServer::echo().bind(bind_addr)?.run().await
1297}