Skip to main content

a2a_protocol_server/dispatch/
jsonrpc.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! JSON-RPC 2.0 dispatcher.
5//!
6//! [`JsonRpcDispatcher`] reads JSON-RPC requests from HTTP bodies, routes
7//! them to the appropriate [`RequestHandler`] method, and serializes the
8//! response (or streams SSE for streaming methods).
9
10use std::convert::Infallible;
11use std::sync::Arc;
12
13use bytes::Bytes;
14use http_body_util::combinators::BoxBody;
15use http_body_util::{BodyExt, Full};
16use hyper::body::Incoming;
17
18use a2a_protocol_types::jsonrpc::{
19    JsonRpcError, JsonRpcErrorResponse, JsonRpcId, JsonRpcRequest, JsonRpcSuccessResponse,
20    JsonRpcVersion,
21};
22
23use crate::dispatch::cors::CorsConfig;
24use crate::error::ServerError;
25use crate::handler::{RequestHandler, SendMessageResult};
26use crate::streaming::build_sse_response;
27
28/// JSON-RPC 2.0 request dispatcher.
29///
30/// Routes incoming JSON-RPC requests to the underlying [`RequestHandler`].
31/// Optionally applies CORS headers to all responses.
32pub struct JsonRpcDispatcher {
33    handler: Arc<RequestHandler>,
34    cors: Option<CorsConfig>,
35}
36
37impl JsonRpcDispatcher {
38    /// Creates a new dispatcher wrapping the given handler.
39    #[must_use]
40    pub const fn new(handler: Arc<RequestHandler>) -> Self {
41        Self {
42            handler,
43            cors: None,
44        }
45    }
46
47    /// Sets CORS configuration for this dispatcher.
48    ///
49    /// When set, all responses will include CORS headers, and `OPTIONS` preflight
50    /// requests will be handled automatically.
51    #[must_use]
52    pub fn with_cors(mut self, cors: CorsConfig) -> Self {
53        self.cors = Some(cors);
54        self
55    }
56
57    /// Dispatches a JSON-RPC request and returns an HTTP response.
58    ///
59    /// For `SendStreamingMessage` and `SubscribeToTask`, the response uses
60    /// SSE (`text/event-stream`). All other methods return JSON.
61    ///
62    /// JSON-RPC errors are always returned as HTTP 200 with an error body.
63    pub async fn dispatch(
64        &self,
65        req: hyper::Request<Incoming>,
66    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
67        // Handle CORS preflight requests.
68        if req.method() == "OPTIONS" {
69            if let Some(ref cors) = self.cors {
70                return cors.preflight_response();
71            }
72            return json_response(204, Vec::new());
73        }
74
75        let mut resp = self.dispatch_inner(req).await;
76        if let Some(ref cors) = self.cors {
77            cors.apply_headers(&mut resp);
78        }
79        resp
80    }
81
82    /// Inner dispatch logic (separated to allow CORS wrapping).
83    #[allow(clippy::too_many_lines)]
84    async fn dispatch_inner(
85        &self,
86        req: hyper::Request<Incoming>,
87    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
88        // Validate Content-Type if present.
89        if let Some(ct) = req.headers().get("content-type") {
90            let ct_str = ct.to_str().unwrap_or("");
91            if !ct_str.starts_with("application/json")
92                && !ct_str.starts_with(a2a_protocol_types::A2A_CONTENT_TYPE)
93            {
94                return parse_error_response(
95                    None,
96                    &format!("unsupported Content-Type: {ct_str}; expected application/json or application/a2a+json"),
97                );
98            }
99        }
100
101        // Read body with size limit (default 4 MiB).
102        let body_bytes = match read_body_limited(req.into_body(), MAX_REQUEST_BODY_SIZE).await {
103            Ok(bytes) => bytes,
104            Err(msg) => return parse_error_response(None, &msg),
105        };
106
107        // JSON-RPC 2.0 §6.3: detect batch (array) vs single (object) request.
108        let raw: serde_json::Value = match serde_json::from_slice(&body_bytes) {
109            Ok(v) => v,
110            Err(e) => return parse_error_response(None, &e.to_string()),
111        };
112
113        if let Some(items) = raw.as_array() {
114            // Batch request: dispatch each element, collect responses.
115            if items.is_empty() {
116                return parse_error_response(None, "empty batch request");
117            }
118            let mut responses: Vec<serde_json::Value> = Vec::with_capacity(items.len());
119            for item in items {
120                let rpc_req: JsonRpcRequest = match serde_json::from_value(item.clone()) {
121                    Ok(r) => r,
122                    Err(e) => {
123                        // Invalid request within batch — return individual parse error.
124                        let err_resp = JsonRpcErrorResponse::new(
125                            None,
126                            JsonRpcError::new(
127                                a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
128                                format!("Parse error: {e}"),
129                            ),
130                        );
131                        if let Ok(v) = serde_json::to_value(&err_resp) {
132                            responses.push(v);
133                        }
134                        continue;
135                    }
136                };
137                let resp_body = self.dispatch_single_request(&rpc_req).await;
138                if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&resp_body) {
139                    responses.push(v);
140                }
141            }
142            let body = serde_json::to_vec(&responses).unwrap_or_default();
143            json_response(200, body)
144        } else {
145            // Single request.
146            let rpc_req: JsonRpcRequest = match serde_json::from_value(raw) {
147                Ok(r) => r,
148                Err(e) => return parse_error_response(None, &e.to_string()),
149            };
150            self.dispatch_single_request_http(&rpc_req).await
151        }
152    }
153
154    /// Dispatches a single JSON-RPC request and returns an HTTP response.
155    ///
156    /// For streaming methods, the response is SSE. For non-streaming, JSON.
157    #[allow(clippy::too_many_lines)]
158    async fn dispatch_single_request_http(
159        &self,
160        rpc_req: &JsonRpcRequest,
161    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
162        let id = rpc_req.id.clone();
163        trace_info!(method = %rpc_req.method, "dispatching JSON-RPC request");
164
165        // Streaming methods return SSE, not JSON.
166        match rpc_req.method.as_str() {
167            "SendStreamingMessage" => return self.dispatch_send_message(id, rpc_req, true).await,
168            "SubscribeToTask" => {
169                return match parse_params::<a2a_protocol_types::params::TaskIdParams>(rpc_req) {
170                    Ok(p) => match self.handler.on_resubscribe(p).await {
171                        Ok(reader) => build_sse_response(reader, None),
172                        Err(e) => error_response(id, &e),
173                    },
174                    Err(e) => error_response(id, &e),
175                };
176            }
177            _ => {}
178        }
179
180        let body = self.dispatch_single_request(rpc_req).await;
181        json_response(200, body)
182    }
183
184    /// Dispatches a single JSON-RPC request and returns the response body bytes.
185    ///
186    /// Used for both single and batch requests.
187    #[allow(clippy::too_many_lines)]
188    async fn dispatch_single_request(&self, rpc_req: &JsonRpcRequest) -> Vec<u8> {
189        let id = rpc_req.id.clone();
190
191        match rpc_req.method.as_str() {
192            "SendMessage" => {
193                match self
194                    .dispatch_send_message_inner(id.clone(), rpc_req, false)
195                    .await
196                {
197                    Ok(resp) => serde_json::to_vec(&resp).unwrap_or_default(),
198                    Err(body) => body,
199                }
200            }
201            "SendStreamingMessage" => {
202                // In batch context, streaming is not supported — return error.
203                let err = ServerError::InvalidParams(
204                    "SendStreamingMessage not supported in batch requests".into(),
205                );
206                let a2a_err = err.to_a2a_error();
207                let resp = JsonRpcErrorResponse::new(
208                    id,
209                    JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
210                );
211                serde_json::to_vec(&resp).unwrap_or_default()
212            }
213            "GetTask" => match parse_params::<a2a_protocol_types::params::TaskQueryParams>(rpc_req)
214            {
215                Ok(p) => match self.handler.on_get_task(p).await {
216                    Ok(r) => success_response_bytes(id, &r),
217                    Err(e) => error_response_bytes(id, &e),
218                },
219                Err(e) => error_response_bytes(id, &e),
220            },
221            "ListTasks" => {
222                match parse_params::<a2a_protocol_types::params::ListTasksParams>(rpc_req) {
223                    Ok(p) => match self.handler.on_list_tasks(p).await {
224                        Ok(r) => success_response_bytes(id, &r),
225                        Err(e) => error_response_bytes(id, &e),
226                    },
227                    Err(e) => error_response_bytes(id, &e),
228                }
229            }
230            "CancelTask" => {
231                match parse_params::<a2a_protocol_types::params::CancelTaskParams>(rpc_req) {
232                    Ok(p) => match self.handler.on_cancel_task(p).await {
233                        Ok(r) => success_response_bytes(id, &r),
234                        Err(e) => error_response_bytes(id, &e),
235                    },
236                    Err(e) => error_response_bytes(id, &e),
237                }
238            }
239            "SubscribeToTask" => {
240                let err = ServerError::InvalidParams(
241                    "SubscribeToTask not supported in batch requests".into(),
242                );
243                error_response_bytes(id, &err)
244            }
245            "CreateTaskPushNotificationConfig" => {
246                match parse_params::<a2a_protocol_types::push::TaskPushNotificationConfig>(rpc_req)
247                {
248                    Ok(p) => match self.handler.on_set_push_config(p).await {
249                        Ok(r) => success_response_bytes(id, &r),
250                        Err(e) => error_response_bytes(id, &e),
251                    },
252                    Err(e) => error_response_bytes(id, &e),
253                }
254            }
255            "GetTaskPushNotificationConfig" => {
256                match parse_params::<a2a_protocol_types::params::GetPushConfigParams>(rpc_req) {
257                    Ok(p) => match self.handler.on_get_push_config(p).await {
258                        Ok(r) => success_response_bytes(id, &r),
259                        Err(e) => error_response_bytes(id, &e),
260                    },
261                    Err(e) => error_response_bytes(id, &e),
262                }
263            }
264            "ListTaskPushNotificationConfigs" => {
265                match parse_params::<a2a_protocol_types::params::TaskIdParams>(rpc_req) {
266                    Ok(p) => match self.handler.on_list_push_configs(&p.id).await {
267                        Ok(r) => success_response_bytes(id, &r),
268                        Err(e) => error_response_bytes(id, &e),
269                    },
270                    Err(e) => error_response_bytes(id, &e),
271                }
272            }
273            "DeleteTaskPushNotificationConfig" => {
274                match parse_params::<a2a_protocol_types::params::DeletePushConfigParams>(rpc_req) {
275                    Ok(p) => match self.handler.on_delete_push_config(p).await {
276                        Ok(()) => success_response_bytes(id, &serde_json::json!({})),
277                        Err(e) => error_response_bytes(id, &e),
278                    },
279                    Err(e) => error_response_bytes(id, &e),
280                }
281            }
282            "GetExtendedAgentCard" => match self.handler.on_get_extended_agent_card().await {
283                Ok(r) => success_response_bytes(id, &r),
284                Err(e) => error_response_bytes(id, &e),
285            },
286            other => {
287                let err = ServerError::MethodNotFound(other.to_owned());
288                error_response_bytes(id, &err)
289            }
290        }
291    }
292
293    /// Helper for dispatching `SendMessage` that returns either a success response
294    /// value (for batch) or the body bytes on error.
295    async fn dispatch_send_message_inner(
296        &self,
297        id: JsonRpcId,
298        rpc_req: &JsonRpcRequest,
299        streaming: bool,
300    ) -> Result<JsonRpcSuccessResponse<serde_json::Value>, Vec<u8>> {
301        let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
302            Ok(p) => p,
303            Err(e) => return Err(error_response_bytes(id, &e)),
304        };
305        match self.handler.on_send_message(params, streaming).await {
306            Ok(SendMessageResult::Response(resp)) => {
307                let result = serde_json::to_value(&resp).unwrap_or(serde_json::Value::Null);
308                Ok(JsonRpcSuccessResponse {
309                    jsonrpc: JsonRpcVersion,
310                    id,
311                    result,
312                })
313            }
314            Ok(SendMessageResult::Stream(_)) => {
315                // Shouldn't happen in non-streaming mode.
316                let err = ServerError::Internal("unexpected stream response".into());
317                Err(error_response_bytes(id, &err))
318            }
319            Err(e) => Err(error_response_bytes(id, &e)),
320        }
321    }
322
323    async fn dispatch_send_message(
324        &self,
325        id: JsonRpcId,
326        rpc_req: &JsonRpcRequest,
327        streaming: bool,
328    ) -> hyper::Response<BoxBody<Bytes, Infallible>> {
329        let params = match parse_params::<a2a_protocol_types::params::MessageSendParams>(rpc_req) {
330            Ok(p) => p,
331            Err(e) => return error_response(id, &e),
332        };
333        match self.handler.on_send_message(params, streaming).await {
334            Ok(SendMessageResult::Response(resp)) => success_response(id, &resp),
335            Ok(SendMessageResult::Stream(reader)) => build_sse_response(reader, None),
336            Err(e) => error_response(id, &e),
337        }
338    }
339}
340
341/// Serializes a success response to bytes (for batch request support).
342fn success_response_bytes<T: serde::Serialize>(id: JsonRpcId, result: &T) -> Vec<u8> {
343    let resp = JsonRpcSuccessResponse {
344        jsonrpc: JsonRpcVersion,
345        id,
346        result: serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
347    };
348    serde_json::to_vec(&resp).unwrap_or_default()
349}
350
351/// Serializes an error response to bytes (for batch request support).
352fn error_response_bytes(id: JsonRpcId, err: &ServerError) -> Vec<u8> {
353    let a2a_err = err.to_a2a_error();
354    let resp = JsonRpcErrorResponse::new(
355        id,
356        JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
357    );
358    serde_json::to_vec(&resp).unwrap_or_default()
359}
360
361impl std::fmt::Debug for JsonRpcDispatcher {
362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        f.debug_struct("JsonRpcDispatcher").finish()
364    }
365}
366
367// ── Free functions ───────────────────────────────────────────────────────────
368
369fn parse_params<T: serde::de::DeserializeOwned>(
370    rpc_req: &JsonRpcRequest,
371) -> Result<T, ServerError> {
372    let params = rpc_req
373        .params
374        .as_ref()
375        .ok_or_else(|| ServerError::InvalidParams("missing params".into()))?;
376    serde_json::from_value(params.clone())
377        .map_err(|e| ServerError::InvalidParams(format!("invalid params: {e}")))
378}
379
380fn success_response<T: serde::Serialize>(
381    id: JsonRpcId,
382    result: &T,
383) -> hyper::Response<BoxBody<Bytes, Infallible>> {
384    let resp = JsonRpcSuccessResponse {
385        jsonrpc: JsonRpcVersion,
386        id: id.clone(),
387        result: serde_json::to_value(result).unwrap_or(serde_json::Value::Null),
388    };
389    match serde_json::to_vec(&resp) {
390        Ok(body) => json_response(200, body),
391        Err(e) => internal_serialization_error(id, &e),
392    }
393}
394
395fn error_response(id: JsonRpcId, err: &ServerError) -> hyper::Response<BoxBody<Bytes, Infallible>> {
396    let a2a_err = err.to_a2a_error();
397    let resp = JsonRpcErrorResponse::new(
398        id.clone(),
399        JsonRpcError::new(a2a_err.code.as_i32(), a2a_err.message),
400    );
401    match serde_json::to_vec(&resp) {
402        Ok(body) => json_response(200, body),
403        Err(e) => internal_serialization_error(id, &e),
404    }
405}
406
407fn parse_error_response(
408    id: JsonRpcId,
409    message: &str,
410) -> hyper::Response<BoxBody<Bytes, Infallible>> {
411    let resp = JsonRpcErrorResponse::new(
412        id.clone(),
413        JsonRpcError::new(
414            a2a_protocol_types::error::ErrorCode::ParseError.as_i32(),
415            format!("Parse error: {message}"),
416        ),
417    );
418    match serde_json::to_vec(&resp) {
419        Ok(body) => json_response(200, body),
420        Err(e) => internal_serialization_error(id, &e),
421    }
422}
423
424/// Fallback response when JSON-RPC serialization itself fails.
425fn internal_serialization_error(
426    _id: JsonRpcId,
427    _err: &serde_json::Error,
428) -> hyper::Response<BoxBody<Bytes, Infallible>> {
429    trace_error!(error = %_err, "JSON-RPC response serialization failed");
430    // Hand-craft a minimal JSON-RPC error to avoid further serialization failures.
431    let body = br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"internal serialization error"}}"#;
432    json_response(500, body.to_vec())
433}
434
435/// Maximum request body size in bytes (4 MiB).
436const MAX_REQUEST_BODY_SIZE: usize = 4 * 1024 * 1024;
437
438/// Maximum duration to read a complete request body (slow loris protection).
439const BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
440
441/// Reads a request body with a size limit and timeout.
442///
443/// Returns an error message if the body exceeds the limit, times out, or cannot be read.
444async fn read_body_limited(body: Incoming, max_size: usize) -> Result<Bytes, String> {
445    // Check Content-Length header upfront if present.
446    let size_hint = <Incoming as hyper::body::Body>::size_hint(&body);
447    if let Some(upper) = size_hint.upper() {
448        if upper > max_size as u64 {
449            return Err(format!(
450                "request body too large: {upper} bytes exceeds {max_size} byte limit"
451            ));
452        }
453    }
454
455    let collected = tokio::time::timeout(BODY_READ_TIMEOUT, body.collect())
456        .await
457        .map_err(|_| "request body read timed out".to_owned())?
458        .map_err(|e| e.to_string())?;
459    let bytes = collected.to_bytes();
460    if bytes.len() > max_size {
461        return Err(format!(
462            "request body too large: {} bytes exceeds {max_size} byte limit",
463            bytes.len()
464        ));
465    }
466    Ok(bytes)
467}
468
469/// Builds a JSON HTTP response with the given status and body.
470fn json_response(status: u16, body: Vec<u8>) -> hyper::Response<BoxBody<Bytes, Infallible>> {
471    hyper::Response::builder()
472        .status(status)
473        .header("content-type", a2a_protocol_types::A2A_CONTENT_TYPE)
474        .header(a2a_protocol_types::A2A_VERSION_HEADER, a2a_protocol_types::A2A_VERSION)
475        .body(Full::new(Bytes::from(body)).boxed())
476        .unwrap_or_else(|_| {
477            // Fallback: plain 500 response if builder fails (should never happen
478            // with valid static header names).
479            hyper::Response::new(
480                Full::new(Bytes::from_static(
481                    br#"{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message":"response build error"}}"#,
482                ))
483                .boxed(),
484            )
485        })
486}