1use std::convert::Infallible;
11use std::sync::Arc;
12
13use bytes::Bytes;
14use http_body_util::combinators::BoxBody;
15use http_body_util::{BodyExt, Full};
16use hyper::body::Incoming;
17
18use a2a_protocol_types::jsonrpc::{
19 JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
20 JsonRpcVersion,
21};
22
23use crate::dispatch::cors::CorsConfig;
24use crate::error::ServerError;
25use crate::handler::{RequestHandler, SendMessageResult};
26use crate::streaming::build_sse_response;
27
28pub struct JsonRpcDispatcher {
33 handler: Arc<RequestHandler>,
34 cors: Option<CorsConfig>,
35}
36
37impl JsonRpcDispatcher {
38 #[must_use]
40 pub const fn new(handler: Arc<RequestHandler>) -> Self {
41 Self {
42 handler,
43 cors: None,
44 }
45 }
46
47 #[must_use]
52 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
53 self.cors = Some(cors);
54 self
55 }
56
57 pub async fn dispatch(
64 &self,
65 req: hyper::Request<Incoming>,
66 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
67 if req.method() == "OPTIONS" {
69 if let Some(ref cors) = self.cors {
70 return cors.preflight_response();
71 }
72 return json_response(204, Vec::new());
73 }
74
75 let mut resp = self.dispatch_inner(req).await;
76 if let Some(ref cors) = self.cors {
77 cors.apply_headers(&mut resp);
78 }
79 resp
80 }
81
82 #[allow(clippy::too_many_lines)]
84 async fn dispatch_inner(
85 &self,
86 req: hyper::Request<Incoming>,
87 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
88 if let Some(ct) = req.headers().get("content-type") {
90 let ct_str = ct.to_str().unwrap_or("");
91 if !ct_str.starts_with("application/json")
92 && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
93 {
94 return parse_error_response(
95 None,
96 &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
97 );
98 }
99 }
100
101 let body_bytes = match read_body_limited(req.into_body(), MAX_REQUEST_BODY_SIZE).await {
103 Ok(bytes) => bytes,
104 Err(msg) => return parse_error_response(None, &msg),
105 };
106
107 let raw: serde_json::Value = match serde_json::from_slice(&body_bytes) {
109 Ok(v) => v,
110 Err(e) => return parse_error_response(None, &e.to_string()),
111 };
112
113 if let Some(items) = raw.as_array() {
114 if items.is_empty() {
116 return parse_error_response(None, "empty batch request");
117 }
118 let mut responses: Vec<serde_json::Value> = Vec::with_capacity(items.len());
119 for item in items {
120 let rpc_req: JsonRpcRequest = match serde_json::from_value(item.clone()) {
121 Ok(r) => r,
122 Err(e) => {
123 let err_resp = JsonRpcErrorResponse::new(
125 None,
126 JsonRpcError::new(
127 a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
128 format!("Parse error: {e}"),
129 ),
130 );
131 if let Ok(v) = serde_json::to_value(&err_resp) {
132 responses.push(v);
133 }
134 continue;
135 }
136 };
137 let resp_body = self.dispatch_single_request(&rpc_req).await;
138 if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&resp_body) {
139 responses.push(v);
140 }
141 }
142 let body = serde_json::to_vec(&responses).unwrap_or_default();
143 json_response(200, body)
144 } else {
145 let rpc_req: JsonRpcRequest = match serde_json::from_value(raw) {
147 Ok(r) => r,
148 Err(e) => return parse_error_response(None, &e.to_string()),
149 };
150 self.dispatch_single_request_http(&rpc_req).await
151 }
152 }
153
154 #[allow(clippy::too_many_lines)]
158 async fn dispatch_single_request_http(
159 &self,
160 rpc_req: &JsonRpcRequest,
161 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
162 let id = rpc_req.id.clone();
163 trace_info!(method = %rpc_req.method, "dispatching JSON-RPC request");
164
165 match rpc_req.method.as_str() {
167 "SendStreamingMessage" => return self.dispatch_send_message(id, rpc_req, true).await,
168 "SubscribeToTask" => {
169 return match parse_params::<a2a_protocol_types::params::TaskIdParams>(rpc_req) {
170 Ok(p) => match self.handler.on_resubscribe(p).await {
171 Ok(reader) => build_sse_response(reader, None),
172 Err(e) => error_response(id, &e),
173 },
174 Err(e) => error_response(id, &e),
175 };
176 }
177 _ => {}
178 }
179
180 let body = self.dispatch_single_request(rpc_req).await;
181 json_response(200, body)
182 }
183
184 #[allow(clippy::too_many_lines)]
188 async fn dispatch_single_request(&self, rpc_req: &JsonRpcRequest) -> Vec<u8> {
189 let id = rpc_req.id.clone();
190
191 match rpc_req.method.as_str() {
192 "SendMessage" => {
193 match self
194 .dispatch_send_message_inner(id.clone(), rpc_req, false)
195 .await
196 {
197 Ok(resp) => serde_json::to_vec(&resp).unwrap_or_default(),
198 Err(body) => body,
199 }
200 }
201 "SendStreamingMessage" => {
202 let err = ServerError::InvalidParams(
204 "SendStreamingMessage not supported in batch requests".into(),
205 );
206 let a2a_err = err.to_a2a_error();
207 let resp = JsonRpcErrorResponse::new(
208 id,
209 JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
210 );
211 serde_json::to_vec(&resp).unwrap_or_default()
212 }
213 "GetTask" => match parse_params::<a2a_protocol_types::params::TaskQueryParams>(rpc_req)
214 {
215 Ok(p) => match self.handler.on_get_task(p).await {
216 Ok(r) => success_response_bytes(id, &r),
217 Err(e) => error_response_bytes(id, &e),
218 },
219 Err(e) => error_response_bytes(id, &e),
220 },
221 "ListTasks" => {
222 match parse_params::<a2a_protocol_types::params::ListTasksParams>(rpc_req) {
223 Ok(p) => match self.handler.on_list_tasks(p).await {
224 Ok(r) => success_response_bytes(id, &r),
225 Err(e) => error_response_bytes(id, &e),
226 },
227 Err(e) => error_response_bytes(id, &e),
228 }
229 }
230 "CancelTask" => {
231 match parse_params::<a2a_protocol_types::params::CancelTaskParams>(rpc_req) {
232 Ok(p) => match self.handler.on_cancel_task(p).await {
233 Ok(r) => success_response_bytes(id, &r),
234 Err(e) => error_response_bytes(id, &e),
235 },
236 Err(e) => error_response_bytes(id, &e),
237 }
238 }
239 "SubscribeToTask" => {
240 let err = ServerError::InvalidParams(
241 "SubscribeToTask not supported in batch requests".into(),
242 );
243 error_response_bytes(id, &err)
244 }
245 "CreateTaskPushNotificationConfig" => {
246 match parse_params::<a2a_protocol_types::push::TaskPushNotificationConfig>(rpc_req)
247 {
248 Ok(p) => match self.handler.on_set_push_config(p).await {
249 Ok(r) => success_response_bytes(id, &r),
250 Err(e) => error_response_bytes(id, &e),
251 },
252 Err(e) => error_response_bytes(id, &e),
253 }
254 }
255 "GetTaskPushNotificationConfig" => {
256 match parse_params::<a2a_protocol_types::params::GetPushConfigParams>(rpc_req) {
257 Ok(p) => match self.handler.on_get_push_config(p).await {
258 Ok(r) => success_response_bytes(id, &r),
259 Err(e) => error_response_bytes(id, &e),
260 },
261 Err(e) => error_response_bytes(id, &e),
262 }
263 }
264 "ListTaskPushNotificationConfigs" => {
265 match parse_params::<a2a_protocol_types::params::TaskIdParams>(rpc_req) {
266 Ok(p) => match self.handler.on_list_push_configs(&p.id).await {
267 Ok(r) => success_response_bytes(id, &r),
268 Err(e) => error_response_bytes(id, &e),
269 },
270 Err(e) => error_response_bytes(id, &e),
271 }
272 }
273 "DeleteTaskPushNotificationConfig" => {
274 match parse_params::<a2a_protocol_types::params::DeletePushConfigParams>(rpc_req) {
275 Ok(p) => match self.handler.on_delete_push_config(p).await {
276 Ok(()) => success_response_bytes(id, &serde_json::json!({})),
277 Err(e) => error_response_bytes(id, &e),
278 },
279 Err(e) => error_response_bytes(id, &e),
280 }
281 }
282 "GetExtendedAgentCard" => match self.handler.on_get_extended_agent_card().await {
283 Ok(r) => success_response_bytes(id, &r),
284 Err(e) => error_response_bytes(id, &e),
285 },
286 other => {
287 let err = ServerError::MethodNotFound(other.to_owned());
288 error_response_bytes(id, &err)
289 }
290 }
291 }
292
293 async fn dispatch_send_message_inner(
296 &self,
297 id: JsonRpcId,
298 rpc_req: &JsonRpcRequest,
299 streaming: bool,
300 ) -> Result<JsonRpcSuccessResponse<serde_json::Value>, Vec<u8>> {
301 let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
302 Ok(p) => p,
303 Err(e) => return Err(error_response_bytes(id, &e)),
304 };
305 match self.handler.on_send_message(params, streaming).await {
306 Ok(SendMessageResult::Response(resp)) => {
307 let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
308 Ok(JsonRpcSuccessResponse {
309 jsonrpc: JsonRpcVersion,
310 id,
311 result,
312 })
313 }
314 Ok(SendMessageResult::Stream(_)) => {
315 let err = ServerError::Internal("unexpected stream response".into());
317 Err(error_response_bytes(id, &err))
318 }
319 Err(e) => Err(error_response_bytes(id, &e)),
320 }
321 }
322
323 async fn dispatch_send_message(
324 &self,
325 id: JsonRpcId,
326 rpc_req: &JsonRpcRequest,
327 streaming: bool,
328 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
329 let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
330 Ok(p) => p,
331 Err(e) => return error_response(id, &e),
332 };
333 match self.handler.on_send_message(params, streaming).await {
334 Ok(SendMessageResult::Response(resp)) => success_response(id, &resp),
335 Ok(SendMessageResult::Stream(reader)) => build_sse_response(reader, None),
336 Err(e) => error_response(id, &e),
337 }
338 }
339}
340
341fn success_response_bytes<T: serde::Serialize>(id: JsonRpcId, result: &T) -> Vec<u8> {
343 let resp = JsonRpcSuccessResponse {
344 jsonrpc: JsonRpcVersion,
345 id,
346 result: serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
347 };
348 serde_json::to_vec(&resp).unwrap_or_default()
349}
350
351fn error_response_bytes(id: JsonRpcId, err: &ServerError) -> Vec<u8> {
353 let a2a_err = err.to_a2a_error();
354 let resp = JsonRpcErrorResponse::new(
355 id,
356 JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
357 );
358 serde_json::to_vec(&resp).unwrap_or_default()
359}
360
361impl std::fmt::Debug for JsonRpcDispatcher {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.debug_struct("JsonRpcDispatcher").finish()
364 }
365}
366
367fn parse_params<T: serde::de::DeserializeOwned>(
370 rpc_req: &JsonRpcRequest,
371) -> Result<T, ServerError> {
372 let params = rpc_req
373 .params
374 .as_ref()
375 .ok_or_else(|| ServerError::InvalidParams("missing params".into()))?;
376 serde_json::from_value(params.clone())
377 .map_err(|e| ServerError::InvalidParams(format!("invalid params: {e}")))
378}
379
380fn success_response<T: serde::Serialize>(
381 id: JsonRpcId,
382 result: &T,
383) -> hyper::Response<BoxBody<Bytes, Infallible>> {
384 let resp = JsonRpcSuccessResponse {
385 jsonrpc: JsonRpcVersion,
386 id: id.clone(),
387 result: serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
388 };
389 match serde_json::to_vec(&resp) {
390 Ok(body) => json_response(200, body),
391 Err(e) => internal_serialization_error(id, &e),
392 }
393}
394
395fn error_response(id: JsonRpcId, err: &ServerError) -> hyper::Response<BoxBody<Bytes, Infallible>> {
396 let a2a_err = err.to_a2a_error();
397 let resp = JsonRpcErrorResponse::new(
398 id.clone(),
399 JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
400 );
401 match serde_json::to_vec(&resp) {
402 Ok(body) => json_response(200, body),
403 Err(e) => internal_serialization_error(id, &e),
404 }
405}
406
407fn parse_error_response(
408 id: JsonRpcId,
409 message: &str,
410) -> hyper::Response<BoxBody<Bytes, Infallible>> {
411 let resp = JsonRpcErrorResponse::new(
412 id.clone(),
413 JsonRpcError::new(
414 a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
415 format!("Parse error: {message}"),
416 ),
417 );
418 match serde_json::to_vec(&resp) {
419 Ok(body) => json_response(200, body),
420 Err(e) => internal_serialization_error(id, &e),
421 }
422}
423
424fn internal_serialization_error(
426 _id: JsonRpcId,
427 _err: &serde_json::Error,
428) -> hyper::Response<BoxBody<Bytes, Infallible>> {
429 trace_error!(error = %_err, "JSON-RPC response serialization failed");
430 let body = br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"internal serialization error"}}"#;
432 json_response(500, body.to_vec())
433}
434
435const MAX_REQUEST_BODY_SIZE: usize = 4 * 1024 * 1024;
437
438const BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
440
441async fn read_body_limited(body: Incoming, max_size: usize) -> Result<Bytes, String> {
445 let size_hint = <Incoming as hyper::body::Body>::size_hint(&body);
447 if let Some(upper) = size_hint.upper() {
448 if upper > max_size as u64 {
449 return Err(format!(
450 "request body too large: {upper} bytes exceeds {max_size} byte limit"
451 ));
452 }
453 }
454
455 let collected = tokio::time::timeout(BODY_READ_TIMEOUT, body.collect())
456 .await
457 .map_err(|_| "request body read timed out".to_owned())?
458 .map_err(|e| e.to_string())?;
459 let bytes = collected.to_bytes();
460 if bytes.len() > max_size {
461 return Err(format!(
462 "request body too large: {} bytes exceeds {max_size} byte limit",
463 bytes.len()
464 ));
465 }
466 Ok(bytes)
467}
468
469fn json_response(status: u16, body: Vec<u8>) -> hyper::Response<BoxBody<Bytes, Infallible>> {
471 hyper::Response::builder()
472 .status(status)
473 .header("content-type", a2a_protocol_types::A2A_CONTENT_TYPE)
474 .header(a2a_protocol_types::A2A_VERSION_HEADER, a2a_protocol_types::A2A_VERSION)
475 .body(Full::new(Bytes::from(body)).boxed())
476 .unwrap_or_else(|_| {
477 hyper::Response::new(
480 Full::new(Bytes::from_static(
481 br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"response build error"}}"#,
482 ))
483 .boxed(),
484 )
485 })
486}