pulseengine_mcp_transport/
validation.rs

1//! Message validation utilities for MCP transport compliance
2
3use serde_json::Value;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum ValidationError {
8    #[error("Message contains embedded newlines")]
9    EmbeddedNewlines,
10
11    #[error("Message is not valid UTF-8: {0}")]
12    InvalidUtf8(String),
13
14    #[error("Request ID cannot be null")]
15    NullRequestId,
16
17    #[error("Notification cannot have an ID")]
18    NotificationWithId,
19
20    #[error("Message exceeds maximum size: {size} > {max}")]
21    MessageTooLarge { size: usize, max: usize },
22
23    #[error("Invalid JSON-RPC format: {0}")]
24    InvalidFormat(String),
25}
26
27/// Validates a raw message string for MCP compliance
28pub fn validate_message_string(
29    message: &str,
30    max_size: Option<usize>,
31) -> Result<(), ValidationError> {
32    // Check for embedded newlines (MCP spec requirement)
33    if message.contains('\n') || message.contains('\r') {
34        return Err(ValidationError::EmbeddedNewlines);
35    }
36
37    // Check message size limit
38    if let Some(max) = max_size {
39        if message.len() > max {
40            return Err(ValidationError::MessageTooLarge {
41                size: message.len(),
42                max,
43            });
44        }
45    }
46
47    // UTF-8 validation is implicit in Rust strings, but we validate the bytes
48    if !message.is_ascii() {
49        // For non-ASCII, ensure it's valid UTF-8 by checking byte validity
50        if let Err(e) = std::str::from_utf8(message.as_bytes()) {
51            return Err(ValidationError::InvalidUtf8(e.to_string()));
52        }
53    }
54
55    Ok(())
56}
57
58/// Validates JSON-RPC message structure and ID requirements
59pub fn validate_jsonrpc_message(value: &Value) -> Result<MessageType, ValidationError> {
60    let obj = value.as_object().ok_or_else(|| {
61        ValidationError::InvalidFormat("Message must be a JSON object".to_string())
62    })?;
63
64    // Check for required jsonrpc field
65    if obj.get("jsonrpc").and_then(|v| v.as_str()) != Some("2.0") {
66        return Err(ValidationError::InvalidFormat(
67            "Missing or invalid jsonrpc field".to_string(),
68        ));
69    }
70
71    // Determine message type and validate ID requirements
72    if obj.contains_key("method") {
73        // This is a request or notification
74        // Validate method field
75        let method = obj
76            .get("method")
77            .and_then(|v| v.as_str())
78            .ok_or_else(|| ValidationError::InvalidFormat("Method must be a string".to_string()))?;
79
80        if method.is_empty() {
81            return Err(ValidationError::InvalidFormat(
82                "Method cannot be empty".to_string(),
83            ));
84        }
85
86        let has_id = obj.contains_key("id");
87        let id_value = obj.get("id");
88
89        if has_id {
90            // Request: ID cannot be null
91            if id_value == Some(&Value::Null) {
92                return Err(ValidationError::NullRequestId);
93            }
94            Ok(MessageType::Request)
95        } else {
96            // Notification: should not have ID
97            Ok(MessageType::Notification)
98        }
99    } else if obj.contains_key("result") || obj.contains_key("error") {
100        // Response: must have ID
101        if !obj.contains_key("id") {
102            return Err(ValidationError::InvalidFormat(
103                "Response must have an ID".to_string(),
104            ));
105        }
106        Ok(MessageType::Response)
107    } else {
108        Err(ValidationError::InvalidFormat(
109            "Unknown message type".to_string(),
110        ))
111    }
112}
113
114/// Attempts to extract ID from a malformed JSON request for error responses
115pub fn extract_id_from_malformed(text: &str) -> Option<pulseengine_mcp_protocol::NumberOrString> {
116    use pulseengine_mcp_protocol::NumberOrString;
117
118    // Try to parse as JSON object and extract ID
119    if let Ok(value) = serde_json::from_str::<Value>(text) {
120        if let Some(obj) = value.as_object() {
121            if let Some(id) = obj.get("id") {
122                return NumberOrString::from_json_value(id.clone());
123            }
124        }
125    }
126
127    // Try regex-based extraction as fallback
128    if let Some(id_match) = extract_id_with_regex(text) {
129        return NumberOrString::from_json_value(id_match);
130    }
131
132    // Default to None if we can't extract
133    None
134}
135
136/// Validates a batch of JSON-RPC messages
137pub fn validate_batch(batch: &[Value]) -> Result<Vec<MessageType>, ValidationError> {
138    if batch.is_empty() {
139        return Err(ValidationError::InvalidFormat(
140            "Batch cannot be empty".to_string(),
141        ));
142    }
143
144    let mut types = Vec::new();
145    for message in batch {
146        types.push(validate_jsonrpc_message(message)?);
147    }
148
149    Ok(types)
150}
151
152/// JSON-RPC message types
153#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum MessageType {
155    Request,
156    Response,
157    Notification,
158}
159
160/// Regex-based ID extraction for malformed JSON (fallback)
161fn extract_id_with_regex(text: &str) -> Option<Value> {
162    use regex::Regex;
163
164    // Try to match common ID patterns
165    let patterns = [
166        r#""id"\s*:\s*"([^"]+)""#, // String ID
167        r#""id"\s*:\s*(\d+)"#,     // Number ID
168        r#""id"\s*:\s*(null)"#,    // Null ID
169    ];
170
171    for pattern in &patterns {
172        if let Ok(re) = Regex::new(pattern) {
173            if let Some(captures) = re.captures(text) {
174                if let Some(id_str) = captures.get(1) {
175                    let id_text = id_str.as_str();
176
177                    // Try to parse as number first
178                    if let Ok(num) = id_text.parse::<i64>() {
179                        return Some(Value::Number(num.into()));
180                    }
181
182                    // Check for null
183                    if id_text == "null" {
184                        return Some(Value::Null);
185                    }
186
187                    // Default to string
188                    return Some(Value::String(id_text.to_string()));
189                }
190            }
191        }
192    }
193
194    None
195}
196
197/// Validates a JSON-RPC message from string input
198pub fn validate_json_rpc_message(message: &str) -> Result<MessageType, ValidationError> {
199    // First validate the message string
200    validate_message_string(message, None)?;
201
202    // Parse as JSON
203    let value = serde_json::from_str(message)
204        .map_err(|e| ValidationError::InvalidFormat(format!("Invalid JSON: {e}")))?;
205
206    // Validate JSON-RPC structure
207    validate_jsonrpc_message(&value)
208}
209
210/// Validates a JSON-RPC batch from string input
211pub fn validate_json_rpc_batch(batch_str: &str) -> Result<Vec<MessageType>, ValidationError> {
212    // First validate the message string
213    validate_message_string(batch_str, None)?;
214
215    // Parse as JSON array
216    let batch_value = serde_json::from_str::<Value>(batch_str)
217        .map_err(|e| ValidationError::InvalidFormat(format!("Invalid JSON: {e}")))?;
218
219    let batch_array = batch_value
220        .as_array()
221        .ok_or_else(|| ValidationError::InvalidFormat("Batch must be an array".to_string()))?;
222
223    if batch_array.is_empty() {
224        return Err(ValidationError::InvalidFormat(
225            "Empty batch not allowed".to_string(),
226        ));
227    }
228
229    // Validate each message in the batch
230    validate_batch(batch_array)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use serde_json::json;
237
238    #[test]
239    fn test_validate_message_string() {
240        // Valid message
241        assert!(validate_message_string("hello world", None).is_ok());
242
243        // Invalid: embedded newline
244        assert!(matches!(
245            validate_message_string("hello\nworld", None),
246            Err(ValidationError::EmbeddedNewlines)
247        ));
248
249        // Invalid: embedded carriage return
250        assert!(matches!(
251            validate_message_string("hello\rworld", None),
252            Err(ValidationError::EmbeddedNewlines)
253        ));
254
255        // Invalid: too large
256        assert!(matches!(
257            validate_message_string("hello world", Some(5)),
258            Err(ValidationError::MessageTooLarge { .. })
259        ));
260    }
261
262    #[test]
263    fn test_validate_jsonrpc_message() {
264        // Valid request
265        let request = json!({
266            "jsonrpc": "2.0",
267            "method": "test",
268            "id": 1
269        });
270        assert_eq!(
271            validate_jsonrpc_message(&request).unwrap(),
272            MessageType::Request
273        );
274
275        // Valid notification
276        let notification = json!({
277            "jsonrpc": "2.0",
278            "method": "test"
279        });
280        assert_eq!(
281            validate_jsonrpc_message(&notification).unwrap(),
282            MessageType::Notification
283        );
284
285        // Valid response
286        let response = json!({
287            "jsonrpc": "2.0",
288            "result": "ok",
289            "id": 1
290        });
291        assert_eq!(
292            validate_jsonrpc_message(&response).unwrap(),
293            MessageType::Response
294        );
295
296        // Invalid: request with null ID
297        let invalid_request = json!({
298            "jsonrpc": "2.0",
299            "method": "test",
300            "id": null
301        });
302        assert!(matches!(
303            validate_jsonrpc_message(&invalid_request),
304            Err(ValidationError::NullRequestId)
305        ));
306    }
307
308    #[test]
309    fn test_extract_id_from_malformed() {
310        // Valid JSON with extractable ID
311        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 123}"#;
312        assert_eq!(
313            extract_id_from_malformed(text),
314            Some(pulseengine_mcp_protocol::NumberOrString::Number(123))
315        );
316
317        // Invalid JSON but regex can extract
318        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": "abc""#; // Missing closing brace
319        assert_eq!(
320            extract_id_from_malformed(text),
321            Some(pulseengine_mcp_protocol::NumberOrString::String(
322                std::sync::Arc::from("abc")
323            ))
324        );
325
326        // No ID extractable
327        let text = r#"{"jsonrpc": "2.0", "method": "test"}"#;
328        assert_eq!(extract_id_from_malformed(text), None);
329    }
330
331    #[test]
332    fn test_validate_batch() {
333        let batch = vec![
334            json!({"jsonrpc": "2.0", "method": "test1", "id": 1}),
335            json!({"jsonrpc": "2.0", "method": "test2"}),
336        ];
337
338        let types = validate_batch(&batch).unwrap();
339        assert_eq!(types, vec![MessageType::Request, MessageType::Notification]);
340
341        // Empty batch should fail
342        assert!(validate_batch(&[]).is_err());
343    }
344}