pulseengine_mcp_transport/
batch.rs1use crate::{RequestHandler, TransportError, validation::validate_batch};
4use pulseengine_mcp_protocol::{Request, Response};
5use serde_json::Value;
6use tracing::debug;
7
8#[derive(Debug, Clone)]
10pub enum JsonRpcMessage {
11 Single(Value),
12 Batch(Vec<Value>),
13}
14
15#[derive(Debug)]
17pub struct BatchResult {
18 pub responses: Vec<Response>,
19 pub has_notifications: bool,
20}
21
22impl JsonRpcMessage {
23 pub fn parse(text: &str) -> Result<Self, serde_json::Error> {
29 let value: Value = serde_json::from_str(text)?;
30
31 if let Some(array) = value.as_array() {
32 Ok(JsonRpcMessage::Batch(array.clone()))
33 } else {
34 Ok(JsonRpcMessage::Single(value))
35 }
36 }
37
38 pub fn to_string(&self) -> Result<String, serde_json::Error> {
44 match self {
45 JsonRpcMessage::Single(value) => serde_json::to_string(value),
46 JsonRpcMessage::Batch(values) => serde_json::to_string(values),
47 }
48 }
49
50 pub fn validate(&self) -> Result<(), TransportError> {
56 match self {
57 JsonRpcMessage::Single(value) => {
58 crate::validation::validate_jsonrpc_message(value)
59 .map_err(|e| TransportError::Protocol(e.to_string()))?;
60 Ok(())
61 }
62 JsonRpcMessage::Batch(values) => {
63 if values.is_empty() {
64 return Err(TransportError::Protocol(
65 "Batch cannot be empty".to_string(),
66 ));
67 }
68
69 validate_batch(values).map_err(|e| TransportError::Protocol(e.to_string()))?;
70 Ok(())
71 }
72 }
73 }
74
75 pub fn extract_requests(&self) -> Result<Vec<Request>, TransportError> {
81 let mut requests = Vec::new();
82
83 match self {
84 JsonRpcMessage::Single(value) => {
85 if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
86 if request.id.is_some() {
88 requests.push(request);
89 }
90 }
91 }
92 JsonRpcMessage::Batch(values) => {
93 for value in values {
94 if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
95 if request.id.is_some() {
97 requests.push(request);
98 }
99 }
100 }
101 }
102 }
103
104 Ok(requests)
105 }
106
107 pub fn extract_notifications(&self) -> Result<Vec<Request>, TransportError> {
113 let mut notifications = Vec::new();
114
115 match self {
116 JsonRpcMessage::Single(value) => {
117 if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
118 if request.id.is_none() {
120 notifications.push(request);
121 }
122 }
123 }
124 JsonRpcMessage::Batch(values) => {
125 for value in values {
126 if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
127 if request.id.is_none() {
129 notifications.push(request);
130 }
131 }
132 }
133 }
134 }
135
136 Ok(notifications)
137 }
138
139 pub fn has_requests(&self) -> bool {
141 match self {
142 JsonRpcMessage::Single(value) => {
143 if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
144 request.id.is_some()
145 } else {
146 false
147 }
148 }
149 JsonRpcMessage::Batch(values) => values.iter().any(|value| {
150 if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
151 request.id.is_some()
152 } else {
153 false
154 }
155 }),
156 }
157 }
158}
159
160pub async fn process_batch(
162 message: JsonRpcMessage,
163 handler: &RequestHandler,
164) -> Result<Option<JsonRpcMessage>, TransportError> {
165 debug!("Processing batch message");
166
167 message.validate()?;
169
170 let requests = message.extract_requests()?;
172 let notifications = message.extract_notifications()?;
173
174 debug!(
175 "Batch contains {} requests and {} notifications",
176 requests.len(),
177 notifications.len()
178 );
179
180 for notification in notifications {
182 debug!("Processing notification: {}", notification.method);
183 let _response = handler(notification).await;
184 }
186
187 if requests.is_empty() {
189 return Ok(None);
190 }
191
192 let mut responses = Vec::new();
194
195 for request in requests {
196 debug!(
197 "Processing request: {} (ID: {:?})",
198 request.method, request.id
199 );
200 let response = handler(request).await;
201 responses.push(response);
202 }
203
204 let response_message = if responses.len() == 1 && !matches!(message, JsonRpcMessage::Batch(_)) {
206 let response_value = serde_json::to_value(&responses[0])
208 .map_err(|e| TransportError::Protocol(format!("Failed to serialize response: {e}")))?;
209 JsonRpcMessage::Single(response_value)
210 } else {
211 let response_values: Result<Vec<Value>, _> =
213 responses.iter().map(serde_json::to_value).collect();
214
215 let response_values = response_values.map_err(|e| {
216 TransportError::Protocol(format!("Failed to serialize batch response: {e}"))
217 })?;
218
219 JsonRpcMessage::Batch(response_values)
220 };
221
222 Ok(Some(response_message))
223}
224
225pub fn create_error_response(
227 error: pulseengine_mcp_protocol::Error,
228 request_id: Option<pulseengine_mcp_protocol::NumberOrString>,
229) -> Response {
230 Response {
231 jsonrpc: "2.0".to_string(),
232 id: request_id,
233 result: None,
234 error: Some(error),
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
242 use serde_json::json;
243
244 fn mock_handler(
246 request: Request,
247 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
248 Box::pin(async move {
249 Response {
250 jsonrpc: "2.0".to_string(),
251 id: request.id,
252 result: Some(json!({"method": request.method})),
253 error: None,
254 }
255 })
256 }
257
258 #[test]
259 fn test_jsonrpc_message_parsing() {
260 let single_json = r#"{"jsonrpc": "2.0", "method": "test", "id": 1}"#;
262 let single_msg = JsonRpcMessage::parse(single_json).unwrap();
263 assert!(matches!(single_msg, JsonRpcMessage::Single(_)));
264
265 let batch_json = r#"[{"jsonrpc": "2.0", "method": "test1", "id": 1}, {"jsonrpc": "2.0", "method": "test2"}]"#;
267 let batch_msg = JsonRpcMessage::parse(batch_json).unwrap();
268 assert!(matches!(batch_msg, JsonRpcMessage::Batch(_)));
269 }
270
271 #[test]
272 fn test_extract_requests_and_notifications() {
273 let batch_json = r#"[
274 {"jsonrpc": "2.0", "method": "request1", "id": 1},
275 {"jsonrpc": "2.0", "method": "notification1"},
276 {"jsonrpc": "2.0", "method": "request2", "id": 2}
277 ]"#;
278
279 let message = JsonRpcMessage::parse(batch_json).unwrap();
280
281 let requests = message.extract_requests().unwrap();
282 assert_eq!(requests.len(), 2);
283 assert_eq!(requests[0].method, "request1");
284 assert_eq!(requests[1].method, "request2");
285
286 let notifications = message.extract_notifications().unwrap();
287 assert_eq!(notifications.len(), 1);
288 assert_eq!(notifications[0].method, "notification1");
289 }
290
291 #[tokio::test]
292 async fn test_process_batch() {
293 let handler: RequestHandler = Box::new(mock_handler);
294
295 let single_json = r#"{"jsonrpc": "2.0", "method": "test", "id": 1}"#;
297 let single_msg = JsonRpcMessage::parse(single_json).unwrap();
298
299 let result = process_batch(single_msg, &handler).await.unwrap();
300 assert!(result.is_some());
301
302 let notification_json = r#"{"jsonrpc": "2.0", "method": "test"}"#;
304 let notification_msg = JsonRpcMessage::parse(notification_json).unwrap();
305
306 let result = process_batch(notification_msg, &handler).await.unwrap();
307 assert!(result.is_none());
308
309 let batch_json = r#"[
311 {"jsonrpc": "2.0", "method": "request1", "id": 1},
312 {"jsonrpc": "2.0", "method": "notification1"},
313 {"jsonrpc": "2.0", "method": "request2", "id": 2}
314 ]"#;
315 let batch_msg = JsonRpcMessage::parse(batch_json).unwrap();
316
317 let result = process_batch(batch_msg, &handler).await.unwrap();
318 assert!(result.is_some());
319
320 if let Some(JsonRpcMessage::Batch(responses)) = result {
321 assert_eq!(responses.len(), 2); } else {
323 panic!("Expected batch response");
324 }
325 }
326
327 #[test]
328 fn test_create_error_response() {
329 let error = McpError::parse_error("Test error");
330 let response = create_error_response(
331 error,
332 Some(pulseengine_mcp_protocol::NumberOrString::Number(123)),
333 );
334
335 assert_eq!(response.jsonrpc, "2.0");
336 assert_eq!(
337 response.id,
338 Some(pulseengine_mcp_protocol::NumberOrString::Number(123))
339 );
340 assert!(response.result.is_none());
341 assert!(response.error.is_some());
342 }
343}