1use std::io::{self, Read, Write};
9
10pub const MAX_PACKET_DATA: usize = 65516;
12
13pub struct Reader<R: Read> {
15 inner: R,
16}
17
18impl<R: Read> Reader<R> {
19 pub fn new(inner: R) -> Self {
20 Self { inner }
21 }
22
23 pub fn read_packet(&mut self) -> io::Result<Option<Vec<u8>>> {
27 let mut hdr = [0u8; 4];
28 self.inner.read_exact(&mut hdr)?;
29 let len_str = std::str::from_utf8(&hdr)
30 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "non-ASCII pktline length"))?;
31 let len = u32::from_str_radix(len_str, 16)
32 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid pktline length"))?;
33 if len == 0 {
34 return Ok(None);
35 }
36 if len < 4 {
37 return Err(io::Error::new(
38 io::ErrorKind::InvalidData,
39 format!("pktline length {len} < 4"),
40 ));
41 }
42 let body_len = (len - 4) as usize;
43 let mut buf = vec![0u8; body_len];
44 self.inner.read_exact(&mut buf)?;
45 Ok(Some(buf))
46 }
47
48 pub fn read_text(&mut self) -> io::Result<Option<String>> {
51 let Some(mut bytes) = self.read_packet()? else {
52 return Ok(None);
53 };
54 if bytes.last() == Some(&b'\n') {
55 bytes.pop();
56 }
57 String::from_utf8(bytes)
58 .map(Some)
59 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
60 }
61}
62
63pub struct Writer<W: Write> {
65 inner: W,
66}
67
68impl<W: Write> Writer<W> {
69 pub fn new(inner: W) -> Self {
70 Self { inner }
71 }
72
73 pub fn write_packet(&mut self, data: &[u8]) -> io::Result<()> {
75 if data.len() > MAX_PACKET_DATA {
76 return Err(io::Error::new(
77 io::ErrorKind::InvalidInput,
78 format!("packet of {} bytes exceeds {MAX_PACKET_DATA}", data.len()),
79 ));
80 }
81 let total = data.len() + 4;
82 write!(self.inner, "{total:04x}")?;
83 self.inner.write_all(data)?;
84 Ok(())
85 }
86
87 pub fn write_text(&mut self, text: &str) -> io::Result<()> {
89 let mut buf = String::with_capacity(text.len() + 1);
90 buf.push_str(text);
91 buf.push('\n');
92 self.write_packet(buf.as_bytes())
93 }
94
95 pub fn write_flush(&mut self) -> io::Result<()> {
97 self.inner.write_all(b"0000")
98 }
99
100 pub fn flush(&mut self) -> io::Result<()> {
102 self.inner.flush()
103 }
104}
105
106pub struct Sink<'a, W: Write> {
113 writer: &'a mut Writer<W>,
114 buf: Vec<u8>,
115}
116
117impl<'a, W: Write> Sink<'a, W> {
118 pub fn new(writer: &'a mut Writer<W>) -> Self {
119 Self {
120 writer,
121 buf: Vec::with_capacity(MAX_PACKET_DATA),
122 }
123 }
124}
125
126impl<W: Write> Write for Sink<'_, W> {
127 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
128 let space = MAX_PACKET_DATA - self.buf.len();
129 let n = buf.len().min(space);
130 self.buf.extend_from_slice(&buf[..n]);
131 if self.buf.len() == MAX_PACKET_DATA {
132 self.writer.write_packet(&self.buf)?;
133 self.buf.clear();
134 }
135 Ok(n)
136 }
137
138 fn flush(&mut self) -> io::Result<()> {
139 if !self.buf.is_empty() {
140 self.writer.write_packet(&self.buf)?;
141 self.buf.clear();
142 }
143 Ok(())
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::io::Cursor;
151
152 #[test]
153 fn round_trip_text_packet() {
154 let mut buf = Vec::new();
155 Writer::new(&mut buf).write_text("hello").unwrap();
156 assert_eq!(buf, b"000ahello\n");
158 let mut r = Reader::new(Cursor::new(&buf));
159 assert_eq!(r.read_text().unwrap().as_deref(), Some("hello"));
160 }
161
162 #[test]
163 fn flush_round_trip() {
164 let mut buf = Vec::new();
165 Writer::new(&mut buf).write_flush().unwrap();
166 assert_eq!(buf, b"0000");
167 let mut r = Reader::new(Cursor::new(&buf));
168 assert_eq!(r.read_packet().unwrap(), None);
169 }
170
171 #[test]
172 fn binary_packet_round_trips() {
173 let payload = b"\x00\x01\x02\xffbytes";
174 let mut buf = Vec::new();
175 Writer::new(&mut buf).write_packet(payload).unwrap();
176 let mut r = Reader::new(Cursor::new(&buf));
177 assert_eq!(r.read_packet().unwrap().as_deref(), Some(&payload[..]));
178 }
179
180 #[test]
181 fn rejects_oversized_packet() {
182 let big = vec![0u8; MAX_PACKET_DATA + 1];
183 let err = Writer::new(Vec::new()).write_packet(&big).unwrap_err();
184 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
185 }
186
187 #[test]
188 fn invalid_length_header() {
189 let mut r = Reader::new(Cursor::new(b"zzzz"));
190 let err = r.read_packet().unwrap_err();
191 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
192 }
193
194 #[test]
195 fn sink_chunks_at_packet_boundary() {
196 let mut buf = Vec::new();
197 let mut writer = Writer::new(&mut buf);
198 let mut sink = Sink::new(&mut writer);
199 let big = vec![b'x'; MAX_PACKET_DATA + 100];
201 sink.write_all(&big).unwrap();
202 sink.flush().unwrap();
203 drop(sink);
204 writer.write_flush().unwrap();
205
206 let mut r = Reader::new(Cursor::new(&buf));
208 let p1 = r.read_packet().unwrap().unwrap();
209 let p2 = r.read_packet().unwrap().unwrap();
210 let p3 = r.read_packet().unwrap();
211 assert_eq!(p1.len(), MAX_PACKET_DATA);
212 assert_eq!(p2.len(), 100);
213 assert_eq!(p3, None);
214 }
215}