1use std::io;
2
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4
5const MAX_INPUT_SIZE: usize = 1_048_576;
6const MAX_OUTPUT_SIZE: usize = 67_108_864;
7
8pub async fn read_message(
18 reader: &mut (impl AsyncReadExt + Unpin),
19) -> Result<serde_json::Value, NativeMessageError> {
20 let mut len_bytes = [0u8; 4];
21 reader.read_exact(&mut len_bytes).await.map_err(|e| {
22 if e.kind() == io::ErrorKind::UnexpectedEof {
23 NativeMessageError::Disconnected
24 } else {
25 NativeMessageError::Io(e)
26 }
27 })?;
28
29 let len = u32::from_le_bytes(len_bytes) as usize;
30 if len > MAX_INPUT_SIZE {
31 return Err(NativeMessageError::TooLarge {
32 size: len,
33 max: MAX_INPUT_SIZE,
34 });
35 }
36
37 let mut buf = vec![0u8; len];
38 reader
39 .read_exact(&mut buf)
40 .await
41 .map_err(NativeMessageError::Io)?;
42
43 serde_json::from_slice(&buf).map_err(NativeMessageError::Json)
44}
45
46pub async fn write_message(
55 writer: &mut (impl AsyncWriteExt + Unpin),
56 msg: &serde_json::Value,
57) -> Result<(), NativeMessageError> {
58 let bytes = serde_json::to_vec(msg).map_err(NativeMessageError::Json)?;
59 if bytes.len() > MAX_OUTPUT_SIZE {
60 return Err(NativeMessageError::TooLarge {
61 size: bytes.len(),
62 max: MAX_OUTPUT_SIZE,
63 });
64 }
65
66 let len_bytes = (bytes.len() as u32).to_le_bytes();
67 writer
68 .write_all(&len_bytes)
69 .await
70 .map_err(NativeMessageError::Io)?;
71 writer
72 .write_all(&bytes)
73 .await
74 .map_err(NativeMessageError::Io)?;
75 writer.flush().await.map_err(NativeMessageError::Io)?;
76 Ok(())
77}
78
79#[derive(Debug, thiserror::Error)]
80pub enum NativeMessageError {
81 #[error("native messaging peer disconnected")]
82 Disconnected,
83
84 #[error("message size {size} exceeds limit {max}")]
85 TooLarge { size: usize, max: usize },
86
87 #[error("IO error: {0}")]
88 Io(#[from] io::Error),
89
90 #[error("JSON error: {0}")]
91 Json(#[from] serde_json::Error),
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use tokio::io::BufReader;
98
99 #[tokio::test]
100 async fn roundtrip_message() {
101 let msg = serde_json::json!({"type": "execute", "id": "abc123", "method": "snapshot"});
102
103 let mut buf = Vec::new();
104 write_message(&mut buf, &msg).await.unwrap();
105
106 let mut reader = BufReader::new(buf.as_slice());
107 let decoded = read_message(&mut reader).await.unwrap();
108
109 assert_eq!(msg, decoded);
110 }
111
112 #[tokio::test]
113 async fn rejects_oversized_input() {
114 let len = (MAX_INPUT_SIZE + 1) as u32;
115 let mut buf = Vec::new();
116 buf.extend_from_slice(&len.to_le_bytes());
117 buf.extend(vec![0u8; MAX_INPUT_SIZE + 1]);
118
119 let mut reader = BufReader::new(buf.as_slice());
120 let result = read_message(&mut reader).await;
121 assert!(matches!(result, Err(NativeMessageError::TooLarge { .. })));
122 }
123
124 #[tokio::test]
125 async fn detects_disconnect() {
126 let buf: &[u8] = &[];
127 let mut reader = BufReader::new(buf);
128 let result = read_message(&mut reader).await;
129 assert!(matches!(result, Err(NativeMessageError::Disconnected)));
130 }
131
132 #[tokio::test]
133 async fn handles_empty_json() {
134 let msg = serde_json::json!({});
135
136 let mut buf = Vec::new();
137 write_message(&mut buf, &msg).await.unwrap();
138
139 let mut reader = BufReader::new(buf.as_slice());
140 let decoded = read_message(&mut reader).await.unwrap();
141 assert_eq!(decoded, serde_json::json!({}));
142 }
143
144 #[tokio::test]
145 async fn handles_multiple_messages() {
146 let msgs = vec![
147 serde_json::json!({"id": "1"}),
148 serde_json::json!({"id": "2", "data": "hello"}),
149 serde_json::json!({"id": "3", "nested": {"key": "value"}}),
150 ];
151
152 let mut buf = Vec::new();
153 for msg in &msgs {
154 write_message(&mut buf, msg).await.unwrap();
155 }
156
157 let mut reader = BufReader::new(buf.as_slice());
158 for expected in &msgs {
159 let decoded = read_message(&mut reader).await.unwrap();
160 assert_eq!(&decoded, expected);
161 }
162 }
163
164 #[tokio::test]
165 async fn partial_length_prefix_is_disconnect() {
166 let buf: &[u8] = &[0x02, 0x00];
167 let mut reader = BufReader::new(buf);
168 let result = read_message(&mut reader).await;
169 assert!(matches!(result, Err(NativeMessageError::Disconnected)));
170 }
171
172 #[tokio::test]
173 async fn invalid_json_returns_error() {
174 let invalid = b"not json at all";
175 let len = invalid.len() as u32;
176 let mut buf = Vec::new();
177 buf.extend_from_slice(&len.to_le_bytes());
178 buf.extend_from_slice(invalid);
179
180 let mut reader = BufReader::new(buf.as_slice());
181 let result = read_message(&mut reader).await;
182 assert!(matches!(result, Err(NativeMessageError::Json(_))));
183 }
184
185 #[tokio::test]
186 async fn zero_length_message() {
187 let mut buf = Vec::new();
188 buf.extend_from_slice(&0u32.to_le_bytes());
189
190 let mut reader = BufReader::new(buf.as_slice());
191 let result = read_message(&mut reader).await;
192 assert!(matches!(result, Err(NativeMessageError::Json(_))));
193 }
194
195 #[tokio::test]
196 async fn unicode_message_roundtrip() {
197 let msg =
198 serde_json::json!({"emoji": "🔥🚀", "cjk": "日本語テスト", "mixed": "hello 世界"});
199
200 let mut buf = Vec::new();
201 write_message(&mut buf, &msg).await.unwrap();
202
203 let mut reader = BufReader::new(buf.as_slice());
204 let decoded = read_message(&mut reader).await.unwrap();
205 assert_eq!(decoded["emoji"], "🔥🚀");
206 assert_eq!(decoded["cjk"], "日本語テスト");
207 }
208
209 #[tokio::test]
210 async fn large_message_near_limit() {
211 let big_string = "x".repeat(500_000);
212 let msg = serde_json::json!({"data": big_string});
213
214 let mut buf = Vec::new();
215 write_message(&mut buf, &msg).await.unwrap();
216
217 let mut reader = BufReader::new(buf.as_slice());
218 let decoded = read_message(&mut reader).await.unwrap();
219 assert_eq!(decoded["data"].as_str().unwrap().len(), 500_000);
220 }
221
222 #[tokio::test]
223 async fn write_message_length_prefix_correct() {
224 let msg = serde_json::json!({"a": 1});
225 let expected_json = serde_json::to_vec(&msg).unwrap();
226
227 let mut buf = Vec::new();
228 write_message(&mut buf, &msg).await.unwrap();
229
230 let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
231 assert_eq!(len, expected_json.len());
232 assert_eq!(&buf[4..], &expected_json);
233 }
234
235 #[tokio::test]
236 async fn truncated_body_is_io_error() {
237 let mut buf = Vec::new();
238 buf.extend_from_slice(&100u32.to_le_bytes());
239 buf.extend_from_slice(b"short");
240
241 let mut reader = BufReader::new(buf.as_slice());
242 let result = read_message(&mut reader).await;
243 assert!(matches!(result, Err(NativeMessageError::Io(_))));
244 }
245
246 #[tokio::test]
249 async fn message_at_exact_1mb_boundary_rejected() {
250 let len = MAX_INPUT_SIZE as u32 + 1;
251 let mut buf = Vec::new();
252 buf.extend_from_slice(&len.to_le_bytes());
253 buf.extend(vec![b'x'; MAX_INPUT_SIZE + 1]);
254
255 let mut reader = BufReader::new(buf.as_slice());
256 let result = read_message(&mut reader).await;
257 assert!(matches!(result, Err(NativeMessageError::TooLarge { .. })));
258 }
259
260 #[tokio::test]
261 async fn message_at_exact_1mb_minus_one_accepted_if_valid_json() {
262 let padding = "x".repeat(MAX_INPUT_SIZE - 20);
263 let json_str = format!(r#"{{"d":"{padding}"}}"#);
264 let json_bytes = json_str.as_bytes();
265 assert!(json_bytes.len() <= MAX_INPUT_SIZE);
266
267 let mut buf = Vec::new();
268 buf.extend_from_slice(&(json_bytes.len() as u32).to_le_bytes());
269 buf.extend_from_slice(json_bytes);
270
271 let mut reader = BufReader::new(buf.as_slice());
272 let decoded = read_message(&mut reader).await.unwrap();
273 assert_eq!(decoded["d"].as_str().unwrap().len(), padding.len());
274 }
275
276 #[tokio::test]
277 async fn deeply_nested_json_1000_levels() {
278 let mut json = String::from("null");
279 for _ in 0..1000 {
280 json = format!(r#"{{"n":{json}}}"#);
281 }
282 let json_bytes = json.as_bytes();
283 assert!(json_bytes.len() <= MAX_INPUT_SIZE);
284
285 let mut buf = Vec::new();
286 buf.extend_from_slice(&(json_bytes.len() as u32).to_le_bytes());
287 buf.extend_from_slice(json_bytes);
288
289 let mut reader = BufReader::new(buf.as_slice());
290 let result = read_message(&mut reader).await;
291 assert!(result.is_err());
293 }
294
295 #[tokio::test]
296 async fn json_with_null_bytes_in_strings() {
297 let msg = serde_json::json!({"data": "hello\u{0000}world"});
298 let mut buf = Vec::new();
299 write_message(&mut buf, &msg).await.unwrap();
300
301 let mut reader = BufReader::new(buf.as_slice());
302 let decoded = read_message(&mut reader).await.unwrap();
303 assert!(decoded["data"].as_str().unwrap().contains('\u{0000}'));
304 }
305
306 #[tokio::test]
307 async fn rapid_sequential_50_messages() {
308 let mut buf = Vec::new();
309 for i in 0..50 {
310 let msg = serde_json::json!({"seq": i, "ts": "2024-01-01T00:00:00Z"});
311 write_message(&mut buf, &msg).await.unwrap();
312 }
313
314 let mut reader = BufReader::new(buf.as_slice());
315 for i in 0..50 {
316 let decoded = read_message(&mut reader).await.unwrap();
317 assert_eq!(decoded["seq"], i);
318 }
319 let eof = read_message(&mut reader).await;
320 assert!(matches!(eof, Err(NativeMessageError::Disconnected)));
321 }
322
323 #[tokio::test]
324 async fn json_with_10000_keys() {
325 let mut map = serde_json::Map::new();
326 for i in 0..10_000 {
327 map.insert(format!("key_{i}"), serde_json::Value::Number(i.into()));
328 }
329 let msg = serde_json::Value::Object(map);
330
331 let mut buf = Vec::new();
332 write_message(&mut buf, &msg).await.unwrap();
333
334 let mut reader = BufReader::new(buf.as_slice());
335 let decoded = read_message(&mut reader).await.unwrap();
336 assert_eq!(decoded.as_object().unwrap().len(), 10_000);
337 }
338
339 #[tokio::test]
340 async fn output_at_64mb_limit_rejected() {
341 let big = "x".repeat(MAX_OUTPUT_SIZE + 1);
342 let msg = serde_json::json!({"huge": big});
343
344 let mut buf = Vec::new();
345 let result = write_message(&mut buf, &msg).await;
346 assert!(matches!(result, Err(NativeMessageError::TooLarge { .. })));
347 }
348
349 #[tokio::test]
350 async fn message_with_all_json_value_types() {
351 let msg = serde_json::json!({
352 "null": null,
353 "bool_true": true,
354 "bool_false": false,
355 "int": 42,
356 "float": 1.23456,
357 "negative": -999,
358 "string": "hello",
359 "array": [1, "two", null, [3]],
360 "object": {"nested": {"deep": true}},
361 "empty_array": [],
362 "empty_object": {},
363 "big_int": 9007199254740992_i64,
364 });
365
366 let mut buf = Vec::new();
367 write_message(&mut buf, &msg).await.unwrap();
368
369 let mut reader = BufReader::new(buf.as_slice());
370 let decoded = read_message(&mut reader).await.unwrap();
371 assert_eq!(decoded, msg);
372 }
373
374 #[tokio::test]
375 async fn length_prefix_u32_max_rejected() {
376 let mut buf = Vec::new();
377 buf.extend_from_slice(&u32::MAX.to_le_bytes());
378 buf.extend(vec![0u8; 100]);
379
380 let mut reader = BufReader::new(buf.as_slice());
381 let result = read_message(&mut reader).await;
382 assert!(matches!(result, Err(NativeMessageError::TooLarge { .. })));
383 }
384
385 #[tokio::test]
386 async fn interleaved_write_read_sequence() {
387 let mut buf = Vec::new();
388
389 let msg1 = serde_json::json!({"phase": "init"});
390 write_message(&mut buf, &msg1).await.unwrap();
391 let msg2 = serde_json::json!({"phase": "execute", "data": [1,2,3]});
392 write_message(&mut buf, &msg2).await.unwrap();
393 let msg3 = serde_json::json!({"phase": "complete", "ok": true});
394 write_message(&mut buf, &msg3).await.unwrap();
395
396 let mut reader = BufReader::new(buf.as_slice());
397 assert_eq!(read_message(&mut reader).await.unwrap()["phase"], "init");
398 assert_eq!(read_message(&mut reader).await.unwrap()["phase"], "execute");
399 assert_eq!(
400 read_message(&mut reader).await.unwrap()["phase"],
401 "complete"
402 );
403 }
404
405 #[tokio::test]
406 async fn unicode_surrogate_edge_cases() {
407 let msg = serde_json::json!({
408 "emoji_sequence": "👨👩👧👦",
409 "zalgo": "h̷̡̢̧e̵̢̧̛l̸̨̧̛l̵̡̢̧ơ̷̢̧",
410 "rtl": "مرحبا",
411 "combining": "a\u{0300}\u{0301}\u{0302}",
412 });
413
414 let mut buf = Vec::new();
415 write_message(&mut buf, &msg).await.unwrap();
416
417 let mut reader = BufReader::new(buf.as_slice());
418 let decoded = read_message(&mut reader).await.unwrap();
419 assert_eq!(decoded["emoji_sequence"], "👨👩👧👦");
420 }
421
422 #[tokio::test]
425 async fn truncated_body_causes_io_error() {
426 let mut buf = Vec::new();
427 buf.extend_from_slice(&100u32.to_le_bytes());
428 buf.extend(vec![b'x'; 50]);
429
430 let mut reader = BufReader::new(buf.as_slice());
431 let result = read_message(&mut reader).await;
432 assert!(matches!(result, Err(NativeMessageError::Io(_))));
433 }
434
435 #[tokio::test]
436 async fn zero_length_message_is_invalid_json() {
437 let mut buf = Vec::new();
438 buf.extend_from_slice(&0u32.to_le_bytes());
439
440 let mut reader = BufReader::new(buf.as_slice());
441 let result = read_message(&mut reader).await;
442 assert!(matches!(result, Err(NativeMessageError::Json(_))));
443 }
444
445 #[tokio::test]
446 async fn partial_length_prefix_causes_disconnect() {
447 let buf = [0x10, 0x00];
448 let mut reader = BufReader::new(buf.as_slice());
449 let result = read_message(&mut reader).await;
450 assert!(matches!(result, Err(NativeMessageError::Disconnected)));
451 }
452
453 #[tokio::test]
454 async fn write_then_read_back_large_near_limit() {
455 let big_string = "a".repeat(900_000);
456 let msg = serde_json::json!({"data": big_string});
457
458 let mut buf = Vec::new();
459 write_message(&mut buf, &msg).await.unwrap();
460
461 let mut reader = BufReader::new(buf.as_slice());
462 let decoded = read_message(&mut reader).await.unwrap();
463 assert_eq!(decoded["data"].as_str().unwrap().len(), 900_000);
464 }
465
466 #[tokio::test]
467 async fn multiple_messages_then_disconnect() {
468 let mut buf = Vec::new();
469 write_message(&mut buf, &serde_json::json!({"n": 1}))
470 .await
471 .unwrap();
472 write_message(&mut buf, &serde_json::json!({"n": 2}))
473 .await
474 .unwrap();
475
476 let mut reader = BufReader::new(buf.as_slice());
477 assert_eq!(read_message(&mut reader).await.unwrap()["n"], 1);
478 assert_eq!(read_message(&mut reader).await.unwrap()["n"], 2);
479 assert!(matches!(
480 read_message(&mut reader).await,
481 Err(NativeMessageError::Disconnected)
482 ));
483 }
484
485 #[tokio::test]
486 async fn invalid_utf8_in_body_is_json_error() {
487 let invalid = [0xFF, 0xFE, 0xFD, 0xFC, 0xFB];
488 let mut buf = Vec::new();
489 buf.extend_from_slice(&(invalid.len() as u32).to_le_bytes());
490 buf.extend_from_slice(&invalid);
491
492 let mut reader = BufReader::new(buf.as_slice());
493 let result = read_message(&mut reader).await;
494 assert!(matches!(result, Err(NativeMessageError::Json(_))));
495 }
496
497 #[tokio::test]
498 async fn valid_json_non_object_types() {
499 let values: Vec<serde_json::Value> = vec![
500 serde_json::json!([1, 2, 3]),
501 serde_json::json!("just a string"),
502 serde_json::json!(42),
503 serde_json::json!(true),
504 serde_json::json!(null),
505 ];
506
507 for val in values {
508 let mut buf = Vec::new();
509 write_message(&mut buf, &val).await.unwrap();
510
511 let mut reader = BufReader::new(buf.as_slice());
512 let decoded = read_message(&mut reader).await.unwrap();
513 assert_eq!(decoded, val);
514 }
515 }
516
517 #[tokio::test]
518 async fn length_prefix_exactly_at_boundary() {
519 let json_body = format!("{{\"x\":\"{}\"}}", "a".repeat(MAX_INPUT_SIZE - 8));
520 assert!(json_body.len() <= MAX_INPUT_SIZE);
521
522 let mut buf = Vec::new();
523 buf.extend_from_slice(&(json_body.len() as u32).to_le_bytes());
524 buf.extend_from_slice(json_body.as_bytes());
525
526 let mut reader = BufReader::new(buf.as_slice());
527 let result = read_message(&mut reader).await;
528 assert!(result.is_ok());
529 }
530
531 #[tokio::test]
532 async fn write_verifies_length_prefix_correctness() {
533 let big_val = "x".repeat(10_000_000);
534 let msg = serde_json::json!({"data": big_val});
535
536 let mut buf = Vec::new();
537 let result = write_message(&mut buf, &msg).await;
538 assert!(result.is_ok());
539 let written_len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
540 assert_eq!(written_len + 4, buf.len());
541 }
542}