1use bytes::{Buf, Bytes, BytesMut};
2use std::future::poll_fn;
3use std::pin::Pin;
4use tokio::io::{AsyncRead, ReadBuf};
5
6use crate::BufferError;
7
8const SMALL_PREFIX_COPY_MAX: usize = 256;
9const SMALL_PREFIX_COPY_REMAINING_MIN: usize = 4 * 1024;
10
11#[derive(Clone, Copy, Debug)]
12pub struct HandoffBufferConfig {
13 pub max_len: usize,
14 pub read_reserve: usize,
15}
16
17impl HandoffBufferConfig {
18 pub fn new(max_len: usize) -> Self {
19 Self {
20 max_len,
21 read_reserve: 16 * 1024,
22 }
23 }
24
25 pub fn with_read_reserve(mut self, read_reserve: usize) -> Self {
26 self.read_reserve = read_reserve;
27 self
28 }
29}
30
31#[derive(Debug)]
32pub struct HandoffBuffer {
33 buf: BytesMut,
34 config: HandoffBufferConfig,
35}
36
37impl HandoffBuffer {
38 pub fn new(max_len: usize) -> Self {
39 Self::with_config(HandoffBufferConfig::new(max_len))
40 }
41
42 pub fn with_config(config: HandoffBufferConfig) -> Self {
43 Self {
44 buf: BytesMut::new(),
45 config,
46 }
47 }
48
49 pub fn from_tail(tail: BytesMut, config: HandoffBufferConfig) -> Result<Self, BufferError> {
50 if tail.len() > config.max_len {
51 return Err(BufferError::LimitExceeded {
52 attempted: tail.len(),
53 limit: config.max_len,
54 });
55 }
56 Ok(Self { buf: tail, config })
57 }
58
59 pub fn len(&self) -> usize {
60 self.buf.len()
61 }
62
63 pub fn is_empty(&self) -> bool {
64 self.buf.is_empty()
65 }
66
67 pub fn capacity(&self) -> usize {
68 self.buf.capacity()
69 }
70
71 pub fn peek(&self) -> &[u8] {
72 &self.buf
73 }
74
75 pub fn reserve_read_capacity(&mut self, additional: usize) -> Result<(), BufferError> {
76 self.check_limit(additional)?;
77 self.buf.reserve(additional);
78 Ok(())
79 }
80
81 pub async fn read_available<R>(&mut self, reader: &mut R) -> Result<usize, BufferError>
82 where
83 R: AsyncRead + Unpin,
84 {
85 let reserve = self.remaining_capacity().min(self.config.read_reserve);
86 if reserve == 0 {
87 return Err(BufferError::LimitExceeded {
88 attempted: self.buf.len() + 1,
89 limit: self.config.max_len,
90 });
91 }
92 if self.buf.capacity() - self.buf.len() < reserve {
93 self.buf.reserve(reserve);
94 }
95 let len = self.buf.len();
96 let read = poll_fn(|cx| {
97 let spare = &mut self.buf.spare_capacity_mut()[..reserve];
98 let mut read_buf = ReadBuf::uninit(spare);
99 match Pin::new(&mut *reader).poll_read(cx, &mut read_buf) {
100 std::task::Poll::Ready(Ok(())) => {
101 std::task::Poll::Ready(Ok(read_buf.filled().len()))
102 }
103 std::task::Poll::Ready(Err(err)) => std::task::Poll::Ready(Err(err)),
104 std::task::Poll::Pending => std::task::Poll::Pending,
105 }
106 })
107 .await?;
108 unsafe {
111 self.buf.set_len(len + read);
112 }
113 Ok(read)
114 }
115
116 pub fn split_prefix(&mut self, n: usize) -> Result<Bytes, BufferError> {
117 if n > self.buf.len() {
118 return Err(BufferError::SplitOutOfBounds {
119 requested: n,
120 available: self.buf.len(),
121 });
122 }
123 if should_copy_prefix(n, self.buf.len() - n) {
124 let prefix = Bytes::copy_from_slice(&self.buf[..n]);
125 self.buf.advance(n);
126 return Ok(prefix);
127 }
128 Ok(self.buf.split_to(n).freeze())
129 }
130
131 pub fn split_prefix_mut(&mut self, n: usize) -> Result<BytesMut, BufferError> {
132 if n > self.buf.len() {
133 return Err(BufferError::SplitOutOfBounds {
134 requested: n,
135 available: self.buf.len(),
136 });
137 }
138 Ok(self.buf.split_to(n))
139 }
140
141 pub fn freeze_all(&mut self) -> Bytes {
142 self.buf.split().freeze()
143 }
144
145 pub fn take_tail(&mut self) -> BytesMut {
146 self.buf.split()
147 }
148
149 pub fn advance(&mut self, cnt: usize) -> Result<(), BufferError> {
150 if cnt > self.buf.len() {
151 return Err(BufferError::SplitOutOfBounds {
152 requested: cnt,
153 available: self.buf.len(),
154 });
155 }
156 self.buf.advance(cnt);
157 Ok(())
158 }
159
160 fn remaining_capacity(&self) -> usize {
161 self.config.max_len.saturating_sub(self.buf.len())
162 }
163
164 fn check_limit(&self, additional: usize) -> Result<(), BufferError> {
165 let attempted = self.buf.len().saturating_add(additional);
166 if attempted > self.config.max_len {
167 return Err(BufferError::LimitExceeded {
168 attempted,
169 limit: self.config.max_len,
170 });
171 }
172 Ok(())
173 }
174}
175
176fn should_copy_prefix(prefix_len: usize, remaining_len: usize) -> bool {
177 prefix_len <= SMALL_PREFIX_COPY_MAX && remaining_len >= SMALL_PREFIX_COPY_REMAINING_MIN
178}
179
180#[cfg(test)]
181mod tests {
182 use bytes::Bytes;
183 use tokio::io::AsyncWriteExt;
184
185 use super::*;
186
187 #[tokio::test]
188 async fn reads_incrementally_and_preserves_tail() {
189 let (mut client, mut server) = tokio::io::duplex(64);
190 let mut buffer = HandoffBuffer::new(128);
191
192 client
193 .write_all(b"hello\npar")
194 .await
195 .expect("write to duplex");
196 assert_eq!(
197 buffer
198 .read_available(&mut server)
199 .await
200 .expect("read first chunk"),
201 9
202 );
203
204 let newline = buffer
205 .peek()
206 .iter()
207 .position(|b| *b == b'\n')
208 .expect("newline present");
209 let frame = buffer.split_prefix(newline + 1).expect("split frame");
210 assert_eq!(frame, Bytes::from_static(b"hello\n"));
211 assert_eq!(buffer.peek(), b"par");
212
213 client
214 .write_all(b"tial\n")
215 .await
216 .expect("write second chunk");
217 assert_eq!(
218 buffer
219 .read_available(&mut server)
220 .await
221 .expect("read second chunk"),
222 5
223 );
224 assert_eq!(buffer.freeze_all(), Bytes::from_static(b"partial\n"));
225 }
226
227 #[tokio::test]
228 async fn enforces_buffer_limit_before_reading_more() {
229 let (mut client, mut server) = tokio::io::duplex(64);
230 let mut buffer =
231 HandoffBuffer::with_config(HandoffBufferConfig::new(4).with_read_reserve(4));
232
233 client.write_all(b"abcd").await.expect("write within limit");
234 assert_eq!(
235 buffer
236 .read_available(&mut server)
237 .await
238 .expect("read within limit"),
239 4
240 );
241
242 let err = buffer
243 .read_available(&mut server)
244 .await
245 .expect_err("buffer is full");
246 assert!(matches!(
247 err,
248 BufferError::LimitExceeded {
249 attempted: 5,
250 limit: 4
251 }
252 ));
253 }
254
255 #[test]
256 fn take_tail_moves_buffered_state() {
257 let mut buffer = HandoffBuffer::new(64);
258 buffer.buf.extend_from_slice(b"stateful bytes");
259
260 let tail = buffer.take_tail();
261 assert!(buffer.is_empty());
262 assert_eq!(&tail[..], b"stateful bytes");
263
264 let inherited =
265 HandoffBuffer::from_tail(tail, HandoffBufferConfig::new(64)).expect("tail fits");
266 assert_eq!(inherited.peek(), b"stateful bytes");
267 }
268
269 #[test]
270 fn split_prefix_checks_bounds() {
271 let mut buffer = HandoffBuffer::new(64);
272 buffer.buf.extend_from_slice(b"abc");
273
274 let err = buffer.split_prefix(4).expect_err("prefix too large");
275 assert!(matches!(
276 err,
277 BufferError::SplitOutOfBounds {
278 requested: 4,
279 available: 3
280 }
281 ));
282 }
283
284 #[test]
285 fn split_prefix_mut_returns_mutable_bytes_without_freezing() {
286 let mut buffer = HandoffBuffer::new(64);
287 buffer.buf.extend_from_slice(b"abcdef");
288
289 let mut prefix = buffer.split_prefix_mut(3).expect("split prefix");
290 prefix[0] = b'X';
291
292 assert_eq!(&prefix[..], b"Xbc");
293 assert_eq!(buffer.peek(), b"def");
294 }
295
296 #[test]
297 fn split_prefix_copies_small_prefix_before_large_tail() {
298 let mut buffer = HandoffBuffer::new(8 * 1024);
299 buffer.buf.extend_from_slice(b"route\n");
300 buffer.buf.extend_from_slice(&vec![b'x'; 4 * 1024]);
301
302 let prefix = buffer.split_prefix(6).expect("split small prefix");
303
304 assert_eq!(prefix, Bytes::from_static(b"route\n"));
305 assert_eq!(buffer.len(), 4 * 1024);
306 }
307}