matrixcode_core/matrixrpc/transport/
codec.rs1use std::io;
16
17use crate::matrixrpc::protocol::JsonRpcMessage;
18
19#[derive(Debug, Default)]
23pub struct FrameCodec {
24 max_message_size: usize,
26}
27
28impl FrameCodec {
29 pub fn new() -> Self {
31 Self {
32 max_message_size: 16 * 1024 * 1024, }
34 }
35
36 pub fn with_max_size(max_message_size: usize) -> Self {
38 Self { max_message_size }
39 }
40
41 pub fn encode(&self, message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
45 let json = message.to_json().map_err(|e| {
46 io::Error::new(
47 io::ErrorKind::InvalidData,
48 format!("JSON encode error: {}", e),
49 )
50 })?;
51
52 let json_bytes = json.into_bytes();
53 if json_bytes.len() > self.max_message_size {
54 return Err(io::Error::new(
55 io::ErrorKind::InvalidData,
56 format!(
57 "Message size {} exceeds maximum {}",
58 json_bytes.len(),
59 self.max_message_size
60 ),
61 ));
62 }
63
64 let header = format!("Content-Length: {}\r\n\r\n", json_bytes.len());
65 let mut frame = header.into_bytes();
66 frame.extend(json_bytes);
67
68 Ok(frame)
69 }
70
71 pub fn encode_to_writer<W: std::io::Write>(
73 &self,
74 writer: &mut W,
75 message: &JsonRpcMessage,
76 ) -> io::Result<()> {
77 let frame = self.encode(message)?;
78 writer.write_all(&frame)?;
79 writer.flush()?;
80 Ok(())
81 }
82
83 pub fn decode_from_buffer<'a>(
88 &self,
89 buffer: &'a [u8],
90 ) -> io::Result<(&'a [u8], Option<JsonRpcMessage>)> {
91 let header_end = match find_header_end(buffer) {
93 Some(pos) => pos,
94 None => return Ok((buffer, None)), };
96
97 let header_str = std::str::from_utf8(&buffer[..header_end]).map_err(|e| {
99 io::Error::new(
100 io::ErrorKind::InvalidData,
101 format!("Invalid UTF-8 in headers: {}", e),
102 )
103 })?;
104
105 let content_length = parse_content_length(header_str)?;
106
107 let body_start = header_end + 4; if buffer.len() < body_start + content_length {
110 return Ok((buffer, None)); }
112
113 let body = &buffer[body_start..body_start + content_length];
115 let json_str = std::str::from_utf8(body).map_err(|e| {
116 io::Error::new(
117 io::ErrorKind::InvalidData,
118 format!("Invalid UTF-8 in body: {}", e),
119 )
120 })?;
121
122 let message = JsonRpcMessage::from_json(json_str).map_err(|e| {
123 io::Error::new(
124 io::ErrorKind::InvalidData,
125 format!("JSON decode error: {}", e),
126 )
127 })?;
128
129 let remaining = &buffer[body_start + content_length..];
131 Ok((remaining, Some(message)))
132 }
133
134 pub fn max_message_size(&self) -> usize {
136 self.max_message_size
137 }
138}
139
140fn find_header_end(buffer: &[u8]) -> Option<usize> {
142 let pattern = b"\r\n\r\n";
143 if buffer.len() < 4 {
144 return None;
145 }
146
147 for i in 0..=buffer.len() - 4 {
148 if &buffer[i..i + 4] == pattern {
149 return Some(i);
150 }
151 }
152 None
153}
154
155fn parse_content_length(headers: &str) -> io::Result<usize> {
157 for line in headers.lines() {
158 let line = line.trim();
159 if let Some((key, value)) = line.split_once(':') {
160 if key.trim().eq_ignore_ascii_case("Content-Length") {
161 let length: usize = value.trim().parse().map_err(|e| {
162 io::Error::new(
163 io::ErrorKind::InvalidData,
164 format!("Invalid Content-Length: {}", e),
165 )
166 })?;
167 return Ok(length);
168 }
169 }
170 }
171 Err(io::Error::new(
172 io::ErrorKind::InvalidData,
173 "Missing Content-Length header",
174 ))
175}
176
177#[allow(dead_code)]
178pub fn encode_message(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
180#[allow(dead_code)]
181 FrameCodec::new().encode(message)
182}
183
184#[allow(dead_code)]
185#[allow(dead_code)]
189pub fn decode_message_from_buffer(buffer: &[u8]) -> io::Result<(Vec<u8>, JsonRpcMessage)> {
190 let codec = FrameCodec::new();
191 let (remaining, message) = codec.decode_from_buffer(buffer)?;
192 match message {
193 Some(msg) => Ok((remaining.to_vec(), msg)),
194 None => Err(io::Error::new(
195 io::ErrorKind::UnexpectedEof,
196 "Incomplete message",
197 )),
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use serde_json::json;
205
206 #[test]
207 fn test_encode_message() {
208 let request = JsonRpcMessage::Request(
209 crate::matrixrpc::protocol::JsonRpcRequest::new("test_method")
210 .params(json!({"key": "value"})),
211 );
212
213 let codec = FrameCodec::new();
214 let frame = codec.encode(&request).unwrap();
215
216 let frame_str = String::from_utf8_lossy(&frame);
217 assert!(frame_str.starts_with("Content-Length:"));
218 assert!(frame_str.contains("\r\n\r\n"));
219 assert!(frame_str.contains("\"method\":\"test_method\""));
220 }
221
222 #[test]
223 fn test_decode_message() {
224 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
225 let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
226
227 let codec = FrameCodec::new();
228 let (remaining, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
229
230 assert!(message.is_some());
231 let msg = message.unwrap();
232 assert!(msg.is_request());
233 assert!(remaining.is_empty());
234 }
235
236 #[test]
237 fn test_encode_decode_roundtrip() {
238 let request = JsonRpcMessage::Request(
239 crate::matrixrpc::protocol::JsonRpcRequest::with_id("test_method", 42)
240 .params(json!({"arg": "value"})),
241 );
242
243 let codec = FrameCodec::new();
244 let frame = codec.encode(&request).unwrap();
245
246 let (_, decoded) = codec.decode_from_buffer(&frame).unwrap();
247 let decoded = decoded.unwrap();
248
249 assert_eq!(
250 decoded.as_request().unwrap().method,
251 request.as_request().unwrap().method
252 );
253 }
254
255 #[test]
256 fn test_max_message_size() {
257 let codec = FrameCodec::with_max_size(10);
258 let request =
259 JsonRpcMessage::Request(crate::matrixrpc::protocol::JsonRpcRequest::new("test"));
260
261 let result = codec.encode(&request);
262 assert!(result.is_err());
263 assert!(matches!(
264 result.unwrap_err().kind(),
265 io::ErrorKind::InvalidData
266 ));
267 }
268
269 #[test]
270 fn test_incomplete_message() {
271 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
272 let partial_frame = format!("Content-Length: {}\r\n\r\n", json.len()); let codec = FrameCodec::new();
275 let result = codec.decode_from_buffer(partial_frame.as_bytes()).unwrap();
276
277 assert!(result.1.is_none());
278 }
279
280 #[test]
281 fn test_multiple_messages_in_buffer() {
282 let json1 = r#"{"jsonrpc":"2.0","method":"test1","id":1}"#;
283 let json2 = r#"{"jsonrpc":"2.0","method":"test2","id":2}"#;
284
285 let codec = FrameCodec::new();
286 let mut buffer = Vec::new();
287 buffer.extend(
288 codec
289 .encode(&JsonRpcMessage::Request(
290 crate::matrixrpc::protocol::JsonRpcRequest::from_json(json1).unwrap(),
291 ))
292 .unwrap(),
293 );
294 buffer.extend(
295 codec
296 .encode(&JsonRpcMessage::Request(
297 crate::matrixrpc::protocol::JsonRpcRequest::from_json(json2).unwrap(),
298 ))
299 .unwrap(),
300 );
301
302 let (remaining1, msg1) = codec.decode_from_buffer(&buffer).unwrap();
304 let msg1 = msg1.unwrap();
305 assert_eq!(msg1.as_request().unwrap().method, "test1");
306
307 let (_, msg2) = codec.decode_from_buffer(remaining1).unwrap();
309 let msg2 = msg2.unwrap();
310 assert_eq!(msg2.as_request().unwrap().method, "test2");
311 }
312
313 #[test]
314 fn test_convenience_functions() {
315 let request = JsonRpcMessage::Request(crate::matrixrpc::protocol::JsonRpcRequest::new(
316 "test_method",
317 ));
318
319 let encoded = encode_message(&request).unwrap();
320 let (_, decoded) = decode_message_from_buffer(&encoded).unwrap();
321
322 assert!(decoded.is_request());
323 }
324
325 #[test]
328 fn test_decode_missing_content_length() {
329 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
330 let frame = format!("Content-Type: application/json\r\n\r\n{}", json);
331
332 let codec = FrameCodec::new();
333 let result = codec.decode_from_buffer(frame.as_bytes());
334 assert!(result.is_err());
335 let err = result.unwrap_err();
336 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
337 assert!(err.to_string().contains("Missing Content-Length"));
338 }
339
340 #[test]
341 fn test_decode_malformed_content_length() {
342 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
343 let frame = format!("Content-Length: abc\r\n\r\n{}", json);
344
345 let codec = FrameCodec::new();
346 let result = codec.decode_from_buffer(frame.as_bytes());
347 assert!(result.is_err());
348 let err = result.unwrap_err();
349 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
350 assert!(err.to_string().contains("Invalid Content-Length"));
351 }
352
353 #[test]
354 fn test_decode_negative_content_length() {
355 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
356 let frame = format!("Content-Length: -1\r\n\r\n{}", json);
357
358 let codec = FrameCodec::new();
359 let result = codec.decode_from_buffer(frame.as_bytes());
360 assert!(result.is_err());
361 }
362
363 #[test]
364 fn test_decode_case_insensitive_header() {
365 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
366 for header in [
368 "content-length",
369 "CONTENT-LENGTH",
370 "Content-length",
371 "CONTENT-length",
372 ] {
373 let frame = format!("{}: {}\r\n\r\n{}", header, json.len(), json);
374 let codec = FrameCodec::new();
375 let (_, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
376 assert!(
377 message.is_some(),
378 "Failed to parse with header: {}",
379 header
380 );
381 }
382 }
383
384 #[test]
385 fn test_decode_with_extra_headers() {
386 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
387 let frame = format!(
388 "Content-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
389 json.len(),
390 json
391 );
392
393 let codec = FrameCodec::new();
394 let (_, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
395 assert!(message.is_some());
396 }
397
398 #[test]
399 fn test_decode_zero_content_length() {
400 let frame = "Content-Length: 0\r\n\r\n";
401
402 let codec = FrameCodec::new();
403 let result = codec.decode_from_buffer(frame.as_bytes());
405 assert!(result.is_err());
406 }
407
408 #[test]
409 fn test_decode_invalid_utf8_in_header() {
410 let invalid_bytes = b"Content-Length: \xFF\xFE\r\n\r\n{}";
413
414 let codec = FrameCodec::new();
415 let result = codec.decode_from_buffer(invalid_bytes);
416 assert!(result.is_err());
417 let err = result.unwrap_err();
418 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
419 assert!(err.to_string().contains("Invalid UTF-8"));
420 }
421
422 #[test]
423 fn test_decode_invalid_json_body() {
424 let invalid_json = r#"{"jsonrpc":"2.0","method":}"#; let frame = format!("Content-Length: {}\r\n\r\n{}", invalid_json.len(), invalid_json);
426
427 let codec = FrameCodec::new();
428 let result = codec.decode_from_buffer(frame.as_bytes());
429 assert!(result.is_err());
430 let err = result.unwrap_err();
431 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
432 assert!(err.to_string().contains("JSON decode error"));
433 }
434
435 #[test]
436 fn test_decode_empty_buffer() {
437 let codec = FrameCodec::new();
438 let (remaining, message) = codec.decode_from_buffer(b"").unwrap();
439 assert!(message.is_none());
440 assert!(remaining.is_empty());
441 }
442
443 #[test]
444 fn test_decode_partial_header() {
445 let partial = b"Content-Length: 10";
446
447 let codec = FrameCodec::new();
448 let (remaining, message) = codec.decode_from_buffer(partial).unwrap();
449 assert!(message.is_none());
450 assert_eq!(remaining, partial);
451 }
452
453 #[test]
454 fn test_decode_partial_body() {
455 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
456 let partial = format!("Content-Length: 100\r\n\r\n{}", json);
458
459 let codec = FrameCodec::new();
460 let (remaining, message) = codec
461 .decode_from_buffer(partial.as_bytes())
462 .unwrap();
463 assert!(message.is_none());
464 assert!(!remaining.is_empty());
465 }
466
467 #[test]
468 fn test_encode_response_message() {
469 let response = JsonRpcMessage::Response(
470 crate::matrixrpc::protocol::JsonRpcResponse::success(1, json!({"result": "ok"})),
471 );
472
473 let codec = FrameCodec::new();
474 let frame = codec.encode(&response).unwrap();
475 let frame_str = String::from_utf8_lossy(&frame);
476
477 assert!(frame_str.contains("\"result\":"));
478 assert!(frame_str.contains("\"ok\""));
479 }
480
481 #[test]
482 fn test_encode_error_response() {
483 let error = JsonRpcMessage::Response(
484 crate::matrixrpc::protocol::JsonRpcResponse::error(
485 1,
486 crate::matrixrpc::protocol::JsonRpcError::method_not_found("unknown"),
487 ),
488 );
489
490 let codec = FrameCodec::new();
491 let frame = codec.encode(&error).unwrap();
492 let frame_str = String::from_utf8_lossy(&frame);
493
494 assert!(frame_str.contains("\"error\""));
495 assert!(frame_str.contains("Method 'unknown' not found"));
496 }
497
498 #[test]
499 fn test_encode_batch_message() {
500 let batch = JsonRpcMessage::Batch(vec![
501 JsonRpcMessage::Request(
502 crate::matrixrpc::protocol::JsonRpcRequest::new("method1"),
503 ),
504 JsonRpcMessage::Request(
505 crate::matrixrpc::protocol::JsonRpcRequest::new("method2"),
506 ),
507 ]);
508
509 let codec = FrameCodec::new();
510 let frame = codec.encode(&batch).unwrap();
511 let frame_str = String::from_utf8_lossy(&frame);
512
513 assert!(frame_str.starts_with('[') || frame_str.contains("["));
514 assert!(frame_str.contains("method1"));
515 assert!(frame_str.contains("method2"));
516 }
517
518 #[test]
519 fn test_encode_notification() {
520 let notification = JsonRpcMessage::Request(
521 crate::matrixrpc::protocol::JsonRpcRequest::notification("notify_event")
522 .params(json!({"event": "test"})),
523 );
524
525 let codec = FrameCodec::new();
526 let frame = codec.encode(¬ification).unwrap();
527 let frame_str = String::from_utf8_lossy(&frame);
528
529 let body_start = frame_str.find("\r\n\r\n").unwrap() + 4;
531 let body = &frame_str[body_start..];
532 let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
533 assert!(parsed.get("id").is_none());
534 assert_eq!(parsed["method"], "notify_event");
535 }
536
537 #[test]
538 fn test_decode_with_trailing_data() {
539 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
540 let frame = format!("Content-Length: {}\r\n\r\n{}extra_data", json.len(), json);
541
542 let codec = FrameCodec::new();
543 let (remaining, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
544
545 assert!(message.is_some());
546 assert_eq!(remaining, b"extra_data");
547 }
548
549 #[test]
550 fn test_decode_message_from_buffer_incomplete() {
551 let partial = b"Content-Length: 100\r\n\r\n{}";
552 let result = decode_message_from_buffer(partial);
553 assert!(result.is_err());
554 let err = result.unwrap_err();
555 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
556 }
557
558 #[test]
559 fn test_content_length_whitespace() {
560 let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
561 let frame = format!("Content-Length: {} \r\n\r\n{}", json.len(), json);
563
564 let codec = FrameCodec::new();
565 let (_, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
566 assert!(message.is_some());
567 }
568
569 #[test]
570 fn test_large_message_within_limit() {
571 let large_params = "x".repeat(1024 * 1024); let request = JsonRpcMessage::Request(
574 crate::matrixrpc::protocol::JsonRpcRequest::new("test").params(json!({"data": large_params})),
575 );
576
577 let codec = FrameCodec::new();
578 let result = codec.encode(&request);
579 assert!(result.is_ok());
580 }
581}