Skip to main content

multi_array_list/
lib.rs

1//! A `MultiArrayList` stores a list of a struct.
2//!
3//! **Experimental**: Only a small subset of the array list API is implemented.
4//!
5//! ---
6//!
7//! > Instead of storing a single list of items, `MultiArrayList` stores separate lists for each field of the struct.
8//! > This allows for memory savings if the struct has padding,
9//! > and also improves cache usage if only some fields are needed for a computation.
10//!
11//! The primary API for accessing fields is the [`items::<name>`][`MultiArrayList::items()`] function.
12//!
13//! ---
14//! _inspired by [Zig's `MultiArrayList`](https://ziglang.org/documentation/master/std/#std.MultiArrayList)._
15//!
16//! # Example
17//! ```rust
18//! # use multi_array_list::MultiArrayList;
19//! struct Pizza {
20//!     radius: u32,
21//!     toppings: Vec<Topping>,
22//! }
23//!
24//! enum Topping {
25//!     Tomato,
26//!     Mozzarella,
27//!     Anchovies,
28//! }
29//!
30//! let mut order = MultiArrayList::<Pizza>::new();
31//!
32//! let margherita = Pizza {
33//!     radius: 12,
34//!     toppings: vec![Topping::Tomato],
35//! };
36//! order.push(margherita);
37//!
38//! let napoli = Pizza {
39//!     radius: 12,
40//!     toppings: vec![Topping::Tomato, Topping::Anchovies],
41//! };
42//! order.push(napoli);
43//!
44//! for topping in order.items_mut::<"toppings", Vec<Topping>>() {
45//!     topping.push(Topping::Mozzarella);
46//! }
47//! ```
48#![allow(incomplete_features)]
49#![feature(type_info)]
50#![feature(unsized_const_params)]
51#![feature(adt_const_params)]
52#![warn(missing_docs)]
53
54use std::alloc::{self, Layout};
55use std::marker::PhantomData;
56use std::mem::type_info::{Field, Type, TypeKind};
57use std::mem::{self, ManuallyDrop, MaybeUninit};
58use std::ptr;
59use std::slice;
60
61// Limitation so we can use an stack-allocated scratch buffer.
62const MAX_FIELDS: usize = 8;
63
64/// Alignment of various types, imprecise for anything but numbers, chars and bools.
65const fn align_of(kind: TypeKind) -> usize {
66    use TypeKind::*;
67    match kind {
68        Int(int) => (int.bits / 8) as usize,
69        Float(float) => (float.bits / 8) as usize,
70        Char(_) => 4,
71        Bool(_) => 1,
72        _ => mem::size_of::<usize>(),
73    }
74}
75
76/// A `MultiArrayList` stores a list of a struct.
77///
78/// > Instead of storing a single list of items, `MultiArrayList` stores separate lists for each field of the struct.
79/// > This allows for memory savings if the struct has padding,
80/// > and also improves cache usage if only some fields are needed for a computation.
81///
82/// The primary API for accessing fields is the [`items(name)`][`MultiArrayList::items()`] function.
83#[derive(Debug)]
84pub struct MultiArrayList<T>
85where
86    T: 'static,
87{
88    elems: *mut *mut u8,
89    cap: usize,
90    len: usize,
91
92    _t: PhantomData<T>,
93}
94
95impl<T> MultiArrayList<T>
96where
97    T: 'static,
98{
99    /// Constructs a new, empty `MultiArrayList<T>`.
100    ///
101    /// # Examples
102    ///
103    /// ```rust
104    /// # use multi_array_list::MultiArrayList;
105    /// struct Point {
106    ///     x: i32,
107    ///     y: i32
108    /// }
109    /// let mut list: MultiArrayList<Point> = MultiArrayList::new();
110    /// ```
111    pub fn new() -> MultiArrayList<T> {
112        let elemn = const {
113            let typeinfo = Type::of::<T>();
114
115            typeinfo.size.unwrap();
116            let TypeKind::Struct(struct_info) = typeinfo.kind else {
117                panic!("MultiArrayList only works for structs");
118            };
119
120            assert!(struct_info.generics.is_empty());
121            assert!(struct_info.fields.len() <= MAX_FIELDS);
122
123            let mut i = 0;
124            while i < struct_info.fields.len() {
125                let field = &struct_info.fields[i];
126                let field_info = field.ty.info();
127
128                assert!(field_info.size.is_some(), "known size required");
129                assert!(matches!(
130                    field_info.kind,
131                    TypeKind::Bool(_)
132                        | TypeKind::Int(_)
133                        | TypeKind::Char(_)
134                        | TypeKind::Float(_)
135                        | TypeKind::Tuple(_)
136                        | TypeKind::Struct(_)
137                        | TypeKind::Enum(_)
138                ));
139
140                i += 1;
141            }
142
143            struct_info.fields.len()
144        };
145
146        // SAFETY:
147        // * We create the layout
148        // * We alloc that
149        unsafe {
150            let layout = Layout::array::<*mut u8>(elemn).unwrap();
151            let elems = alloc::alloc(layout) as *mut *mut u8;
152
153            MultiArrayList {
154                elems,
155                cap: 0,
156                len: 0,
157                _t: PhantomData,
158            }
159        }
160    }
161
162    /// Constructs a new, empty `MultiArrayList<T>` with at least the specified capacity.
163    ///
164    /// # Examples
165    ///
166    /// ```rust
167    /// # use multi_array_list::MultiArrayList;
168    /// struct Point {
169    ///     x: i32,
170    ///     y: i32
171    /// }
172    /// let mut list: MultiArrayList<Point> = MultiArrayList::with_capacity(10);
173    /// ```
174    pub fn with_capacity(capacity: usize) -> MultiArrayList<T> {
175        let mut list = MultiArrayList::<T>::new();
176        list.grow(capacity);
177        list
178    }
179
180    fn elem_ptrs(&self) -> &[*const u8] {
181        unsafe { slice::from_raw_parts(self.elems as *const *const u8, Self::fields().len()) }
182    }
183
184    fn elem_ptrs_mut(&self) -> &[*mut u8] {
185        unsafe { slice::from_raw_parts(self.elems as *const *mut u8, Self::fields().len()) }
186    }
187
188    const fn fields() -> &'static [Field] {
189        let typeinfo = Type::of::<T>();
190
191        let TypeKind::Struct(struct_info) = typeinfo.kind else {
192            panic!("MultiArrayList only works for structs");
193        };
194
195        struct_info.fields
196    }
197
198    /// Returns the total number of elements the vector can hold without reallocating.
199    pub fn capacity(&self) -> usize {
200        self.cap
201    }
202
203    /// Returns the number of elements in the `MultiArrayList`, also referred to as its 'length'.
204    pub fn len(&self) -> usize {
205        self.len
206    }
207
208    /// Appends an element to the back of a collection.
209    pub fn push(&mut self, value: T) {
210        let len = self.len;
211
212        // We will consume this value, don't drop it.
213        let value = ManuallyDrop::new(value);
214
215        if len == self.capacity() {
216            self.grow(1);
217        }
218
219        let fields = const {
220            let fields = Self::fields();
221            let mut new_fields = [(0, 0); MAX_FIELDS];
222
223            let mut i = 0;
224            while i < fields.len() {
225                let field = &fields[i];
226                let field_info = field.ty.info();
227                let element_size = field_info.size.unwrap();
228                new_fields[i] = (field.offset, element_size);
229                i += 1;
230            }
231
232            new_fields
233        };
234
235        unsafe {
236            let value: *const T = &*value as *const _;
237            let value_ptr: *const u8 = value.cast();
238
239            let elem_ptrs = self.elem_ptrs_mut();
240
241            for (&field_layout, elem) in fields[0..Self::fields().len()].iter().zip(elem_ptrs) {
242                let offset = field_layout.0;
243                let size = field_layout.1;
244
245                if size > 0 {
246                    let field_elem = elem.offset((size * self.len()) as isize);
247                    let src = value_ptr.offset(offset as isize);
248                    ptr::copy_nonoverlapping(src, field_elem, size);
249                }
250            }
251        }
252
253        self.len += 1;
254    }
255
256    /// Removes the last element from the `MultiArrayList` and returns it, or `None` if it is empty.
257    pub fn pop(&mut self) -> Option<Box<T>> {
258        if self.len() == 0 {
259            return None;
260        }
261        let idx = self.len() - 1;
262        let wip: Box<MaybeUninit<T>> = Box::new(MaybeUninit::uninit());
263
264        let fields = const {
265            let fields = Self::fields();
266            let mut new_fields = [(0, 0); MAX_FIELDS];
267
268            let mut i = 0;
269            while i < fields.len() {
270                let field = &fields[i];
271                let field_info = field.ty.info();
272                let element_size = field_info.size.unwrap();
273                new_fields[i] = (field.offset, element_size);
274                i += 1;
275            }
276
277            new_fields
278        };
279
280        let wip = unsafe {
281            let elem_ptrs = self.elem_ptrs_mut();
282            let wip = Box::into_raw(wip);
283            let wip_ptr: *mut u8 = wip.cast();
284
285            for (&field_layout, elem) in fields.iter().zip(elem_ptrs) {
286                let offset = field_layout.0;
287                let size = field_layout.1;
288                if size > 0 {
289                    let field_elem = elem.offset((size * idx) as isize);
290                    let dst_ptr = wip_ptr.offset(offset as isize);
291                    ptr::copy_nonoverlapping(field_elem, dst_ptr, size);
292                }
293            }
294
295            Box::from_raw(wip)
296        };
297
298        self.len -= 1;
299
300        // SAFETY: We initialized all fields at their correct offset and size
301        // based on values we previously copied into our element array.
302        unsafe { Some(wip.assume_init()) }
303    }
304
305    /// Returns an iterator over the `MultiArrayList`.
306    ///
307    /// The iterator yields all items from start to end.
308    pub fn iter<'a>(&'a self) -> Iter<'a, T> {
309        Iter {
310            array: self,
311            pos: 0,
312        }
313    }
314
315    const fn get_field_by_name<const NAME: &'static str, V>() -> (usize, usize) {
316        let mut idx = 0;
317        while idx < Self::fields().len() {
318            let field = &Self::fields()[idx];
319            if field.name.eq_ignore_ascii_case(NAME) {
320                break;
321            }
322            idx += 1
323        }
324
325        if idx == Self::fields().len() {
326            panic!("unknown item name");
327        }
328
329        let elem_size = mem::size_of::<V>();
330        assert!(idx < Self::fields().len());
331        let field = &Self::fields()[idx];
332        let field_ty = field.ty.info();
333        assert!(elem_size == field_ty.size.unwrap());
334        (idx, elem_size)
335    }
336
337    /// Get an iterator of values for a specified field.
338    ///
339    /// # Compile errors
340    ///
341    /// * Fails to compile if the requested field does not exist.
342    /// * Fails to compile if the requested type does not match the found field (by size).
343    pub fn items<'a, const NAME: &'static str, V>(&'a self) -> Slice<'a, V> {
344        let (idx, elem_size) = const { Self::get_field_by_name::<NAME, V>() };
345
346        unsafe {
347            let elem_ptrs = self.elem_ptrs();
348            let ptr = elem_ptrs[idx];
349            let end = ptr.offset((self.len * elem_size) as isize);
350
351            return Slice {
352                ptr,
353                end,
354                typ: PhantomData,
355            };
356        }
357    }
358
359    /// Get an iterator of mutable values for a specified field.
360    ///
361    /// # Compile errors
362    ///
363    /// * Fails to compile if the requested field does not exist.
364    /// * Fails to compile if the requested type does not match the found field (by size).
365    pub fn items_mut<'a, const NAME: &'static str, V>(&'a mut self) -> SliceMut<'a, V> {
366        let (idx, elem_size) = const { Self::get_field_by_name::<NAME, V>() };
367
368        unsafe {
369            let elem_ptrs = self.elem_ptrs_mut();
370            let ptr = elem_ptrs[idx];
371            let end = ptr.offset((self.len * elem_size) as isize);
372
373            return SliceMut {
374                ptr,
375                end,
376                typ: PhantomData,
377            };
378        }
379    }
380
381    fn grow(&mut self, additional: usize) {
382        let old_cap = self.capacity();
383        let new_cap = old_cap + additional;
384
385        let fields = const {
386            let fields = Self::fields();
387            let mut new_fields = [(0, 0); MAX_FIELDS];
388
389            let mut i = 0;
390            while i < fields.len() {
391                let field_info = fields[i].ty.info();
392                let element_size = field_info.size.unwrap();
393                let align = align_of(field_info.kind);
394                new_fields[i] = (element_size, align);
395                i += 1;
396            }
397
398            new_fields
399        };
400
401        for (idx, &field) in fields[0..Self::fields().len()].iter().enumerate() {
402            let (element_size, align) = field;
403            let old_array_size = element_size * old_cap;
404            let new_array_size = element_size * new_cap;
405
406            if new_array_size == 0 {
407                continue;
408            }
409
410            unsafe {
411                let ptr = self.elems.offset(idx as isize);
412                if old_cap == 0 {
413                    let layout = Layout::from_size_align_unchecked(new_array_size, align);
414                    *ptr = alloc::alloc(layout);
415                } else {
416                    let layout = Layout::from_size_align_unchecked(old_array_size, align);
417                    *ptr = alloc::realloc(*ptr, layout, new_array_size);
418                }
419            }
420        }
421
422        self.cap = new_cap;
423    }
424
425    /// Shrinks the capacity of the vector as much as possible.
426    pub fn shrink_to_fit(&mut self) {
427        let cur_cap = self.capacity();
428        let len = self.len();
429        if len == cur_cap {
430            return;
431        }
432
433        let fields = const {
434            let fields = Self::fields();
435            let mut new_fields = [(0, 0); MAX_FIELDS];
436
437            let mut i = 0;
438            while i < fields.len() {
439                let field = &fields[i];
440                let field_info = field.ty.info();
441                let element_size = field_info.size.unwrap();
442                let align = align_of(field_info.kind);
443                new_fields[i] = (element_size, align);
444                i += 1;
445            }
446
447            new_fields
448        };
449
450        unsafe {
451            for (idx, &field) in fields[0..Self::fields().len()].iter().enumerate() {
452                let (element_size, align) = field;
453                let old_array_size = element_size * cur_cap;
454                let new_array_size = element_size * len;
455
456                if element_size == 0 {
457                    continue;
458                }
459
460                let ptr = self.elems.offset(idx as isize);
461                let layout = Layout::from_size_align_unchecked(old_array_size, align);
462                if new_array_size == 0 {
463                    alloc::dealloc(*ptr, layout);
464                    *ptr = ptr::null_mut();
465                } else {
466                    *ptr = alloc::realloc(*ptr, layout, new_array_size);
467                }
468            }
469        }
470
471        self.cap = len;
472    }
473}
474
475impl<T> Drop for MultiArrayList<T>
476where
477    T: 'static,
478{
479    fn drop(&mut self) {
480        while let Some(elem) = self.pop() {
481            drop(elem);
482        }
483
484        unsafe {
485            self.shrink_to_fit();
486
487            let elemn = Self::fields().len();
488            let layout = Layout::array::<*mut u8>(elemn).unwrap();
489            alloc::dealloc(self.elems as *mut u8, layout);
490        }
491    }
492}
493
494/// An iterator over all values of a specified field.
495pub struct Slice<'a, V> {
496    ptr: *const u8,
497    end: *const u8,
498    typ: PhantomData<&'a V>,
499}
500
501impl<'a, V> Iterator for Slice<'a, V> {
502    type Item = &'a V;
503
504    fn next(&mut self) -> Option<Self::Item> {
505        if self.ptr >= self.end {
506            return None;
507        }
508        let elem_size = mem::size_of::<V>();
509
510        unsafe {
511            let slice: &[u8] = slice::from_raw_parts(self.ptr, elem_size);
512            let val = mem::transmute_copy::<&[u8], &V>(&slice);
513            self.ptr = self.ptr.offset(elem_size as isize);
514            return Some(val);
515        }
516    }
517}
518
519/// An iterator over mutable values of a specified field.
520pub struct SliceMut<'a, V> {
521    ptr: *mut u8,
522    end: *mut u8,
523    typ: PhantomData<&'a mut V>,
524}
525
526impl<'a, V> Iterator for SliceMut<'a, V> {
527    type Item = &'a mut V;
528
529    fn next(&mut self) -> Option<Self::Item> {
530        if self.ptr >= self.end {
531            return None;
532        }
533        let elem_size = mem::size_of::<V>();
534
535        unsafe {
536            let mut slice: &mut [u8] = slice::from_raw_parts_mut(self.ptr, elem_size);
537            let val = mem::transmute_copy::<&mut [u8], &mut V>(&mut slice);
538            self.ptr = self.ptr.offset(elem_size as isize);
539            return Some(val);
540        }
541    }
542}
543
544/// An iterator over all elements of the `MultiArrayList`.
545pub struct Iter<'a, T>
546where
547    T: 'static,
548{
549    array: &'a MultiArrayList<T>,
550    pos: usize,
551}
552
553impl<'a, T> Iterator for Iter<'a, T> {
554    type Item = IterRef<'a, T>;
555
556    fn next(&mut self) -> Option<Self::Item> {
557        if self.pos >= self.array.len() {
558            return None;
559        }
560
561        unsafe {
562            let elems_slice = slice::from_raw_parts(
563                self.array.elems as *const *const u8,
564                MultiArrayList::<T>::fields().len(),
565            );
566
567            self.pos += 1;
568            Some(IterRef {
569                ptr: elems_slice,
570                idx: self.pos - 1,
571                elem: PhantomData,
572            })
573        }
574    }
575}
576
577/// A reference to a single element of `T`.
578///
579/// Use [`get(field)`][`IterRef::get`] to access a single field.
580pub struct IterRef<'a, T>
581where
582    T: 'static,
583{
584    ptr: &'a [*const u8],
585    idx: usize,
586    elem: PhantomData<&'a T>,
587}
588
589impl<'a, T> IterRef<'a, T> {
590    /// Get a specified field of `T`.
591    ///
592    /// # Compile errors
593    ///
594    /// * Fails to compile if the requested field does not exist.
595    /// * Fails to compile if the requested type does not match the found field (by size).
596    pub fn get<const NAME: &'static str, V>(&self) -> &V {
597        let (idx, elem_size) = const { MultiArrayList::<T>::get_field_by_name::<NAME, V>() };
598
599        unsafe {
600            let ptr = self.ptr[idx].offset((self.idx * elem_size) as isize);
601            let slice: &[u8] = slice::from_raw_parts(ptr, elem_size);
602            let val = mem::transmute_copy::<&[u8], &V>(&slice);
603            return val;
604        }
605    }
606}
607
608#[cfg(test)]
609mod test {
610    use super::*;
611    use std::mem;
612
613    #[derive(Debug, PartialEq, Eq)]
614    struct Point {
615        x: i32,
616        y: i32,
617    }
618
619    #[repr(u8)]
620    #[allow(dead_code)]
621    enum Random {
622        Four,
623    }
624
625    #[test]
626    fn size() {
627        assert_eq!(24, mem::size_of::<MultiArrayList<Point>>());
628    }
629
630    #[test]
631    fn empty() {
632        let list = MultiArrayList::<Point>::new();
633        assert_eq!(0, list.len());
634        assert_eq!(0, list.capacity());
635    }
636
637    #[test]
638    fn push() {
639        let mut list = MultiArrayList::<Point>::new();
640        assert_eq!(0, list.len());
641        assert_eq!(0, list.capacity());
642
643        let p1 = Point { x: 2, y: 8 };
644        list.push(p1);
645        assert_eq!(1, list.len());
646        assert!(0 < list.capacity());
647    }
648
649    #[test]
650    fn pop() {
651        let mut list = MultiArrayList::<Point>::new();
652        assert_eq!(0, list.len());
653        assert_eq!(0, list.capacity());
654
655        let p1 = Point { x: 2, y: 8 };
656        list.push(p1);
657        assert_eq!(1, list.len());
658        assert!(0 < list.capacity());
659
660        let p2 = list.pop().unwrap();
661        assert_eq!(0, list.len());
662        assert_eq!(&Point { x: 2, y: 8 }, &*p2);
663    }
664
665    #[test]
666    fn capacity() {
667        let mut list = MultiArrayList::<Point>::with_capacity(10);
668        assert_eq!(0, list.len());
669        assert_eq!(10, list.capacity());
670
671        let p1 = Point { x: 2, y: 8 };
672        list.push(p1);
673        assert_eq!(1, list.len());
674        assert_eq!(10, list.capacity());
675    }
676
677    #[test]
678    fn zst() {
679        #[allow(dead_code)]
680        #[derive(Debug, PartialEq, Eq)]
681        struct Empty {
682            x: (),
683        }
684
685        let mut list = MultiArrayList::<Empty>::new();
686        assert_eq!(0, list.len());
687        assert_eq!(0, list.capacity());
688
689        list.push(Empty { x: () });
690
691        let elem = list.pop().unwrap();
692        assert_eq!(Empty { x: () }, *elem);
693    }
694
695    /*
696    #[test]
697    #[should_panic = "structs"]
698    fn struct_only() {
699        let _m = MultiArrayList::<Random>::new();
700    }
701    */
702
703    /*
704    #[test]
705    #[should_panic = "structs"]
706    fn not_empty() {
707        let _m = MultiArrayList::<()>::new();
708    }
709    */
710}