1mod response;
13
14use std::collections::HashMap;
15use std::convert::Infallible;
16use std::sync::Arc;
17
18use bytes::Bytes;
19use http_body_util::combinators::BoxBody;
20use hyper::body::Incoming;
21
22use a2a_protocol_types::jsonrpc::{
23 JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
24 JsonRpcVersion,
25};
26
27use crate::agent_card::StaticAgentCardHandler;
28use crate::dispatch::cors::CorsConfig;
29use crate::error::ServerError;
30use crate::handler::{RequestHandler, SendMessageResult};
31use crate::serve::Dispatcher;
32use crate::streaming::build_sse_response;
33
34use response::{
35 error_response, error_response_bytes, extract_headers, json_response, parse_error_response,
36 parse_params, read_body_limited, success_response, success_response_bytes,
37};
38
39pub struct JsonRpcDispatcher {
47 handler: Arc<RequestHandler>,
48 card_handler: Option<StaticAgentCardHandler>,
49 cors: Option<CorsConfig>,
50 config: super::DispatchConfig,
51}
52
53impl JsonRpcDispatcher {
54 #[must_use]
57 pub fn new(handler: Arc<RequestHandler>) -> Self {
58 Self::with_config(handler, super::DispatchConfig::default())
59 }
60
61 #[must_use]
63 pub fn with_config(handler: Arc<RequestHandler>, config: super::DispatchConfig) -> Self {
64 let card_handler = handler
65 .agent_card
66 .as_ref()
67 .and_then(|card| StaticAgentCardHandler::new(card).ok());
68 Self {
69 handler,
70 card_handler,
71 cors: None,
72 config,
73 }
74 }
75
76 #[must_use]
81 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
82 self.cors = Some(cors);
83 self
84 }
85
86 pub async fn dispatch(
93 &self,
94 req: hyper::Request<Incoming>,
95 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
96 if req.method() == "OPTIONS" {
98 if let Some(ref cors) = self.cors {
99 return cors.preflight_response();
100 }
101 return json_response(204, Vec::new());
102 }
103
104 if req.method() == "GET" && req.uri().path() == "/.well-known/agent.json" {
107 let mut resp = self.card_handler.as_ref().map_or_else(
108 || json_response(404, br#"{"error":"agent card not configured"}"#.to_vec()),
109 |h| h.handle(&req).map(http_body_util::BodyExt::boxed),
110 );
111 if let Some(ref cors) = self.cors {
112 cors.apply_headers(&mut resp);
113 }
114 return resp;
115 }
116
117 let mut resp = self.dispatch_inner(req).await;
118 if let Some(ref cors) = self.cors {
119 cors.apply_headers(&mut resp);
120 }
121 resp
122 }
123
124 #[allow(clippy::too_many_lines)]
126 async fn dispatch_inner(
127 &self,
128 req: hyper::Request<Incoming>,
129 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
130 if let Some(ct) = req.headers().get("content-type") {
132 let ct_str = ct.to_str().unwrap_or("");
133 if !ct_str.starts_with("application/json")
134 && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
135 {
136 return parse_error_response(
137 None,
138 &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
139 );
140 }
141 }
142
143 let headers = extract_headers(req.headers());
145
146 let body_bytes = match read_body_limited(
148 req.into_body(),
149 self.config.max_request_body_size,
150 self.config.body_read_timeout,
151 )
152 .await
153 {
154 Ok(bytes) => bytes,
155 Err(msg) => return parse_error_response(None, &msg),
156 };
157
158 let raw: serde_json::Value = match serde_json::from_slice(&body_bytes) {
160 Ok(v) => v,
161 Err(e) => return parse_error_response(None, &e.to_string()),
162 };
163
164 if raw.is_array() {
165 let serde_json::Value::Array(items) = raw else {
167 unreachable!()
168 };
169 if items.is_empty() {
170 return parse_error_response(None, "empty batch request");
171 }
172 if items.len() > self.config.max_batch_size {
174 return parse_error_response(
175 None,
176 &format!(
177 "batch too large: {} requests exceeds {} limit",
178 items.len(),
179 self.config.max_batch_size
180 ),
181 );
182 }
183 let mut responses: Vec<serde_json::Value> = Vec::with_capacity(items.len());
184 for item in items {
185 let rpc_req: JsonRpcRequest = match serde_json::from_value(item) {
186 Ok(r) => r,
187 Err(e) => {
188 let err_resp = JsonRpcErrorResponse::new(
190 None,
191 JsonRpcError::new(
192 a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
193 format!("Parse error: {e}"),
194 ),
195 );
196 if let Ok(v) = serde_json::to_value(&err_resp) {
197 responses.push(v);
198 }
199 continue;
200 }
201 };
202 let resp_body = self.dispatch_single_request(&rpc_req, &headers).await;
203 if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&resp_body) {
204 responses.push(v);
205 }
206 }
207 let body = serde_json::to_vec(&responses).unwrap_or_default();
208 json_response(200, body)
209 } else {
210 let rpc_req: JsonRpcRequest = match serde_json::from_value(raw) {
212 Ok(r) => r,
213 Err(e) => return parse_error_response(None, &e.to_string()),
214 };
215 self.dispatch_single_request_http(&rpc_req, &headers).await
216 }
217 }
218
219 #[allow(clippy::too_many_lines)]
223 async fn dispatch_single_request_http(
224 &self,
225 rpc_req: &JsonRpcRequest,
226 headers: &HashMap<String, String>,
227 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
228 let id = rpc_req.id.clone();
229 trace_info!(method = %rpc_req.method, "dispatching JSON-RPC request");
230
231 match rpc_req.method.as_str() {
233 "SendStreamingMessage" | "message/stream" => {
234 return self.dispatch_send_message(id, rpc_req, true, headers).await;
235 }
236 "SubscribeToTask" | "tasks/subscribe" => {
237 return match parse_params::<a2a_protocol_types::params::TaskIdParams>(rpc_req) {
238 Ok(p) => match self.handler.on_resubscribe(p, Some(headers)).await {
239 Ok(reader) => build_sse_response(
240 reader,
241 Some(self.config.sse_keep_alive_interval),
242 Some(self.config.sse_channel_capacity),
243 ),
244 Err(e) => error_response(id, &e),
245 },
246 Err(e) => error_response(id, &e),
247 };
248 }
249 _ => {}
250 }
251
252 let body = self.dispatch_single_request(rpc_req, headers).await;
253 json_response(200, body)
254 }
255
256 #[allow(clippy::too_many_lines)]
260 async fn dispatch_single_request(
261 &self,
262 rpc_req: &JsonRpcRequest,
263 headers: &HashMap<String, String>,
264 ) -> Vec<u8> {
265 let id = rpc_req.id.clone();
266
267 match rpc_req.method.as_str() {
268 "SendMessage" | "message/send" => {
269 match self
270 .dispatch_send_message_inner(id.clone(), rpc_req, false, headers)
271 .await
272 {
273 Ok(resp) => serde_json::to_vec(&resp).unwrap_or_default(),
274 Err(body) => body,
275 }
276 }
277 "SendStreamingMessage" | "message/stream" => {
278 let err = ServerError::InvalidParams(
280 "SendStreamingMessage not supported in batch requests".into(),
281 );
282 let a2a_err = err.to_a2a_error();
283 let resp = JsonRpcErrorResponse::new(
284 id,
285 JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
286 );
287 serde_json::to_vec(&resp).unwrap_or_default()
288 }
289 "GetTask" | "tasks/get" => {
290 match parse_params::<a2a_protocol_types::params::TaskQueryParams>(rpc_req) {
291 Ok(p) => match self.handler.on_get_task(p, Some(headers)).await {
292 Ok(r) => success_response_bytes(id, &r),
293 Err(e) => error_response_bytes(id, &e),
294 },
295 Err(e) => error_response_bytes(id, &e),
296 }
297 }
298 "ListTasks" | "tasks/list" => {
299 match parse_params::<a2a_protocol_types::params::ListTasksParams>(rpc_req) {
300 Ok(p) => match self.handler.on_list_tasks(p, Some(headers)).await {
301 Ok(r) => success_response_bytes(id, &r),
302 Err(e) => error_response_bytes(id, &e),
303 },
304 Err(e) => error_response_bytes(id, &e),
305 }
306 }
307 "CancelTask" | "tasks/cancel" => {
308 match parse_params::<a2a_protocol_types::params::CancelTaskParams>(rpc_req) {
309 Ok(p) => match self.handler.on_cancel_task(p, Some(headers)).await {
310 Ok(r) => success_response_bytes(id, &r),
311 Err(e) => error_response_bytes(id, &e),
312 },
313 Err(e) => error_response_bytes(id, &e),
314 }
315 }
316 "SubscribeToTask" | "tasks/subscribe" => {
317 let err = ServerError::InvalidParams(
318 "SubscribeToTask not supported in batch requests".into(),
319 );
320 error_response_bytes(id, &err)
321 }
322 "CreateTaskPushNotificationConfig" | "tasks/pushNotificationConfig/set" => {
323 match parse_params::<a2a_protocol_types::push::TaskPushNotificationConfig>(rpc_req)
324 {
325 Ok(p) => match self.handler.on_set_push_config(p, Some(headers)).await {
326 Ok(r) => success_response_bytes(id, &r),
327 Err(e) => error_response_bytes(id, &e),
328 },
329 Err(e) => error_response_bytes(id, &e),
330 }
331 }
332 "GetTaskPushNotificationConfig" | "tasks/pushNotificationConfig/get" => {
333 match parse_params::<a2a_protocol_types::params::GetPushConfigParams>(rpc_req) {
334 Ok(p) => match self.handler.on_get_push_config(p, Some(headers)).await {
335 Ok(r) => success_response_bytes(id, &r),
336 Err(e) => error_response_bytes(id, &e),
337 },
338 Err(e) => error_response_bytes(id, &e),
339 }
340 }
341 "ListTaskPushNotificationConfigs" | "tasks/pushNotificationConfig/list" => {
342 match parse_params::<a2a_protocol_types::params::ListPushConfigsParams>(rpc_req) {
343 Ok(p) => match self
344 .handler
345 .on_list_push_configs(&p.task_id, p.tenant.as_deref(), Some(headers))
346 .await
347 {
348 Ok(configs) => {
349 let resp = a2a_protocol_types::responses::ListPushConfigsResponse {
350 configs,
351 next_page_token: None,
352 };
353 success_response_bytes(id, &resp)
354 }
355 Err(e) => error_response_bytes(id, &e),
356 },
357 Err(e) => error_response_bytes(id, &e),
358 }
359 }
360 "DeleteTaskPushNotificationConfig" | "tasks/pushNotificationConfig/delete" => {
361 match parse_params::<a2a_protocol_types::params::DeletePushConfigParams>(rpc_req) {
362 Ok(p) => match self.handler.on_delete_push_config(p, Some(headers)).await {
363 Ok(()) => success_response_bytes(id, &serde_json::json!({})),
364 Err(e) => error_response_bytes(id, &e),
365 },
366 Err(e) => error_response_bytes(id, &e),
367 }
368 }
369 "GetExtendedAgentCard" | "agent/authenticatedExtendedCard" => {
370 match self.handler.on_get_extended_agent_card(Some(headers)).await {
371 Ok(r) => success_response_bytes(id, &r),
372 Err(e) => error_response_bytes(id, &e),
373 }
374 }
375 other => {
376 let err = ServerError::MethodNotFound(other.to_owned());
377 error_response_bytes(id, &err)
378 }
379 }
380 }
381
382 async fn dispatch_send_message_inner(
385 &self,
386 id: JsonRpcId,
387 rpc_req: &JsonRpcRequest,
388 streaming: bool,
389 headers: &HashMap<String, String>,
390 ) -> Result<JsonRpcSuccessResponse<serde_json::Value>, Vec<u8>> {
391 let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
392 Ok(p) => p,
393 Err(e) => return Err(error_response_bytes(id, &e)),
394 };
395 match self
396 .handler
397 .on_send_message(params, streaming, Some(headers))
398 .await
399 {
400 Ok(SendMessageResult::Response(resp)) => {
401 let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
402 Ok(JsonRpcSuccessResponse {
403 jsonrpc: JsonRpcVersion,
404 id,
405 result,
406 })
407 }
408 Ok(SendMessageResult::Stream(_)) => {
409 let err = ServerError::Internal("unexpected stream response".into());
411 Err(error_response_bytes(id, &err))
412 }
413 Err(e) => Err(error_response_bytes(id, &e)),
414 }
415 }
416
417 async fn dispatch_send_message(
418 &self,
419 id: JsonRpcId,
420 rpc_req: &JsonRpcRequest,
421 streaming: bool,
422 headers: &HashMap<String, String>,
423 ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
424 let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
425 Ok(p) => p,
426 Err(e) => return error_response(id, &e),
427 };
428 match self
429 .handler
430 .on_send_message(params, streaming, Some(headers))
431 .await
432 {
433 Ok(SendMessageResult::Response(resp)) => success_response(id, &resp),
434 Ok(SendMessageResult::Stream(reader)) => build_sse_response(
435 reader,
436 Some(self.config.sse_keep_alive_interval),
437 Some(self.config.sse_channel_capacity),
438 ),
439 Err(e) => error_response(id, &e),
440 }
441 }
442}
443
444impl std::fmt::Debug for JsonRpcDispatcher {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446 f.debug_struct("JsonRpcDispatcher").finish()
447 }
448}
449
450impl Dispatcher for JsonRpcDispatcher {
453 fn dispatch(
454 &self,
455 req: hyper::Request<Incoming>,
456 ) -> std::pin::Pin<
457 Box<dyn std::future::Future<Output = crate::serve::DispatchResponse> + Send + '_>,
458 > {
459 Box::pin(self.dispatch(req))
460 }
461}