Skip to main content

brainwires_a2a/server/
mod.rs

1//! A2A server — serves JSON-RPC, REST, and optionally gRPC.
2
3/// gRPC service implementation.
4pub mod grpc_service;
5/// Core handler trait.
6pub mod handler;
7/// JSON-RPC method dispatch.
8pub mod jsonrpc_router;
9/// HTTP/REST route handling.
10pub mod rest_router;
11/// SSE response construction.
12pub mod sse_response;
13
14pub use handler::A2aHandler;
15
16#[cfg(feature = "grpc-server")]
17pub use grpc_service::GrpcBridge;
18
19use std::net::SocketAddr;
20use std::sync::Arc;
21
22use crate::error::A2aError;
23use crate::jsonrpc::{JsonRpcRequest, METHOD_MESSAGE_STREAM, METHOD_TASKS_RESUBSCRIBE, RequestId};
24use crate::params::{SendMessageRequest, SubscribeToTaskRequest};
25
26/// Maximum request body size (10 MB).
27const MAX_REQUEST_BODY_SIZE: usize = 10 * 1024 * 1024;
28
29/// Unified A2A server serving JSON-RPC + REST (HTTP) and optionally gRPC.
30pub struct A2aServer<H: A2aHandler> {
31    handler: Arc<H>,
32    addr: SocketAddr,
33    #[cfg(feature = "grpc-server")]
34    grpc_addr: Option<SocketAddr>,
35    shutdown: Option<tokio::sync::watch::Receiver<()>>,
36}
37
38impl<H: A2aHandler> A2aServer<H> {
39    /// Create a new server bound to `addr`.
40    pub fn new(handler: H, addr: SocketAddr) -> Self {
41        Self {
42            handler: Arc::new(handler),
43            addr,
44            #[cfg(feature = "grpc-server")]
45            grpc_addr: None,
46            shutdown: None,
47        }
48    }
49
50    /// Enable gRPC on a separate port.
51    #[cfg(feature = "grpc-server")]
52    pub fn with_grpc(mut self, grpc_addr: SocketAddr) -> Self {
53        self.grpc_addr = Some(grpc_addr);
54        self
55    }
56
57    /// Set a shutdown signal. When the sender is dropped or a value is sent,
58    /// the server will stop accepting new connections.
59    pub fn with_shutdown(mut self, rx: tokio::sync::watch::Receiver<()>) -> Self {
60        self.shutdown = Some(rx);
61        self
62    }
63
64    /// Run the server (blocks until shutdown signal or forever if none set).
65    pub async fn run(self) -> Result<(), A2aError> {
66        use hyper::body::Incoming;
67        use hyper::service::service_fn;
68        use hyper_util::rt::TokioIo;
69
70        let handler = self.handler.clone();
71        let listener = tokio::net::TcpListener::bind(self.addr)
72            .await
73            .map_err(|e| A2aError::internal(format!("Failed to bind: {e}")))?;
74
75        tracing::info!("A2A server listening on {}", self.addr);
76
77        // Optionally spawn gRPC server with bind error propagation
78        #[cfg(feature = "grpc-server")]
79        if let Some(grpc_addr) = self.grpc_addr {
80            let grpc_handler = self.handler.clone();
81            let shutdown_rx = self.shutdown.clone();
82            let (bind_tx, mut bind_rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
83            tokio::spawn(async move {
84                let bridge = GrpcBridge::new(grpc_handler);
85                let svc =
86                    crate::proto::lf_a2a_v1::a2a_service_server::A2aServiceServer::new(bridge);
87                tracing::info!("A2A gRPC server listening on {grpc_addr}");
88                let builder = tonic::transport::Server::builder().add_service(svc);
89                let result = if let Some(mut rx) = shutdown_rx {
90                    builder
91                        .serve_with_shutdown(grpc_addr, async move {
92                            let _ = rx.changed().await;
93                        })
94                        .await
95                } else {
96                    builder.serve(grpc_addr).await
97                };
98                match result {
99                    Ok(()) => {
100                        let _ = bind_tx.send(Ok(()));
101                    }
102                    Err(e) => {
103                        let msg = format!("gRPC server error: {e}");
104                        tracing::error!("{msg}");
105                        let _ = bind_tx.send(Err(msg));
106                    }
107                }
108            });
109            // Give gRPC a moment to fail on immediate bind errors
110            tokio::task::yield_now().await;
111            if let Ok(Err(msg)) = bind_rx.try_recv() {
112                return Err(A2aError::internal(msg));
113            }
114        }
115
116        let mut shutdown = self.shutdown;
117
118        loop {
119            let accept_result = if let Some(ref mut rx) = shutdown {
120                tokio::select! {
121                    result = listener.accept() => Some(result),
122                    _ = rx.changed() => None,
123                }
124            } else {
125                Some(listener.accept().await)
126            };
127
128            let (stream, _peer) = match accept_result {
129                None => {
130                    tracing::info!("A2A server shutting down");
131                    return Ok(());
132                }
133                Some(Ok(conn)) => conn,
134                Some(Err(e)) => {
135                    tracing::warn!("Accept error (continuing): {e}");
136                    continue;
137                }
138            };
139
140            let handler = handler.clone();
141            tokio::spawn(async move {
142                let io = TokioIo::new(stream);
143                let svc = service_fn(move |req: hyper::Request<Incoming>| {
144                    let handler = handler.clone();
145                    async move { handle_http_request(handler, req).await }
146                });
147                if let Err(e) = hyper_util::server::conn::auto::Builder::new(
148                    hyper_util::rt::TokioExecutor::new(),
149                )
150                .serve_connection(io, svc)
151                .await
152                {
153                    tracing::debug!("Connection error: {e}");
154                }
155            });
156        }
157    }
158}
159
160// ---------------------------------------------------------------------------
161// Response body type — supports both buffered and streaming responses
162// ---------------------------------------------------------------------------
163
164#[cfg(feature = "server")]
165type BoxBody = http_body_util::Either<
166    http_body_util::Full<bytes::Bytes>,
167    http_body_util::StreamBody<
168        std::pin::Pin<
169            Box<
170                dyn futures::Stream<Item = Result<http_body::Frame<bytes::Bytes>, std::io::Error>>
171                    + Send,
172            >,
173        >,
174    >,
175>;
176
177#[cfg(feature = "server")]
178fn json_response(status: u16, body: String) -> hyper::Response<BoxBody> {
179    let mut resp = hyper::Response::builder()
180        .status(status)
181        .header("Content-Type", "application/json")
182        .header("Access-Control-Allow-Origin", "*")
183        .header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
184        .header(
185            "Access-Control-Allow-Headers",
186            "Content-Type, Authorization",
187        )
188        .body(http_body_util::Either::Left(http_body_util::Full::new(
189            bytes::Bytes::from(body),
190        )))
191        .expect("response builder with valid status and headers cannot fail");
192    let _ = &mut resp;
193    resp
194}
195
196#[cfg(feature = "server")]
197fn sse_response(
198    stream: std::pin::Pin<
199        Box<
200            dyn futures::Stream<Item = Result<http_body::Frame<bytes::Bytes>, std::io::Error>>
201                + Send,
202        >,
203    >,
204) -> hyper::Response<BoxBody> {
205    hyper::Response::builder()
206        .status(200)
207        .header("Content-Type", "text/event-stream")
208        .header("Cache-Control", "no-cache")
209        .header("Connection", "keep-alive")
210        .header("Access-Control-Allow-Origin", "*")
211        .header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
212        .header(
213            "Access-Control-Allow-Headers",
214            "Content-Type, Authorization",
215        )
216        .body(http_body_util::Either::Right(
217            http_body_util::StreamBody::new(stream),
218        ))
219        .expect("response builder with valid status and headers cannot fail")
220}
221
222// ---------------------------------------------------------------------------
223// HTTP request handler
224// ---------------------------------------------------------------------------
225
226#[cfg(feature = "server")]
227async fn handle_http_request<H: A2aHandler>(
228    handler: Arc<H>,
229    req: hyper::Request<hyper::body::Incoming>,
230) -> Result<hyper::Response<BoxBody>, hyper::Error> {
231    use http_body_util::BodyExt;
232
233    let method = req.method().clone();
234    let path = req.uri().path().to_string();
235
236    // CORS preflight
237    if method == hyper::Method::OPTIONS {
238        return Ok(json_response(204, String::new()));
239    }
240
241    // Agent card discovery
242    if method == hyper::Method::GET && path == "/.well-known/agent-card.json" {
243        let card = handler.agent_card();
244        let body = serde_json::to_string(card).unwrap_or_default();
245        return Ok(json_response(200, body));
246    }
247
248    // Collect body with size limit
249    let limited = http_body_util::Limited::new(req.into_body(), MAX_REQUEST_BODY_SIZE);
250    let body_bytes = match limited.collect().await {
251        Ok(c) => c.to_bytes(),
252        Err(_) => {
253            let err = A2aError::invalid_request("Request body too large");
254            let body = serde_json::to_string(&err).unwrap_or_default();
255            return Ok(json_response(413, body));
256        }
257    };
258
259    // JSON-RPC: POST to /
260    if method == hyper::Method::POST && path == "/" {
261        return handle_jsonrpc(&handler, &body_bytes).await;
262    }
263
264    // REST routes
265    let method_str = method.as_str();
266    match rest_router::dispatch_rest(&handler, method_str, &path, &body_bytes).await {
267        Ok(rest_router::RestResult::Json(val)) => {
268            let body = serde_json::to_string(&val).unwrap_or_default();
269            Ok(json_response(200, body))
270        }
271        Ok(rest_router::RestResult::Stream(stream)) => {
272            let sse_stream = sse_response::stream_to_sse_rest(stream);
273            Ok(sse_response(sse_stream))
274        }
275        Err(e) => {
276            let body = serde_json::to_string(&e).unwrap_or_default();
277            Ok(json_response(404, body))
278        }
279    }
280}
281
282#[cfg(feature = "server")]
283async fn handle_jsonrpc<H: A2aHandler>(
284    handler: &Arc<H>,
285    body: &bytes::Bytes,
286) -> Result<hyper::Response<BoxBody>, hyper::Error> {
287    let request: JsonRpcRequest = match serde_json::from_slice(body) {
288        Ok(r) => r,
289        Err(e) => {
290            let resp = crate::jsonrpc::JsonRpcResponse::error(
291                RequestId::Number(0),
292                A2aError::parse_error(e.to_string()),
293            );
294            let body = serde_json::to_string(&resp).unwrap_or_default();
295            return Ok(json_response(200, body));
296        }
297    };
298
299    // Check for streaming methods
300    if request.method == METHOD_MESSAGE_STREAM {
301        let id = request.id.clone();
302        let params = request.params.clone().unwrap_or(serde_json::Value::Null);
303        let req: SendMessageRequest = match serde_json::from_value(params) {
304            Ok(r) => r,
305            Err(e) => {
306                let resp = crate::jsonrpc::JsonRpcResponse::error(id, A2aError::from(e));
307                let body = serde_json::to_string(&resp).unwrap_or_default();
308                return Ok(json_response(200, body));
309            }
310        };
311        match handler.on_send_streaming_message(req).await {
312            Ok(stream) => {
313                let sse_stream = sse_response::stream_to_sse(id, stream);
314                return Ok(sse_response(sse_stream));
315            }
316            Err(e) => {
317                let resp = crate::jsonrpc::JsonRpcResponse::error(id, e);
318                let body = serde_json::to_string(&resp).unwrap_or_default();
319                return Ok(json_response(200, body));
320            }
321        }
322    }
323
324    if request.method == METHOD_TASKS_RESUBSCRIBE {
325        let id = request.id.clone();
326        let params = request.params.clone().unwrap_or(serde_json::Value::Null);
327        let req: SubscribeToTaskRequest = match serde_json::from_value(params) {
328            Ok(r) => r,
329            Err(e) => {
330                let resp = crate::jsonrpc::JsonRpcResponse::error(id, A2aError::from(e));
331                let body = serde_json::to_string(&resp).unwrap_or_default();
332                return Ok(json_response(200, body));
333            }
334        };
335        match handler.on_subscribe_to_task(req).await {
336            Ok(stream) => {
337                let sse_stream = sse_response::stream_to_sse(id, stream);
338                return Ok(sse_response(sse_stream));
339            }
340            Err(e) => {
341                let resp = crate::jsonrpc::JsonRpcResponse::error(id, e);
342                let body = serde_json::to_string(&resp).unwrap_or_default();
343                return Ok(json_response(200, body));
344            }
345        }
346    }
347
348    // Non-streaming JSON-RPC
349    let response = match jsonrpc_router::dispatch(handler, &request).await {
350        Ok(Some(resp)) => resp,
351        Ok(None) => {
352            // Should not happen — streaming methods handled above
353
354            crate::jsonrpc::JsonRpcResponse::error(
355                request.id,
356                A2aError::internal("Unexpected routing state"),
357            )
358        }
359        Err(resp) => resp,
360    };
361
362    let body = serde_json::to_string(&response).unwrap_or_default();
363    Ok(json_response(200, body))
364}