1use std::io::{self, Read, Write};
9
10pub const MAX_PACKET_DATA: usize = 65516;
12
13pub struct Reader<R: Read> {
15 inner: R,
16}
17
18impl<R: Read> Reader<R> {
19 pub fn new(inner: R) -> Self {
20 Self { inner }
21 }
22
23 pub fn read_packet(&mut self) -> io::Result<Option<Vec<u8>>> {
27 let mut hdr = [0u8; 4];
28 self.inner.read_exact(&mut hdr)?;
29 let len_str = std::str::from_utf8(&hdr).map_err(|_| {
30 io::Error::new(io::ErrorKind::InvalidData, "non-ASCII pktline length")
31 })?;
32 let len = u32::from_str_radix(len_str, 16)
33 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid pktline length"))?;
34 if len == 0 {
35 return Ok(None);
36 }
37 if len < 4 {
38 return Err(io::Error::new(
39 io::ErrorKind::InvalidData,
40 format!("pktline length {len} < 4"),
41 ));
42 }
43 let body_len = (len - 4) as usize;
44 let mut buf = vec![0u8; body_len];
45 self.inner.read_exact(&mut buf)?;
46 Ok(Some(buf))
47 }
48
49 pub fn read_text(&mut self) -> io::Result<Option<String>> {
52 let Some(mut bytes) = self.read_packet()? else {
53 return Ok(None);
54 };
55 if bytes.last() == Some(&b'\n') {
56 bytes.pop();
57 }
58 String::from_utf8(bytes)
59 .map(Some)
60 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
61 }
62}
63
64pub struct Writer<W: Write> {
66 inner: W,
67}
68
69impl<W: Write> Writer<W> {
70 pub fn new(inner: W) -> Self {
71 Self { inner }
72 }
73
74 pub fn write_packet(&mut self, data: &[u8]) -> io::Result<()> {
76 if data.len() > MAX_PACKET_DATA {
77 return Err(io::Error::new(
78 io::ErrorKind::InvalidInput,
79 format!("packet of {} bytes exceeds {MAX_PACKET_DATA}", data.len()),
80 ));
81 }
82 let total = data.len() + 4;
83 write!(self.inner, "{total:04x}")?;
84 self.inner.write_all(data)?;
85 Ok(())
86 }
87
88 pub fn write_text(&mut self, text: &str) -> io::Result<()> {
90 let mut buf = String::with_capacity(text.len() + 1);
91 buf.push_str(text);
92 buf.push('\n');
93 self.write_packet(buf.as_bytes())
94 }
95
96 pub fn write_flush(&mut self) -> io::Result<()> {
98 self.inner.write_all(b"0000")
99 }
100
101 pub fn flush(&mut self) -> io::Result<()> {
103 self.inner.flush()
104 }
105}
106
107pub struct Sink<'a, W: Write> {
114 writer: &'a mut Writer<W>,
115 buf: Vec<u8>,
116}
117
118impl<'a, W: Write> Sink<'a, W> {
119 pub fn new(writer: &'a mut Writer<W>) -> Self {
120 Self {
121 writer,
122 buf: Vec::with_capacity(MAX_PACKET_DATA),
123 }
124 }
125}
126
127impl<W: Write> Write for Sink<'_, W> {
128 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
129 let space = MAX_PACKET_DATA - self.buf.len();
130 let n = buf.len().min(space);
131 self.buf.extend_from_slice(&buf[..n]);
132 if self.buf.len() == MAX_PACKET_DATA {
133 self.writer.write_packet(&self.buf)?;
134 self.buf.clear();
135 }
136 Ok(n)
137 }
138
139 fn flush(&mut self) -> io::Result<()> {
140 if !self.buf.is_empty() {
141 self.writer.write_packet(&self.buf)?;
142 self.buf.clear();
143 }
144 Ok(())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::io::Cursor;
152
153 #[test]
154 fn round_trip_text_packet() {
155 let mut buf = Vec::new();
156 Writer::new(&mut buf).write_text("hello").unwrap();
157 assert_eq!(buf, b"000ahello\n");
159 let mut r = Reader::new(Cursor::new(&buf));
160 assert_eq!(r.read_text().unwrap().as_deref(), Some("hello"));
161 }
162
163 #[test]
164 fn flush_round_trip() {
165 let mut buf = Vec::new();
166 Writer::new(&mut buf).write_flush().unwrap();
167 assert_eq!(buf, b"0000");
168 let mut r = Reader::new(Cursor::new(&buf));
169 assert_eq!(r.read_packet().unwrap(), None);
170 }
171
172 #[test]
173 fn binary_packet_round_trips() {
174 let payload = b"\x00\x01\x02\xffbytes";
175 let mut buf = Vec::new();
176 Writer::new(&mut buf).write_packet(payload).unwrap();
177 let mut r = Reader::new(Cursor::new(&buf));
178 assert_eq!(r.read_packet().unwrap().as_deref(), Some(&payload[..]));
179 }
180
181 #[test]
182 fn rejects_oversized_packet() {
183 let big = vec![0u8; MAX_PACKET_DATA + 1];
184 let err = Writer::new(Vec::new()).write_packet(&big).unwrap_err();
185 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
186 }
187
188 #[test]
189 fn invalid_length_header() {
190 let mut r = Reader::new(Cursor::new(b"zzzz"));
191 let err = r.read_packet().unwrap_err();
192 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
193 }
194
195 #[test]
196 fn sink_chunks_at_packet_boundary() {
197 let mut buf = Vec::new();
198 let mut writer = Writer::new(&mut buf);
199 let mut sink = Sink::new(&mut writer);
200 let big = vec![b'x'; MAX_PACKET_DATA + 100];
202 sink.write_all(&big).unwrap();
203 sink.flush().unwrap();
204 drop(sink);
205 writer.write_flush().unwrap();
206
207 let mut r = Reader::new(Cursor::new(&buf));
209 let p1 = r.read_packet().unwrap().unwrap();
210 let p2 = r.read_packet().unwrap().unwrap();
211 let p3 = r.read_packet().unwrap();
212 assert_eq!(p1.len(), MAX_PACKET_DATA);
213 assert_eq!(p2.len(), 100);
214 assert_eq!(p3, None);
215 }
216}