Skip to main content

bytes_handoff/
read.rs

1use bytes::{Buf, Bytes, BytesMut};
2use std::future::poll_fn;
3use std::pin::Pin;
4use tokio::io::{AsyncRead, ReadBuf};
5
6use crate::BufferError;
7
8const SMALL_PREFIX_COPY_MAX: usize = 256;
9const SMALL_PREFIX_COPY_REMAINING_MIN: usize = 4 * 1024;
10
11#[derive(Clone, Copy, Debug)]
12pub struct HandoffBufferConfig {
13    pub max_len: usize,
14    pub read_reserve: usize,
15}
16
17impl HandoffBufferConfig {
18    pub fn new(max_len: usize) -> Self {
19        Self {
20            max_len,
21            read_reserve: 16 * 1024,
22        }
23    }
24
25    pub fn with_read_reserve(mut self, read_reserve: usize) -> Self {
26        self.read_reserve = read_reserve;
27        self
28    }
29}
30
31#[derive(Debug)]
32pub struct HandoffBuffer {
33    buf: BytesMut,
34    config: HandoffBufferConfig,
35}
36
37impl HandoffBuffer {
38    pub fn new(max_len: usize) -> Self {
39        Self::with_config(HandoffBufferConfig::new(max_len))
40    }
41
42    pub fn with_config(config: HandoffBufferConfig) -> Self {
43        Self {
44            buf: BytesMut::new(),
45            config,
46        }
47    }
48
49    pub fn from_tail(tail: BytesMut, config: HandoffBufferConfig) -> Result<Self, BufferError> {
50        if tail.len() > config.max_len {
51            return Err(BufferError::LimitExceeded {
52                attempted: tail.len(),
53                limit: config.max_len,
54            });
55        }
56        Ok(Self { buf: tail, config })
57    }
58
59    pub fn len(&self) -> usize {
60        self.buf.len()
61    }
62
63    pub fn is_empty(&self) -> bool {
64        self.buf.is_empty()
65    }
66
67    pub fn capacity(&self) -> usize {
68        self.buf.capacity()
69    }
70
71    pub fn peek(&self) -> &[u8] {
72        &self.buf
73    }
74
75    pub fn reserve_read_capacity(&mut self, additional: usize) -> Result<(), BufferError> {
76        self.check_limit(additional)?;
77        self.buf.reserve(additional);
78        Ok(())
79    }
80
81    pub async fn read_available<R>(&mut self, reader: &mut R) -> Result<usize, BufferError>
82    where
83        R: AsyncRead + Unpin,
84    {
85        let reserve = self.remaining_capacity().min(self.config.read_reserve);
86        if reserve == 0 {
87            return Err(BufferError::LimitExceeded {
88                attempted: self.buf.len() + 1,
89                limit: self.config.max_len,
90            });
91        }
92        if self.buf.capacity() - self.buf.len() < reserve {
93            self.buf.reserve(reserve);
94        }
95        let len = self.buf.len();
96        let read = poll_fn(|cx| {
97            let spare = &mut self.buf.spare_capacity_mut()[..reserve];
98            let mut read_buf = ReadBuf::uninit(spare);
99            match Pin::new(&mut *reader).poll_read(cx, &mut read_buf) {
100                std::task::Poll::Ready(Ok(())) => {
101                    std::task::Poll::Ready(Ok(read_buf.filled().len()))
102                }
103                std::task::Poll::Ready(Err(err)) => std::task::Poll::Ready(Err(err)),
104                std::task::Poll::Pending => std::task::Poll::Pending,
105            }
106        })
107        .await?;
108        // SAFETY: `poll_read` initialized exactly `read` bytes in the spare
109        // capacity exposed through `ReadBuf`.
110        unsafe {
111            self.buf.set_len(len + read);
112        }
113        Ok(read)
114    }
115
116    pub fn split_prefix(&mut self, n: usize) -> Result<Bytes, BufferError> {
117        if n > self.buf.len() {
118            return Err(BufferError::SplitOutOfBounds {
119                requested: n,
120                available: self.buf.len(),
121            });
122        }
123        if should_copy_prefix(n, self.buf.len() - n) {
124            let prefix = Bytes::copy_from_slice(&self.buf[..n]);
125            self.buf.advance(n);
126            return Ok(prefix);
127        }
128        Ok(self.buf.split_to(n).freeze())
129    }
130
131    pub fn split_prefix_mut(&mut self, n: usize) -> Result<BytesMut, BufferError> {
132        if n > self.buf.len() {
133            return Err(BufferError::SplitOutOfBounds {
134                requested: n,
135                available: self.buf.len(),
136            });
137        }
138        Ok(self.buf.split_to(n))
139    }
140
141    pub fn freeze_all(&mut self) -> Bytes {
142        self.buf.split().freeze()
143    }
144
145    pub fn take_tail(&mut self) -> BytesMut {
146        self.buf.split()
147    }
148
149    pub fn advance(&mut self, cnt: usize) -> Result<(), BufferError> {
150        if cnt > self.buf.len() {
151            return Err(BufferError::SplitOutOfBounds {
152                requested: cnt,
153                available: self.buf.len(),
154            });
155        }
156        self.buf.advance(cnt);
157        Ok(())
158    }
159
160    fn remaining_capacity(&self) -> usize {
161        self.config.max_len.saturating_sub(self.buf.len())
162    }
163
164    fn check_limit(&self, additional: usize) -> Result<(), BufferError> {
165        let attempted = self.buf.len().saturating_add(additional);
166        if attempted > self.config.max_len {
167            return Err(BufferError::LimitExceeded {
168                attempted,
169                limit: self.config.max_len,
170            });
171        }
172        Ok(())
173    }
174}
175
176fn should_copy_prefix(prefix_len: usize, remaining_len: usize) -> bool {
177    prefix_len <= SMALL_PREFIX_COPY_MAX && remaining_len >= SMALL_PREFIX_COPY_REMAINING_MIN
178}
179
180#[cfg(test)]
181mod tests {
182    use bytes::Bytes;
183    use tokio::io::AsyncWriteExt;
184
185    use super::*;
186
187    #[tokio::test]
188    async fn reads_incrementally_and_preserves_tail() {
189        let (mut client, mut server) = tokio::io::duplex(64);
190        let mut buffer = HandoffBuffer::new(128);
191
192        client
193            .write_all(b"hello\npar")
194            .await
195            .expect("write to duplex");
196        assert_eq!(
197            buffer
198                .read_available(&mut server)
199                .await
200                .expect("read first chunk"),
201            9
202        );
203
204        let newline = buffer
205            .peek()
206            .iter()
207            .position(|b| *b == b'\n')
208            .expect("newline present");
209        let frame = buffer.split_prefix(newline + 1).expect("split frame");
210        assert_eq!(frame, Bytes::from_static(b"hello\n"));
211        assert_eq!(buffer.peek(), b"par");
212
213        client
214            .write_all(b"tial\n")
215            .await
216            .expect("write second chunk");
217        assert_eq!(
218            buffer
219                .read_available(&mut server)
220                .await
221                .expect("read second chunk"),
222            5
223        );
224        assert_eq!(buffer.freeze_all(), Bytes::from_static(b"partial\n"));
225    }
226
227    #[tokio::test]
228    async fn enforces_buffer_limit_before_reading_more() {
229        let (mut client, mut server) = tokio::io::duplex(64);
230        let mut buffer =
231            HandoffBuffer::with_config(HandoffBufferConfig::new(4).with_read_reserve(4));
232
233        client.write_all(b"abcd").await.expect("write within limit");
234        assert_eq!(
235            buffer
236                .read_available(&mut server)
237                .await
238                .expect("read within limit"),
239            4
240        );
241
242        let err = buffer
243            .read_available(&mut server)
244            .await
245            .expect_err("buffer is full");
246        assert!(matches!(
247            err,
248            BufferError::LimitExceeded {
249                attempted: 5,
250                limit: 4
251            }
252        ));
253    }
254
255    #[test]
256    fn take_tail_moves_buffered_state() {
257        let mut buffer = HandoffBuffer::new(64);
258        buffer.buf.extend_from_slice(b"stateful bytes");
259
260        let tail = buffer.take_tail();
261        assert!(buffer.is_empty());
262        assert_eq!(&tail[..], b"stateful bytes");
263
264        let inherited =
265            HandoffBuffer::from_tail(tail, HandoffBufferConfig::new(64)).expect("tail fits");
266        assert_eq!(inherited.peek(), b"stateful bytes");
267    }
268
269    #[test]
270    fn split_prefix_checks_bounds() {
271        let mut buffer = HandoffBuffer::new(64);
272        buffer.buf.extend_from_slice(b"abc");
273
274        let err = buffer.split_prefix(4).expect_err("prefix too large");
275        assert!(matches!(
276            err,
277            BufferError::SplitOutOfBounds {
278                requested: 4,
279                available: 3
280            }
281        ));
282    }
283
284    #[test]
285    fn split_prefix_mut_returns_mutable_bytes_without_freezing() {
286        let mut buffer = HandoffBuffer::new(64);
287        buffer.buf.extend_from_slice(b"abcdef");
288
289        let mut prefix = buffer.split_prefix_mut(3).expect("split prefix");
290        prefix[0] = b'X';
291
292        assert_eq!(&prefix[..], b"Xbc");
293        assert_eq!(buffer.peek(), b"def");
294    }
295
296    #[test]
297    fn split_prefix_copies_small_prefix_before_large_tail() {
298        let mut buffer = HandoffBuffer::new(8 * 1024);
299        buffer.buf.extend_from_slice(b"route\n");
300        buffer.buf.extend_from_slice(&vec![b'x'; 4 * 1024]);
301
302        let prefix = buffer.split_prefix(6).expect("split small prefix");
303
304        assert_eq!(prefix, Bytes::from_static(b"route\n"));
305        assert_eq!(buffer.len(), 4 * 1024);
306    }
307}