Skip to main content

diskann_quantization/alloc/
poly.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    alloc::Layout,
8    mem::MaybeUninit,
9    ops::{Deref, DerefMut},
10    ptr::NonNull,
11};
12
13#[cfg(feature = "flatbuffers")]
14use super::Allocator;
15use super::{AllocatorCore, AllocatorError, GlobalAllocator, TryClone};
16
17/// An owning pointer type like `std::Box` that supports custom allocators.
18///
19/// # Examples
20///
21/// ## Creating and Mutating a Simple Type
22///
23/// This example demonstrates that `Poly` behaves pretty much like a `Box` with allocator
24/// support for simple types.
25/// ```
26/// use diskann_quantization::alloc::{Poly, GlobalAllocator};
27///
28/// // `Poly` constructors can fail due to allocator errors, so return `Results`.
29/// let mut x = Poly::new(10usize, GlobalAllocator).unwrap();
30/// assert_eq!(*x, 10);
31///
32/// *x = 50;
33/// assert_eq!(*x, 50);
34/// ```
35///
36/// ## Creating and Mutating a Slice
37///
38/// The standard library trait [`FromIterator`] is not implemented for `Poly` because an
39/// allocator is required for all construction operations. Instead, the inherent method
40/// [`Poly::from_iter`] can be used, provided the iterator is one of the select few for
41/// which [`TrustedIter`] is implemented, indicating that the length of the iterator can
42/// be relied on in unsafe code.
43///
44/// ```
45/// use diskann_quantization::alloc::{Poly, GlobalAllocator};
46/// let v = vec![
47///     "foo".to_string(),
48///     "bar".to_string(),
49///     "baz".to_string(),
50/// ];
51///
52/// let poly = Poly::from_iter(v.into_iter(), GlobalAllocator).unwrap();
53/// assert_eq!(poly.len(), 3);
54/// assert_eq!(poly[0], "foo");
55/// assert_eq!(poly[1], "bar");
56/// assert_eq!(poly[2], "baz");
57/// ```
58///
59/// ## Using a Custom Allocator
60///
61/// This crate provides a handful of custom allocators, including the [`super::BumpAllocator`].
62/// It can be used to group together allocations into a single arena.
63///
64/// ```
65/// use diskann_quantization::{
66///     alloc::{Poly, BumpAllocator},
67///     num::PowerOfTwo
68/// };
69///
70/// let dim = 10;
71///
72/// // Estimate how many bytes are needed to create two such slices. We can control the
73/// // alignment of the base pointer for the `BumpAllocator` to avoid extra memory used
74/// // to satisfy alignment.
75/// let alloc = BumpAllocator::new(
76///     dim * (std::mem::size_of::<f64>() + std::mem::size_of::<u8>()),
77///     PowerOfTwo::new(64).unwrap(),
78/// ).unwrap();
79///
80/// let foo = Poly::<[f64], _>::from_iter(
81///    (0..dim).map(|i| i as f64),
82///    alloc.clone(),
83/// ).unwrap();
84///
85/// let bar = Poly::<[u8], _>::from_iter(
86///    (0..dim).map(|i| i as u8),
87///    alloc.clone(),
88/// ).unwrap();
89///
90/// // The base pointer for the allocated object in `foo` is the base pointer of the arena
91/// // owned by the BumpAllocator.
92/// assert_eq!(alloc.as_ptr(), Poly::as_ptr(&foo).cast::<u8>());
93///
94/// // The base pointer for the allocated object in `bar` is also within the arena owned
95/// // by the BumpAllocator as well.
96/// assert_eq!(
97///     unsafe { alloc.as_ptr().add(std::mem::size_of_val(&*foo)) },
98///     Poly::as_ptr(&bar).cast::<u8>(),
99/// );
100///
101/// // The allocator is now full - so further allocations will fail.
102/// assert!(Poly::new(10usize, alloc.clone()).is_err());
103///
104/// // If we drop the allocator, the clones inside the `Poly` containers will keep the
105/// // backing memory alive.
106/// std::mem::forget(alloc);
107/// assert!(foo.iter().enumerate().all(|(i, v)| i as f64 == *v));
108/// assert!(bar.iter().enumerate().all(|(i, v)| i as u8 == *v));
109/// ```
110///
111/// ## Using Trait Object
112///
113/// `Poly` is compatible with trait objects as well using the [`crate::poly!`] macro.
114/// A macro is needed because traits such as
115/// [`Unsize`](https://doc.rust-lang.org/std/marker/trait.Unsize.html) and
116/// [`CoerceUnsized`](https://doc.rust-lang.org/std/ops/trait.CoerceUnsized.html) are
117/// unstable.
118///
119/// ```
120/// use diskann_quantization::{
121///     poly,
122///     alloc::BumpAllocator,
123///     num::PowerOfTwo,
124/// };
125/// use std::fmt::Display;
126///
127/// let message = "hello world";
128///
129/// let alloc = BumpAllocator::new(512, PowerOfTwo::new(64).unwrap()).unwrap();
130///
131/// // Create a new `Poly` trait object for `std::fmt::Display`. Due to limitations in the
132/// // macro matching rules, identifiers are needed for the object and allocator.
133/// let clone = alloc.clone();
134/// let poly = poly!(std::fmt::Display, message, clone).unwrap();
135/// assert_eq!(poly.to_string(), "hello world");
136///
137/// // Here - we demonstrate the full type of the returned `Poly`.
138/// let clone = alloc.clone();
139/// let _: diskann_quantization::alloc::Poly<dyn Display, _> = poly!(
140///     Display,
141///     message,
142///     clone
143/// ).unwrap();
144///
145/// // If additional auto traits are needed like `Send`, the brace-style syntax can be used
146/// let clone = alloc.clone();
147/// let poly = poly!({ std::fmt::Display + Send + Sync }, message, clone).unwrap();
148///
149/// // Existing `Poly` objects can be converted using the same macro.
150/// let poly = diskann_quantization::alloc::Poly::new(message, alloc.clone()).unwrap();
151/// let poly = poly!(std::fmt::Display, poly);
152/// assert_eq!(poly.to_string(), "hello world");
153/// ```
154///
155/// Naturally, the implementation of the trait is checked for validity.
156#[derive(Debug)]
157#[repr(C)]
158pub struct Poly<T, A = GlobalAllocator>
159where
160    T: ?Sized,
161    A: AllocatorCore,
162{
163    ptr: NonNull<T>,
164    allocator: A,
165}
166
167// SAFETY: `Poly` is `Send` when the pointed-to object and allocator are `Send`.
168unsafe impl<T, A> Send for Poly<T, A>
169where
170    T: ?Sized + Send,
171    A: AllocatorCore + Send,
172{
173}
174
175// SAFETY: `Poly` is `Sync` when the pointed-to object and allocator are `Sync`.
176unsafe impl<T, A> Sync for Poly<T, A>
177where
178    T: ?Sized + Sync,
179    A: AllocatorCore + Sync,
180{
181}
182
183/// Error type returned from [`Poly::new_with`].
184#[derive(Debug, Clone, Copy)]
185pub enum CompoundError<E> {
186    /// An allocator error occurred while allocating the base `Poly`.
187    Allocator(AllocatorError),
188    /// An error occurred while running the closure.
189    Constructor(E),
190}
191
192impl<T, A> Poly<T, A>
193where
194    A: AllocatorCore,
195{
196    /// Allocate memory from `allocator` and place `value` into that location.
197    pub fn new(value: T, allocator: A) -> Result<Self, AllocatorError> {
198        if std::mem::size_of::<T>() == 0 {
199            Ok(Self {
200                ptr: NonNull::dangling(),
201                allocator,
202            })
203        } else {
204            let ptr = allocator.allocate(Layout::new::<T>())?;
205
206            // SAFETY: On success, [`Allocator::allocate`] is required to return a suitable
207            // aligned pointer to a slice of size at least `std::mem::size_of::<T>()`.
208            //
209            // Therefore, the cast is valid.
210            //
211            // The write is safe because there is no existing object at the pointed to location.
212            let ptr: NonNull<T> = unsafe {
213                let ptr = ptr.cast::<T>();
214                ptr.as_ptr().write(value);
215                ptr
216            };
217
218            Ok(Self { ptr, allocator })
219        }
220    }
221
222    /// Allocate memory from `allocator` for `T`, then run the provided closure, moving
223    /// the result into the allocated memory.
224    ///
225    /// Because this allocates the storage for the object first, it can be used in
226    /// situations where the object to be allocated will use the same allocator, but should
227    /// be allocated after the base.
228    pub fn new_with<F, E>(f: F, allocator: A) -> Result<Self, CompoundError<E>>
229    where
230        F: FnOnce(A) -> Result<T, E>,
231        A: Clone,
232    {
233        // Construct an uninitialized version of `Self` first before invoking the constructor
234        // closure.
235        let mut this = Self::new_uninit(allocator.clone()).map_err(CompoundError::Allocator)?;
236        this.write(f(allocator).map_err(CompoundError::Constructor)?);
237
238        // SAFETY: We wrote to the `MaybeUninit` with the valid object returned from `f`.
239        Ok(unsafe { this.assume_init() })
240    }
241
242    /// Construct a new [`Poly`] with uninitialized contents in `allocator`.
243    pub fn new_uninit(allocator: A) -> Result<Poly<MaybeUninit<T>, A>, AllocatorError> {
244        if std::mem::size_of::<T>() == 0 {
245            Ok(Poly {
246                ptr: NonNull::dangling(),
247                allocator,
248            })
249        } else {
250            let ptr = allocator.allocate(Layout::new::<MaybeUninit<T>>())?;
251
252            // SAFETY: This cast is valid because
253            //
254            // 1. [`Allocator::allocate]` is required to on success to return a pointer to a
255            //    slice compatible with the provided layout.
256            //
257            // 2. This memory has not been initialized. Since `MaybeUninit` does not `Drop`
258            //    its contents, it is okay to hand out.
259            let ptr: NonNull<MaybeUninit<T>> = ptr.cast::<MaybeUninit<T>>();
260            Ok(Poly { ptr, allocator })
261        }
262    }
263}
264
265impl<T, A> Poly<T, A>
266where
267    T: ?Sized,
268    A: AllocatorCore,
269{
270    /// Consume `Self`, returning the wrapped pointer and allocator.
271    ///
272    /// This function does not trigger any drop logic nor deallocation.
273    pub fn into_raw(this: Self) -> (NonNull<T>, A) {
274        let ptr = this.ptr;
275
276        // SAFETY: This creates a bit-size copy of the allocator in `this`. Since we
277        // immediately forget `this`, this behaves like moving the allocator out of `this`,
278        // which is safe because `this` is taken by value.
279        let allocator = unsafe { std::ptr::read(&this.allocator) };
280        std::mem::forget(this);
281        (ptr, allocator)
282    }
283
284    /// Construct a [`Poly`] from a raw pointer and allocator.
285    ///
286    /// After calling this function, the returned [`Poly`] will assume ownership of the
287    /// provided pointer.
288    ///
289    /// # Safety
290    ///
291    /// The value of `ptr` must have runtime alignment compatible with
292    /// ```text
293    /// std::mem::Layout::for_value(&*ptr.as_ptr())
294    /// ```
295    /// and point to a valid object of type `T`.
296    ///
297    /// The pointer must point to memory currently allocated in `allocator`.
298    pub unsafe fn from_raw(ptr: NonNull<T>, allocator: A) -> Self {
299        Poly { ptr, allocator }
300    }
301
302    /// Return a pointer to the object managed by this `Poly`.
303    pub fn as_ptr(this: &Self) -> *const T {
304        this.ptr.as_ptr().cast_const()
305    }
306
307    /// Return a reference to the underlying allocator.
308    pub fn allocator(&self) -> &A {
309        &self.allocator
310    }
311}
312
313impl<T, A> Poly<MaybeUninit<T>, A>
314where
315    A: AllocatorCore,
316{
317    /// Converts to `Poly<T, A>`.
318    ///
319    /// # Safety
320    ///
321    /// The caller must ensure that the value has truly been initialized.
322    pub unsafe fn assume_init(self) -> Poly<T, A> {
323        let (ptr, allocator) = Poly::into_raw(self);
324        // SAFETY: It's the caller's responsibility to ensure that the pointed-to value
325        // has truely been initialized.
326        unsafe { Poly::from_raw(ptr.cast::<T>(), allocator) }
327    }
328}
329
330impl<T, A> Poly<[T], A>
331where
332    A: AllocatorCore,
333{
334    /// Construct a new `Poly` containing an uninitialized slice of length `len` with
335    /// memory allocated from `allocator`.
336    pub fn new_uninit_slice(
337        len: usize,
338        allocator: A,
339    ) -> Result<Poly<[MaybeUninit<T>], A>, AllocatorError> {
340        let layout = Layout::array::<T>(len).map_err(|_| AllocatorError)?;
341        let ptr = if layout.size() == 0 {
342            // SAFETY: We're either constructing a slice of zero sized types, or a slice
343            // of length zero. In either case, `NonNull::dangling()` ensures proper
344            // alignment of the non-null base pointer.
345            unsafe {
346                NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(
347                    NonNull::dangling().as_ptr(),
348                    len,
349                ))
350            }
351        } else {
352            let ptr = allocator.allocate(layout)?;
353            debug_assert_eq!(ptr.len(), layout.size());
354
355            // SAFETY: `Allocator` is required to provide a properly sized and aligned
356            // slice upon success. Wrapping the raw memory in `MaybeUninit` is okay because
357            // we will not try to `Drop` values of type `T` until they've been properly
358            // initialized.
359            unsafe {
360                NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(
361                    ptr.as_ptr().cast::<MaybeUninit<T>>(),
362                    len,
363                ))
364            }
365        };
366
367        // SAFETY: `ptr` points to a properly initialized object allocated from `allocator`.
368        Ok(unsafe { Poly::from_raw(ptr, allocator) })
369    }
370
371    /// Construct a new `Poly` from the iterator.
372    pub fn from_iter<I>(iter: I, allocator: A) -> Result<Self, AllocatorError>
373    where
374        I: TrustedIter<Item = T>,
375    {
376        // A guard that drops the initialized portion of a partially constructed slice
377        // in the event that `iter` panics.
378        struct Guard<'a, T, A>
379        where
380            A: AllocatorCore,
381        {
382            uninit: &'a mut Poly<[MaybeUninit<T>], A>,
383            initialized_to: usize,
384        }
385
386        impl<T, A> Drop for Guard<'_, T, A>
387        where
388            A: AllocatorCore,
389        {
390            fn drop(&mut self) {
391                // Performance optimization: skip if `T` doesn't need to be dropped.
392                //
393                // Not needed for release builds since the drop loop will be optimized away,
394                // but can make debug build run a little faster.
395                //
396                // See: https://doc.rust-lang.org/std/mem/fn.needs_drop.html
397                if std::mem::needs_drop::<T>() {
398                    self.uninit
399                        .iter_mut()
400                        .take(self.initialized_to)
401                        .for_each(|u|
402                            // SAFETY: `self.initialized_to` is only incremented after a
403                            // successful write and therefore `u` is initialized.
404                            unsafe { u.assume_init_drop() });
405                }
406            }
407        }
408
409        let mut uninit = Poly::<[T], A>::new_uninit_slice(iter.len(), allocator)?;
410
411        let mut guard = Guard {
412            uninit: &mut uninit,
413            initialized_to: 0,
414        };
415
416        std::iter::zip(iter, guard.uninit.iter_mut()).for_each(|(src, dst)| {
417            dst.write(src);
418            guard.initialized_to += 1;
419        });
420
421        debug_assert_eq!(
422            guard.initialized_to,
423            guard.uninit.len(),
424            "an incorrect number of elements was initialized",
425        );
426
427        // Forget the guard so its destructor doesn't run and there-by ruin all the good
428        // work we just did.
429        std::mem::forget(guard);
430
431        // SAFETY: Since `iter` has a trusted length, we know every element in `uninit` has
432        // been properly initialized.
433        Ok(unsafe { uninit.assume_init() })
434    }
435}
436
437impl<T, A> Poly<[MaybeUninit<T>], A>
438where
439    A: AllocatorCore,
440{
441    /// Converts to `Poly<[T], A>`.
442    ///
443    /// # Safety
444    ///
445    /// The caller must ensure that the value has truly been initialized.
446    pub unsafe fn assume_init(self) -> Poly<[T], A> {
447        let len = self.deref().len();
448        let (ptr, allocator) = Poly::into_raw(self);
449
450        // SAFETY: The slice cast is valid because
451        //
452        // 1. The caller has asserted that `self` has been initialized.
453        // 2. `MaybeUninit<T>` is ABI compatible with `T`.
454        //
455        // The unchecked `NonNull` is valid because `self.ptr` is `NonNull`.
456        let ptr = unsafe {
457            NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(
458                ptr.as_ptr().cast::<T>(),
459                len,
460            ))
461        };
462
463        // SAFETY: The memory layout and pointed-to contents of of `[T]` is exactly the
464        // same as `[MaybeUninit<T>]`. So it is acceptable to deallocate this derived `[T]`
465        // from `allocator`.
466        //
467        // Additionally, the caller has asserted that the pointed-to slice is valid.
468        unsafe { Poly::<[T], A>::from_raw(ptr, allocator) }
469    }
470}
471
472impl<T, A> Poly<[T], A>
473where
474    A: AllocatorCore,
475    T: Clone,
476{
477    /// Construct a new `Poly` slice with each entry initialized to `value`.
478    pub fn broadcast(value: T, len: usize, allocator: A) -> Result<Self, AllocatorError> {
479        Self::from_iter((0..len).map(|_| value.clone()), allocator)
480    }
481}
482
483impl<T, A> Drop for Poly<T, A>
484where
485    T: ?Sized,
486    A: AllocatorCore,
487{
488    fn drop(&mut self) {
489        // SAFETY: Because `self` hasn't been dropped quite yet, the pointed to object is
490        // still valid.
491        let layout = Layout::for_value(unsafe { self.ptr.as_ref() });
492
493        // SAFETY: `Poly` owns the pointed-to object, so we can drop it when the `Poly`
494        // is dropped.
495        unsafe { std::ptr::drop_in_place(self.ptr.as_ptr()) };
496
497        if layout.size() != 0 {
498            // This cast is safe because `u8`'s alignment requirements are equal to or less
499            // than `T`. Additionally, the layout was derived from the pointed to object.
500            let as_slice =
501                std::ptr::slice_from_raw_parts_mut(self.ptr.as_ptr().cast::<u8>(), layout.size());
502
503            // SAFETY: `self.ptr` was non-null.
504            let ptr = unsafe { NonNull::new_unchecked(as_slice) };
505
506            // SAFETY: The construction of `Poly` means that the pointer is always passed
507            // around with its corresponding allocator.
508            unsafe { self.allocator.deallocate(ptr, layout) }
509        }
510    }
511}
512
513impl<T, A> Deref for Poly<T, A>
514where
515    T: ?Sized,
516    A: AllocatorCore,
517{
518    type Target = T;
519    fn deref(&self) -> &Self::Target {
520        // SAFETY: As long as `Self` is alive, the pointed-to object is valid.
521        unsafe { self.ptr.as_ref() }
522    }
523}
524
525impl<T, A> DerefMut for Poly<T, A>
526where
527    T: ?Sized,
528    A: AllocatorCore,
529{
530    fn deref_mut(&mut self) -> &mut Self::Target {
531        // SAFETY: As long as `Self` is alive, the pointed-to object is valid.
532        //
533        // Since there is a mutable reference to `Self`, the access to the pointed-to
534        // object is exclusive, so it is safe to return a mutable reference.
535        unsafe { self.ptr.as_mut() }
536    }
537}
538
539///////////////
540// From Iter //
541///////////////
542
543/// A local marker type for iterators with a trusted length.
544///
545/// # Safety
546///
547/// Implementation must ensure that the implementation of `ExactSizeIterator` is such that
548/// that unsafe code can rely on the returned value.
549pub unsafe trait TrustedIter: ExactSizeIterator {}
550
551//---------------//
552// Raw Iterators //
553//---------------//
554
555// SAFETY: `std::slice` is trusted.
556unsafe impl<T> TrustedIter for std::slice::Iter<'_, T> {}
557// SAFETY: `std::vec` is trusted.
558unsafe impl<T> TrustedIter for std::vec::IntoIter<T> {}
559// SAFETY: `std::ops::Range` is trusted.
560unsafe impl TrustedIter for std::ops::Range<usize> {}
561// SAFETY: `std::array::IntoIter` is trusted.
562unsafe impl<T, const N: usize> TrustedIter for std::array::IntoIter<T, N> {}
563// SAFETY: `rand::seq::index::IndexVecIntoIter` is trusted.
564unsafe impl TrustedIter for rand::seq::index::IndexVecIntoIter {}
565
566#[cfg(feature = "flatbuffers")]
567// SAFETY: We trust the implementors of `flatbuffer` return trustable lengths for this iterator.
568unsafe impl<'a, T> TrustedIter for flatbuffers::VectorIter<'a, T> where T: flatbuffers::Follow<'a> {}
569
570//---------------//
571// Map Iterators //
572//---------------//
573
574// SAFETY: Maps of trusted iterators are trusted.
575unsafe impl<I, U, F> TrustedIter for std::iter::Map<I, F>
576where
577    I: TrustedIter,
578    F: FnMut(I::Item) -> U,
579{
580}
581
582// SAFETY: Enumerates of trusted iterators are trusted.
583unsafe impl<I> TrustedIter for std::iter::Enumerate<I> where I: TrustedIter {}
584
585// SAFETY: Clones of trusted iterators are trusted.
586unsafe impl<'a, I, T> TrustedIter for std::iter::Cloned<I>
587where
588    I: TrustedIter<Item = &'a T>,
589    T: 'a + Clone,
590{
591}
592
593// SAFETY: Copies of trusted iterators are trusted.
594unsafe impl<'a, I, T> TrustedIter for std::iter::Copied<I>
595where
596    I: TrustedIter<Item = &'a T>,
597    T: 'a + Copy,
598{
599}
600
601// SAFETY: Zip of trusted iterators is trusted.
602unsafe impl<T, U> TrustedIter for std::iter::Zip<T, U>
603where
604    T: TrustedIter,
605    U: TrustedIter,
606{
607}
608
609//////////////////
610// Trait Object //
611//////////////////
612
613#[macro_export]
614macro_rules! poly {
615    // Creating a new poly types.
616    ({ $($traits:tt)+ }, $v:ident, $alloc:ident) => {{
617        $crate::alloc::Poly::new($v, $alloc).map(|poly| {
618            $crate::alloc::poly!({ $($traits)+ }, poly)
619        })
620    }};
621    ($trait:path, $v:ident, $alloc:ident) => {{
622        $crate::alloc::poly!({ $trait }, $v, $alloc)
623    }};
624
625    // Coercing an existing `poly`.
626    ({ $($traits:tt)+ }, $poly:ident) => {{
627        let (ptr, alloc) = $crate::alloc::Poly::into_raw($poly);
628
629        // The deduction chain goes that we need to coerce the pointer from `*const T` for
630        // some concrete type `T` to `*const dyn $traits...`.
631        //
632        // Putting the dyn trait in the turbo-fish for the `Poly::from_raw` forces the
633        // corresponding pointer argument to be the dynamic pointer.
634        //
635        // As such, Rust will check to see if Unsized coercion applies. If not, we get
636        // a compilation error.
637        //
638        // SAFETY: The unsafe part is the call to `from_raw`, which is safe because we just
639        // obtained `ptr` from `into_raw` and the pointed-to object is still the same, so
640        // may be safely deallocated with `alloc`.
641        unsafe { $crate::alloc::Poly::<dyn $($traits)*, _>::from_raw(ptr, alloc) }
642    }};
643    ($trait:path, $poly:ident) => {{
644        $crate::alloc::poly!({ $trait }, $poly)
645    }};
646
647    // Literal array constructor.
648    ([$($x:expr),* $(,)?], $alloc:ident) => {{
649        Poly::from_iter([$($x,)*].into_iter(), $alloc)
650    }}
651}
652
653pub use poly;
654
655///////////////
656// Try Clone //
657///////////////
658
659impl<T, A> TryClone for Poly<T, A>
660where
661    T: Clone,
662    A: AllocatorCore + Clone,
663{
664    fn try_clone(&self) -> Result<Self, AllocatorError> {
665        let clone = (*self).clone();
666        Poly::new(clone, self.allocator().clone())
667    }
668}
669
670impl<T, A> TryClone for Poly<[T], A>
671where
672    T: Clone,
673    A: AllocatorCore + Clone,
674{
675    fn try_clone(&self) -> Result<Self, AllocatorError> {
676        Poly::from_iter(self.iter().cloned(), self.allocator().clone())
677    }
678}
679
680impl<T, A> TryClone for Option<Poly<T, A>>
681where
682    T: ?Sized,
683    A: AllocatorCore,
684    Poly<T, A>: super::TryClone,
685{
686    fn try_clone(&self) -> Result<Self, AllocatorError> {
687        Ok(match self {
688            Some(v) => Some(v.try_clone()?),
689            None => None,
690        })
691    }
692}
693
694////////////////////////////////
695// Recursively Defined Traits //
696////////////////////////////////
697
698impl<T, A> PartialEq for Poly<T, A>
699where
700    T: ?Sized + PartialEq,
701    A: AllocatorCore,
702{
703    #[inline]
704    fn eq(&self, other: &Self) -> bool {
705        PartialEq::eq(&**self, &**other)
706    }
707}
708
709////////////////
710// Conversion //
711////////////////
712
713impl<T> From<Box<[T]>> for Poly<[T], GlobalAllocator> {
714    fn from(value: Box<[T]>) -> Self {
715        // SAFETY: The underlying pointer for `Box` is always `NonNull`, and
716        // `GlobalAllocator` is the same as the global allocator used by `std::Box`.
717        unsafe {
718            Poly::from_raw(
719                NonNull::new_unchecked(Box::into_raw(value)),
720                GlobalAllocator,
721            )
722        }
723    }
724}
725
726////////////////////////
727// Flatbuffer Support //
728////////////////////////
729
730#[cfg(feature = "flatbuffers")]
731// SAFETY: We correctly report the length of the buffer and allocate downwards.
732//
733// Also - clippy doesn't work if the safety comment comes before the `cfg`.
734unsafe impl<A> flatbuffers::Allocator for Poly<[u8], A>
735where
736    A: Allocator,
737{
738    type Error = AllocatorError;
739
740    // Double the size of `self` and move the current results to the end of the new buffer.
741    fn grow_downwards(&mut self) -> Result<(), Self::Error> {
742        // `self.len()` is constrained to be less than `isize::MAX`, so we can double it and
743        // still fit within `usize`.
744        let next_len = (2 * self.len()).max(1);
745        let mut next = Poly::broadcast(0u8, next_len, self.allocator().clone())?;
746        next[next_len - self.len()..].copy_from_slice(self);
747        *self = next;
748        Ok(())
749    }
750
751    fn len(&self) -> usize {
752        (**self).len()
753    }
754}
755
756///////////
757// Tests //
758///////////
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use crate::test_util::AlwaysFails;
764
765    struct HasHoles {
766        s: String,
767        a: u32,
768        b: u8,
769    }
770
771    impl HasHoles {
772        fn new(s: String, a: u32, b: u8) -> Self {
773            Self { s, a, b }
774        }
775    }
776
777    fn assert_is_send<T>(_: &T)
778    where
779        T: Send,
780    {
781    }
782
783    //-------//
784    // Sizes //
785    //-------//
786
787    #[test]
788    fn size_check() {
789        assert_eq!(std::mem::size_of::<Poly<usize>>(), 8);
790        assert_eq!(std::mem::size_of::<Option<Poly<usize>>>(), 8);
791    }
792
793    //-------//
794    // Basic //
795    //-------//
796
797    #[test]
798    fn basic_test_copy() {
799        let x = 10usize;
800        let poly = Poly::new(x, GlobalAllocator).unwrap();
801        assert_eq!(*poly, 10);
802    }
803
804    #[test]
805    fn basic_test_borrow() {
806        let x = &10usize;
807        let poly = Poly::<&usize>::new(x, GlobalAllocator).unwrap();
808        assert_eq!(**poly, 10);
809    }
810
811    #[test]
812    fn test_with_drop() {
813        let poly = Poly::<String>::new("hello world".to_string(), GlobalAllocator).unwrap();
814        assert_eq!(&**poly, "hello world");
815    }
816
817    #[test]
818    fn test_mutate() {
819        let mut poly = Poly::<String>::new("foo".to_string(), GlobalAllocator).unwrap();
820        assert_eq!(&**poly, "foo");
821        *poly = "bar".to_string();
822        assert_eq!(&**poly, "bar");
823    }
824
825    //------------------//
826    // Zero Sized Types //
827    //------------------//
828
829    #[test]
830    fn zero_sized() {
831        let _ = Poly::<()>::new((), GlobalAllocator).unwrap();
832    }
833
834    #[test]
835    fn zero_sized_raw() {
836        let x = Poly::<()>::new((), GlobalAllocator).unwrap();
837        let (ptr, alloc) = Poly::into_raw(x);
838        // SAFETY: `ptr` and `alloc` were obtained from `into_raw`.
839        let _ = unsafe { Poly::from_raw(ptr, alloc) };
840    }
841
842    #[test]
843    fn zero_sized_uninit() {
844        let _ = Poly::<()>::new_uninit(GlobalAllocator).unwrap();
845    }
846
847    #[test]
848    fn zero_sized_uninit_to_init() {
849        let x = Poly::<()>::new_uninit(GlobalAllocator).unwrap();
850        // SAFETY: No initialization is required for zero sized types.
851        let _ = unsafe { x.assume_init() };
852    }
853
854    #[test]
855    fn zero_sized_slice() {
856        let x = Poly::<[()]>::from_iter((0..0).map(|_| ()), GlobalAllocator).unwrap();
857        assert!(x.is_empty());
858
859        let x = Poly::<[()]>::from_iter((0..10).map(|_| ()), GlobalAllocator).unwrap();
860        assert_eq!(x.len(), 10);
861
862        let x = Poly::<[usize]>::from_iter(0..0, GlobalAllocator).unwrap();
863        assert!(x.is_empty());
864
865        let x =
866            Poly::<[String]>::from_iter((0..0).map(|i| i.to_string()), GlobalAllocator).unwrap();
867        assert!(x.is_empty());
868    }
869
870    //--------//
871    // Uninit //
872    //--------//
873
874    #[test]
875    fn dropping_uninit_is_okay() {
876        // `String` has a non-trivial `Drop` implementation.
877        //
878        // If we return an uninitialized `Poly<String>` and do not initialize the
879        // contents, dropping the `Poly<MaybeUninit<String>>` should not trigger undefined
880        // behavior.
881        let _ = Poly::<HasHoles>::new_uninit(GlobalAllocator).unwrap();
882    }
883
884    #[test]
885    fn test_assume_init() {
886        let mut poly = Poly::<HasHoles>::new_uninit(GlobalAllocator).unwrap();
887        poly.write(HasHoles::new("hello world".into(), 10, 5));
888        // SAFETY: We just initialized `poly`.
889        let poly: Poly<HasHoles> = unsafe { poly.assume_init() };
890        assert_eq!(poly.s, "hello world");
891        assert_eq!(poly.a, 10);
892        assert_eq!(poly.b, 5);
893    }
894
895    #[test]
896    fn test_assume_init_slice_copy() {
897        let mut poly = Poly::<[usize]>::new_uninit_slice(10, GlobalAllocator).unwrap();
898        assert_eq!(poly.len(), 10);
899        for (i, v) in poly.iter_mut().enumerate() {
900            v.write(i);
901        }
902        // SAFETY: We just initialized `poly`.
903        let poly: Poly<[usize]> = unsafe { poly.assume_init() };
904
905        for (i, v) in poly.iter().enumerate() {
906            assert_eq!(*v, i);
907        }
908    }
909
910    #[test]
911    fn test_assume_init_slice_drop() {
912        let mut poly = Poly::<[HasHoles]>::new_uninit_slice(10, GlobalAllocator).unwrap();
913        assert_eq!(poly.len(), 10);
914        for (i, v) in poly.iter_mut().enumerate() {
915            v.write(HasHoles::new(
916                i.to_string(),
917                i.try_into().unwrap(),
918                i.try_into().unwrap(),
919            ));
920        }
921        // SAFETY: We just initialized `poly`.
922        let poly: Poly<[HasHoles]> = unsafe { poly.assume_init() };
923
924        for (i, v) in poly.iter().enumerate() {
925            assert_eq!(v.s, i.to_string());
926            assert_eq!(v.a as usize, i);
927            assert_eq!(v.b as usize, i);
928        }
929    }
930
931    //-----------//
932    // From Iter //
933    //-----------//
934
935    #[test]
936    fn from_iter_strings() {
937        let p =
938            Poly::<[String], _>::from_iter((0..5).map(|i| i.to_string()), GlobalAllocator).unwrap();
939
940        assert_eq!(&*p, &["0", "1", "2", "3", "4"])
941    }
942
943    /// Test for undefined behavior if an iterator panics on the first item. In this
944    /// situation nothing should be dropped.
945    ///
946    /// This must be tested using Miri.
947    #[test]
948    #[should_panic(expected = "first")]
949    fn from_iter_cleanup_first() {
950        Poly::<[String], _>::from_iter((0..5).map(|_| panic!("first")), GlobalAllocator).unwrap();
951    }
952
953    /// This test induces a panic in `from_iter` in the middle of iteration to test the
954    /// incremental drop logic.
955    ///
956    /// A non-compliant implementation will leak memory, which Miri can detect.
957    #[test]
958    #[should_panic(expected = "middle")]
959    fn from_iter_cleanup_middle() {
960        Poly::<[String], _>::from_iter(
961            (0..5).map(|i| {
962                if i == 3 {
963                    panic!("middle");
964                } else {
965                    i.to_string()
966                }
967            }),
968            GlobalAllocator,
969        )
970        .unwrap();
971    }
972
973    /// This test is like `from_iter_cleanup_middle` but just panics at the very end.
974    ///
975    /// A non-compliant implementation will leak memory, which Miri can detect.
976    #[test]
977    #[should_panic(expected = "last")]
978    fn from_iter_cleanup_last() {
979        Poly::<[String], _>::from_iter(
980            (0..5).map(|i| {
981                let string = i.to_string();
982                if i == 4 {
983                    panic!("last");
984                }
985                string
986            }),
987            GlobalAllocator,
988        )
989        .unwrap();
990    }
991
992    //------------------//
993    // Allocator Errors //
994    //------------------//
995
996    #[test]
997    fn new_error() {
998        let _ = Poly::new(10usize, AlwaysFails).unwrap_err();
999    }
1000
1001    #[test]
1002    fn new_with_error() {
1003        let err = Poly::new_with(
1004            |_| -> Result<u8, std::convert::Infallible> { Ok(0) },
1005            AlwaysFails,
1006        )
1007        .unwrap_err();
1008        assert!(matches!(err, CompoundError::Allocator(_)));
1009
1010        let err = Poly::new_with(
1011            |_| -> Result<u8, std::num::TryFromIntError> {
1012                let x: u8 = (1000usize).try_into()?;
1013                Ok(x)
1014            },
1015            GlobalAllocator,
1016        )
1017        .unwrap_err();
1018        assert!(matches!(
1019            err,
1020            CompoundError::Constructor(std::num::TryFromIntError { .. })
1021        ));
1022    }
1023
1024    #[test]
1025    fn new_uninit_error() {
1026        let _ = Poly::<String, _>::new_uninit(AlwaysFails).unwrap_err();
1027    }
1028
1029    #[test]
1030    fn new_uninit_slice_error() {
1031        let _ = Poly::<[usize], _>::new_uninit_slice(10, AlwaysFails).unwrap_err();
1032    }
1033
1034    #[test]
1035    fn new_from_iter_error() {
1036        let _ = Poly::<[usize], _>::from_iter(0..10, AlwaysFails).unwrap_err();
1037    }
1038
1039    //---------------//
1040    // Trait Objects //
1041    //---------------//
1042
1043    trait Describe {
1044        fn describe(&self) -> String;
1045        fn describe_mut(&mut self) -> String;
1046    }
1047
1048    struct ImplsDescribe;
1049
1050    impl Describe for ImplsDescribe {
1051        fn describe(&self) -> String {
1052            "describe const".to_string()
1053        }
1054
1055        fn describe_mut(&mut self) -> String {
1056            "describe mut".to_string()
1057        }
1058    }
1059
1060    struct AlsoImplsDescribe(String);
1061
1062    impl Describe for AlsoImplsDescribe {
1063        fn describe(&self) -> String {
1064            format!("describe const: {}", self.0)
1065        }
1066
1067        fn describe_mut(&mut self) -> String {
1068            format!("describe mut: {}", self.0)
1069        }
1070    }
1071
1072    struct DescribeLifetime<'a>(&'a str);
1073
1074    impl Describe for DescribeLifetime<'_> {
1075        fn describe(&self) -> String {
1076            format!("describe const: {}", self.0)
1077        }
1078
1079        fn describe_mut(&mut self) -> String {
1080            format!("describe mut: {}", self.0)
1081        }
1082    }
1083
1084    trait Foo<T> {
1085        fn foo(&self, v: T) -> T;
1086    }
1087
1088    impl Foo<f32> for f32 {
1089        fn foo(&self, v: f32) -> f32 {
1090            *self + v
1091        }
1092    }
1093
1094    #[test]
1095    fn test_dyn_trait() {
1096        // Traits without generic parameters
1097        {
1098            let mut poly0 = poly!(Describe, ImplsDescribe, GlobalAllocator).unwrap();
1099
1100            let also = AlsoImplsDescribe("foo".to_string());
1101            let mut poly1 = poly!({ Describe + Send }, also, GlobalAllocator).unwrap();
1102            assert_is_send::<Poly<dyn Describe + Send, _>>(&poly1);
1103
1104            assert_eq!(poly1.describe(), "describe const: foo");
1105            assert_eq!(poly1.describe_mut(), "describe mut: foo");
1106
1107            assert_eq!(poly0.describe(), "describe const");
1108            assert_eq!(poly0.describe_mut(), "describe mut");
1109        }
1110
1111        {
1112            // Transform a `Poly<T>` to `Poly<dyn T>`.
1113            let mut poly =
1114                Poly::new(AlsoImplsDescribe("bar".to_string()), GlobalAllocator).unwrap();
1115            assert_is_send::<Poly<AlsoImplsDescribe>>(&poly);
1116
1117            assert_eq!(poly.describe(), "describe const: bar");
1118            assert_eq!(poly.describe_mut(), "describe mut: bar");
1119
1120            let mut poly = poly!({ Describe + Send }, poly);
1121
1122            assert_is_send::<Poly<dyn Describe + Send>>(&poly);
1123
1124            assert_eq!(poly.describe(), "describe const: bar");
1125            assert_eq!(poly.describe_mut(), "describe mut: bar");
1126        }
1127
1128        // Traits with generic parameters
1129        {
1130            let f = 1.0f32;
1131            let poly = poly!({ Foo<f32> }, f, GlobalAllocator).unwrap();
1132            assert_eq!(poly.foo(2.0), 3.0);
1133        }
1134
1135        {
1136            let poly = Poly::new(1.0f32, GlobalAllocator).unwrap();
1137            let poly = poly!({ Foo<f32> + Send }, poly);
1138
1139            assert_is_send::<Poly<dyn Foo<f32> + Send>>(&poly);
1140
1141            assert_eq!(poly.foo(2.0), 3.0);
1142        }
1143
1144        // Traits with generic parameters in a function.
1145        //
1146        // An improper implementation of the `poly!` macro won't be able to use the generic
1147        // parameter `T` due to the "cannot use generic parameter from outer item"
1148        // constraint.
1149        fn test<'a, T>(x: T) -> Poly<dyn Foo<T> + 'a>
1150        where
1151            T: Foo<T> + 'a,
1152        {
1153            poly!({ Foo<T> }, x, GlobalAllocator).unwrap()
1154        }
1155
1156        {
1157            let x = test(1.0f32);
1158            assert_eq!(x.foo(2.0), 3.0);
1159        }
1160    }
1161
1162    #[test]
1163    fn test_dyn_trait_with_lifetime() {
1164        let base: String = "foo".into();
1165        let describe = DescribeLifetime(&base);
1166
1167        let mut poly: Poly<dyn Describe> = poly!({ Describe }, describe, GlobalAllocator).unwrap();
1168        assert_eq!(poly.describe(), "describe const: foo");
1169        assert_eq!(poly.describe_mut(), "describe mut: foo");
1170    }
1171
1172    #[test]
1173    fn test_try_clone_item() {
1174        let x = Poly::<String>::new("hello".to_string(), GlobalAllocator).unwrap();
1175        let y = x.try_clone().unwrap();
1176        assert_eq!(x, y);
1177    }
1178
1179    #[test]
1180    fn test_try_clone_slice() {
1181        let x = Poly::<[String]>::from_iter(
1182            ["foo".to_string(), "bar".to_string(), "baz".to_string()].into_iter(),
1183            GlobalAllocator,
1184        )
1185        .unwrap();
1186
1187        let y = x.try_clone().unwrap();
1188        assert_eq!(x, y);
1189    }
1190
1191    #[test]
1192    fn test_try_clone_option() {
1193        let mut x = Some(Poly::<String>::new("hello".to_string(), GlobalAllocator).unwrap());
1194        let y = x.try_clone().unwrap();
1195        assert_eq!(x, y);
1196
1197        x = None;
1198        let y = x.try_clone().unwrap();
1199        assert_eq!(x, y);
1200    }
1201
1202    #[cfg(feature = "flatbuffers")]
1203    #[test]
1204    fn test_grow_downwards() {
1205        let mut x = Poly::from_iter([1u8, 2u8, 3u8].into_iter(), GlobalAllocator).unwrap();
1206        <_ as flatbuffers::Allocator>::grow_downwards(&mut x).unwrap();
1207        assert_eq!(&*x, &[0, 0, 0, 1, 2, 3]);
1208
1209        <_ as flatbuffers::Allocator>::grow_downwards(&mut x).unwrap();
1210        assert_eq!(&*x, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3]);
1211    }
1212}