pmcp/shared/
batch.rs

1//! Batch request handling for JSON-RPC 2.0.
2//!
3//! This module provides support for processing multiple JSON-RPC requests
4//! in a single batch, as per the JSON-RPC 2.0 specification.
5
6use crate::error::{Error, Result};
7use crate::types::{JSONRPCRequest, JSONRPCResponse};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11/// A batch of JSON-RPC requests.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(untagged)]
14pub enum BatchRequest {
15    /// A single request
16    Single(JSONRPCRequest),
17    /// Multiple requests in a batch
18    Batch(Vec<JSONRPCRequest>),
19}
20
21/// A batch of JSON-RPC responses.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(untagged)]
24pub enum BatchResponse {
25    /// A single response
26    Single(JSONRPCResponse),
27    /// Multiple responses in a batch
28    Batch(Vec<JSONRPCResponse>),
29}
30
31impl BatchRequest {
32    /// Parse a JSON value into a batch request.
33    pub fn from_value(value: Value) -> Result<Self> {
34        serde_json::from_value(value)
35            .map_err(|e| Error::parse(format!("Invalid batch request format: {}", e)))
36    }
37
38    /// Convert the batch request to a JSON value.
39    pub fn to_value(&self) -> Result<Value> {
40        serde_json::to_value(self)
41            .map_err(|e| Error::internal(format!("Failed to serialize batch request: {}", e)))
42    }
43
44    /// Check if this is a batch request (multiple requests).
45    pub fn is_batch(&self) -> bool {
46        matches!(self, Self::Batch(_))
47    }
48
49    /// Get the requests as a vector.
50    pub fn into_requests(self) -> Vec<JSONRPCRequest> {
51        match self {
52            Self::Single(req) => vec![req],
53            Self::Batch(reqs) => reqs,
54        }
55    }
56
57    /// Get the number of requests in the batch.
58    pub fn len(&self) -> usize {
59        match self {
60            Self::Single(_) => 1,
61            Self::Batch(reqs) => reqs.len(),
62        }
63    }
64
65    /// Check if the batch is empty.
66    pub fn is_empty(&self) -> bool {
67        match self {
68            Self::Single(_) => false,
69            Self::Batch(reqs) => reqs.is_empty(),
70        }
71    }
72}
73
74impl BatchResponse {
75    /// Create a batch response from a vector of responses.
76    ///
77    /// # Panics
78    ///
79    /// Panics if the vector has exactly 1 element but `next()` returns `None` (which should never happen).
80    pub fn from_responses(responses: Vec<JSONRPCResponse>) -> Self {
81        match responses.len() {
82            0 => Self::Batch(vec![]),
83            1 => Self::Single(responses.into_iter().next().unwrap()),
84            _ => Self::Batch(responses),
85        }
86    }
87
88    /// Convert the batch response to a JSON value.
89    pub fn to_value(&self) -> Result<Value> {
90        serde_json::to_value(self)
91            .map_err(|e| Error::internal(format!("Failed to serialize batch response: {}", e)))
92    }
93
94    /// Get the responses as a vector.
95    pub fn into_responses(self) -> Vec<JSONRPCResponse> {
96        match self {
97            Self::Single(resp) => vec![resp],
98            Self::Batch(resps) => resps,
99        }
100    }
101}
102
103/// Process a batch of requests.
104///
105/// This function takes a batch request and a handler function, processes each
106/// request (potentially in parallel), and returns a batch response.
107pub async fn process_batch_request<F, Fut>(batch: BatchRequest, handler: F) -> Result<BatchResponse>
108where
109    F: Fn(JSONRPCRequest) -> Fut + Clone + Send + Sync + 'static,
110    Fut: std::future::Future<Output = JSONRPCResponse> + Send + 'static,
111{
112    let requests = batch.into_requests();
113
114    // Empty batch should return empty array
115    if requests.is_empty() {
116        return Ok(BatchResponse::Batch(vec![]));
117    }
118
119    // Process all requests
120    // Process requests in parallel while maintaining order
121    #[cfg(not(target_arch = "wasm32"))]
122    let responses = if requests.len() > 1 {
123        // Use parallel processing for multiple requests
124        let config = crate::utils::parallel_batch::ParallelBatchConfig::default();
125        crate::utils::parallel_batch::process_batch_parallel(requests, handler, config).await?
126    } else {
127        // For single request, process directly
128        let mut responses = Vec::with_capacity(requests.len());
129        for request in requests {
130            let response = handler(request).await;
131            responses.push(response);
132        }
133        responses
134    };
135
136    // For WASM, always process sequentially
137    #[cfg(target_arch = "wasm32")]
138    let responses = {
139        let mut responses = Vec::with_capacity(requests.len());
140        for request in requests {
141            let response = handler(request).await;
142            responses.push(response);
143        }
144        responses
145    };
146
147    Ok(BatchResponse::from_responses(responses))
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::types::{jsonrpc::ResponsePayload, RequestId};
154    use serde_json::json;
155
156    #[test]
157    fn test_single_request_parsing() {
158        let json = json!({
159            "jsonrpc": "2.0",
160            "method": "test",
161            "params": {"value": 42},
162            "id": 1
163        });
164
165        let batch = BatchRequest::from_value(json).unwrap();
166        assert!(!batch.is_batch());
167        assert_eq!(batch.len(), 1);
168    }
169
170    #[test]
171    fn test_batch_request_parsing() {
172        let json = json!([
173            {
174                "jsonrpc": "2.0",
175                "method": "test1",
176                "id": 1
177            },
178            {
179                "jsonrpc": "2.0",
180                "method": "test2",
181                "id": 2
182            }
183        ]);
184
185        let batch = BatchRequest::from_value(json).unwrap();
186        assert!(batch.is_batch());
187        assert_eq!(batch.len(), 2);
188    }
189
190    #[test]
191    fn test_empty_batch() {
192        let json = json!([]);
193        let batch = BatchRequest::from_value(json).unwrap();
194        assert!(batch.is_batch());
195        assert!(batch.is_empty());
196    }
197
198    #[tokio::test]
199    async fn test_process_batch() {
200        let batch = BatchRequest::Batch(vec![
201            JSONRPCRequest {
202                jsonrpc: "2.0".to_string(),
203                method: "test1".to_string(),
204                params: None,
205                id: RequestId::from(1i64),
206            },
207            JSONRPCRequest {
208                jsonrpc: "2.0".to_string(),
209                method: "test2".to_string(),
210                params: None,
211                id: RequestId::from(2i64),
212            },
213        ]);
214
215        let result = process_batch_request(batch, |req| async move {
216            JSONRPCResponse {
217                jsonrpc: "2.0".to_string(),
218                id: req.id.clone(),
219                payload: ResponsePayload::Result(json!({
220                    "method": req.method,
221                    "success": true
222                })),
223            }
224        })
225        .await
226        .unwrap();
227
228        let responses = result.into_responses();
229        assert_eq!(responses.len(), 2);
230        assert_eq!(responses[0].id, RequestId::from(1i64));
231        assert_eq!(responses[1].id, RequestId::from(2i64));
232    }
233}