pulseengine_mcp_transport/
validation.rs1use serde_json::Value;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum ValidationError {
8 #[error("Message contains embedded newlines")]
9 EmbeddedNewlines,
10
11 #[error("Message is not valid UTF-8: {0}")]
12 InvalidUtf8(String),
13
14 #[error("Request ID cannot be null")]
15 NullRequestId,
16
17 #[error("Notification cannot have an ID")]
18 NotificationWithId,
19
20 #[error("Message exceeds maximum size: {size} > {max}")]
21 MessageTooLarge { size: usize, max: usize },
22
23 #[error("Invalid JSON-RPC format: {0}")]
24 InvalidFormat(String),
25}
26
27pub fn validate_message_string(
29 message: &str,
30 max_size: Option<usize>,
31) -> Result<(), ValidationError> {
32 if message.contains('\n') || message.contains('\r') {
34 return Err(ValidationError::EmbeddedNewlines);
35 }
36
37 if let Some(max) = max_size {
39 if message.len() > max {
40 return Err(ValidationError::MessageTooLarge {
41 size: message.len(),
42 max,
43 });
44 }
45 }
46
47 if !message.is_ascii() {
49 if let Err(e) = std::str::from_utf8(message.as_bytes()) {
51 return Err(ValidationError::InvalidUtf8(e.to_string()));
52 }
53 }
54
55 Ok(())
56}
57
58pub fn validate_jsonrpc_message(value: &Value) -> Result<MessageType, ValidationError> {
60 let obj = value.as_object().ok_or_else(|| {
61 ValidationError::InvalidFormat("Message must be a JSON object".to_string())
62 })?;
63
64 if obj.get("jsonrpc").and_then(|v| v.as_str()) != Some("2.0") {
66 return Err(ValidationError::InvalidFormat(
67 "Missing or invalid jsonrpc field".to_string(),
68 ));
69 }
70
71 if obj.contains_key("method") {
73 let method = obj
76 .get("method")
77 .and_then(|v| v.as_str())
78 .ok_or_else(|| ValidationError::InvalidFormat("Method must be a string".to_string()))?;
79
80 if method.is_empty() {
81 return Err(ValidationError::InvalidFormat(
82 "Method cannot be empty".to_string(),
83 ));
84 }
85
86 let has_id = obj.contains_key("id");
87 let id_value = obj.get("id");
88
89 if has_id {
90 if id_value == Some(&Value::Null) {
92 return Err(ValidationError::NullRequestId);
93 }
94 Ok(MessageType::Request)
95 } else {
96 Ok(MessageType::Notification)
98 }
99 } else if obj.contains_key("result") || obj.contains_key("error") {
100 if !obj.contains_key("id") {
102 return Err(ValidationError::InvalidFormat(
103 "Response must have an ID".to_string(),
104 ));
105 }
106 Ok(MessageType::Response)
107 } else {
108 Err(ValidationError::InvalidFormat(
109 "Unknown message type".to_string(),
110 ))
111 }
112}
113
114pub fn extract_id_from_malformed(text: &str) -> Option<pulseengine_mcp_protocol::NumberOrString> {
116 use pulseengine_mcp_protocol::NumberOrString;
117
118 if let Ok(value) = serde_json::from_str::<Value>(text) {
120 if let Some(obj) = value.as_object() {
121 if let Some(id) = obj.get("id") {
122 return NumberOrString::from_json_value(id.clone());
123 }
124 }
125 }
126
127 if let Some(id_match) = extract_id_with_regex(text) {
129 return NumberOrString::from_json_value(id_match);
130 }
131
132 None
134}
135
136pub fn validate_batch(batch: &[Value]) -> Result<Vec<MessageType>, ValidationError> {
138 if batch.is_empty() {
139 return Err(ValidationError::InvalidFormat(
140 "Batch cannot be empty".to_string(),
141 ));
142 }
143
144 let mut types = Vec::new();
145 for message in batch {
146 types.push(validate_jsonrpc_message(message)?);
147 }
148
149 Ok(types)
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum MessageType {
155 Request,
156 Response,
157 Notification,
158}
159
160fn extract_id_with_regex(text: &str) -> Option<Value> {
162 use regex::Regex;
163
164 let patterns = [
166 r#""id"\s*:\s*"([^"]+)""#, r#""id"\s*:\s*(\d+)"#, r#""id"\s*:\s*(null)"#, ];
170
171 for pattern in &patterns {
172 if let Ok(re) = Regex::new(pattern) {
173 if let Some(captures) = re.captures(text) {
174 if let Some(id_str) = captures.get(1) {
175 let id_text = id_str.as_str();
176
177 if let Ok(num) = id_text.parse::<i64>() {
179 return Some(Value::Number(num.into()));
180 }
181
182 if id_text == "null" {
184 return Some(Value::Null);
185 }
186
187 return Some(Value::String(id_text.to_string()));
189 }
190 }
191 }
192 }
193
194 None
195}
196
197pub fn validate_json_rpc_message(message: &str) -> Result<MessageType, ValidationError> {
199 validate_message_string(message, None)?;
201
202 let value = serde_json::from_str(message)
204 .map_err(|e| ValidationError::InvalidFormat(format!("Invalid JSON: {e}")))?;
205
206 validate_jsonrpc_message(&value)
208}
209
210pub fn validate_json_rpc_batch(batch_str: &str) -> Result<Vec<MessageType>, ValidationError> {
212 validate_message_string(batch_str, None)?;
214
215 let batch_value = serde_json::from_str::<Value>(batch_str)
217 .map_err(|e| ValidationError::InvalidFormat(format!("Invalid JSON: {e}")))?;
218
219 let batch_array = batch_value
220 .as_array()
221 .ok_or_else(|| ValidationError::InvalidFormat("Batch must be an array".to_string()))?;
222
223 if batch_array.is_empty() {
224 return Err(ValidationError::InvalidFormat(
225 "Empty batch not allowed".to_string(),
226 ));
227 }
228
229 validate_batch(batch_array)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use serde_json::json;
237
238 #[test]
239 fn test_validate_message_string() {
240 assert!(validate_message_string("hello world", None).is_ok());
242
243 assert!(matches!(
245 validate_message_string("hello\nworld", None),
246 Err(ValidationError::EmbeddedNewlines)
247 ));
248
249 assert!(matches!(
251 validate_message_string("hello\rworld", None),
252 Err(ValidationError::EmbeddedNewlines)
253 ));
254
255 assert!(matches!(
257 validate_message_string("hello world", Some(5)),
258 Err(ValidationError::MessageTooLarge { .. })
259 ));
260 }
261
262 #[test]
263 fn test_validate_jsonrpc_message() {
264 let request = json!({
266 "jsonrpc": "2.0",
267 "method": "test",
268 "id": 1
269 });
270 assert_eq!(
271 validate_jsonrpc_message(&request).unwrap(),
272 MessageType::Request
273 );
274
275 let notification = json!({
277 "jsonrpc": "2.0",
278 "method": "test"
279 });
280 assert_eq!(
281 validate_jsonrpc_message(¬ification).unwrap(),
282 MessageType::Notification
283 );
284
285 let response = json!({
287 "jsonrpc": "2.0",
288 "result": "ok",
289 "id": 1
290 });
291 assert_eq!(
292 validate_jsonrpc_message(&response).unwrap(),
293 MessageType::Response
294 );
295
296 let invalid_request = json!({
298 "jsonrpc": "2.0",
299 "method": "test",
300 "id": null
301 });
302 assert!(matches!(
303 validate_jsonrpc_message(&invalid_request),
304 Err(ValidationError::NullRequestId)
305 ));
306 }
307
308 #[test]
309 fn test_extract_id_from_malformed() {
310 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 123}"#;
312 assert_eq!(
313 extract_id_from_malformed(text),
314 Some(pulseengine_mcp_protocol::NumberOrString::Number(123))
315 );
316
317 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": "abc""#; assert_eq!(
320 extract_id_from_malformed(text),
321 Some(pulseengine_mcp_protocol::NumberOrString::String(
322 std::sync::Arc::from("abc")
323 ))
324 );
325
326 let text = r#"{"jsonrpc": "2.0", "method": "test"}"#;
328 assert_eq!(extract_id_from_malformed(text), None);
329 }
330
331 #[test]
332 fn test_validate_batch() {
333 let batch = vec![
334 json!({"jsonrpc": "2.0", "method": "test1", "id": 1}),
335 json!({"jsonrpc": "2.0", "method": "test2"}),
336 ];
337
338 let types = validate_batch(&batch).unwrap();
339 assert_eq!(types, vec![MessageType::Request, MessageType::Notification]);
340
341 assert!(validate_batch(&[]).is_err());
343 }
344}