pulseengine_mcp_transport/
batch.rs

1//! JSON-RPC batch message handling
2
3use crate::{RequestHandler, TransportError, validation::validate_batch};
4use pulseengine_mcp_protocol::{Request, Response};
5use serde_json::Value;
6use tracing::debug;
7
8/// Represents a JSON-RPC message that can be either single or batch
9#[derive(Debug, Clone)]
10pub enum JsonRpcMessage {
11    Single(Value),
12    Batch(Vec<Value>),
13}
14
15/// Represents a processed batch result
16#[derive(Debug)]
17pub struct BatchResult {
18    pub responses: Vec<Response>,
19    pub has_notifications: bool,
20}
21
22impl JsonRpcMessage {
23    /// Parse a JSON string into a `JsonRpcMessage`
24    ///
25    /// # Errors
26    ///
27    /// Returns an error if the JSON is invalid
28    pub fn parse(text: &str) -> Result<Self, serde_json::Error> {
29        let value: Value = serde_json::from_str(text)?;
30
31        if let Some(array) = value.as_array() {
32            Ok(JsonRpcMessage::Batch(array.clone()))
33        } else {
34            Ok(JsonRpcMessage::Single(value))
35        }
36    }
37
38    /// Convert to JSON string
39    ///
40    /// # Errors
41    ///
42    /// Returns an error if serialization fails
43    pub fn to_string(&self) -> Result<String, serde_json::Error> {
44        match self {
45            JsonRpcMessage::Single(value) => serde_json::to_string(value),
46            JsonRpcMessage::Batch(values) => serde_json::to_string(values),
47        }
48    }
49
50    /// Validate the message according to JSON-RPC and MCP specs
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the message is invalid according to JSON-RPC or MCP specifications
55    pub fn validate(&self) -> Result<(), TransportError> {
56        match self {
57            JsonRpcMessage::Single(value) => {
58                crate::validation::validate_jsonrpc_message(value)
59                    .map_err(|e| TransportError::Protocol(e.to_string()))?;
60                Ok(())
61            }
62            JsonRpcMessage::Batch(values) => {
63                if values.is_empty() {
64                    return Err(TransportError::Protocol(
65                        "Batch cannot be empty".to_string(),
66                    ));
67                }
68
69                validate_batch(values).map_err(|e| TransportError::Protocol(e.to_string()))?;
70                Ok(())
71            }
72        }
73    }
74
75    /// Extract requests from the message (filtering out notifications)
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if request extraction fails
80    pub fn extract_requests(&self) -> Result<Vec<Request>, TransportError> {
81        let mut requests = Vec::new();
82
83        match self {
84            JsonRpcMessage::Single(value) => {
85                if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
86                    // Only include if it has an ID (requests, not notifications)
87                    if request.id.is_some() {
88                        requests.push(request);
89                    }
90                }
91            }
92            JsonRpcMessage::Batch(values) => {
93                for value in values {
94                    if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
95                        // Only include if it has an ID (requests, not notifications)
96                        if request.id.is_some() {
97                            requests.push(request);
98                        }
99                    }
100                }
101            }
102        }
103
104        Ok(requests)
105    }
106
107    /// Extract notifications from the message
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if notification extraction fails
112    pub fn extract_notifications(&self) -> Result<Vec<Request>, TransportError> {
113        let mut notifications = Vec::new();
114
115        match self {
116            JsonRpcMessage::Single(value) => {
117                if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
118                    // Only include if it doesn't have an ID (notifications)
119                    if request.id.is_none() {
120                        notifications.push(request);
121                    }
122                }
123            }
124            JsonRpcMessage::Batch(values) => {
125                for value in values {
126                    if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
127                        // Only include if it doesn't have an ID (notifications)
128                        if request.id.is_none() {
129                            notifications.push(request);
130                        }
131                    }
132                }
133            }
134        }
135
136        Ok(notifications)
137    }
138
139    /// Check if this message contains any requests (vs only notifications)
140    pub fn has_requests(&self) -> bool {
141        match self {
142            JsonRpcMessage::Single(value) => {
143                if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
144                    request.id.is_some()
145                } else {
146                    false
147                }
148            }
149            JsonRpcMessage::Batch(values) => values.iter().any(|value| {
150                if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
151                    request.id.is_some()
152                } else {
153                    false
154                }
155            }),
156        }
157    }
158}
159
160/// Process a batch of requests through a handler
161pub async fn process_batch(
162    message: JsonRpcMessage,
163    handler: &RequestHandler,
164) -> Result<Option<JsonRpcMessage>, TransportError> {
165    debug!("Processing batch message");
166
167    // Validate the message first
168    message.validate()?;
169
170    // Extract requests and notifications
171    let requests = message.extract_requests()?;
172    let notifications = message.extract_notifications()?;
173
174    debug!(
175        "Batch contains {} requests and {} notifications",
176        requests.len(),
177        notifications.len()
178    );
179
180    // Process notifications (no response expected)
181    for notification in notifications {
182        debug!("Processing notification: {}", notification.method);
183        let _response = handler(notification).await;
184        // Notifications don't generate responses, so we ignore the result
185    }
186
187    // If no requests, return None (no response needed)
188    if requests.is_empty() {
189        return Ok(None);
190    }
191
192    // Process requests and collect responses
193    let mut responses = Vec::new();
194
195    for request in requests {
196        debug!(
197            "Processing request: {} (ID: {:?})",
198            request.method, request.id
199        );
200        let response = handler(request).await;
201        responses.push(response);
202    }
203
204    // Return appropriate response format
205    let response_message = if responses.len() == 1 && !matches!(message, JsonRpcMessage::Batch(_)) {
206        // Single request, single response
207        let response_value = serde_json::to_value(&responses[0])
208            .map_err(|e| TransportError::Protocol(format!("Failed to serialize response: {e}")))?;
209        JsonRpcMessage::Single(response_value)
210    } else {
211        // Batch response
212        let response_values: Result<Vec<Value>, _> =
213            responses.iter().map(serde_json::to_value).collect();
214
215        let response_values = response_values.map_err(|e| {
216            TransportError::Protocol(format!("Failed to serialize batch response: {e}"))
217        })?;
218
219        JsonRpcMessage::Batch(response_values)
220    };
221
222    Ok(Some(response_message))
223}
224
225/// Create an error response for a malformed request
226pub fn create_error_response(
227    error: pulseengine_mcp_protocol::Error,
228    request_id: Option<pulseengine_mcp_protocol::NumberOrString>,
229) -> Response {
230    Response {
231        jsonrpc: "2.0".to_string(),
232        id: request_id,
233        result: None,
234        error: Some(error),
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
242    use serde_json::json;
243
244    // Mock handler for testing
245    fn mock_handler(
246        request: Request,
247    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
248        Box::pin(async move {
249            Response {
250                jsonrpc: "2.0".to_string(),
251                id: request.id,
252                result: Some(json!({"method": request.method})),
253                error: None,
254            }
255        })
256    }
257
258    #[test]
259    fn test_jsonrpc_message_parsing() {
260        // Single message
261        let single_json = r#"{"jsonrpc": "2.0", "method": "test", "id": 1}"#;
262        let single_msg = JsonRpcMessage::parse(single_json).unwrap();
263        assert!(matches!(single_msg, JsonRpcMessage::Single(_)));
264
265        // Batch message
266        let batch_json = r#"[{"jsonrpc": "2.0", "method": "test1", "id": 1}, {"jsonrpc": "2.0", "method": "test2"}]"#;
267        let batch_msg = JsonRpcMessage::parse(batch_json).unwrap();
268        assert!(matches!(batch_msg, JsonRpcMessage::Batch(_)));
269    }
270
271    #[test]
272    fn test_extract_requests_and_notifications() {
273        let batch_json = r#"[
274            {"jsonrpc": "2.0", "method": "request1", "id": 1},
275            {"jsonrpc": "2.0", "method": "notification1"},
276            {"jsonrpc": "2.0", "method": "request2", "id": 2}
277        ]"#;
278
279        let message = JsonRpcMessage::parse(batch_json).unwrap();
280
281        let requests = message.extract_requests().unwrap();
282        assert_eq!(requests.len(), 2);
283        assert_eq!(requests[0].method, "request1");
284        assert_eq!(requests[1].method, "request2");
285
286        let notifications = message.extract_notifications().unwrap();
287        assert_eq!(notifications.len(), 1);
288        assert_eq!(notifications[0].method, "notification1");
289    }
290
291    #[tokio::test]
292    async fn test_process_batch() {
293        let handler: RequestHandler = Box::new(mock_handler);
294
295        // Test single request
296        let single_json = r#"{"jsonrpc": "2.0", "method": "test", "id": 1}"#;
297        let single_msg = JsonRpcMessage::parse(single_json).unwrap();
298
299        let result = process_batch(single_msg, &handler).await.unwrap();
300        assert!(result.is_some());
301
302        // Test notification only (should return None)
303        let notification_json = r#"{"jsonrpc": "2.0", "method": "test"}"#;
304        let notification_msg = JsonRpcMessage::parse(notification_json).unwrap();
305
306        let result = process_batch(notification_msg, &handler).await.unwrap();
307        assert!(result.is_none());
308
309        // Test batch with mixed requests and notifications
310        let batch_json = r#"[
311            {"jsonrpc": "2.0", "method": "request1", "id": 1},
312            {"jsonrpc": "2.0", "method": "notification1"},
313            {"jsonrpc": "2.0", "method": "request2", "id": 2}
314        ]"#;
315        let batch_msg = JsonRpcMessage::parse(batch_json).unwrap();
316
317        let result = process_batch(batch_msg, &handler).await.unwrap();
318        assert!(result.is_some());
319
320        if let Some(JsonRpcMessage::Batch(responses)) = result {
321            assert_eq!(responses.len(), 2); // Only requests generate responses
322        } else {
323            panic!("Expected batch response");
324        }
325    }
326
327    #[test]
328    fn test_create_error_response() {
329        let error = McpError::parse_error("Test error");
330        let response = create_error_response(
331            error,
332            Some(pulseengine_mcp_protocol::NumberOrString::Number(123)),
333        );
334
335        assert_eq!(response.jsonrpc, "2.0");
336        assert_eq!(
337            response.id,
338            Some(pulseengine_mcp_protocol::NumberOrString::Number(123))
339        );
340        assert!(response.result.is_none());
341        assert!(response.error.is_some());
342    }
343}