Skip to main content

travsr_plugin_protocol/
codec.rs

1use std::io::{self, Read, Write};
2
3/// Hard cap on a single frame's payload (256 MiB). A plugin is an untrusted
4/// peer over the wire: the 4-byte length prefix is attacker-controlled and a
5/// hostile/buggy plugin could send `0xFFFF_FFFF` to make the daemon allocate
6/// 4 GiB and OOM. We refuse any frame larger than this before allocating.
7/// 256 MiB comfortably exceeds the largest realistic ParseOutput/InvokeResponse.
8pub const MAX_FRAME_LEN: usize = 256 * 1024 * 1024;
9
10pub fn encode_message<T: serde::Serialize>(msg: &T) -> io::Result<Vec<u8>> {
11    let payload =
12        serde_json::to_vec(msg).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
13    if payload.len() > MAX_FRAME_LEN {
14        return Err(io::Error::new(
15            io::ErrorKind::InvalidData,
16            format!(
17                "frame payload {} bytes exceeds MAX_FRAME_LEN ({MAX_FRAME_LEN} bytes)",
18                payload.len()
19            ),
20        ));
21    }
22    let len = u32::try_from(payload.len())
23        .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload too large"))?;
24    let mut frame = Vec::with_capacity(4 + payload.len());
25    frame.extend_from_slice(&len.to_be_bytes());
26    frame.extend_from_slice(&payload);
27    Ok(frame)
28}
29
30pub fn decode_message<T: serde::de::DeserializeOwned>(reader: &mut impl Read) -> io::Result<T> {
31    let mut len_buf = [0u8; 4];
32    reader.read_exact(&mut len_buf)?;
33    let len = u32::from_be_bytes(len_buf) as usize;
34    // Reject oversized frames BEFORE allocating — prevents a hostile plugin from
35    // triggering a multi-gigabyte allocation via a forged length prefix (DoS).
36    if len > MAX_FRAME_LEN {
37        return Err(io::Error::new(
38            io::ErrorKind::InvalidData,
39            format!("frame length {len} exceeds MAX_FRAME_LEN ({MAX_FRAME_LEN} bytes)"),
40        ));
41    }
42    let mut payload = vec![0u8; len];
43    reader.read_exact(&mut payload)?;
44    serde_json::from_slice(&payload).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
45}
46
47pub fn write_message<T: serde::Serialize>(writer: &mut impl Write, msg: &T) -> io::Result<()> {
48    let frame = encode_message(msg)?;
49    writer.write_all(&frame)?;
50    writer.flush()
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::types::ParseRequest;
57    use std::io::Cursor;
58    use std::path::PathBuf;
59
60    #[test]
61    fn round_trip_parse_request() {
62        let req = ParseRequest {
63            path: PathBuf::from("src/main.ts"),
64            vname_path: "src/main.ts".into(),
65            corpus: "github.com/acme/foo".into(),
66            package: "acme".into(),
67            source: None,
68        };
69        let payload = serde_json::to_vec(&req).unwrap();
70        let encoded = encode_message(&req).unwrap();
71        // Frame must be exactly 4-byte length prefix + payload.
72        assert_eq!(encoded.len(), 4 + payload.len());
73        // First 4 bytes must be the big-endian payload length.
74        let len = u32::from_be_bytes(encoded[..4].try_into().unwrap()) as usize;
75        assert_eq!(len, payload.len());
76        let mut cursor = Cursor::new(&encoded);
77        let decoded: ParseRequest = decode_message(&mut cursor).unwrap();
78        assert_eq!(decoded.path, req.path);
79        assert_eq!(decoded.vname_path, req.vname_path);
80        assert_eq!(decoded.corpus, req.corpus);
81    }
82
83    #[test]
84    fn decode_rejects_oversized_frame_without_allocating() {
85        // Forge a length prefix of u32::MAX (~4 GiB) with no payload behind it.
86        // decode_message must reject it on the length check, never attempting the
87        // multi-gigabyte allocation, and never blocking on read_exact for a body.
88        let mut framed = (u32::MAX).to_be_bytes().to_vec();
89        framed.extend_from_slice(b"only a few bytes follow");
90        let mut cursor = Cursor::new(framed);
91        let err = decode_message::<ParseRequest>(&mut cursor).unwrap_err();
92        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
93        assert!(
94            err.to_string().contains("MAX_FRAME_LEN"),
95            "expected MAX_FRAME_LEN rejection, got: {err}"
96        );
97    }
98
99    #[test]
100    fn decode_accepts_frame_at_the_boundary() {
101        // A payload exactly at the cap must still decode (it is a valid JSON string).
102        let big = "x".repeat(1024);
103        let encoded = encode_message(&big).unwrap();
104        let mut cursor = Cursor::new(encoded);
105        let decoded: String = decode_message(&mut cursor).unwrap();
106        assert_eq!(decoded, big);
107    }
108}