libdd_tinybytes/
lib.rs

1// Copyright 2024-Present Datadog, Inc. https://www.datadoghq.com/
2// SPDX-License-Identifier: Apache-2.0
3
4#![cfg_attr(not(test), deny(clippy::panic))]
5#![cfg_attr(not(test), deny(clippy::unwrap_used))]
6#![cfg_attr(not(test), deny(clippy::expect_used))]
7#![cfg_attr(not(test), deny(clippy::todo))]
8#![cfg_attr(not(test), deny(clippy::unimplemented))]
9
10use std::{
11    borrow, cmp, fmt, hash,
12    ops::{self, RangeBounds},
13    ptr::NonNull,
14    sync::atomic::AtomicUsize,
15};
16
17#[cfg(feature = "serde")]
18use serde::Serialize;
19
20/// Immutable bytes type with zero copy cloning and slicing.
21#[derive(Clone)]
22pub struct Bytes {
23    ptr: NonNull<u8>,
24    len: usize,
25    // The `bytes`` field is used to ensure that the underlying bytes are freed when there are no
26    // more references to the `Bytes` object. For static buffers the field is `None`.
27    bytes: Option<RefCountedCell>,
28}
29
30/// The underlying bytes that the `Bytes` object references.
31pub trait UnderlyingBytes: AsRef<[u8]> + Send + Sync + 'static {}
32
33/// Since the Bytes type is immutable, and UnderlyingBytes is `Send + Sync``, it is safe to share
34/// `Bytes` across threads.
35unsafe impl Send for Bytes {}
36unsafe impl Sync for Bytes {}
37
38impl Bytes {
39    #[inline]
40    /// Creates a new `Bytes` from the given slice data and the refcount
41    ///
42    /// # Safety
43    ///
44    /// * the pointer should be valid for the given length
45    /// * the pointer should be valid for reads as long as the refcount or any of it's clone is not
46    ///   dropped
47    pub const unsafe fn from_raw_refcount(
48        ptr: NonNull<u8>,
49        len: usize,
50        refcount: RefCountedCell,
51    ) -> Self {
52        Self {
53            ptr,
54            len,
55            bytes: Some(refcount),
56        }
57    }
58
59    /// Creates empty `Bytes`.
60    #[inline]
61    pub const fn empty() -> Self {
62        Self::from_static(b"")
63    }
64
65    /// Creates `Bytes` from a static slice.
66    #[inline]
67    pub const fn from_static(value: &'static [u8]) -> Self {
68        Self {
69            // SAFETY: static slice always have a valid pointer and length
70            ptr: unsafe { NonNull::new_unchecked(value.as_ptr().cast_mut()) },
71            len: value.len(),
72            bytes: None,
73        }
74    }
75
76    /// Creates `Bytes` from a slice, by copying.
77    pub fn copy_from_slice(data: &[u8]) -> Self {
78        Self::from_underlying(data.to_vec())
79    }
80
81    /// Returns the length of the `Bytes`.
82    #[inline]
83    pub const fn len(&self) -> usize {
84        self.len
85    }
86
87    /// Returns `true` if the `Bytes` is empty.
88    #[inline]
89    pub const fn is_empty(&self) -> bool {
90        self.len == 0
91    }
92
93    /// Returns a slice of self for the provided range.
94    ///
95    /// This will return a new `Bytes` handle set to the slice, and will not copy the underlying
96    /// data.
97    ///
98    /// This operation is `O(1)`.
99    ///
100    /// # Panics
101    ///
102    /// Slicing will panic if the range does not conform to  `start <= end` and `end <= self.len()`.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// use libdd_tinybytes::Bytes;
108    ///
109    /// let bytes = Bytes::copy_from_slice(b"hello world");
110    /// let slice = bytes.slice(0..5);
111    /// assert_eq!(slice.as_ref(), b"hello");
112    ///
113    /// let slice = bytes.slice(6..11);
114    /// assert_eq!(slice.as_ref(), b"world");
115    /// ```
116    pub fn slice(&self, range: impl RangeBounds<usize>) -> Self {
117        use std::ops::Bound;
118
119        let len = self.len();
120
121        #[allow(clippy::expect_used)]
122        let start = match range.start_bound() {
123            Bound::Included(&n) => n,
124            Bound::Excluded(&n) => n.checked_add(1).expect("range start overflow"),
125            Bound::Unbounded => 0,
126        };
127
128        #[allow(clippy::expect_used)]
129        let end = match range.end_bound() {
130            Bound::Included(&n) => n.checked_add(1).expect("range end overflow"),
131            Bound::Excluded(&n) => n,
132            Bound::Unbounded => len,
133        };
134
135        assert!(
136            start <= end,
137            "range start must not be greater than end: {start:?} > {end:?}"
138        );
139        assert!(
140            end <= len,
141            "range end must not be greater than length: {end:?} > {len:?}"
142        );
143
144        if end == start {
145            Bytes::empty()
146        } else {
147            self.safe_slice_ref(start, end)
148        }
149    }
150
151    /// Returns a slice of self that is equivalent to the given `subset`, if it is a subset.
152    ///
153    /// When processing a `Bytes` buffer with other tools, one often gets a
154    /// `&[u8]` which is in fact a slice of the `Bytes`, i.e. a subset of it.
155    /// This function turns that `&[u8]` into another `Bytes`, as if one had
156    /// called `self.slice()` with the range that corresponds to `subset`.
157    ///
158    /// This operation is `O(1)`.
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// use libdd_tinybytes::Bytes;
164    ///
165    /// let bytes = Bytes::copy_from_slice(b"hello world");
166    /// let subset = &bytes.as_ref()[0..5];
167    /// let slice = bytes.slice_ref(subset).unwrap();
168    /// assert_eq!(slice.as_ref(), b"hello");
169    ///
170    /// let subset = &bytes.as_ref()[6..11];
171    /// let slice = bytes.slice_ref(subset).unwrap();
172    /// assert_eq!(slice.as_ref(), b"world");
173    ///
174    /// let invalid_subset = b"invalid";
175    /// assert!(bytes.slice_ref(invalid_subset).is_none());
176    /// ```
177    pub fn slice_ref(&self, subset: &[u8]) -> Option<Bytes> {
178        // An empty slice can be a subset of any slice.
179        if subset.is_empty() {
180            return Some(Bytes::empty());
181        }
182
183        let subset_start = subset.as_ptr() as usize;
184        let subset_end = subset_start + subset.len();
185        let self_start = self.ptr.as_ptr() as usize;
186        let self_end = self_start + self.len;
187        if subset_start >= self_start && subset_end <= self_end {
188            Some(self.safe_slice_ref(subset_start - self_start, subset_end - self_start))
189        } else {
190            None
191        }
192    }
193
194    pub fn from_underlying<T: UnderlyingBytes>(value: T) -> Self {
195        unsafe {
196            let refcounted = make_refcounted(value);
197            let a = refcounted.data.cast::<CustomArc<T>>().as_ptr();
198
199            // SAFETY:
200            // * the pointer associated with a slice is non null and valid for the length of the
201            //   slice
202            // * it stays valid as long as value is not dropped
203            let data: &T = &(*a).data;
204            let (ptr, len) = {
205                let s = data.as_ref();
206                (NonNull::new_unchecked(s.as_ptr().cast_mut()), s.len())
207            };
208            Self::from_raw_refcount(ptr, len, refcounted)
209        }
210    }
211
212    #[inline]
213    fn safe_slice_ref(&self, start: usize, end: usize) -> Self {
214        if !(start <= end && end <= self.len) {
215            #[allow(clippy::panic)]
216            {
217                panic!("Out of bound slicing of Bytes instance")
218            }
219        }
220        // SAFETY:
221        // * start is less than len, so the resulting pointer is
222        // going either inside the allocation or one past
223        // * we have 0 <= start <= end <= len so 0 <= end - start <= len - start. Since the new ptr
224        // points to ptr + start, then memory span is between ptr + start and (ptr + start) + (len -
225        // start) = ptr + len
226        Self {
227            ptr: unsafe { self.ptr.add(start) },
228            len: end - start,
229            bytes: self.bytes.clone(),
230        }
231    }
232
233    #[inline]
234    fn as_slice(&self) -> &[u8] {
235        // SAFETY: ptr is valid for the associated length
236        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr().cast_const(), self.len()) }
237    }
238}
239
240// Implementations of `UnderlyingBytes` for common types.
241impl UnderlyingBytes for Vec<u8> {}
242impl UnderlyingBytes for Box<[u8]> {}
243impl UnderlyingBytes for String {}
244
245// Implementations of common traits for `Bytes`.
246impl Default for Bytes {
247    fn default() -> Self {
248        Self::empty()
249    }
250}
251
252impl<T: UnderlyingBytes> From<T> for Bytes {
253    fn from(value: T) -> Self {
254        Self::from_underlying(value)
255    }
256}
257
258impl AsRef<[u8]> for Bytes {
259    #[inline]
260    fn as_ref(&self) -> &[u8] {
261        self.as_slice()
262    }
263}
264
265impl borrow::Borrow<[u8]> for Bytes {
266    #[inline]
267    fn borrow(&self) -> &[u8] {
268        self.as_slice()
269    }
270}
271
272impl ops::Deref for Bytes {
273    type Target = [u8];
274    #[inline]
275    fn deref(&self) -> &Self::Target {
276        self.as_slice()
277    }
278}
279
280impl<T: AsRef<[u8]>> PartialEq<T> for Bytes {
281    #[inline]
282    fn eq(&self, other: &T) -> bool {
283        self.as_slice() == other.as_ref()
284    }
285}
286
287impl Eq for Bytes {}
288
289impl<T: AsRef<[u8]>> PartialOrd<T> for Bytes {
290    fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
291        self.as_slice().partial_cmp(other.as_ref())
292    }
293}
294
295impl Ord for Bytes {
296    fn cmp(&self, other: &Bytes) -> cmp::Ordering {
297        self.as_slice().cmp(other.as_slice())
298    }
299}
300
301impl hash::Hash for Bytes {
302    // TODO should we cache the hash since we know the bytes are immutable?
303    #[inline]
304    fn hash<H: hash::Hasher>(&self, state: &mut H) {
305        self.as_slice().hash(state);
306    }
307}
308
309impl fmt::Debug for Bytes {
310    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
311        fmt::Debug::fmt(self.as_slice(), f)
312    }
313}
314
315#[cfg(feature = "serde")]
316impl Serialize for Bytes {
317    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
318    where
319        S: serde::Serializer,
320    {
321        serializer.serialize_bytes(self.as_slice())
322    }
323}
324
325pub struct RefCountedCell {
326    data: NonNull<()>,
327    vtable: &'static RefCountedCellVTable,
328}
329
330unsafe impl Send for RefCountedCell {}
331unsafe impl Sync for RefCountedCell {}
332
333impl RefCountedCell {
334    #[inline]
335    /// Creates a new `RefCountedCell` from the given data and vtable.
336    ///
337    /// The data pointer can be used to store arbitrary data, that won't be dropped until the last
338    /// clone to the `RefCountedCell` is dropped.
339    /// The vtable customizes the behavior of a Waker which gets created from a RawWaker. For each
340    /// operation on the Waker, the associated function in the vtable of the underlying RawWaker
341    /// will be called.
342    ///
343    /// # Safety
344    ///
345    /// * The value pointed to by `data` must be 'static + Send + Sync
346    pub const unsafe fn from_raw(data: NonNull<()>, vtable: &'static RefCountedCellVTable) -> Self {
347        RefCountedCell { data, vtable }
348    }
349}
350
351impl Clone for RefCountedCell {
352    fn clone(&self) -> Self {
353        unsafe { (self.vtable.clone)(self.data) }
354    }
355}
356
357impl Drop for RefCountedCell {
358    fn drop(&mut self) {
359        unsafe { (self.vtable.drop)(self.data) }
360    }
361}
362
363pub struct RefCountedCellVTable {
364    pub clone: unsafe fn(NonNull<()>) -> RefCountedCell,
365    pub drop: unsafe fn(NonNull<()>),
366}
367
368/// A custom Arc implementation that contains only the strong count
369///
370/// This struct is not exposed to the outside of this functions and is
371/// only interacted with through the `RefCountedCell` API.
372struct CustomArc<T> {
373    rc: AtomicUsize,
374    #[allow(unused)]
375    data: T,
376}
377
378/// Creates a refcounted cell.
379///
380/// The data passed to this cell will only be dopped when the last
381/// clone of the cell is dropped.
382fn make_refcounted<T: Send + Sync + 'static>(data: T) -> RefCountedCell {
383    unsafe fn custom_arc_clone<T>(data: NonNull<()>) -> RefCountedCell {
384        let custom_arc = data.cast::<CustomArc<T>>().as_ref();
385        custom_arc
386            .rc
387            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
388        RefCountedCell::from_raw(
389            data,
390            &RefCountedCellVTable {
391                clone: custom_arc_clone::<T>,
392                drop: custom_arc_drop::<T>,
393            },
394        )
395    }
396
397    unsafe fn custom_arc_drop<T>(data: NonNull<()>) {
398        let custom_arc = data.cast::<CustomArc<T>>().as_ref();
399        if custom_arc
400            .rc
401            .fetch_sub(1, std::sync::atomic::Ordering::Release)
402            != 1
403        {
404            return;
405        }
406
407        // Run drop + free memory on the data manually rather than casting back to a box
408        // because otherwise miri complains
409
410        // See standard library documentation for std::sync::Arc to see why this is needed.
411        // https://github.com/rust-lang/rust/blob/2a5da7acd4c3eae638aa1c46f3a537940e60a0e4/library/alloc/src/sync.rs#L2647-L2675
412        std::sync::atomic::fence(std::sync::atomic::Ordering::Acquire);
413        {
414            let custom_arc = data.cast::<CustomArc<T>>().as_mut();
415            std::ptr::drop_in_place(custom_arc);
416        }
417
418        std::alloc::dealloc(
419            data.as_ptr() as *mut u8,
420            std::alloc::Layout::new::<CustomArc<T>>(),
421        );
422    }
423
424    let rc = Box::leak(Box::new(CustomArc {
425        rc: AtomicUsize::new(1),
426        data,
427    })) as *mut _ as *const ();
428    RefCountedCell {
429        data: unsafe { NonNull::new_unchecked(rc as *mut ()) },
430        vtable: &RefCountedCellVTable {
431            clone: custom_arc_clone::<T>,
432            drop: custom_arc_drop::<T>,
433        },
434    }
435}
436
437#[cfg(feature = "bytes_string")]
438mod bytes_string;
439#[cfg(feature = "bytes_string")]
440pub use bytes_string::BytesString;
441
442#[cfg(test)]
443mod test;