polymock/
bytes.rs

1use core::fmt::{self, Debug, Formatter};
2use core::ops::Deref;
3use core::ptr::NonNull;
4use core::slice;
5use core::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
6
7use alloc::boxed::Box;
8use alloc::vec::Vec;
9
10use crate::arena::{ChunkInner, ChunkRef};
11
12/// A cheaply cloneable view into a slice of memory allocated by an [`Arena`].
13///
14/// `Bytes` is a mostly drop-in replacement for [`Bytes`] from the [`bytes`] crate and behaves the
15/// same internally.
16///
17/// `Bytes` is expected to be deprecated once the `bytes` crate allows creating a [`Bytes`]
18/// values using a custom vtable.
19///
20/// [`Bytes`]: https://docs.rs/bytes/latest/bytes/struct.Bytes.html
21/// [`bytes`]: https://docs.rs/bytes/latest/bytes/index.html
22/// [`Arena`]: crate::Arena
23pub struct Bytes {
24    pub(crate) ptr: *const u8,
25    pub(crate) len: usize,
26    data: AtomicPtr<()>,
27    vtable: &'static Vtable,
28}
29
30impl Bytes {
31    pub fn copy_from_slice(data: &[u8]) -> Self {
32        data.to_vec().into()
33    }
34
35    #[inline]
36    pub fn truncate(&mut self, len: usize) {
37        if self.len > len {
38            self.len = len;
39        }
40    }
41
42    /// Forces the length of the `Bytes` to `len`.
43    ///
44    /// Note that the length of `Bytes` is the same as the length of [`BytesMut`].
45    ///
46    /// # Safety
47    ///
48    /// The length must not exceed the original length that the `Bytes` was originally allocated
49    /// with.
50    ///
51    /// [`BytesMut`]: crate::BytesMut
52    #[inline]
53    pub unsafe fn set_len(&mut self, len: usize) {
54        self.len = len;
55    }
56
57    #[inline]
58    pub(crate) unsafe fn from_raw_parts(chunk: ChunkRef, ptr: NonNull<u8>, len: usize) -> Self {
59        let data = AtomicPtr::new(chunk.into_raw() as *mut ());
60
61        Bytes {
62            ptr: ptr.as_ptr(),
63            len,
64            data,
65            vtable: &CHUNK_VTABLE,
66        }
67    }
68
69    #[inline]
70    fn as_slice(&self) -> &[u8] {
71        unsafe { slice::from_raw_parts(self.ptr, self.len) }
72    }
73}
74
75impl Clone for Bytes {
76    #[inline]
77    fn clone(&self) -> Self {
78        unsafe { (self.vtable.clone)(&self.data, self.ptr, self.len) }
79    }
80}
81
82impl Drop for Bytes {
83    #[inline]
84    fn drop(&mut self) {
85        unsafe { (self.vtable.drop)(&self.data, self.ptr, self.len) };
86    }
87}
88
89impl From<Bytes> for Vec<u8> {
90    #[inline]
91    fn from(value: Bytes) -> Self {
92        unsafe { (value.vtable.to_vec)(&value.data, value.ptr, value.len) }
93    }
94}
95
96impl Deref for Bytes {
97    type Target = [u8];
98
99    #[inline]
100    fn deref(&self) -> &Self::Target {
101        self.as_slice()
102    }
103}
104
105impl AsRef<[u8]> for Bytes {
106    #[inline]
107    fn as_ref(&self) -> &[u8] {
108        self.as_slice()
109    }
110}
111
112impl Debug for Bytes {
113    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
114        f.debug_struct("Bytes")
115            .field("buffer", &self.as_slice())
116            .finish()
117    }
118}
119
120unsafe impl Send for Bytes {}
121unsafe impl Sync for Bytes {}
122
123impl PartialEq for Bytes {
124    #[inline]
125    fn eq(&self, other: &Self) -> bool {
126        // Both buffers are equal when they point to the same memory.
127        if self.ptr == other.ptr && self.len == other.len {
128            true
129        } else {
130            self.as_slice() == other.as_slice()
131        }
132    }
133}
134
135impl PartialEq<[u8]> for Bytes {
136    #[inline]
137    fn eq(&self, other: &[u8]) -> bool {
138        self.as_slice() == other
139    }
140}
141
142impl Eq for Bytes {}
143
144impl PartialOrd for Bytes {
145    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
146        Some(self.cmp(other))
147    }
148}
149
150impl Ord for Bytes {
151    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
152        self.as_slice().cmp(other.as_slice())
153    }
154}
155
156impl PartialOrd<[u8]> for Bytes {
157    fn partial_cmp(&self, other: &[u8]) -> Option<core::cmp::Ordering> {
158        Some(self.as_slice().cmp(other))
159    }
160}
161
162pub(crate) struct Vtable {
163    /// fn(data, ptr, len)
164    clone: unsafe fn(&AtomicPtr<()>, *const u8, usize) -> Bytes,
165    /// fn(data, ptr, len)
166    to_vec: unsafe fn(&AtomicPtr<()>, *const u8, usize) -> Vec<u8>,
167    /// fn(data, ptr, len)
168    drop: unsafe fn(&AtomicPtr<()>, *const u8, usize),
169}
170
171// === impl ARC_SLICE ===
172
173struct ArcSlice {
174    ptr: *mut [u8],
175    ref_count: AtomicUsize,
176}
177
178impl Drop for ArcSlice {
179    fn drop(&mut self) {
180        unsafe {
181            drop(Box::from_raw(self.ptr));
182        }
183    }
184}
185
186static ARC_VTABLE: Vtable = Vtable {
187    clone: arc_clone,
188    to_vec: arc_to_vec,
189    drop: arc_drop,
190};
191
192fn arc_clone(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Bytes {
193    let inner = data.load(Ordering::Relaxed);
194
195    let slice = unsafe { &*(inner as *mut ArcSlice) };
196
197    if slice.ref_count.fetch_add(1, Ordering::Relaxed) > usize::MAX >> 1 {
198        crate::abort();
199    }
200
201    Bytes {
202        ptr,
203        len,
204        data: AtomicPtr::new(inner),
205        vtable: &ARC_VTABLE,
206    }
207}
208
209fn arc_to_vec(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Vec<u8> {
210    // FIXME: The last Bytes can be optimized by taking ownership of the buffer.
211    let vec = unsafe { core::slice::from_raw_parts(ptr, len).to_vec() };
212    arc_drop(data, ptr, len);
213    vec
214}
215
216fn arc_drop(data: &AtomicPtr<()>, _ptr: *const u8, _len: usize) {
217    let inner = data.load(Ordering::Relaxed) as *mut ArcSlice;
218    let slice = unsafe { &*inner };
219
220    let rc = slice.ref_count.fetch_sub(1, Ordering::Release);
221    if rc != 1 {
222        return;
223    }
224
225    slice.ref_count.load(Ordering::Acquire);
226
227    unsafe {
228        drop(Box::from_raw(inner));
229    }
230}
231
232impl From<Box<[u8]>> for Bytes {
233    fn from(value: Box<[u8]>) -> Self {
234        let slice_len = value.len();
235        let slice_ptr = value.as_ptr();
236        let ptr = Box::into_raw(value);
237
238        let data = Box::into_raw(Box::new(ArcSlice {
239            ptr,
240            ref_count: AtomicUsize::new(1),
241        }));
242
243        Bytes {
244            ptr: slice_ptr,
245            len: slice_len,
246            data: AtomicPtr::new(data as *mut ()),
247            vtable: &ARC_VTABLE,
248        }
249    }
250}
251
252impl From<Vec<u8>> for Bytes {
253    fn from(value: Vec<u8>) -> Self {
254        value.into_boxed_slice().into()
255    }
256}
257
258// === impl CHUNK ===
259
260static CHUNK_VTABLE: Vtable = Vtable {
261    clone: chunk_clone,
262    to_vec: chunk_to_vec,
263    drop: chunk_drop,
264};
265
266fn chunk_clone(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Bytes {
267    let data_ptr = data.load(Ordering::Relaxed);
268    let chunk = unsafe { &*(data_ptr as *mut ChunkInner) };
269
270    chunk.increment_reference_count();
271
272    Bytes {
273        ptr,
274        len,
275        data: AtomicPtr::new(data_ptr),
276        vtable: &CHUNK_VTABLE,
277    }
278}
279
280fn chunk_to_vec(data: &AtomicPtr<()>, ptr: *const u8, len: usize) -> Vec<u8> {
281    // Note that the whole buffer always needs to be copied. The chunk is still
282    // owned by the arena.
283    let buf = unsafe { core::slice::from_raw_parts(ptr, len).to_vec() };
284    chunk_drop(data, ptr, len);
285    buf
286}
287
288fn chunk_drop(data: &AtomicPtr<()>, _ptr: *const u8, _len: usize) {
289    let chunk = data.load(Ordering::Relaxed) as *mut ChunkInner;
290    let chunk_ref = unsafe { &*chunk };
291
292    let old_rc = chunk_ref.ref_count.fetch_sub(1, Ordering::Release);
293    if old_rc != 1 {
294        return;
295    }
296
297    chunk_ref.ref_count.load(Ordering::Acquire);
298
299    // Take ownership of the chunk.
300    // SAFETY: The chunk was leaked once first created. A call to `chunk_drop` means
301    // that the last reference of this `Bytes` was dropped.
302    unsafe {
303        drop(ChunkRef::from_ptr(chunk));
304    }
305}
306
307#[cfg(all(not(loom), test))]
308mod tests {
309    use std::sync::atomic::Ordering;
310
311    use super::Bytes;
312    use crate::arena::ChunkRef;
313
314    #[test]
315    fn test_bytes() {
316        let chunk = ChunkRef::new(1000);
317        let ptr = chunk.alloc(100).unwrap();
318
319        assert_eq!(chunk.ref_count.load(Ordering::Relaxed), 1);
320
321        let bytes = unsafe { Bytes::from_raw_parts(chunk.clone(), ptr, 100) };
322
323        // assert_eq!(bytes.chunk, chunk);
324        assert_eq!(bytes.len, 100);
325
326        assert_eq!(chunk.ref_count.load(Ordering::Relaxed), 2);
327
328        let bytes2 = bytes.clone();
329        assert_eq!(bytes.ptr, bytes2.ptr);
330        assert_eq!(bytes.len, bytes2.len);
331
332        assert_eq!(chunk.ref_count.load(Ordering::Relaxed), 3);
333
334        drop(bytes);
335        assert_eq!(chunk.ref_count.load(Ordering::Relaxed), 2);
336
337        drop(bytes2);
338        assert_eq!(chunk.ref_count.load(Ordering::Relaxed), 1);
339    }
340}
341
342#[cfg(all(test, loom))]
343mod loom_tests {
344    use std::vec::Vec;
345
346    use loom::sync::atomic::Ordering;
347    use loom::thread;
348
349    use super::Bytes;
350    use crate::arena::ChunkRef;
351
352    const THREADS: usize = 2;
353    const ITERATIONS: usize = 5;
354
355    #[test]
356    fn test_bytes() {
357        loom::model(|| {
358            let chunk = ChunkRef::new(1000);
359            let ptr = chunk.alloc(100).unwrap();
360            let bytes = unsafe { Bytes::from_raw_parts(chunk.clone(), ptr, 100) };
361
362            let threads: Vec<_> = (0..THREADS)
363                .map(|_| {
364                    let bytes = bytes.clone();
365
366                    thread::spawn(move || {
367                        let mut bufs = Vec::with_capacity(ITERATIONS);
368
369                        for _ in 0..ITERATIONS {
370                            bufs.push(bytes.clone());
371                        }
372
373                        for buf in bufs.into_iter() {
374                            drop(buf);
375                        }
376                    })
377                })
378                .collect();
379
380            for th in threads {
381                th.join().unwrap();
382            }
383
384            assert_eq!(chunk.ref_count.load(Ordering::Relaxed), 2);
385        });
386    }
387}