Skip to main content

multi_array_list/
lib.rs

1//! A `MultiArrayList` stores a list of a struct.
2//!
3//! **Experimental**: Only a small subset of the array list API is implemented.
4//!
5//! ---
6//!
7//! > Instead of storing a single list of items, `MultiArrayList` stores separate lists for each field of the struct.
8//! > This allows for memory savings if the struct has padding,
9//! > and also improves cache usage if only some fields are needed for a computation.
10//!
11//! The primary API for accessing fields is the [`items::<name>`][`MultiArrayList::items()`] function.
12//!
13//! ---
14//! _inspired by [Zig's `MultiArrayList`](https://ziglang.org/documentation/master/std/#std.MultiArrayList)._
15//!
16//! # Example
17//! ```rust
18//! #![allow(incomplete_features)]
19//! #![feature(generic_const_exprs)]
20//! # use multi_array_list::MultiArrayList;
21//! struct Pizza {
22//!     radius: u32,
23//!     toppings: Vec<Topping>,
24//! }
25//!
26//! enum Topping {
27//!     Tomato,
28//!     Mozzarella,
29//!     Anchovies,
30//! }
31//!
32//! let mut order = MultiArrayList::<Pizza>::new();
33//!
34//! let margherita = Pizza {
35//!     radius: 12,
36//!     toppings: vec![Topping::Tomato],
37//! };
38//! order.push(margherita);
39//!
40//! let napoli = Pizza {
41//!     radius: 12,
42//!     toppings: vec![Topping::Tomato, Topping::Anchovies],
43//! };
44//! order.push(napoli);
45//!
46//! for topping in order.items_mut::<"toppings", Vec<Topping>>() {
47//!     topping.push(Topping::Mozzarella);
48//! }
49//! ```
50#![allow(incomplete_features)]
51#![feature(adt_const_params)]
52#![feature(generic_const_exprs)]
53#![feature(type_info)]
54#![feature(unsized_const_params)]
55#![feature(const_cmp)]
56#![feature(const_trait_impl)]
57#![warn(missing_docs)]
58
59use std::alloc::{self, Layout};
60use std::any::TypeId;
61use std::marker::PhantomData;
62use std::mem::type_info::{Field, Type, TypeKind};
63use std::mem::{self, ManuallyDrop, MaybeUninit};
64use std::ptr;
65use std::slice;
66
67/// Alignment of various types, imprecise for anything but numbers, chars and bools.
68const fn align_of(kind: TypeKind) -> usize {
69    use TypeKind::*;
70    match kind {
71        Int(int) => (int.bits / 8) as usize,
72        Float(float) => (float.bits / 8) as usize,
73        Char(_) => 4,
74        Bool(_) => 1,
75        _ => mem::size_of::<usize>(),
76    }
77}
78
79/// Number of fields in `T`
80pub const fn field_len<T>() -> usize {
81    assert_valid_fields::<T>();
82    let typeinfo = Type::of::<T>();
83
84    typeinfo.size.unwrap();
85    let TypeKind::Struct(struct_info) = typeinfo.kind else {
86        panic!("MultiArrayList only works for structs");
87    };
88    struct_info.fields.len()
89}
90
91const fn assert_valid_fields<T>() {
92    let typeinfo = Type::of::<T>();
93
94    typeinfo.size.unwrap();
95    let TypeKind::Struct(struct_info) = typeinfo.kind else {
96        panic!("MultiArrayList only works for structs");
97    };
98
99    assert!(struct_info.generics.is_empty());
100
101    let mut i = 0;
102    while i < struct_info.fields.len() {
103        let field = &struct_info.fields[i];
104        let field_info = field.ty.info();
105
106        assert!(field_info.size.is_some(), "known size required");
107        assert!(matches!(
108            field_info.kind,
109            TypeKind::Bool(_)
110                | TypeKind::Int(_)
111                | TypeKind::Char(_)
112                | TypeKind::Float(_)
113                | TypeKind::Tuple(_)
114                | TypeKind::Struct(_)
115                | TypeKind::Enum(_)
116        ));
117
118        i += 1;
119    }
120}
121
122const fn field_sizes<T>() -> [usize; field_len::<T>()] {
123    let typeinfo = Type::of::<T>();
124
125    let TypeKind::Struct(struct_info) = typeinfo.kind else {
126        panic!("MultiArrayList only works for structs");
127    };
128
129    let fields = struct_info.fields;
130    let mut new_fields = [0; field_len::<T>()];
131
132    let mut i = 0;
133    while i < fields.len() {
134        let field = &fields[i];
135        let field_info = field.ty.info();
136        let element_size = field_info.size.unwrap();
137        new_fields[i] = element_size;
138        i += 1;
139    }
140
141    new_fields
142}
143
144const fn field_aligns<T>() -> [usize; field_len::<T>()] {
145    let typeinfo = Type::of::<T>();
146
147    let TypeKind::Struct(struct_info) = typeinfo.kind else {
148        panic!("MultiArrayList only works for structs");
149    };
150
151    let fields = struct_info.fields;
152    let mut new_fields = [0; field_len::<T>()];
153
154    let mut i = 0;
155    while i < fields.len() {
156        let field = &fields[i];
157        let align = align_of(field.ty.info().kind);
158        new_fields[i] = align;
159        i += 1;
160    }
161
162    new_fields
163}
164
165/// A `MultiArrayList` stores a list of a struct.
166///
167/// > Instead of storing a single list of items, `MultiArrayList` stores separate lists for each field of the struct.
168/// > This allows for memory savings if the struct has padding,
169/// > and also improves cache usage if only some fields are needed for a computation.
170///
171/// The primary API for accessing fields is the [`items(name)`][`MultiArrayList::items()`] function.
172#[derive(Debug)]
173pub struct MultiArrayList<T>
174where
175    T: 'static,
176    [(); field_len::<T>()]:,
177{
178    elems: [*mut u8; field_len::<T>()],
179    cap: usize,
180    len: usize,
181
182    _t: PhantomData<T>,
183}
184
185impl<T> MultiArrayList<T>
186where
187    T: 'static,
188    [(); field_len::<T>()]:,
189{
190    /// Constructs a new, empty `MultiArrayList<T>`.
191    ///
192    /// # Examples
193    ///
194    /// ```rust
195    /// #![allow(incomplete_features)]
196    /// #![feature(generic_const_exprs)]
197    /// # use multi_array_list::MultiArrayList;
198    /// struct Point {
199    ///     x: i32,
200    ///     y: i32
201    /// }
202    /// let mut list: MultiArrayList<Point> = MultiArrayList::new();
203    /// ```
204    ///
205    /// # Type can't be not a struct
206    ///
207    /// ```compile_fail
208    /// #![allow(incomplete_features)]
209    /// #![feature(generic_const_exprs)]
210    /// use multi_array_list::MultiArrayList;
211    /// #[repr(u8)]
212    /// enum Random {
213    ///     Four,
214    /// }
215    /// let m = MultiArrayList::<Random>::new();
216    /// ```
217    ///
218    /// # Type can't be empty
219    ///
220    /// ```compile_fail
221    /// #![allow(incomplete_features)]
222    /// #![feature(generic_const_exprs)]
223    /// use multi_array_list::MultiArrayList;
224    /// let m = MultiArrayList::<()>::new();
225    /// ```
226    pub fn new() -> MultiArrayList<T> {
227        let elems = [ptr::null_mut(); field_len::<T>()];
228
229        MultiArrayList {
230            elems,
231            cap: 0,
232            len: 0,
233            _t: PhantomData,
234        }
235    }
236
237    /// Constructs a new, empty `MultiArrayList<T>` with at least the specified capacity.
238    ///
239    /// # Examples
240    ///
241    /// ```rust
242    /// #![allow(incomplete_features)]
243    /// #![feature(generic_const_exprs)]
244    /// # use multi_array_list::MultiArrayList;
245    /// struct Point {
246    ///     x: i32,
247    ///     y: i32
248    /// }
249    /// let mut list: MultiArrayList<Point> = MultiArrayList::with_capacity(10);
250    /// ```
251    pub fn with_capacity(capacity: usize) -> MultiArrayList<T> {
252        let mut list = MultiArrayList::<T>::new();
253        list.grow(capacity);
254        list
255    }
256
257    fn elem_ptrs(&self) -> &[*mut u8] {
258        &self.elems[..]
259    }
260
261    const fn fields() -> &'static [Field] {
262        let typeinfo = Type::of::<T>();
263
264        let TypeKind::Struct(struct_info) = typeinfo.kind else {
265            panic!("MultiArrayList only works for structs");
266        };
267
268        struct_info.fields
269    }
270
271    /// Returns the total number of elements the vector can hold without reallocating.
272    pub fn capacity(&self) -> usize {
273        self.cap
274    }
275
276    /// Returns the number of elements in the `MultiArrayList`, also referred to as its 'length'.
277    pub fn len(&self) -> usize {
278        self.len
279    }
280
281    /// Appends an element to the back of a collection.
282    pub fn push(&mut self, value: T) {
283        let len = self.len;
284
285        // We will consume this value, don't drop it.
286        let value = ManuallyDrop::new(value);
287
288        if len == self.capacity() {
289            self.grow(1);
290        }
291
292        let sizes = const { field_sizes::<T>() };
293
294        unsafe {
295            let value: *const T = &*value as *const _;
296            let value_ptr: *const u8 = value.cast();
297
298            let elem_ptrs = self.elem_ptrs();
299            let fields = const { Self::fields() };
300
301            for (idx, (field, elem)) in fields.iter().zip(elem_ptrs).enumerate() {
302                let offset = field.offset;
303                let size = sizes[idx];
304
305                if size > 0 {
306                    let field_elem = elem.offset((size * self.len()) as isize);
307                    let src = value_ptr.offset(offset as isize);
308                    ptr::copy_nonoverlapping(src, field_elem, size);
309                }
310            }
311        }
312
313        self.len += 1;
314    }
315
316    /// Removes the last element from the `MultiArrayList` and returns it, or `None` if it is empty.
317    pub fn pop(&mut self) -> Option<Box<T>> {
318        if self.len() == 0 {
319            return None;
320        }
321        let idx = self.len() - 1;
322        let wip: Box<MaybeUninit<T>> = Box::new(MaybeUninit::uninit());
323
324        let wip = unsafe {
325            let elem_ptrs = self.elem_ptrs();
326            let wip = Box::into_raw(wip);
327            let wip_ptr: *mut u8 = wip.cast();
328
329            let sizes = const { field_sizes::<T>() };
330
331            for (i, (field, elem)) in Self::fields().iter().zip(elem_ptrs).enumerate() {
332                let offset = field.offset;
333                let size = sizes[i];
334                if size > 0 {
335                    let field_elem = elem.offset((size * idx) as isize);
336                    let dst_ptr = wip_ptr.offset(offset as isize);
337                    ptr::copy_nonoverlapping(field_elem, dst_ptr, size);
338                }
339            }
340
341            Box::from_raw(wip)
342        };
343
344        self.len -= 1;
345
346        // SAFETY: We initialized all fields at their correct offset and size
347        // based on values we previously copied into our element array.
348        unsafe { Some(wip.assume_init()) }
349    }
350
351    /// Returns an iterator over the `MultiArrayList`.
352    ///
353    /// The iterator yields all items from start to end.
354    ///
355    /// ```
356    /// #![allow(incomplete_features)]
357    /// #![feature(generic_const_exprs)]
358    /// use multi_array_list::MultiArrayList;
359    ///
360    /// struct Point {
361    ///     x: i32,
362    ///     y: i32,
363    /// }
364    ///
365    /// let mut points = MultiArrayList::<Point>::new();
366    /// points.push(Point { x: 1, y: 0 });
367    /// for point in points.iter() {
368    ///     assert_eq!(1, *point.get::<"x", _>());
369    ///     assert_eq!(0, *point.get::<"y", _>());
370    /// }
371    /// ```
372    pub fn iter<'a>(&'a self) -> Iter<'a, T> {
373        Iter {
374            array: self,
375            pos: 0,
376        }
377    }
378
379    const fn get_field_by_name<const NAME: &'static str, V: 'static>() -> (usize, usize) {
380        let mut idx = 0;
381        while idx < Self::fields().len() {
382            let field = &Self::fields()[idx];
383            if field.name.eq_ignore_ascii_case(NAME) {
384                break;
385            }
386            idx += 1
387        }
388
389        if idx == Self::fields().len() {
390            panic!("unknown item name");
391        }
392
393        assert!(idx < Self::fields().len());
394        let field = &Self::fields()[idx];
395        assert!(TypeId::of::<V>() == field.ty, "types don't match");
396
397        let elem_size = mem::size_of::<V>();
398        (idx, elem_size)
399    }
400
401    /// Get an iterator of values for a specified field.
402    ///
403    /// # Compile errors
404    ///
405    /// * Fails to compile if the requested field does not exist.
406    /// * Fails to compile if the requested type does not match the found field (by size).
407    ///
408    /// ```compile_fail
409    /// #![allow(incomplete_features)]
410    /// #![feature(generic_const_exprs)]
411    /// use multi_array_list::MultiArrayList;
412    /// struct HereBeDragons {
413    ///     x: u32,
414    ///     y: (u8, u16),
415    /// }
416    /// let mut list = MultiArrayList::<HereBeDragons>::new();
417    /// _ = list.items::<"y", (u16, u8)>(); // will not compile
418    /// ```
419    pub fn items<'a, const NAME: &'static str, V: 'static>(&'a self) -> Slice<'a, V> {
420        let (idx, elem_size) = const { Self::get_field_by_name::<NAME, V>() };
421
422        unsafe {
423            let elem_ptrs = self.elem_ptrs();
424            let ptr = elem_ptrs[idx];
425            let end = ptr.offset((self.len * elem_size) as isize);
426
427            return Slice {
428                ptr,
429                end,
430                typ: PhantomData,
431            };
432        }
433    }
434
435    /// Get an iterator of mutable values for a specified field.
436    ///
437    /// # Compile errors
438    ///
439    /// * Fails to compile if the requested field does not exist.
440    /// * Fails to compile if the requested type does not match the found field (by size).
441    pub fn items_mut<'a, const NAME: &'static str, V: 'static>(&'a mut self) -> SliceMut<'a, V> {
442        let (idx, elem_size) = const { Self::get_field_by_name::<NAME, V>() };
443
444        unsafe {
445            let elem_ptrs = self.elem_ptrs();
446            let ptr = elem_ptrs[idx];
447            let end = ptr.offset((self.len * elem_size) as isize);
448
449            return SliceMut {
450                ptr,
451                end,
452                typ: PhantomData,
453            };
454        }
455    }
456
457    fn grow(&mut self, additional: usize) {
458        let old_cap = self.capacity();
459        let new_cap = old_cap + additional;
460
461        let sizes = const { field_sizes::<T>() };
462        let aligns = const { field_aligns::<T>() };
463
464        for idx in 0..Self::fields().len() {
465            let element_size = sizes[idx];
466            let align = aligns[idx];
467            let old_array_size = element_size * old_cap;
468            let new_array_size = element_size * new_cap;
469
470            if new_array_size == 0 {
471                continue;
472            }
473
474            unsafe {
475                let ptr = &mut self.elems[idx];
476                if old_cap == 0 {
477                    let layout = Layout::from_size_align_unchecked(new_array_size, align);
478                    *ptr = alloc::alloc(layout);
479                } else {
480                    let layout = Layout::from_size_align_unchecked(old_array_size, align);
481                    *ptr = alloc::realloc(*ptr, layout, new_array_size);
482                }
483            }
484        }
485
486        self.cap = new_cap;
487    }
488
489    /// Shrinks the capacity of the vector as much as possible.
490    pub fn shrink_to_fit(&mut self) {
491        let cur_cap = self.capacity();
492        let len = self.len();
493        if len == cur_cap {
494            return;
495        }
496
497        let sizes = const { field_sizes::<T>() };
498        let aligns = const { field_aligns::<T>() };
499
500        unsafe {
501            for idx in 0..Self::fields().len() {
502                let element_size = sizes[idx];
503                let align = aligns[idx];
504                let old_array_size = element_size * cur_cap;
505                let new_array_size = element_size * len;
506
507                if element_size == 0 {
508                    continue;
509                }
510
511                let ptr = &mut self.elems[idx];
512                let layout = Layout::from_size_align_unchecked(old_array_size, align);
513                if new_array_size == 0 {
514                    alloc::dealloc(*ptr, layout);
515                    *ptr = ptr::null_mut();
516                } else {
517                    *ptr = alloc::realloc(*ptr, layout, new_array_size);
518                }
519            }
520        }
521
522        self.cap = len;
523    }
524}
525
526impl<T> Drop for MultiArrayList<T>
527where
528    T: 'static,
529    [(); field_len::<T>()]:,
530{
531    fn drop(&mut self) {
532        while let Some(elem) = self.pop() {
533            drop(elem);
534        }
535
536        self.shrink_to_fit();
537    }
538}
539
540/// An iterator over all values of a specified field.
541pub struct Slice<'a, V> {
542    ptr: *const u8,
543    end: *const u8,
544    typ: PhantomData<&'a V>,
545}
546
547impl<'a, V> Iterator for Slice<'a, V> {
548    type Item = &'a V;
549
550    fn next(&mut self) -> Option<Self::Item> {
551        if self.ptr >= self.end {
552            return None;
553        }
554        let elem_size = mem::size_of::<V>();
555
556        unsafe {
557            let slice: &[u8] = slice::from_raw_parts(self.ptr, elem_size);
558            let val = mem::transmute_copy::<&[u8], &V>(&slice);
559            self.ptr = self.ptr.offset(elem_size as isize);
560            return Some(val);
561        }
562    }
563}
564
565/// An iterator over mutable values of a specified field.
566pub struct SliceMut<'a, V> {
567    ptr: *mut u8,
568    end: *mut u8,
569    typ: PhantomData<&'a mut V>,
570}
571
572impl<'a, V> Iterator for SliceMut<'a, V> {
573    type Item = &'a mut V;
574
575    fn next(&mut self) -> Option<Self::Item> {
576        if self.ptr >= self.end {
577            return None;
578        }
579        let elem_size = mem::size_of::<V>();
580
581        unsafe {
582            let mut slice: &mut [u8] = slice::from_raw_parts_mut(self.ptr, elem_size);
583            let val = mem::transmute_copy::<&mut [u8], &mut V>(&mut slice);
584            self.ptr = self.ptr.offset(elem_size as isize);
585            return Some(val);
586        }
587    }
588}
589
590/// An iterator over all elements of the `MultiArrayList`.
591pub struct Iter<'a, T>
592where
593    T: 'static,
594    [(); field_len::<T>()]:,
595{
596    array: &'a MultiArrayList<T>,
597    pos: usize,
598}
599
600impl<'a, T> Iterator for Iter<'a, T>
601where
602    T: 'static,
603    [(); field_len::<T>()]:,
604{
605    type Item = IterRef<'a, T>;
606
607    fn next(&mut self) -> Option<Self::Item> {
608        if self.pos >= self.array.len() {
609            return None;
610        }
611
612        let elems_slice = &self.array.elems[..];
613
614        self.pos += 1;
615        Some(IterRef {
616            ptr: elems_slice,
617            idx: self.pos - 1,
618            elem: PhantomData,
619        })
620    }
621}
622
623/// A reference to a single element of `T`.
624///
625/// Use [`get(field)`][`IterRef::get`] to access a single field.
626pub struct IterRef<'a, T>
627where
628    T: 'static,
629{
630    ptr: &'a [*mut u8],
631    idx: usize,
632    elem: PhantomData<&'a T>,
633}
634
635impl<'a, T> IterRef<'a, T>
636where
637    [(); field_len::<T>()]:,
638{
639    /// Get a specified field of `T`.
640    ///
641    /// # Compile errors
642    ///
643    /// * Fails to compile if the requested field does not exist.
644    /// * Fails to compile if the requested type does not match the found field (by size).
645    ///
646    /// ## It needs the correct type or it fails to compile
647    ///
648    /// ```compile_fail
649    /// #![allow(incomplete_features)]
650    /// #![feature(generic_const_exprs)]
651    /// use multi_array_list::MultiArrayList;
652    ///
653    /// struct Point {
654    ///     x: i32,
655    ///     y: i32,
656    /// }
657    ///
658    /// let mut points = MultiArrayList::<Point>::new();
659    /// points.push(Point { x: 1, y: 0 });
660    /// for point in points.iter() {
661    ///     point.get::<"x", bool>();
662    /// }
663    /// ```
664    pub fn get<const NAME: &'static str, V: 'static>(&self) -> &V {
665        let (idx, elem_size) = const { MultiArrayList::<T>::get_field_by_name::<NAME, V>() };
666
667        unsafe {
668            let ptr = self.ptr[idx].offset((self.idx * elem_size) as isize);
669            let slice: &[u8] = slice::from_raw_parts(ptr, elem_size);
670            let val = mem::transmute_copy::<&[u8], &V>(&slice);
671            return val;
672        }
673    }
674}
675
676#[cfg(test)]
677mod test {
678    use super::*;
679    use std::mem;
680
681    #[derive(Debug, PartialEq, Eq)]
682    struct Point {
683        x: i32,
684        y: i32,
685    }
686
687    #[repr(u8)]
688    #[allow(dead_code)]
689    enum Random {
690        Four,
691    }
692
693    #[test]
694    fn size() {
695        assert_eq!(32, mem::size_of::<MultiArrayList<Point>>());
696    }
697
698    #[test]
699    fn empty() {
700        let list = MultiArrayList::<Point>::new();
701        assert_eq!(0, list.len());
702        assert_eq!(0, list.capacity());
703    }
704
705    #[test]
706    fn push() {
707        let mut list = MultiArrayList::<Point>::new();
708        assert_eq!(0, list.len());
709        assert_eq!(0, list.capacity());
710
711        let p1 = Point { x: 2, y: 8 };
712        list.push(p1);
713        assert_eq!(1, list.len());
714        assert!(0 < list.capacity());
715    }
716
717    #[test]
718    fn pop() {
719        let mut list = MultiArrayList::<Point>::new();
720        assert_eq!(0, list.len());
721        assert_eq!(0, list.capacity());
722
723        let p1 = Point { x: 2, y: 8 };
724        list.push(p1);
725        assert_eq!(1, list.len());
726        assert!(0 < list.capacity());
727
728        let p2 = list.pop().unwrap();
729        assert_eq!(0, list.len());
730        assert_eq!(&Point { x: 2, y: 8 }, &*p2);
731    }
732
733    #[test]
734    fn capacity() {
735        let mut list = MultiArrayList::<Point>::with_capacity(10);
736        assert_eq!(0, list.len());
737        assert_eq!(10, list.capacity());
738
739        let p1 = Point { x: 2, y: 8 };
740        list.push(p1);
741        assert_eq!(1, list.len());
742        assert_eq!(10, list.capacity());
743    }
744
745    #[test]
746    fn zst() {
747        #[allow(dead_code)]
748        #[derive(Debug, PartialEq, Eq)]
749        struct Empty {
750            x: (),
751        }
752
753        let mut list = MultiArrayList::<Empty>::new();
754        assert_eq!(0, list.len());
755        assert_eq!(0, list.capacity());
756
757        list.push(Empty { x: () });
758
759        let elem = list.pop().unwrap();
760        assert_eq!(Empty { x: () }, *elem);
761    }
762}
763
764#[doc = include_str!("../README.md")]
765#[cfg(doctest)]
766pub struct ReadmeDoctests;