Skip to main content

a2a_rs_server/
server.rs

1//! Generic A2A JSON-RPC Server
2
3use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17    error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18    CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19    GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20    ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21    SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22    TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::routing::{get, post};
33use axum::{Json, Router};
34use futures::future::FutureExt;
35use futures::stream::Stream;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46    pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50    fn default() -> Self {
51        Self {
52            bind_address: "0.0.0.0:8080".to_string(),
53        }
54    }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62    config: ServerConfig,
63    handler: BoxedHandler,
64    task_store: TaskStore,
65    webhook_store: WebhookStore,
66    auth_extractor: Option<AuthExtractor>,
67    additional_routes: Option<Router<AppState>>,
68    event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72    pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74        Self {
75            config: ServerConfig::default(),
76            handler: Arc::new(handler),
77            task_store: TaskStore::new(),
78            webhook_store: WebhookStore::new(),
79            auth_extractor: None,
80            additional_routes: None,
81            event_tx,
82        }
83    }
84
85    pub fn echo() -> Self {
86        Self::new(EchoHandler::default())
87    }
88
89    pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90        let _: SocketAddr = address.parse()?;
91        self.config.bind_address = address.to_string();
92        Ok(self)
93    }
94
95    pub fn bind_unchecked(mut self, address: &str) -> Self {
96        self.config.bind_address = address.to_string();
97        self
98    }
99
100    pub fn task_store(mut self, store: TaskStore) -> Self {
101        self.task_store = store;
102        self
103    }
104
105    pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106    where
107        F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108    {
109        self.auth_extractor = Some(Arc::new(extractor));
110        self
111    }
112
113    pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114        self.additional_routes = Some(routes);
115        self
116    }
117
118    pub fn get_task_store(&self) -> TaskStore {
119        self.task_store.clone()
120    }
121
122    pub fn build_router(self) -> Router {
123        let bind: SocketAddr = self.config.bind_address.parse().expect("Invalid bind address");
124        let base_url = format!("http://{}", bind);
125        let card = Arc::new(self.handler.agent_card(&base_url));
126
127        let state = AppState {
128            handler: self.handler,
129            task_store: self.task_store,
130            webhook_store: self.webhook_store,
131            card,
132            auth_extractor: self.auth_extractor,
133            event_tx: self.event_tx,
134        };
135
136        let timed_routes = Router::new()
137            .route("/health", get(health))
138            .route("/.well-known/agent-card.json", get(agent_card))
139            .route("/v1/rpc", post(handle_rpc))
140            .layer(
141                ServiceBuilder::new()
142                    .layer(HandleErrorLayer::new(handle_timeout_error))
143                    .timeout(Duration::from_secs(30)),
144            );
145
146        let sse_routes = Router::new()
147            .route("/v1/tasks/:task_id/subscribe", get(handle_task_subscribe_sse))
148            .route("/v1/message/stream", post(handle_message_stream_sse));
149
150        let mut router = timed_routes.merge(sse_routes);
151
152        if let Some(additional) = self.additional_routes {
153            router = router.merge(additional);
154        }
155
156        router.with_state(state)
157    }
158
159    pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
160        self.event_tx.clone()
161    }
162
163    pub async fn run(self) -> anyhow::Result<()> {
164        let bind: SocketAddr = self.config.bind_address.parse()?;
165
166        let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
167        webhook_delivery.start(self.event_tx.subscribe());
168
169        let router = self.build_router();
170
171        info!(%bind, "Starting A2A server");
172        let listener = tokio::net::TcpListener::bind(bind).await?;
173        axum::serve(listener, router)
174            .with_graceful_shutdown(shutdown_signal())
175            .await?;
176
177        info!("Server shutdown complete");
178        Ok(())
179    }
180}
181
182async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
183    if err.is::<tower::timeout::error::Elapsed>() {
184        (
185            StatusCode::REQUEST_TIMEOUT,
186            Json(error(
187                serde_json::Value::Null,
188                errors::INTERNAL_ERROR,
189                "Request timed out",
190                None,
191            )),
192        )
193    } else {
194        (
195            StatusCode::INTERNAL_SERVER_ERROR,
196            Json(error(
197                serde_json::Value::Null,
198                errors::INTERNAL_ERROR,
199                &format!("Internal error: {}", err),
200                None,
201            )),
202        )
203    }
204}
205
206async fn shutdown_signal() {
207    let ctrl_c = async {
208        signal::ctrl_c()
209            .await
210            .expect("failed to install Ctrl+C handler");
211    };
212
213    #[cfg(unix)]
214    let terminate = async {
215        signal::unix::signal(signal::unix::SignalKind::terminate())
216            .expect("failed to install SIGTERM handler")
217            .recv()
218            .await;
219    };
220
221    #[cfg(not(unix))]
222    let terminate = std::future::pending::<()>();
223
224    tokio::select! {
225        _ = ctrl_c => {},
226        _ = terminate => {},
227    }
228
229    info!("Shutdown signal received, initiating graceful shutdown...");
230}
231
232#[derive(Clone)]
233pub struct AppState {
234    handler: BoxedHandler,
235    task_store: TaskStore,
236    webhook_store: WebhookStore,
237    card: Arc<AgentCard>,
238    auth_extractor: Option<AuthExtractor>,
239    event_tx: broadcast::Sender<StreamResponse>,
240}
241
242impl AppState {
243    pub fn task_store(&self) -> &TaskStore {
244        &self.task_store
245    }
246
247    pub fn agent_card(&self) -> &AgentCard {
248        &self.card
249    }
250
251    pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
252        &self.event_tx
253    }
254
255    pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
256        self.event_tx.subscribe()
257    }
258
259    pub fn broadcast_event(&self, event: StreamResponse) {
260        let _ = self.event_tx.send(event);
261    }
262
263    /// Check if streaming is enabled in capabilities
264    fn streaming_enabled(&self) -> bool {
265        self.card.capabilities.streaming.unwrap_or(false)
266    }
267
268    /// Check if push notifications are enabled in capabilities
269    fn push_notifications_enabled(&self) -> bool {
270        self.card.capabilities.push_notifications.unwrap_or(false)
271    }
272
273    /// Get the endpoint URL from the agent card
274    fn endpoint_url(&self) -> &str {
275        self.card.endpoint().unwrap_or("")
276    }
277}
278
279// ============ History Trimming ============
280
281fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
282    match history_length {
283        Some(0) => {
284            task.history = None;
285        }
286        Some(n) => {
287            if let Some(ref mut history) = task.history {
288                let len = history.len();
289                if len > n as usize {
290                    *history = history.split_off(len - n as usize);
291                }
292            }
293        }
294        None => {}
295    }
296}
297
298// ============ Version Validation ============
299
300fn validate_a2a_version(headers: &HeaderMap, req_id: &serde_json::Value) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
301    if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
302        let version_str = version_header.to_str().unwrap_or("");
303
304        let parts: Vec<&str> = version_str.split('.').collect();
305        if parts.len() >= 2 {
306            if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
307                if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
308                    return Ok(());
309                }
310
311                return Err((
312                    StatusCode::BAD_REQUEST,
313                    Json(error(
314                        req_id.clone(),
315                        errors::VERSION_NOT_SUPPORTED,
316                        &format!(
317                            "Protocol version {}.{} not supported. Supported version: {}.{}",
318                            major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
319                        ),
320                        Some(serde_json::json!({
321                            "requestedVersion": version_str,
322                            "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
323                        })),
324                    )),
325                ));
326            }
327        }
328
329        // Also accept bare "1.0" without minor
330        if version_str == "1" || version_str == "1.0" {
331            return Ok(());
332        }
333
334        return Err((
335            StatusCode::BAD_REQUEST,
336            Json(error(
337                req_id.clone(),
338                errors::VERSION_NOT_SUPPORTED,
339                &format!("Invalid version format: {}. Expected major.minor (e.g., '1.0')", version_str),
340                None,
341            )),
342        ));
343    }
344
345    Ok(())
346}
347
348// ============ Error Response Helpers ============
349
350#[allow(dead_code)]
351pub fn rpc_error(
352    id: serde_json::Value,
353    code: i32,
354    message: &str,
355    status: StatusCode,
356) -> (StatusCode, Json<JsonRpcResponse>) {
357    (status, Json(error(id, code, message, None)))
358}
359
360#[allow(dead_code)]
361pub fn rpc_error_with_data(
362    id: serde_json::Value,
363    code: i32,
364    message: &str,
365    data: serde_json::Value,
366    status: StatusCode,
367) -> (StatusCode, Json<JsonRpcResponse>) {
368    (status, Json(error(id, code, message, Some(data))))
369}
370
371#[allow(dead_code)]
372pub fn rpc_success(id: serde_json::Value, result: serde_json::Value) -> (StatusCode, Json<JsonRpcResponse>) {
373    (StatusCode::OK, Json(success(id, result)))
374}
375
376// ============ Route Handlers ============
377
378async fn health() -> Json<serde_json::Value> {
379    Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
380}
381
382async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
383    Json((*state.card).clone())
384}
385
386async fn handle_task_subscribe_sse(
387    State(state): State<AppState>,
388    axum::extract::Path(task_id): axum::extract::Path<String>,
389) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
390    let mut rx = state.subscribe_events();
391    let task_store = state.task_store.clone();
392    let target_task_id = task_id;
393
394    let stream = async_stream::stream! {
395        if let Some(task) = task_store.get_flexible(&target_task_id).await {
396            let event = StreamResponse::Task(task);
397            if let Ok(json) = serde_json::to_string(&event) {
398                yield Ok(Event::default().data(json));
399            }
400        }
401
402        loop {
403            match rx.recv().await {
404                Ok(event) => {
405                    let matches = match &event {
406                        StreamResponse::Task(t) => t.id == target_task_id,
407                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
408                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
409                        StreamResponse::Message(_) => false,
410                    };
411                    if matches {
412                        if let Ok(json) = serde_json::to_string(&event) {
413                            yield Ok(Event::default().data(json));
414                        }
415
416                        if let StreamResponse::Task(t) = &event {
417                            if t.status.state.is_terminal() {
418                                break;
419                            }
420                        }
421                        if let StreamResponse::StatusUpdate(e) = &event {
422                            if e.status.state.is_terminal() {
423                                break;
424                            }
425                        }
426                    }
427                }
428                Err(broadcast::error::RecvError::Lagged(_)) => {
429                    continue;
430                }
431                Err(broadcast::error::RecvError::Closed) => {
432                    break;
433                }
434            }
435        }
436    };
437
438    Sse::new(stream).keep_alive(KeepAlive::default())
439}
440
441async fn handle_rpc(
442    State(state): State<AppState>,
443    headers: HeaderMap,
444    Json(req): Json<JsonRpcRequest>,
445) -> (StatusCode, Json<JsonRpcResponse>) {
446    if req.jsonrpc != "2.0" {
447        let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
448        return (StatusCode::BAD_REQUEST, Json(resp));
449    }
450
451    if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
452        return err_response;
453    }
454
455    let auth_context = state
456        .auth_extractor
457        .as_ref()
458        .and_then(|extractor| extractor(&headers));
459
460    match req.method.as_str() {
461        // Spec method names per JSON-RPC binding
462        "message/send" => handle_message_send(state, req, auth_context).await,
463        "message/sendStreaming" => handle_message_stream_rpc(state, req).await,
464        "tasks/get" => handle_tasks_get(state, req).await,
465        "tasks/list" => handle_tasks_list(state, req).await,
466        "tasks/cancel" => handle_tasks_cancel(state, req).await,
467        "tasks/subscribe" => handle_tasks_subscribe(state, req).await,
468        "tasks/pushNotificationConfig/create" => handle_push_config_create(state, req).await,
469        "tasks/pushNotificationConfig/get" => handle_push_config_get(state, req).await,
470        "tasks/pushNotificationConfig/list" => handle_push_config_list(state, req).await,
471        "tasks/pushNotificationConfig/delete" => handle_push_config_delete(state, req).await,
472        "agentCard/getExtended" => {
473            handle_get_extended_agent_card(state, req, auth_context).await
474        }
475        _ => (
476            StatusCode::NOT_FOUND,
477            Json(error(
478                req.id,
479                errors::METHOD_NOT_FOUND,
480                "method not found",
481                None,
482            )),
483        ),
484    }
485}
486
487fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
488    match e {
489        HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
490        HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
491        HandlerError::BackendUnavailable { .. } => (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE),
492        HandlerError::ProcessingFailed { .. } => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
493        HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
494    }
495}
496
497async fn handle_message_send(
498    state: AppState,
499    req: JsonRpcRequest,
500    auth_context: Option<AuthContext>,
501) -> (StatusCode, Json<JsonRpcResponse>) {
502    let req_id = req.id.clone();
503
504    let params: Result<SendMessageRequest, _> =
505        serde_json::from_value(req.params.clone().unwrap_or_default());
506
507    let params = match params {
508        Ok(p) => p,
509        Err(err) => {
510            return (
511                StatusCode::BAD_REQUEST,
512                Json(error(
513                    req_id,
514                    errors::INVALID_PARAMS,
515                    "invalid params",
516                    Some(serde_json::json!({"error": err.to_string()})),
517                )),
518            );
519        }
520    };
521
522    let blocking = params
523        .configuration
524        .as_ref()
525        .and_then(|c| c.blocking)
526        .unwrap_or(false);
527    let history_length = params
528        .configuration
529        .as_ref()
530        .and_then(|c| c.history_length);
531
532    match state.handler.handle_message(params.message, auth_context).await {
533        Ok(response) => {
534            match response {
535                SendMessageResponse::Task(mut task) => {
536                    // Store the task and broadcast event
537                    state.task_store.insert(task.clone()).await;
538                    state.broadcast_event(StreamResponse::Task(task.clone()));
539
540                    // If blocking mode, wait for task to reach terminal state
541                    if blocking && !task.status.state.is_terminal() {
542                        let task_id = task.id.clone();
543                        let mut rx = state.subscribe_events();
544
545                        let wait_result = timeout(BLOCKING_TIMEOUT, async {
546                            loop {
547                                tokio::select! {
548                                    result = rx.recv() => {
549                                        match result {
550                                            Ok(StreamResponse::Task(t)) if t.id == task_id => {
551                                                if t.status.state.is_terminal() {
552                                                    return Some(t);
553                                                }
554                                            }
555                                            Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
556                                                if e.status.state.is_terminal() {
557                                                    if let Some(t) = state.task_store.get(&task_id).await {
558                                                        return Some(t);
559                                                    }
560                                                }
561                                            }
562                                            Err(broadcast::error::RecvError::Closed) => {
563                                                return None;
564                                            }
565                                            _ => {}
566                                        }
567                                    }
568                                    _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
569                                        if let Some(t) = state.task_store.get(&task_id).await {
570                                            if t.status.state.is_terminal() {
571                                                return Some(t);
572                                            }
573                                        }
574                                    }
575                                }
576                            }
577                        })
578                        .await;
579
580                        match wait_result {
581                            Ok(Some(final_task)) => task = final_task,
582                            Ok(None) => {
583                                if let Some(t) = state.task_store.get(&task.id).await {
584                                    task = t;
585                                }
586                            }
587                            Err(_) => {
588                                tracing::warn!("Blocking request timed out for task {}", task.id);
589                                if let Some(t) = state.task_store.get(&task.id).await {
590                                    task = t;
591                                }
592                            }
593                        }
594                    }
595
596                    apply_history_length(&mut task, history_length);
597
598                    // Wrap in SendMessageResponse for spec-compliant serialization
599                    let resp = SendMessageResponse::Task(task);
600                    match serde_json::to_value(resp) {
601                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
602                        Err(e) => (
603                            StatusCode::INTERNAL_SERVER_ERROR,
604                            Json(error(
605                                req_id,
606                                errors::INTERNAL_ERROR,
607                                "serialization failed",
608                                Some(serde_json::json!({"error": e.to_string()})),
609                            )),
610                        ),
611                    }
612                }
613                SendMessageResponse::Message(msg) => {
614                    let resp = SendMessageResponse::Message(msg);
615                    match serde_json::to_value(resp) {
616                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
617                        Err(e) => (
618                            StatusCode::INTERNAL_SERVER_ERROR,
619                            Json(error(
620                                req_id,
621                                errors::INTERNAL_ERROR,
622                                "serialization failed",
623                                Some(serde_json::json!({"error": e.to_string()})),
624                            )),
625                        ),
626                    }
627                }
628            }
629        }
630        Err(e) => {
631            let (code, status) = handler_error_to_rpc(&e);
632            (status, Json(error(req_id, code, &e.to_string(), None)))
633        }
634    }
635}
636
637async fn handle_message_stream_rpc(
638    state: AppState,
639    req: JsonRpcRequest,
640) -> (StatusCode, Json<JsonRpcResponse>) {
641    if !state.streaming_enabled() {
642        return (
643            StatusCode::BAD_REQUEST,
644            Json(error(
645                req.id,
646                errors::UNSUPPORTED_OPERATION,
647                "streaming not supported by this agent",
648                None,
649            )),
650        );
651    }
652
653    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
654    let stream_url = format!("{}/v1/message/stream", base_url);
655
656    (
657        StatusCode::OK,
658        Json(success(
659            req.id,
660            serde_json::json!({
661                "streamUrl": stream_url,
662                "protocol": "sse",
663                "method": "POST"
664            }),
665        )),
666    )
667}
668
669async fn handle_message_stream_sse(
670    State(state): State<AppState>,
671    headers: HeaderMap,
672    Json(params): Json<SendMessageRequest>,
673) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)> {
674    if !state.streaming_enabled() {
675        return Err((
676            StatusCode::BAD_REQUEST,
677            Json(error(
678                serde_json::Value::Null,
679                errors::UNSUPPORTED_OPERATION,
680                "streaming not supported by this agent",
681                None,
682            )),
683        ));
684    }
685
686    let auth_context = state
687        .auth_extractor
688        .as_ref()
689        .and_then(|extractor| extractor(&headers));
690
691    let response = state
692        .handler
693        .handle_message(params.message, auth_context)
694        .await
695        .map_err(|e| {
696            let (code, status) = handler_error_to_rpc(&e);
697            (status, Json(error(serde_json::Value::Null, code, &e.to_string(), None)))
698        })?;
699
700    // Extract task from response (streaming only works with tasks)
701    let task = match response {
702        SendMessageResponse::Task(t) => t,
703        SendMessageResponse::Message(_) => {
704            return Err((
705                StatusCode::BAD_REQUEST,
706                Json(error(
707                    serde_json::Value::Null,
708                    errors::UNSUPPORTED_OPERATION,
709                    "handler returned a message, streaming requires a task",
710                    None,
711                )),
712            ));
713        }
714    };
715
716    let task_id = task.id.clone();
717    state.task_store.insert(task.clone()).await;
718    state.broadcast_event(StreamResponse::Task(task.clone()));
719
720    let mut rx = state.subscribe_events();
721    let task_store = state.task_store.clone();
722    let target_task_id = task_id;
723
724    let stream = async_stream::stream! {
725        let event = StreamResponse::Task(task);
726        if let Ok(json) = serde_json::to_string(&event) {
727            yield Ok(Event::default().data(json));
728        }
729
730        loop {
731            match rx.recv().await {
732                Ok(event) => {
733                    let matches = match &event {
734                        StreamResponse::Task(t) => t.id == target_task_id,
735                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
736                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
737                        StreamResponse::Message(m) => {
738                            m.context_id.as_ref().is_some_and(|ctx| {
739                                task_store.get(&target_task_id).now_or_never()
740                                    .flatten()
741                                    .is_some_and(|t| t.context_id == *ctx)
742                            })
743                        }
744                    };
745
746                    if matches {
747                        if let Ok(json) = serde_json::to_string(&event) {
748                            yield Ok(Event::default().data(json));
749                        }
750
751                        if let StreamResponse::Task(t) = &event {
752                            if t.status.state.is_terminal() {
753                                break;
754                            }
755                        }
756                        if let StreamResponse::StatusUpdate(e) = &event {
757                            if e.status.state.is_terminal() {
758                                break;
759                            }
760                        }
761                    }
762                }
763                Err(broadcast::error::RecvError::Lagged(_)) => continue,
764                Err(broadcast::error::RecvError::Closed) => break,
765            }
766        }
767    };
768
769    Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
770}
771
772async fn handle_tasks_get(
773    state: AppState,
774    req: JsonRpcRequest,
775) -> (StatusCode, Json<JsonRpcResponse>) {
776    let params: Result<GetTaskRequest, _> =
777        serde_json::from_value(req.params.unwrap_or_default());
778
779    match params {
780        Ok(p) => {
781            match state.task_store.get_flexible(&p.id).await {
782                Some(mut task) => {
783                    apply_history_length(&mut task, p.history_length);
784
785                    match serde_json::to_value(task) {
786                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
787                        Err(e) => (
788                            StatusCode::INTERNAL_SERVER_ERROR,
789                            Json(error(
790                                req.id,
791                                errors::INTERNAL_ERROR,
792                                "serialization failed",
793                                Some(serde_json::json!({"error": e.to_string()})),
794                            )),
795                        ),
796                    }
797                }
798                None => (
799                    StatusCode::NOT_FOUND,
800                    Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
801                ),
802            }
803        }
804        Err(err) => (
805            StatusCode::BAD_REQUEST,
806            Json(error(
807                req.id,
808                errors::INVALID_PARAMS,
809                "invalid params",
810                Some(serde_json::json!({"error": err.to_string()})),
811            )),
812        ),
813    }
814}
815
816async fn handle_tasks_cancel(
817    state: AppState,
818    req: JsonRpcRequest,
819) -> (StatusCode, Json<JsonRpcResponse>) {
820    let params: Result<CancelTaskRequest, _> =
821        serde_json::from_value(req.params.unwrap_or_default());
822
823    match params {
824        Ok(p) => {
825            let result = state
826                .task_store
827                .update_flexible(&p.id, |task| {
828                    if task.status.state.is_terminal() {
829                        return Err(errors::TASK_NOT_CANCELABLE);
830                    }
831                    task.status.state = TaskState::Canceled;
832                    task.status.timestamp = Some(now_iso8601());
833                    Ok(())
834                })
835                .await;
836
837            match result {
838                Some(Ok(task)) => {
839                    if let Err(e) = state.handler.cancel_task(&task.id).await {
840                        tracing::warn!("Handler cancel_task failed: {}", e);
841                    }
842
843                    state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
844                        task_id: task.id.clone(),
845                        context_id: task.context_id.clone(),
846                        status: task.status.clone(),
847                        metadata: None,
848                    }));
849
850                    match serde_json::to_value(task) {
851                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
852                        Err(e) => (
853                            StatusCode::INTERNAL_SERVER_ERROR,
854                            Json(error(
855                                req.id,
856                                errors::INTERNAL_ERROR,
857                                "serialization failed",
858                                Some(serde_json::json!({"error": e.to_string()})),
859                            )),
860                        ),
861                    }
862                }
863                Some(Err(error_code)) => (
864                    StatusCode::BAD_REQUEST,
865                    Json(error(req.id, error_code, "task not cancelable", None)),
866                ),
867                None => (
868                    StatusCode::NOT_FOUND,
869                    Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
870                ),
871            }
872        }
873        Err(err) => (
874            StatusCode::BAD_REQUEST,
875            Json(error(
876                req.id,
877                errors::INVALID_PARAMS,
878                "invalid params",
879                Some(serde_json::json!({"error": err.to_string()})),
880            )),
881        ),
882    }
883}
884
885async fn handle_tasks_list(
886    state: AppState,
887    req: JsonRpcRequest,
888) -> (StatusCode, Json<JsonRpcResponse>) {
889    let params: Result<ListTasksRequest, _> =
890        serde_json::from_value(req.params.unwrap_or_default());
891
892    match params {
893        Ok(p) => {
894            let response = state.task_store.list_filtered(&p).await;
895            match serde_json::to_value(response) {
896                Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
897                Err(e) => (
898                    StatusCode::INTERNAL_SERVER_ERROR,
899                    Json(error(
900                        req.id,
901                        errors::INTERNAL_ERROR,
902                        "serialization failed",
903                        Some(serde_json::json!({"error": e.to_string()})),
904                    )),
905                ),
906            }
907        }
908        Err(err) => (
909            StatusCode::BAD_REQUEST,
910            Json(error(
911                req.id,
912                errors::INVALID_PARAMS,
913                "invalid params",
914                Some(serde_json::json!({"error": err.to_string()})),
915            )),
916        ),
917    }
918}
919
920async fn handle_tasks_subscribe(
921    state: AppState,
922    req: JsonRpcRequest,
923) -> (StatusCode, Json<JsonRpcResponse>) {
924    let params: Result<SubscribeToTaskRequest, _> =
925        serde_json::from_value(req.params.unwrap_or_default());
926
927    match params {
928        Ok(p) => {
929            if state.task_store.get_flexible(&p.id).await.is_none() {
930                return (
931                    StatusCode::NOT_FOUND,
932                    Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
933                );
934            }
935
936            let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
937            let subscribe_url = format!("{}/v1/tasks/{}/subscribe", base_url, p.id);
938
939            (
940                StatusCode::OK,
941                Json(success(
942                    req.id,
943                    serde_json::json!({
944                        "subscribeUrl": subscribe_url,
945                        "protocol": "sse"
946                    }),
947                )),
948            )
949        }
950        Err(err) => (
951            StatusCode::BAD_REQUEST,
952            Json(error(
953                req.id,
954                errors::INVALID_PARAMS,
955                "invalid params",
956                Some(serde_json::json!({"error": err.to_string()})),
957            )),
958        ),
959    }
960}
961
962async fn handle_get_extended_agent_card(
963    state: AppState,
964    req: JsonRpcRequest,
965    auth_context: Option<AuthContext>,
966) -> (StatusCode, Json<JsonRpcResponse>) {
967    let Some(auth) = auth_context else {
968        return (
969            StatusCode::UNAUTHORIZED,
970            Json(error(
971                req.id,
972                errors::INVALID_REQUEST,
973                "authentication required for extended agent card",
974                None,
975            )),
976        );
977    };
978
979    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
980
981    match state.handler.extended_agent_card(base_url, &auth).await {
982        Some(card) => match serde_json::to_value(card) {
983            Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
984            Err(e) => (
985                StatusCode::INTERNAL_SERVER_ERROR,
986                Json(error(
987                    req.id,
988                    errors::INTERNAL_ERROR,
989                    "serialization failed",
990                    Some(serde_json::json!({"error": e.to_string()})),
991                )),
992            ),
993        },
994        None => (
995            StatusCode::NOT_FOUND,
996            Json(error(
997                req.id,
998                errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
999                "extended agent card not configured",
1000                None,
1001            )),
1002        ),
1003    }
1004}
1005
1006// ============ Push Notification Config Handlers ============
1007
1008async fn handle_push_config_create(
1009    state: AppState,
1010    req: JsonRpcRequest,
1011) -> (StatusCode, Json<JsonRpcResponse>) {
1012    if !state.push_notifications_enabled() {
1013        return (
1014            StatusCode::BAD_REQUEST,
1015            Json(error(
1016                req.id,
1017                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1018                "push notifications not supported",
1019                None,
1020            )),
1021        );
1022    }
1023
1024    let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1025        serde_json::from_value(req.params.unwrap_or_default());
1026
1027    match params {
1028        Ok(p) => {
1029            if state.task_store.get_flexible(&p.task_id).await.is_none() {
1030                return (
1031                    StatusCode::NOT_FOUND,
1032                    Json(error(req.id, errors::TASK_NOT_FOUND, "task not found", None)),
1033                );
1034            }
1035
1036            if let Err(e) = state
1037                .webhook_store
1038                .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1039                .await
1040            {
1041                return (
1042                    StatusCode::BAD_REQUEST,
1043                    Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1044                );
1045            }
1046
1047            (
1048                StatusCode::OK,
1049                Json(success(
1050                    req.id,
1051                    serde_json::json!({
1052                        "configId": p.config_id,
1053                        "config": p.push_notification_config
1054                    }),
1055                )),
1056            )
1057        }
1058        Err(err) => (
1059            StatusCode::BAD_REQUEST,
1060            Json(error(
1061                req.id,
1062                errors::INVALID_PARAMS,
1063                "invalid params",
1064                Some(serde_json::json!({"error": err.to_string()})),
1065            )),
1066        ),
1067    }
1068}
1069
1070async fn handle_push_config_get(
1071    state: AppState,
1072    req: JsonRpcRequest,
1073) -> (StatusCode, Json<JsonRpcResponse>) {
1074    if !state.push_notifications_enabled() {
1075        return (
1076            StatusCode::BAD_REQUEST,
1077            Json(error(
1078                req.id,
1079                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1080                "push notifications not supported",
1081                None,
1082            )),
1083        );
1084    }
1085
1086    let params: Result<GetTaskPushNotificationConfigRequest, _> =
1087        serde_json::from_value(req.params.unwrap_or_default());
1088
1089    match params {
1090        Ok(p) => {
1091            match state.webhook_store.get(&p.task_id, &p.id).await {
1092                Some(config) => (
1093                    StatusCode::OK,
1094                    Json(success(
1095                        req.id,
1096                        serde_json::json!({
1097                            "configId": p.id,
1098                            "config": config
1099                        }),
1100                    )),
1101                ),
1102                None => (
1103                    StatusCode::NOT_FOUND,
1104                    Json(error(req.id, errors::TASK_NOT_FOUND, "push notification config not found", None)),
1105                ),
1106            }
1107        }
1108        Err(err) => (
1109            StatusCode::BAD_REQUEST,
1110            Json(error(
1111                req.id,
1112                errors::INVALID_PARAMS,
1113                "invalid params",
1114                Some(serde_json::json!({"error": err.to_string()})),
1115            )),
1116        ),
1117    }
1118}
1119
1120async fn handle_push_config_list(
1121    state: AppState,
1122    req: JsonRpcRequest,
1123) -> (StatusCode, Json<JsonRpcResponse>) {
1124    if !state.push_notifications_enabled() {
1125        return (
1126            StatusCode::BAD_REQUEST,
1127            Json(error(
1128                req.id,
1129                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1130                "push notifications not supported",
1131                None,
1132            )),
1133        );
1134    }
1135
1136    let params: Result<ListTaskPushNotificationConfigRequest, _> =
1137        serde_json::from_value(req.params.unwrap_or_default());
1138
1139    match params {
1140        Ok(p) => {
1141            let configs = state.webhook_store.list(&p.task_id).await;
1142
1143            let configs_json: Vec<_> = configs
1144                .iter()
1145                .map(|c| {
1146                    serde_json::json!({
1147                        "configId": c.config_id,
1148                        "config": c.config
1149                    })
1150                })
1151                .collect();
1152
1153            (
1154                StatusCode::OK,
1155                Json(success(
1156                    req.id,
1157                    serde_json::json!({
1158                        "configs": configs_json,
1159                        "nextPageToken": ""
1160                    }),
1161                )),
1162            )
1163        }
1164        Err(err) => (
1165            StatusCode::BAD_REQUEST,
1166            Json(error(
1167                req.id,
1168                errors::INVALID_PARAMS,
1169                "invalid params",
1170                Some(serde_json::json!({"error": err.to_string()})),
1171            )),
1172        ),
1173    }
1174}
1175
1176async fn handle_push_config_delete(
1177    state: AppState,
1178    req: JsonRpcRequest,
1179) -> (StatusCode, Json<JsonRpcResponse>) {
1180    if !state.push_notifications_enabled() {
1181        return (
1182            StatusCode::BAD_REQUEST,
1183            Json(error(
1184                req.id,
1185                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1186                "push notifications not supported",
1187                None,
1188            )),
1189        );
1190    }
1191
1192    let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1193        serde_json::from_value(req.params.unwrap_or_default());
1194
1195    match params {
1196        Ok(p) => {
1197            if state.webhook_store.delete(&p.task_id, &p.id).await {
1198                (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1199            } else {
1200                (
1201                    StatusCode::NOT_FOUND,
1202                    Json(error(req.id, errors::TASK_NOT_FOUND, "push notification config not found", None)),
1203                )
1204            }
1205        }
1206        Err(err) => (
1207            StatusCode::BAD_REQUEST,
1208            Json(error(
1209                req.id,
1210                errors::INVALID_PARAMS,
1211                "invalid params",
1212                Some(serde_json::json!({"error": err.to_string()})),
1213            )),
1214        ),
1215    }
1216}
1217
1218pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1219    bind_addr: &str,
1220    handler: H,
1221) -> anyhow::Result<()> {
1222    A2aServer::new(handler).bind(bind_addr)?.run().await
1223}
1224
1225pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1226    A2aServer::echo().bind(bind_addr)?.run().await
1227}