linkedbytes/
lib.rs

1//! [`LinkedBytes`] is a linked list of [`Bytes`] and [`BytesMut`] (though we use VecDeque to
2//! implement it now).
3//!
4//! It is primarily used to manage [`Bytes`] and [`BytesMut`] and make a [`&[IoSlice<'_>]`]
5//! to be used by `writev`.
6use std::{collections::VecDeque, io::IoSlice};
7
8use bytes::{BufMut, Bytes, BytesMut};
9use faststr::FastStr;
10use tokio::io::{AsyncWrite, AsyncWriteExt};
11
12const DEFAULT_BUFFER_SIZE: usize = 8192; // 8KB
13const DEFAULT_DEQUE_SIZE: usize = 16;
14
15pub struct LinkedBytes {
16    // This is used to avoid allocating a new Vec when calling `as_ioslice`.
17    // It is self-referential in fact, but we can guarantee that it is safe,
18    // so we just use `'static` here.
19    // [`ioslice`] must be the first field, so that it is dropped before [`list`]
20    // and [`bytes`] to keep soundness.
21    ioslice: Vec<IoSlice<'static>>,
22
23    bytes: BytesMut,
24    list: VecDeque<Node>,
25}
26
27pub enum Node {
28    Bytes(Bytes),
29    BytesMut(BytesMut),
30    FastStr(FastStr),
31}
32
33impl AsRef<[u8]> for Node {
34    #[inline]
35    fn as_ref(&self) -> &[u8] {
36        match self {
37            Node::Bytes(b) => b.as_ref(),
38            Node::BytesMut(b) => b.as_ref(),
39            Node::FastStr(s) => s.as_ref(),
40        }
41    }
42}
43
44impl LinkedBytes {
45    #[inline]
46    pub fn new() -> Self {
47        Self::with_capacity(DEFAULT_BUFFER_SIZE)
48    }
49
50    #[inline]
51    pub fn with_capacity(cap: usize) -> Self {
52        let bytes = BytesMut::with_capacity(cap);
53        let list = VecDeque::with_capacity(DEFAULT_DEQUE_SIZE);
54        Self {
55            list,
56            bytes,
57            ioslice: Vec::with_capacity(DEFAULT_DEQUE_SIZE),
58        }
59    }
60
61    #[inline]
62    pub const fn bytes(&self) -> &BytesMut {
63        &self.bytes
64    }
65
66    #[inline]
67    pub const fn bytes_mut(&mut self) -> &mut BytesMut {
68        &mut self.bytes
69    }
70
71    #[inline]
72    pub fn reserve(&mut self, additional: usize) {
73        self.bytes.reserve(additional);
74    }
75
76    pub fn len(&self) -> usize {
77        let mut len = 0;
78        for node in self.list.iter() {
79            len += node.as_ref().len();
80        }
81        len + self.bytes.len()
82    }
83
84    pub fn is_empty(&self) -> bool {
85        self.len() == 0
86    }
87
88    pub fn insert(&mut self, bytes: Bytes) {
89        let node = Node::Bytes(bytes);
90        // split current bytes
91        let prev = self.bytes.split();
92
93        self.list.push_back(Node::BytesMut(prev));
94        self.list.push_back(node);
95    }
96
97    pub fn insert_faststr(&mut self, fast_str: FastStr) {
98        let node = Node::FastStr(fast_str);
99        // split current bytes
100        let prev = self.bytes.split();
101
102        self.list.push_back(Node::BytesMut(prev));
103        self.list.push_back(node);
104    }
105
106    pub fn io_slice(&self) -> Vec<IoSlice<'_>> {
107        let mut ioslice = Vec::with_capacity(self.list.len() + 1);
108        for node in self.list.iter() {
109            let bytes = node.as_ref();
110            if bytes.is_empty() {
111                continue;
112            }
113            ioslice.push(IoSlice::new(bytes));
114        }
115        if !self.bytes.is_empty() {
116            ioslice.push(IoSlice::new(self.bytes.as_ref()));
117        }
118        ioslice
119    }
120
121    // TODO: use write_all_vectored when stable
122    pub async fn write_all_vectored<W: AsyncWrite + Unpin>(
123        &mut self,
124        writer: &mut W,
125    ) -> std::io::Result<()> {
126        self.ioslice.clear();
127        self.ioslice.reserve(self.list.len() + 1);
128        // prepare ioslice
129        for node in self.list.iter() {
130            let bytes = node.as_ref();
131            if bytes.is_empty() {
132                continue;
133            }
134            // SAFETY: we can guarantee that the lifetime of `bytes` can't outlive self
135            self.ioslice
136                .push(IoSlice::new(unsafe { &*(bytes as *const _) }));
137        }
138        if !self.bytes.is_empty() {
139            self.ioslice
140                .push(IoSlice::new(unsafe { &*(self.bytes.as_ref() as *const _) }));
141        }
142
143        // do write_all_vectored
144        // we use usize here to avoid `Send` bound required for *mut IoSlice
145        let (mut base_ptr, mut len) = (self.ioslice.as_mut_ptr() as usize, self.ioslice.len());
146        while len != 0 {
147            let ioslice = unsafe { std::slice::from_raw_parts(base_ptr as *mut IoSlice, len) };
148            let n = writer.write_vectored(ioslice).await?;
149            if n == 0 {
150                return Err(std::io::ErrorKind::WriteZero.into());
151            }
152            // Number of buffers to remove.
153            let mut remove = 0;
154            // Total length of all the to be removed buffers.
155            let mut accumulated_len = 0;
156            for buf in ioslice.iter() {
157                if accumulated_len + buf.len() > n {
158                    break;
159                } else {
160                    accumulated_len += buf.len();
161                    remove += 1;
162                }
163            }
164
165            // adjust the outer [IoSlice]
166            base_ptr = unsafe { (base_ptr as *mut IoSlice).add(remove) as usize };
167            len -= remove;
168            if len == 0 {
169                assert!(
170                    n == accumulated_len,
171                    "advancing io slices beyond their length"
172                );
173            } else {
174                // adjust the inner IoSlice
175                let inner_slice = unsafe { &mut *(base_ptr as *mut IoSlice) };
176                let (inner_ptr, inner_len) = (inner_slice.as_ptr(), inner_slice.len());
177                let remaining = n - accumulated_len;
178                assert!(
179                    remaining <= inner_len,
180                    "advancing io slice beyond its length"
181                );
182                let new_ptr = unsafe { inner_ptr.add(remaining) };
183                let new_len = inner_len - remaining;
184                *inner_slice =
185                    IoSlice::new(unsafe { std::slice::from_raw_parts(new_ptr, new_len) });
186            }
187        }
188        self.ioslice.clear();
189        Ok(())
190    }
191
192    // TODO: use write_all_vectored when stable
193    pub fn sync_write_all_vectored<W: std::io::Write>(
194        &mut self,
195        writer: &mut W,
196    ) -> std::io::Result<()> {
197        self.ioslice.clear();
198        self.ioslice.reserve(self.list.len() + 1);
199        // prepare ioslice
200        for node in self.list.iter() {
201            let bytes = node.as_ref();
202            if bytes.is_empty() {
203                continue;
204            }
205            // SAFETY: we can guarantee that the lifetime of `bytes` can't outlive self
206            self.ioslice
207                .push(IoSlice::new(unsafe { &*(bytes as *const _) }));
208        }
209        if !self.bytes.is_empty() {
210            self.ioslice
211                .push(IoSlice::new(unsafe { &*(self.bytes.as_ref() as *const _) }));
212        }
213
214        // do write_all_vectored
215        let (mut base_ptr, mut len) = (self.ioslice.as_mut_ptr(), self.ioslice.len());
216        while len != 0 {
217            let ioslice = unsafe { std::slice::from_raw_parts(base_ptr, len) };
218            let n = writer.write_vectored(ioslice)?;
219            if n == 0 {
220                return Err(std::io::ErrorKind::WriteZero.into());
221            }
222            // Number of buffers to remove.
223            let mut remove = 0;
224            // Total length of all the to be removed buffers.
225            let mut accumulated_len = 0;
226            for buf in ioslice.iter() {
227                if accumulated_len + buf.len() > n {
228                    break;
229                } else {
230                    accumulated_len += buf.len();
231                    remove += 1;
232                }
233            }
234
235            // adjust the outer [IoSlice]
236            base_ptr = unsafe { base_ptr.add(remove) };
237            len -= remove;
238            if len == 0 {
239                assert!(
240                    n == accumulated_len,
241                    "advancing io slices beyond their length"
242                );
243            } else {
244                // adjust the inner IoSlice
245                let inner_slice = unsafe { &mut *base_ptr };
246                let (inner_ptr, inner_len) = (inner_slice.as_ptr(), inner_slice.len());
247                let remaining = n - accumulated_len;
248                assert!(
249                    remaining <= inner_len,
250                    "advancing io slice beyond its length"
251                );
252                let new_ptr = unsafe { inner_ptr.add(remaining) };
253                let new_len = inner_len - remaining;
254                *inner_slice =
255                    IoSlice::new(unsafe { std::slice::from_raw_parts(new_ptr, new_len) });
256            }
257        }
258        self.ioslice.clear();
259        Ok(())
260    }
261
262    pub fn reset(&mut self) {
263        // ioslice must be cleared before list
264        self.ioslice.clear();
265
266        if self.list.is_empty() {
267            // only clear bytes
268            self.bytes.clear();
269            return;
270        }
271
272        let Node::BytesMut(mut head) = self.list.pop_front().unwrap() else {
273            // this should not happen
274            panic!("head is not BytesMut");
275        };
276
277        while let Some(node) = self.list.pop_front() {
278            if let Node::BytesMut(next_buf) = node {
279                head.unsplit(next_buf);
280            }
281        }
282
283        // don't forget to unsplit self.bytes
284        // here we need to do this in a tricky way, because we can't move self.bytes
285        unsafe {
286            self.bytes.set_len(self.bytes.capacity());
287        }
288        let remaining = self.bytes.split();
289        head.unsplit(remaining);
290        self.bytes = head;
291
292        self.bytes.clear();
293    }
294}
295
296// Unstable APIs
297impl LinkedBytes {
298    /// This splits the current bytes_mut and push it to the list.
299    /// This is an unstable API that may change in the future, don't rely on this.
300    /// Returns the index of the node.
301    #[doc(hidden)]
302    #[inline]
303    pub fn split(&mut self) -> usize {
304        let prev = self.bytes.split();
305        let node = Node::BytesMut(prev);
306        self.list.push_back(node);
307        self.list.len() - 1
308    }
309
310    /// This gets the node at the given index.
311    /// If you want to get the current bytes_mut, use `bytes_mut()` instead.
312    /// This is an unstable API that may change in the future, don't rely on this.
313    #[doc(hidden)]
314    #[inline]
315    pub fn get_list_mut(&mut self, index: usize) -> Option<&mut Node> {
316        self.list.get_mut(index)
317    }
318
319    /// This gets the iterator of the list.
320    /// This is an unstable API that may change in the future, don't rely on this.
321    #[doc(hidden)]
322    #[inline]
323    pub fn iter_list(&self) -> impl Iterator<Item = &Node> {
324        self.list.iter()
325    }
326
327    /// This converts the list to an iterator.
328    /// This is an unstable API that may change in the future, don't rely on this.
329    #[doc(hidden)]
330    #[inline]
331    pub fn into_iter_list(mut self) -> impl Iterator<Item = Node> {
332        let node = Node::BytesMut(self.bytes);
333        self.list.push_back(node);
334        self.list.into_iter()
335    }
336}
337
338impl Default for LinkedBytes {
339    #[inline]
340    fn default() -> Self {
341        Self::new()
342    }
343}
344
345unsafe impl BufMut for LinkedBytes {
346    #[inline]
347    fn remaining_mut(&self) -> usize {
348        self.bytes.remaining_mut()
349    }
350
351    #[inline]
352    unsafe fn advance_mut(&mut self, cnt: usize) {
353        self.bytes.advance_mut(cnt)
354    }
355
356    #[inline]
357    fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
358        self.bytes.chunk_mut()
359    }
360}