1pub mod grpc_service;
5pub mod handler;
7pub mod jsonrpc_router;
9pub mod rest_router;
11pub mod sse_response;
13
14pub use handler::A2aHandler;
15
16#[cfg(feature = "grpc-server")]
17pub use grpc_service::GrpcBridge;
18
19use std::net::SocketAddr;
20use std::sync::Arc;
21
22use crate::error::A2aError;
23use crate::jsonrpc::{JsonRpcRequest, METHOD_MESSAGE_STREAM, METHOD_TASKS_RESUBSCRIBE, RequestId};
24use crate::params::{SendMessageRequest, SubscribeToTaskRequest};
25
26const MAX_REQUEST_BODY_SIZE: usize = 10 * 1024 * 1024;
28
29pub struct A2aServer<H: A2aHandler> {
31 handler: Arc<H>,
32 addr: SocketAddr,
33 #[cfg(feature = "grpc-server")]
34 grpc_addr: Option<SocketAddr>,
35 shutdown: Option<tokio::sync::watch::Receiver<()>>,
36}
37
38impl<H: A2aHandler> A2aServer<H> {
39 pub fn new(handler: H, addr: SocketAddr) -> Self {
41 Self {
42 handler: Arc::new(handler),
43 addr,
44 #[cfg(feature = "grpc-server")]
45 grpc_addr: None,
46 shutdown: None,
47 }
48 }
49
50 #[cfg(feature = "grpc-server")]
52 pub fn with_grpc(mut self, grpc_addr: SocketAddr) -> Self {
53 self.grpc_addr = Some(grpc_addr);
54 self
55 }
56
57 pub fn with_shutdown(mut self, rx: tokio::sync::watch::Receiver<()>) -> Self {
60 self.shutdown = Some(rx);
61 self
62 }
63
64 pub async fn run(self) -> Result<(), A2aError> {
66 use hyper::body::Incoming;
67 use hyper::service::service_fn;
68 use hyper_util::rt::TokioIo;
69
70 let handler = self.handler.clone();
71 let listener = tokio::net::TcpListener::bind(self.addr)
72 .await
73 .map_err(|e| A2aError::internal(format!("Failed to bind: {e}")))?;
74
75 tracing::info!("A2A server listening on {}", self.addr);
76
77 #[cfg(feature = "grpc-server")]
79 if let Some(grpc_addr) = self.grpc_addr {
80 let grpc_handler = self.handler.clone();
81 let shutdown_rx = self.shutdown.clone();
82 let (bind_tx, mut bind_rx) = tokio::sync::oneshot::channel::<Result<(), String>>();
83 tokio::spawn(async move {
84 let bridge = GrpcBridge::new(grpc_handler);
85 let svc =
86 crate::proto::lf_a2a_v1::a2a_service_server::A2aServiceServer::new(bridge);
87 tracing::info!("A2A gRPC server listening on {grpc_addr}");
88 let builder = tonic::transport::Server::builder().add_service(svc);
89 let result = if let Some(mut rx) = shutdown_rx {
90 builder
91 .serve_with_shutdown(grpc_addr, async move {
92 let _ = rx.changed().await;
93 })
94 .await
95 } else {
96 builder.serve(grpc_addr).await
97 };
98 match result {
99 Ok(()) => {
100 let _ = bind_tx.send(Ok(()));
101 }
102 Err(e) => {
103 let msg = format!("gRPC server error: {e}");
104 tracing::error!("{msg}");
105 let _ = bind_tx.send(Err(msg));
106 }
107 }
108 });
109 tokio::task::yield_now().await;
111 if let Ok(Err(msg)) = bind_rx.try_recv() {
112 return Err(A2aError::internal(msg));
113 }
114 }
115
116 let mut shutdown = self.shutdown;
117
118 loop {
119 let accept_result = if let Some(ref mut rx) = shutdown {
120 tokio::select! {
121 result = listener.accept() => Some(result),
122 _ = rx.changed() => None,
123 }
124 } else {
125 Some(listener.accept().await)
126 };
127
128 let (stream, _peer) = match accept_result {
129 None => {
130 tracing::info!("A2A server shutting down");
131 return Ok(());
132 }
133 Some(Ok(conn)) => conn,
134 Some(Err(e)) => {
135 tracing::warn!("Accept error (continuing): {e}");
136 continue;
137 }
138 };
139
140 let handler = handler.clone();
141 tokio::spawn(async move {
142 let io = TokioIo::new(stream);
143 let svc = service_fn(move |req: hyper::Request<Incoming>| {
144 let handler = handler.clone();
145 async move { handle_http_request(handler, req).await }
146 });
147 if let Err(e) = hyper_util::server::conn::auto::Builder::new(
148 hyper_util::rt::TokioExecutor::new(),
149 )
150 .serve_connection(io, svc)
151 .await
152 {
153 tracing::debug!("Connection error: {e}");
154 }
155 });
156 }
157 }
158}
159
160#[cfg(feature = "server")]
165type BoxBody = http_body_util::Either<
166 http_body_util::Full<bytes::Bytes>,
167 http_body_util::StreamBody<
168 std::pin::Pin<
169 Box<
170 dyn futures::Stream<Item = Result<http_body::Frame<bytes::Bytes>, std::io::Error>>
171 + Send,
172 >,
173 >,
174 >,
175>;
176
177#[cfg(feature = "server")]
178fn json_response(status: u16, body: String) -> hyper::Response<BoxBody> {
179 let mut resp = hyper::Response::builder()
180 .status(status)
181 .header("Content-Type", "application/json")
182 .header("Access-Control-Allow-Origin", "*")
183 .header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
184 .header(
185 "Access-Control-Allow-Headers",
186 "Content-Type, Authorization",
187 )
188 .body(http_body_util::Either::Left(http_body_util::Full::new(
189 bytes::Bytes::from(body),
190 )))
191 .expect("response builder with valid status and headers cannot fail");
192 let _ = &mut resp;
193 resp
194}
195
196#[cfg(feature = "server")]
197fn sse_response(
198 stream: std::pin::Pin<
199 Box<
200 dyn futures::Stream<Item = Result<http_body::Frame<bytes::Bytes>, std::io::Error>>
201 + Send,
202 >,
203 >,
204) -> hyper::Response<BoxBody> {
205 hyper::Response::builder()
206 .status(200)
207 .header("Content-Type", "text/event-stream")
208 .header("Cache-Control", "no-cache")
209 .header("Connection", "keep-alive")
210 .header("Access-Control-Allow-Origin", "*")
211 .header("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
212 .header(
213 "Access-Control-Allow-Headers",
214 "Content-Type, Authorization",
215 )
216 .body(http_body_util::Either::Right(
217 http_body_util::StreamBody::new(stream),
218 ))
219 .expect("response builder with valid status and headers cannot fail")
220}
221
222#[cfg(feature = "server")]
227async fn handle_http_request<H: A2aHandler>(
228 handler: Arc<H>,
229 req: hyper::Request<hyper::body::Incoming>,
230) -> Result<hyper::Response<BoxBody>, hyper::Error> {
231 use http_body_util::BodyExt;
232
233 let method = req.method().clone();
234 let path = req.uri().path().to_string();
235
236 if method == hyper::Method::OPTIONS {
238 return Ok(json_response(204, String::new()));
239 }
240
241 if method == hyper::Method::GET && path == "/.well-known/agent-card.json" {
243 let card = handler.agent_card();
244 let body = serde_json::to_string(card).unwrap_or_default();
245 return Ok(json_response(200, body));
246 }
247
248 let limited = http_body_util::Limited::new(req.into_body(), MAX_REQUEST_BODY_SIZE);
250 let body_bytes = match limited.collect().await {
251 Ok(c) => c.to_bytes(),
252 Err(_) => {
253 let err = A2aError::invalid_request("Request body too large");
254 let body = serde_json::to_string(&err).unwrap_or_default();
255 return Ok(json_response(413, body));
256 }
257 };
258
259 if method == hyper::Method::POST && path == "/" {
261 return handle_jsonrpc(&handler, &body_bytes).await;
262 }
263
264 let method_str = method.as_str();
266 match rest_router::dispatch_rest(&handler, method_str, &path, &body_bytes).await {
267 Ok(rest_router::RestResult::Json(val)) => {
268 let body = serde_json::to_string(&val).unwrap_or_default();
269 Ok(json_response(200, body))
270 }
271 Ok(rest_router::RestResult::Stream(stream)) => {
272 let sse_stream = sse_response::stream_to_sse_rest(stream);
273 Ok(sse_response(sse_stream))
274 }
275 Err(e) => {
276 let body = serde_json::to_string(&e).unwrap_or_default();
277 Ok(json_response(404, body))
278 }
279 }
280}
281
282#[cfg(feature = "server")]
283async fn handle_jsonrpc<H: A2aHandler>(
284 handler: &Arc<H>,
285 body: &bytes::Bytes,
286) -> Result<hyper::Response<BoxBody>, hyper::Error> {
287 let request: JsonRpcRequest = match serde_json::from_slice(body) {
288 Ok(r) => r,
289 Err(e) => {
290 let resp = crate::jsonrpc::JsonRpcResponse::error(
291 RequestId::Number(0),
292 A2aError::parse_error(e.to_string()),
293 );
294 let body = serde_json::to_string(&resp).unwrap_or_default();
295 return Ok(json_response(200, body));
296 }
297 };
298
299 if request.method == METHOD_MESSAGE_STREAM {
301 let id = request.id.clone();
302 let params = request.params.clone().unwrap_or(serde_json::Value::Null);
303 let req: SendMessageRequest = match serde_json::from_value(params) {
304 Ok(r) => r,
305 Err(e) => {
306 let resp = crate::jsonrpc::JsonRpcResponse::error(id, A2aError::from(e));
307 let body = serde_json::to_string(&resp).unwrap_or_default();
308 return Ok(json_response(200, body));
309 }
310 };
311 match handler.on_send_streaming_message(req).await {
312 Ok(stream) => {
313 let sse_stream = sse_response::stream_to_sse(id, stream);
314 return Ok(sse_response(sse_stream));
315 }
316 Err(e) => {
317 let resp = crate::jsonrpc::JsonRpcResponse::error(id, e);
318 let body = serde_json::to_string(&resp).unwrap_or_default();
319 return Ok(json_response(200, body));
320 }
321 }
322 }
323
324 if request.method == METHOD_TASKS_RESUBSCRIBE {
325 let id = request.id.clone();
326 let params = request.params.clone().unwrap_or(serde_json::Value::Null);
327 let req: SubscribeToTaskRequest = match serde_json::from_value(params) {
328 Ok(r) => r,
329 Err(e) => {
330 let resp = crate::jsonrpc::JsonRpcResponse::error(id, A2aError::from(e));
331 let body = serde_json::to_string(&resp).unwrap_or_default();
332 return Ok(json_response(200, body));
333 }
334 };
335 match handler.on_subscribe_to_task(req).await {
336 Ok(stream) => {
337 let sse_stream = sse_response::stream_to_sse(id, stream);
338 return Ok(sse_response(sse_stream));
339 }
340 Err(e) => {
341 let resp = crate::jsonrpc::JsonRpcResponse::error(id, e);
342 let body = serde_json::to_string(&resp).unwrap_or_default();
343 return Ok(json_response(200, body));
344 }
345 }
346 }
347
348 let response = match jsonrpc_router::dispatch(handler, &request).await {
350 Ok(Some(resp)) => resp,
351 Ok(None) => {
352 crate::jsonrpc::JsonRpcResponse::error(
355 request.id,
356 A2aError::internal("Unexpected routing state"),
357 )
358 }
359 Err(resp) => resp,
360 };
361
362 let body = serde_json::to_string(&response).unwrap_or_default();
363 Ok(json_response(200, body))
364}