Skip to main content

a2a_rs_server/
server.rs

1//! Generic A2A JSON-RPC Server
2
3use std::convert::Infallible;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::error_handling::HandleErrorLayer;
9use tokio::signal;
10use tokio::time::timeout;
11use tower::ServiceBuilder;
12
13const BLOCKING_TIMEOUT: Duration = Duration::from_secs(300);
14const BLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(100);
15
16use a2a_rs_core::{
17    error, errors, now_iso8601, success, AgentCard, CancelTaskRequest,
18    CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest,
19    GetTaskPushNotificationConfigRequest, GetTaskRequest, JsonRpcRequest, JsonRpcResponse,
20    ListTaskPushNotificationConfigRequest, ListTasksRequest, SendMessageRequest,
21    SendMessageResponse, StreamResponse, SubscribeToTaskRequest, Task, TaskState,
22    TaskStatusUpdateEvent, PROTOCOL_VERSION,
23};
24
25const A2A_VERSION_HEADER: &str = "A2A-Version";
26const SUPPORTED_VERSION_MAJOR: u32 = 1;
27const SUPPORTED_VERSION_MINOR: u32 = 0;
28
29use axum::extract::State;
30use axum::http::{HeaderMap, StatusCode};
31use axum::response::sse::{Event, KeepAlive, Sse};
32use axum::routing::{get, post};
33use axum::{Json, Router};
34use futures::future::FutureExt;
35use futures::stream::Stream;
36use tokio::sync::broadcast;
37use tracing::info;
38
39use crate::handler::{AuthContext, BoxedHandler, EchoHandler, HandlerError};
40use crate::task_store::TaskStore;
41use crate::webhook_delivery::WebhookDelivery;
42use crate::webhook_store::WebhookStore;
43
44#[derive(Debug, Clone)]
45pub struct ServerConfig {
46    pub bind_address: String,
47}
48
49impl Default for ServerConfig {
50    fn default() -> Self {
51        Self {
52            bind_address: "0.0.0.0:8080".to_string(),
53        }
54    }
55}
56
57pub type AuthExtractor = Arc<dyn Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync>;
58
59const EVENT_CHANNEL_CAPACITY: usize = 1024;
60
61pub struct A2aServer {
62    config: ServerConfig,
63    handler: BoxedHandler,
64    task_store: TaskStore,
65    webhook_store: WebhookStore,
66    auth_extractor: Option<AuthExtractor>,
67    additional_routes: Option<Router<AppState>>,
68    event_tx: broadcast::Sender<StreamResponse>,
69}
70
71impl A2aServer {
72    pub fn new(handler: impl crate::handler::MessageHandler + 'static) -> Self {
73        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
74        Self {
75            config: ServerConfig::default(),
76            handler: Arc::new(handler),
77            task_store: TaskStore::new(),
78            webhook_store: WebhookStore::new(),
79            auth_extractor: None,
80            additional_routes: None,
81            event_tx,
82        }
83    }
84
85    pub fn echo() -> Self {
86        Self::new(EchoHandler::default())
87    }
88
89    pub fn bind(mut self, address: &str) -> Result<Self, std::net::AddrParseError> {
90        let _: SocketAddr = address.parse()?;
91        self.config.bind_address = address.to_string();
92        Ok(self)
93    }
94
95    pub fn bind_unchecked(mut self, address: &str) -> Self {
96        self.config.bind_address = address.to_string();
97        self
98    }
99
100    pub fn task_store(mut self, store: TaskStore) -> Self {
101        self.task_store = store;
102        self
103    }
104
105    pub fn auth_extractor<F>(mut self, extractor: F) -> Self
106    where
107        F: Fn(&HeaderMap) -> Option<AuthContext> + Send + Sync + 'static,
108    {
109        self.auth_extractor = Some(Arc::new(extractor));
110        self
111    }
112
113    pub fn additional_routes(mut self, routes: Router<AppState>) -> Self {
114        self.additional_routes = Some(routes);
115        self
116    }
117
118    pub fn get_task_store(&self) -> TaskStore {
119        self.task_store.clone()
120    }
121
122    pub fn build_router(self) -> Router {
123        let bind: SocketAddr = self
124            .config
125            .bind_address
126            .parse()
127            .expect("Invalid bind address");
128        let base_url = format!("http://{}", bind);
129        let card = Arc::new(self.handler.agent_card(&base_url));
130
131        let state = AppState {
132            handler: self.handler,
133            task_store: self.task_store,
134            webhook_store: self.webhook_store,
135            card,
136            auth_extractor: self.auth_extractor,
137            event_tx: self.event_tx,
138        };
139
140        let timed_routes = Router::new()
141            .route("/health", get(health))
142            .route("/.well-known/agent-card.json", get(agent_card))
143            .route("/v1/rpc", post(handle_rpc))
144            .layer(
145                ServiceBuilder::new()
146                    .layer(HandleErrorLayer::new(handle_timeout_error))
147                    .timeout(Duration::from_secs(30)),
148            );
149
150        let sse_routes = Router::new()
151            .route(
152                "/v1/tasks/:task_id/subscribe",
153                get(handle_task_subscribe_sse),
154            )
155            .route("/v1/message/stream", post(handle_message_stream_sse));
156
157        let mut router = timed_routes.merge(sse_routes);
158
159        if let Some(additional) = self.additional_routes {
160            router = router.merge(additional);
161        }
162
163        router.with_state(state)
164    }
165
166    pub fn get_event_sender(&self) -> broadcast::Sender<StreamResponse> {
167        self.event_tx.clone()
168    }
169
170    pub async fn run(self) -> anyhow::Result<()> {
171        let bind: SocketAddr = self.config.bind_address.parse()?;
172
173        let webhook_delivery = Arc::new(WebhookDelivery::new(self.webhook_store.clone()));
174        webhook_delivery.start(self.event_tx.subscribe());
175
176        let router = self.build_router();
177
178        info!(%bind, "Starting A2A server");
179        let listener = tokio::net::TcpListener::bind(bind).await?;
180        axum::serve(listener, router)
181            .with_graceful_shutdown(shutdown_signal())
182            .await?;
183
184        info!("Server shutdown complete");
185        Ok(())
186    }
187}
188
189async fn handle_timeout_error(err: tower::BoxError) -> (StatusCode, Json<JsonRpcResponse>) {
190    if err.is::<tower::timeout::error::Elapsed>() {
191        (
192            StatusCode::REQUEST_TIMEOUT,
193            Json(error(
194                serde_json::Value::Null,
195                errors::INTERNAL_ERROR,
196                "Request timed out",
197                None,
198            )),
199        )
200    } else {
201        (
202            StatusCode::INTERNAL_SERVER_ERROR,
203            Json(error(
204                serde_json::Value::Null,
205                errors::INTERNAL_ERROR,
206                &format!("Internal error: {}", err),
207                None,
208            )),
209        )
210    }
211}
212
213async fn shutdown_signal() {
214    let ctrl_c = async {
215        signal::ctrl_c()
216            .await
217            .expect("failed to install Ctrl+C handler");
218    };
219
220    #[cfg(unix)]
221    let terminate = async {
222        signal::unix::signal(signal::unix::SignalKind::terminate())
223            .expect("failed to install SIGTERM handler")
224            .recv()
225            .await;
226    };
227
228    #[cfg(not(unix))]
229    let terminate = std::future::pending::<()>();
230
231    tokio::select! {
232        _ = ctrl_c => {},
233        _ = terminate => {},
234    }
235
236    info!("Shutdown signal received, initiating graceful shutdown...");
237}
238
239#[derive(Clone)]
240pub struct AppState {
241    handler: BoxedHandler,
242    task_store: TaskStore,
243    webhook_store: WebhookStore,
244    card: Arc<AgentCard>,
245    auth_extractor: Option<AuthExtractor>,
246    event_tx: broadcast::Sender<StreamResponse>,
247}
248
249impl AppState {
250    pub fn task_store(&self) -> &TaskStore {
251        &self.task_store
252    }
253
254    pub fn agent_card(&self) -> &AgentCard {
255        &self.card
256    }
257
258    pub fn event_sender(&self) -> &broadcast::Sender<StreamResponse> {
259        &self.event_tx
260    }
261
262    pub fn subscribe_events(&self) -> broadcast::Receiver<StreamResponse> {
263        self.event_tx.subscribe()
264    }
265
266    pub fn broadcast_event(&self, event: StreamResponse) {
267        let _ = self.event_tx.send(event);
268    }
269
270    /// Check if streaming is enabled in capabilities
271    fn streaming_enabled(&self) -> bool {
272        self.card.capabilities.streaming.unwrap_or(false)
273    }
274
275    /// Check if push notifications are enabled in capabilities
276    fn push_notifications_enabled(&self) -> bool {
277        self.card.capabilities.push_notifications.unwrap_or(false)
278    }
279
280    /// Get the endpoint URL from the agent card
281    fn endpoint_url(&self) -> &str {
282        self.card.endpoint().unwrap_or("")
283    }
284}
285
286// ============ History Trimming ============
287
288fn apply_history_length(task: &mut Task, history_length: Option<u32>) {
289    match history_length {
290        Some(0) => {
291            task.history = None;
292        }
293        Some(n) => {
294            if let Some(ref mut history) = task.history {
295                let len = history.len();
296                if len > n as usize {
297                    *history = history.split_off(len - n as usize);
298                }
299            }
300        }
301        None => {}
302    }
303}
304
305// ============ Version Validation ============
306
307fn validate_a2a_version(
308    headers: &HeaderMap,
309    req_id: &serde_json::Value,
310) -> Result<(), (StatusCode, Json<JsonRpcResponse>)> {
311    if let Some(version_header) = headers.get(A2A_VERSION_HEADER) {
312        let version_str = version_header.to_str().unwrap_or("");
313
314        let parts: Vec<&str> = version_str.split('.').collect();
315        if parts.len() >= 2 {
316            if let (Ok(major), Ok(minor)) = (parts[0].parse::<u32>(), parts[1].parse::<u32>()) {
317                if major == SUPPORTED_VERSION_MAJOR && minor == SUPPORTED_VERSION_MINOR {
318                    return Ok(());
319                }
320
321                return Err((
322                    StatusCode::BAD_REQUEST,
323                    Json(error(
324                        req_id.clone(),
325                        errors::VERSION_NOT_SUPPORTED,
326                        &format!(
327                            "Protocol version {}.{} not supported. Supported version: {}.{}",
328                            major, minor, SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR
329                        ),
330                        Some(serde_json::json!({
331                            "requestedVersion": version_str,
332                            "supportedVersion": format!("{}.{}", SUPPORTED_VERSION_MAJOR, SUPPORTED_VERSION_MINOR)
333                        })),
334                    )),
335                ));
336            }
337        }
338
339        // Also accept bare "1.0" without minor
340        if version_str == "1" || version_str == "1.0" {
341            return Ok(());
342        }
343
344        return Err((
345            StatusCode::BAD_REQUEST,
346            Json(error(
347                req_id.clone(),
348                errors::VERSION_NOT_SUPPORTED,
349                &format!(
350                    "Invalid version format: {}. Expected major.minor (e.g., '1.0')",
351                    version_str
352                ),
353                None,
354            )),
355        ));
356    }
357
358    Ok(())
359}
360
361// ============ Error Response Helpers ============
362
363#[allow(dead_code)]
364pub fn rpc_error(
365    id: serde_json::Value,
366    code: i32,
367    message: &str,
368    status: StatusCode,
369) -> (StatusCode, Json<JsonRpcResponse>) {
370    (status, Json(error(id, code, message, None)))
371}
372
373#[allow(dead_code)]
374pub fn rpc_error_with_data(
375    id: serde_json::Value,
376    code: i32,
377    message: &str,
378    data: serde_json::Value,
379    status: StatusCode,
380) -> (StatusCode, Json<JsonRpcResponse>) {
381    (status, Json(error(id, code, message, Some(data))))
382}
383
384#[allow(dead_code)]
385pub fn rpc_success(
386    id: serde_json::Value,
387    result: serde_json::Value,
388) -> (StatusCode, Json<JsonRpcResponse>) {
389    (StatusCode::OK, Json(success(id, result)))
390}
391
392// ============ Route Handlers ============
393
394async fn health() -> Json<serde_json::Value> {
395    Json(serde_json::json!({"status": "ok", "protocol": PROTOCOL_VERSION}))
396}
397
398async fn agent_card(State(state): State<AppState>) -> Json<AgentCard> {
399    Json((*state.card).clone())
400}
401
402async fn handle_task_subscribe_sse(
403    State(state): State<AppState>,
404    axum::extract::Path(task_id): axum::extract::Path<String>,
405) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
406    let mut rx = state.subscribe_events();
407    let task_store = state.task_store.clone();
408    let target_task_id = task_id;
409
410    let stream = async_stream::stream! {
411        if let Some(task) = task_store.get_flexible(&target_task_id).await {
412            let event = StreamResponse::Task(task);
413            if let Ok(json) = serde_json::to_string(&event) {
414                yield Ok(Event::default().data(json));
415            }
416        }
417
418        loop {
419            match rx.recv().await {
420                Ok(event) => {
421                    let matches = match &event {
422                        StreamResponse::Task(t) => t.id == target_task_id,
423                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
424                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
425                        StreamResponse::Message(_) => false,
426                    };
427                    if matches {
428                        if let Ok(json) = serde_json::to_string(&event) {
429                            yield Ok(Event::default().data(json));
430                        }
431
432                        if let StreamResponse::Task(t) = &event {
433                            if t.status.state.is_terminal() {
434                                break;
435                            }
436                        }
437                        if let StreamResponse::StatusUpdate(e) = &event {
438                            if e.status.state.is_terminal() {
439                                break;
440                            }
441                        }
442                    }
443                }
444                Err(broadcast::error::RecvError::Lagged(_)) => {
445                    continue;
446                }
447                Err(broadcast::error::RecvError::Closed) => {
448                    break;
449                }
450            }
451        }
452    };
453
454    Sse::new(stream).keep_alive(KeepAlive::default())
455}
456
457async fn handle_rpc(
458    State(state): State<AppState>,
459    headers: HeaderMap,
460    Json(req): Json<JsonRpcRequest>,
461) -> (StatusCode, Json<JsonRpcResponse>) {
462    if req.jsonrpc != "2.0" {
463        let resp = error(req.id, errors::INVALID_REQUEST, "jsonrpc must be 2.0", None);
464        return (StatusCode::BAD_REQUEST, Json(resp));
465    }
466
467    if let Err(err_response) = validate_a2a_version(&headers, &req.id) {
468        return err_response;
469    }
470
471    let auth_context = state
472        .auth_extractor
473        .as_ref()
474        .and_then(|extractor| extractor(&headers));
475
476    match req.method.as_str() {
477        // Spec method names per JSON-RPC binding
478        "message/send" => handle_message_send(state, req, auth_context).await,
479        "message/sendStreaming" => handle_message_stream_rpc(state, req).await,
480        "tasks/get" => handle_tasks_get(state, req).await,
481        "tasks/list" => handle_tasks_list(state, req).await,
482        "tasks/cancel" => handle_tasks_cancel(state, req).await,
483        "tasks/subscribe" => handle_tasks_subscribe(state, req).await,
484        "tasks/pushNotificationConfig/create" => handle_push_config_create(state, req).await,
485        "tasks/pushNotificationConfig/get" => handle_push_config_get(state, req).await,
486        "tasks/pushNotificationConfig/list" => handle_push_config_list(state, req).await,
487        "tasks/pushNotificationConfig/delete" => handle_push_config_delete(state, req).await,
488        "agentCard/getExtended" => handle_get_extended_agent_card(state, req, auth_context).await,
489        _ => (
490            StatusCode::NOT_FOUND,
491            Json(error(
492                req.id,
493                errors::METHOD_NOT_FOUND,
494                "method not found",
495                None,
496            )),
497        ),
498    }
499}
500
501fn handler_error_to_rpc(e: &HandlerError) -> (i32, StatusCode) {
502    match e {
503        HandlerError::InvalidInput(_) => (errors::INVALID_PARAMS, StatusCode::BAD_REQUEST),
504        HandlerError::AuthRequired(_) => (errors::INVALID_REQUEST, StatusCode::UNAUTHORIZED),
505        HandlerError::BackendUnavailable { .. } => {
506            (errors::INTERNAL_ERROR, StatusCode::SERVICE_UNAVAILABLE)
507        }
508        HandlerError::ProcessingFailed { .. } => {
509            (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR)
510        }
511        HandlerError::Internal(_) => (errors::INTERNAL_ERROR, StatusCode::INTERNAL_SERVER_ERROR),
512    }
513}
514
515async fn handle_message_send(
516    state: AppState,
517    req: JsonRpcRequest,
518    auth_context: Option<AuthContext>,
519) -> (StatusCode, Json<JsonRpcResponse>) {
520    let req_id = req.id.clone();
521
522    let params: Result<SendMessageRequest, _> =
523        serde_json::from_value(req.params.clone().unwrap_or_default());
524
525    let params = match params {
526        Ok(p) => p,
527        Err(err) => {
528            return (
529                StatusCode::BAD_REQUEST,
530                Json(error(
531                    req_id,
532                    errors::INVALID_PARAMS,
533                    "invalid params",
534                    Some(serde_json::json!({"error": err.to_string()})),
535                )),
536            );
537        }
538    };
539
540    let blocking = params
541        .configuration
542        .as_ref()
543        .and_then(|c| c.blocking)
544        .unwrap_or(false);
545    let history_length = params.configuration.as_ref().and_then(|c| c.history_length);
546
547    match state
548        .handler
549        .handle_message(params.message, auth_context)
550        .await
551    {
552        Ok(response) => {
553            match response {
554                SendMessageResponse::Task(mut task) => {
555                    // Store the task and broadcast event
556                    state.task_store.insert(task.clone()).await;
557                    state.broadcast_event(StreamResponse::Task(task.clone()));
558
559                    // If blocking mode, wait for task to reach terminal state
560                    if blocking && !task.status.state.is_terminal() {
561                        let task_id = task.id.clone();
562                        let mut rx = state.subscribe_events();
563
564                        let wait_result = timeout(BLOCKING_TIMEOUT, async {
565                            loop {
566                                tokio::select! {
567                                    result = rx.recv() => {
568                                        match result {
569                                            Ok(StreamResponse::Task(t)) if t.id == task_id => {
570                                                if t.status.state.is_terminal() {
571                                                    return Some(t);
572                                                }
573                                            }
574                                            Ok(StreamResponse::StatusUpdate(e)) if e.task_id == task_id => {
575                                                if e.status.state.is_terminal() {
576                                                    if let Some(t) = state.task_store.get(&task_id).await {
577                                                        return Some(t);
578                                                    }
579                                                }
580                                            }
581                                            Err(broadcast::error::RecvError::Closed) => {
582                                                return None;
583                                            }
584                                            _ => {}
585                                        }
586                                    }
587                                    _ = tokio::time::sleep(BLOCKING_POLL_INTERVAL) => {
588                                        if let Some(t) = state.task_store.get(&task_id).await {
589                                            if t.status.state.is_terminal() {
590                                                return Some(t);
591                                            }
592                                        }
593                                    }
594                                }
595                            }
596                        })
597                        .await;
598
599                        match wait_result {
600                            Ok(Some(final_task)) => task = final_task,
601                            Ok(None) => {
602                                if let Some(t) = state.task_store.get(&task.id).await {
603                                    task = t;
604                                }
605                            }
606                            Err(_) => {
607                                tracing::warn!("Blocking request timed out for task {}", task.id);
608                                if let Some(t) = state.task_store.get(&task.id).await {
609                                    task = t;
610                                }
611                            }
612                        }
613                    }
614
615                    apply_history_length(&mut task, history_length);
616
617                    // Wrap in SendMessageResponse for spec-compliant serialization
618                    let resp = SendMessageResponse::Task(task);
619                    match serde_json::to_value(resp) {
620                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
621                        Err(e) => (
622                            StatusCode::INTERNAL_SERVER_ERROR,
623                            Json(error(
624                                req_id,
625                                errors::INTERNAL_ERROR,
626                                "serialization failed",
627                                Some(serde_json::json!({"error": e.to_string()})),
628                            )),
629                        ),
630                    }
631                }
632                SendMessageResponse::Message(msg) => {
633                    let resp = SendMessageResponse::Message(msg);
634                    match serde_json::to_value(resp) {
635                        Ok(val) => (StatusCode::OK, Json(success(req_id, val))),
636                        Err(e) => (
637                            StatusCode::INTERNAL_SERVER_ERROR,
638                            Json(error(
639                                req_id,
640                                errors::INTERNAL_ERROR,
641                                "serialization failed",
642                                Some(serde_json::json!({"error": e.to_string()})),
643                            )),
644                        ),
645                    }
646                }
647            }
648        }
649        Err(e) => {
650            let (code, status) = handler_error_to_rpc(&e);
651            (status, Json(error(req_id, code, &e.to_string(), None)))
652        }
653    }
654}
655
656async fn handle_message_stream_rpc(
657    state: AppState,
658    req: JsonRpcRequest,
659) -> (StatusCode, Json<JsonRpcResponse>) {
660    if !state.streaming_enabled() {
661        return (
662            StatusCode::BAD_REQUEST,
663            Json(error(
664                req.id,
665                errors::UNSUPPORTED_OPERATION,
666                "streaming not supported by this agent",
667                None,
668            )),
669        );
670    }
671
672    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
673    let stream_url = format!("{}/v1/message/stream", base_url);
674
675    (
676        StatusCode::OK,
677        Json(success(
678            req.id,
679            serde_json::json!({
680                "streamUrl": stream_url,
681                "protocol": "sse",
682                "method": "POST"
683            }),
684        )),
685    )
686}
687
688async fn handle_message_stream_sse(
689    State(state): State<AppState>,
690    headers: HeaderMap,
691    Json(params): Json<SendMessageRequest>,
692) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<JsonRpcResponse>)>
693{
694    if !state.streaming_enabled() {
695        return Err((
696            StatusCode::BAD_REQUEST,
697            Json(error(
698                serde_json::Value::Null,
699                errors::UNSUPPORTED_OPERATION,
700                "streaming not supported by this agent",
701                None,
702            )),
703        ));
704    }
705
706    let auth_context = state
707        .auth_extractor
708        .as_ref()
709        .and_then(|extractor| extractor(&headers));
710
711    let response = state
712        .handler
713        .handle_message(params.message, auth_context)
714        .await
715        .map_err(|e| {
716            let (code, status) = handler_error_to_rpc(&e);
717            (
718                status,
719                Json(error(serde_json::Value::Null, code, &e.to_string(), None)),
720            )
721        })?;
722
723    // Extract task from response (streaming only works with tasks)
724    let task = match response {
725        SendMessageResponse::Task(t) => t,
726        SendMessageResponse::Message(_) => {
727            return Err((
728                StatusCode::BAD_REQUEST,
729                Json(error(
730                    serde_json::Value::Null,
731                    errors::UNSUPPORTED_OPERATION,
732                    "handler returned a message, streaming requires a task",
733                    None,
734                )),
735            ));
736        }
737    };
738
739    let task_id = task.id.clone();
740    state.task_store.insert(task.clone()).await;
741    state.broadcast_event(StreamResponse::Task(task.clone()));
742
743    let mut rx = state.subscribe_events();
744    let task_store = state.task_store.clone();
745    let target_task_id = task_id;
746
747    let stream = async_stream::stream! {
748        let event = StreamResponse::Task(task);
749        if let Ok(json) = serde_json::to_string(&event) {
750            yield Ok(Event::default().data(json));
751        }
752
753        loop {
754            match rx.recv().await {
755                Ok(event) => {
756                    let matches = match &event {
757                        StreamResponse::Task(t) => t.id == target_task_id,
758                        StreamResponse::StatusUpdate(e) => e.task_id == target_task_id,
759                        StreamResponse::ArtifactUpdate(e) => e.task_id == target_task_id,
760                        StreamResponse::Message(m) => {
761                            m.context_id.as_ref().is_some_and(|ctx| {
762                                task_store.get(&target_task_id).now_or_never()
763                                    .flatten()
764                                    .is_some_and(|t| t.context_id == *ctx)
765                            })
766                        }
767                    };
768
769                    if matches {
770                        if let Ok(json) = serde_json::to_string(&event) {
771                            yield Ok(Event::default().data(json));
772                        }
773
774                        if let StreamResponse::Task(t) = &event {
775                            if t.status.state.is_terminal() {
776                                break;
777                            }
778                        }
779                        if let StreamResponse::StatusUpdate(e) = &event {
780                            if e.status.state.is_terminal() {
781                                break;
782                            }
783                        }
784                    }
785                }
786                Err(broadcast::error::RecvError::Lagged(_)) => continue,
787                Err(broadcast::error::RecvError::Closed) => break,
788            }
789        }
790    };
791
792    Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
793}
794
795async fn handle_tasks_get(
796    state: AppState,
797    req: JsonRpcRequest,
798) -> (StatusCode, Json<JsonRpcResponse>) {
799    let params: Result<GetTaskRequest, _> = serde_json::from_value(req.params.unwrap_or_default());
800
801    match params {
802        Ok(p) => match state.task_store.get_flexible(&p.id).await {
803            Some(mut task) => {
804                apply_history_length(&mut task, p.history_length);
805
806                match serde_json::to_value(task) {
807                    Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
808                    Err(e) => (
809                        StatusCode::INTERNAL_SERVER_ERROR,
810                        Json(error(
811                            req.id,
812                            errors::INTERNAL_ERROR,
813                            "serialization failed",
814                            Some(serde_json::json!({"error": e.to_string()})),
815                        )),
816                    ),
817                }
818            }
819            None => (
820                StatusCode::NOT_FOUND,
821                Json(error(
822                    req.id,
823                    errors::TASK_NOT_FOUND,
824                    "task not found",
825                    None,
826                )),
827            ),
828        },
829        Err(err) => (
830            StatusCode::BAD_REQUEST,
831            Json(error(
832                req.id,
833                errors::INVALID_PARAMS,
834                "invalid params",
835                Some(serde_json::json!({"error": err.to_string()})),
836            )),
837        ),
838    }
839}
840
841async fn handle_tasks_cancel(
842    state: AppState,
843    req: JsonRpcRequest,
844) -> (StatusCode, Json<JsonRpcResponse>) {
845    let params: Result<CancelTaskRequest, _> =
846        serde_json::from_value(req.params.unwrap_or_default());
847
848    match params {
849        Ok(p) => {
850            let result = state
851                .task_store
852                .update_flexible(&p.id, |task| {
853                    if task.status.state.is_terminal() {
854                        return Err(errors::TASK_NOT_CANCELABLE);
855                    }
856                    task.status.state = TaskState::Canceled;
857                    task.status.timestamp = Some(now_iso8601());
858                    Ok(())
859                })
860                .await;
861
862            match result {
863                Some(Ok(task)) => {
864                    if let Err(e) = state.handler.cancel_task(&task.id).await {
865                        tracing::warn!("Handler cancel_task failed: {}", e);
866                    }
867
868                    state.broadcast_event(StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
869                        task_id: task.id.clone(),
870                        context_id: task.context_id.clone(),
871                        status: task.status.clone(),
872                        metadata: None,
873                    }));
874
875                    match serde_json::to_value(task) {
876                        Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
877                        Err(e) => (
878                            StatusCode::INTERNAL_SERVER_ERROR,
879                            Json(error(
880                                req.id,
881                                errors::INTERNAL_ERROR,
882                                "serialization failed",
883                                Some(serde_json::json!({"error": e.to_string()})),
884                            )),
885                        ),
886                    }
887                }
888                Some(Err(error_code)) => (
889                    StatusCode::BAD_REQUEST,
890                    Json(error(req.id, error_code, "task not cancelable", None)),
891                ),
892                None => (
893                    StatusCode::NOT_FOUND,
894                    Json(error(
895                        req.id,
896                        errors::TASK_NOT_FOUND,
897                        "task not found",
898                        None,
899                    )),
900                ),
901            }
902        }
903        Err(err) => (
904            StatusCode::BAD_REQUEST,
905            Json(error(
906                req.id,
907                errors::INVALID_PARAMS,
908                "invalid params",
909                Some(serde_json::json!({"error": err.to_string()})),
910            )),
911        ),
912    }
913}
914
915async fn handle_tasks_list(
916    state: AppState,
917    req: JsonRpcRequest,
918) -> (StatusCode, Json<JsonRpcResponse>) {
919    let params: Result<ListTasksRequest, _> =
920        serde_json::from_value(req.params.unwrap_or_default());
921
922    match params {
923        Ok(p) => {
924            let response = state.task_store.list_filtered(&p).await;
925            match serde_json::to_value(response) {
926                Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
927                Err(e) => (
928                    StatusCode::INTERNAL_SERVER_ERROR,
929                    Json(error(
930                        req.id,
931                        errors::INTERNAL_ERROR,
932                        "serialization failed",
933                        Some(serde_json::json!({"error": e.to_string()})),
934                    )),
935                ),
936            }
937        }
938        Err(err) => (
939            StatusCode::BAD_REQUEST,
940            Json(error(
941                req.id,
942                errors::INVALID_PARAMS,
943                "invalid params",
944                Some(serde_json::json!({"error": err.to_string()})),
945            )),
946        ),
947    }
948}
949
950async fn handle_tasks_subscribe(
951    state: AppState,
952    req: JsonRpcRequest,
953) -> (StatusCode, Json<JsonRpcResponse>) {
954    let params: Result<SubscribeToTaskRequest, _> =
955        serde_json::from_value(req.params.unwrap_or_default());
956
957    match params {
958        Ok(p) => {
959            if state.task_store.get_flexible(&p.id).await.is_none() {
960                return (
961                    StatusCode::NOT_FOUND,
962                    Json(error(
963                        req.id,
964                        errors::TASK_NOT_FOUND,
965                        "task not found",
966                        None,
967                    )),
968                );
969            }
970
971            let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
972            let subscribe_url = format!("{}/v1/tasks/{}/subscribe", base_url, p.id);
973
974            (
975                StatusCode::OK,
976                Json(success(
977                    req.id,
978                    serde_json::json!({
979                        "subscribeUrl": subscribe_url,
980                        "protocol": "sse"
981                    }),
982                )),
983            )
984        }
985        Err(err) => (
986            StatusCode::BAD_REQUEST,
987            Json(error(
988                req.id,
989                errors::INVALID_PARAMS,
990                "invalid params",
991                Some(serde_json::json!({"error": err.to_string()})),
992            )),
993        ),
994    }
995}
996
997async fn handle_get_extended_agent_card(
998    state: AppState,
999    req: JsonRpcRequest,
1000    auth_context: Option<AuthContext>,
1001) -> (StatusCode, Json<JsonRpcResponse>) {
1002    let Some(auth) = auth_context else {
1003        return (
1004            StatusCode::UNAUTHORIZED,
1005            Json(error(
1006                req.id,
1007                errors::INVALID_REQUEST,
1008                "authentication required for extended agent card",
1009                None,
1010            )),
1011        );
1012    };
1013
1014    let base_url = state.endpoint_url().trim_end_matches("/v1/rpc");
1015
1016    match state.handler.extended_agent_card(base_url, &auth).await {
1017        Some(card) => match serde_json::to_value(card) {
1018            Ok(val) => (StatusCode::OK, Json(success(req.id, val))),
1019            Err(e) => (
1020                StatusCode::INTERNAL_SERVER_ERROR,
1021                Json(error(
1022                    req.id,
1023                    errors::INTERNAL_ERROR,
1024                    "serialization failed",
1025                    Some(serde_json::json!({"error": e.to_string()})),
1026                )),
1027            ),
1028        },
1029        None => (
1030            StatusCode::NOT_FOUND,
1031            Json(error(
1032                req.id,
1033                errors::EXTENDED_AGENT_CARD_NOT_CONFIGURED,
1034                "extended agent card not configured",
1035                None,
1036            )),
1037        ),
1038    }
1039}
1040
1041// ============ Push Notification Config Handlers ============
1042
1043async fn handle_push_config_create(
1044    state: AppState,
1045    req: JsonRpcRequest,
1046) -> (StatusCode, Json<JsonRpcResponse>) {
1047    if !state.push_notifications_enabled() {
1048        return (
1049            StatusCode::BAD_REQUEST,
1050            Json(error(
1051                req.id,
1052                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1053                "push notifications not supported",
1054                None,
1055            )),
1056        );
1057    }
1058
1059    let params: Result<CreateTaskPushNotificationConfigRequest, _> =
1060        serde_json::from_value(req.params.unwrap_or_default());
1061
1062    match params {
1063        Ok(p) => {
1064            if state.task_store.get_flexible(&p.task_id).await.is_none() {
1065                return (
1066                    StatusCode::NOT_FOUND,
1067                    Json(error(
1068                        req.id,
1069                        errors::TASK_NOT_FOUND,
1070                        "task not found",
1071                        None,
1072                    )),
1073                );
1074            }
1075
1076            if let Err(e) = state
1077                .webhook_store
1078                .set(&p.task_id, &p.config_id, p.push_notification_config.clone())
1079                .await
1080            {
1081                return (
1082                    StatusCode::BAD_REQUEST,
1083                    Json(error(req.id, errors::INVALID_PARAMS, &e.to_string(), None)),
1084                );
1085            }
1086
1087            (
1088                StatusCode::OK,
1089                Json(success(
1090                    req.id,
1091                    serde_json::json!({
1092                        "configId": p.config_id,
1093                        "config": p.push_notification_config
1094                    }),
1095                )),
1096            )
1097        }
1098        Err(err) => (
1099            StatusCode::BAD_REQUEST,
1100            Json(error(
1101                req.id,
1102                errors::INVALID_PARAMS,
1103                "invalid params",
1104                Some(serde_json::json!({"error": err.to_string()})),
1105            )),
1106        ),
1107    }
1108}
1109
1110async fn handle_push_config_get(
1111    state: AppState,
1112    req: JsonRpcRequest,
1113) -> (StatusCode, Json<JsonRpcResponse>) {
1114    if !state.push_notifications_enabled() {
1115        return (
1116            StatusCode::BAD_REQUEST,
1117            Json(error(
1118                req.id,
1119                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1120                "push notifications not supported",
1121                None,
1122            )),
1123        );
1124    }
1125
1126    let params: Result<GetTaskPushNotificationConfigRequest, _> =
1127        serde_json::from_value(req.params.unwrap_or_default());
1128
1129    match params {
1130        Ok(p) => match state.webhook_store.get(&p.task_id, &p.id).await {
1131            Some(config) => (
1132                StatusCode::OK,
1133                Json(success(
1134                    req.id,
1135                    serde_json::json!({
1136                        "configId": p.id,
1137                        "config": config
1138                    }),
1139                )),
1140            ),
1141            None => (
1142                StatusCode::NOT_FOUND,
1143                Json(error(
1144                    req.id,
1145                    errors::TASK_NOT_FOUND,
1146                    "push notification config not found",
1147                    None,
1148                )),
1149            ),
1150        },
1151        Err(err) => (
1152            StatusCode::BAD_REQUEST,
1153            Json(error(
1154                req.id,
1155                errors::INVALID_PARAMS,
1156                "invalid params",
1157                Some(serde_json::json!({"error": err.to_string()})),
1158            )),
1159        ),
1160    }
1161}
1162
1163async fn handle_push_config_list(
1164    state: AppState,
1165    req: JsonRpcRequest,
1166) -> (StatusCode, Json<JsonRpcResponse>) {
1167    if !state.push_notifications_enabled() {
1168        return (
1169            StatusCode::BAD_REQUEST,
1170            Json(error(
1171                req.id,
1172                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1173                "push notifications not supported",
1174                None,
1175            )),
1176        );
1177    }
1178
1179    let params: Result<ListTaskPushNotificationConfigRequest, _> =
1180        serde_json::from_value(req.params.unwrap_or_default());
1181
1182    match params {
1183        Ok(p) => {
1184            let configs = state.webhook_store.list(&p.task_id).await;
1185
1186            let configs_json: Vec<_> = configs
1187                .iter()
1188                .map(|c| {
1189                    serde_json::json!({
1190                        "configId": c.config_id,
1191                        "config": c.config
1192                    })
1193                })
1194                .collect();
1195
1196            (
1197                StatusCode::OK,
1198                Json(success(
1199                    req.id,
1200                    serde_json::json!({
1201                        "configs": configs_json,
1202                        "nextPageToken": ""
1203                    }),
1204                )),
1205            )
1206        }
1207        Err(err) => (
1208            StatusCode::BAD_REQUEST,
1209            Json(error(
1210                req.id,
1211                errors::INVALID_PARAMS,
1212                "invalid params",
1213                Some(serde_json::json!({"error": err.to_string()})),
1214            )),
1215        ),
1216    }
1217}
1218
1219async fn handle_push_config_delete(
1220    state: AppState,
1221    req: JsonRpcRequest,
1222) -> (StatusCode, Json<JsonRpcResponse>) {
1223    if !state.push_notifications_enabled() {
1224        return (
1225            StatusCode::BAD_REQUEST,
1226            Json(error(
1227                req.id,
1228                errors::PUSH_NOTIFICATION_NOT_SUPPORTED,
1229                "push notifications not supported",
1230                None,
1231            )),
1232        );
1233    }
1234
1235    let params: Result<DeleteTaskPushNotificationConfigRequest, _> =
1236        serde_json::from_value(req.params.unwrap_or_default());
1237
1238    match params {
1239        Ok(p) => {
1240            if state.webhook_store.delete(&p.task_id, &p.id).await {
1241                (StatusCode::OK, Json(success(req.id, serde_json::json!({}))))
1242            } else {
1243                (
1244                    StatusCode::NOT_FOUND,
1245                    Json(error(
1246                        req.id,
1247                        errors::TASK_NOT_FOUND,
1248                        "push notification config not found",
1249                        None,
1250                    )),
1251                )
1252            }
1253        }
1254        Err(err) => (
1255            StatusCode::BAD_REQUEST,
1256            Json(error(
1257                req.id,
1258                errors::INVALID_PARAMS,
1259                "invalid params",
1260                Some(serde_json::json!({"error": err.to_string()})),
1261            )),
1262        ),
1263    }
1264}
1265
1266pub async fn run_server<H: crate::handler::MessageHandler + 'static>(
1267    bind_addr: &str,
1268    handler: H,
1269) -> anyhow::Result<()> {
1270    A2aServer::new(handler).bind(bind_addr)?.run().await
1271}
1272
1273pub async fn run_echo_server(bind_addr: &str) -> anyhow::Result<()> {
1274    A2aServer::echo().bind(bind_addr)?.run().await
1275}