Skip to main content

libgrite_ipc/
framing.rs

1//! Length-prefixed message framing for Unix socket IPC
2//!
3//! Wire format: [4 bytes: payload length as u32 big-endian][payload bytes]
4
5use std::io::{self, Read, Write};
6
7/// Maximum message size (16 MB) to prevent allocation bombs
8const MAX_MESSAGE_SIZE: u32 = 16 * 1024 * 1024;
9
10/// Write a length-prefixed message to a synchronous stream
11pub fn write_framed(stream: &mut impl Write, data: &[u8]) -> io::Result<()> {
12    let len: u32 = data.len().try_into().map_err(|_| {
13        io::Error::new(
14            io::ErrorKind::InvalidInput,
15            format!(
16                "Message too large: {} bytes (max {})",
17                data.len(),
18                MAX_MESSAGE_SIZE
19            ),
20        )
21    })?;
22    if len > MAX_MESSAGE_SIZE {
23        return Err(io::Error::new(
24            io::ErrorKind::InvalidInput,
25            format!(
26                "Message too large: {} bytes (max {})",
27                len, MAX_MESSAGE_SIZE
28            ),
29        ));
30    }
31    stream.write_all(&len.to_be_bytes())?;
32    stream.write_all(data)?;
33    stream.flush()
34}
35
36/// Read a length-prefixed message from a synchronous stream
37pub fn read_framed(stream: &mut impl Read) -> io::Result<Vec<u8>> {
38    let mut len_buf = [0u8; 4];
39    stream.read_exact(&mut len_buf)?;
40    let len = u32::from_be_bytes(len_buf);
41
42    if len > MAX_MESSAGE_SIZE {
43        return Err(io::Error::new(
44            io::ErrorKind::InvalidData,
45            format!(
46                "Message too large: {} bytes (max {})",
47                len, MAX_MESSAGE_SIZE
48            ),
49        ));
50    }
51
52    let mut buf = vec![0u8; len as usize];
53    stream.read_exact(&mut buf)?;
54    Ok(buf)
55}
56
57/// Write a length-prefixed message to an async tokio stream
58pub async fn write_framed_async(
59    stream: &mut (impl tokio::io::AsyncWriteExt + Unpin),
60    data: &[u8],
61) -> io::Result<()> {
62    let len: u32 = data.len().try_into().map_err(|_| {
63        io::Error::new(
64            io::ErrorKind::InvalidInput,
65            format!(
66                "Message too large: {} bytes (max {})",
67                data.len(),
68                MAX_MESSAGE_SIZE
69            ),
70        )
71    })?;
72    if len > MAX_MESSAGE_SIZE {
73        return Err(io::Error::new(
74            io::ErrorKind::InvalidInput,
75            format!(
76                "Message too large: {} bytes (max {})",
77                len, MAX_MESSAGE_SIZE
78            ),
79        ));
80    }
81    stream.write_all(&len.to_be_bytes()).await?;
82    stream.write_all(data).await?;
83    stream.flush().await
84}
85
86/// Read a length-prefixed message from an async tokio stream
87pub async fn read_framed_async(
88    stream: &mut (impl tokio::io::AsyncReadExt + Unpin),
89) -> io::Result<Vec<u8>> {
90    let mut len_buf = [0u8; 4];
91    stream.read_exact(&mut len_buf).await?;
92    let len = u32::from_be_bytes(len_buf);
93
94    if len > MAX_MESSAGE_SIZE {
95        return Err(io::Error::new(
96            io::ErrorKind::InvalidData,
97            format!(
98                "Message too large: {} bytes (max {})",
99                len, MAX_MESSAGE_SIZE
100            ),
101        ));
102    }
103
104    let mut buf = vec![0u8; len as usize];
105    stream.read_exact(&mut buf).await?;
106    Ok(buf)
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use std::io::Cursor;
113
114    #[test]
115    fn test_sync_roundtrip() {
116        let data = b"hello world";
117        let mut buf = Vec::new();
118        write_framed(&mut buf, data).unwrap();
119
120        let mut cursor = Cursor::new(buf);
121        let result = read_framed(&mut cursor).unwrap();
122        assert_eq!(result, data);
123    }
124
125    #[test]
126    fn test_empty_message() {
127        let data = b"";
128        let mut buf = Vec::new();
129        write_framed(&mut buf, data).unwrap();
130
131        let mut cursor = Cursor::new(buf);
132        let result = read_framed(&mut cursor).unwrap();
133        assert_eq!(result, data.to_vec());
134    }
135
136    #[test]
137    fn test_message_too_large() {
138        let len = (MAX_MESSAGE_SIZE + 1).to_be_bytes();
139        let mut cursor = Cursor::new(len.to_vec());
140        let result = read_framed(&mut cursor);
141        assert!(result.is_err());
142        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
143    }
144
145    #[tokio::test]
146    async fn test_async_roundtrip() {
147        let data = b"hello async world";
148        let mut buf = Vec::new();
149        write_framed_async(&mut buf, data).await.unwrap();
150
151        let mut cursor = io::Cursor::new(buf);
152        let result = read_framed_async(&mut cursor).await.unwrap();
153        assert_eq!(result, data);
154    }
155}