multi_array_list/
lib.rs

1//! A `MultiArrayList` stores a list of a struct.
2//!
3//! **Experimental**: Only a small subset of a the array list API is implemented.
4//!
5//! ---
6//!
7//! > Instead of storing a single list of items, `MultiArrayList` stores separate lists for each field of the struct.
8//! > This allows for memory savings if the struct has padding,
9//! > and also improves cache usage if only some fields are needed for a computation.
10//!
11//! The primary API for accessing fields is the [`items(name)`][`MultiArrayList::items()`] function.
12//!
13//! ---
14//! _inspired by [Zig's `MultiArrayList`](https://ziglang.org/documentation/master/std/#std.MultiArrayList)._
15//!
16//! # Example
17//! ```rust
18//! # use multi_array_list::MultiArrayList;
19//! # use facet::Facet;
20//! #[derive(Facet, Clone)]
21//! struct Pizza {
22//!     radius: u32,
23//!     toppings: Vec<Topping>,
24//! }
25//!
26//! #[derive(Facet, Clone, Copy)]
27//! #[repr(u8)]
28//! enum Topping {
29//!     Tomato,
30//!     Mozzarella,
31//!     Anchovies,
32//! }
33//!
34//! let mut order = MultiArrayList::<Pizza>::new();
35//!
36//! let margherita = Pizza {
37//!     radius: 12,
38//!     toppings: vec![Topping::Tomato],
39//! };
40//! order.push(margherita);
41//!
42//! let napoli = Pizza {
43//!     radius: 12,
44//!     toppings: vec![Topping::Tomato, Topping::Anchovies],
45//! };
46//! order.push(napoli);
47//!
48//! for topping in order.items_mut::<Vec<Topping>>("toppings") {
49//!     topping.push(Topping::Mozzarella);
50//! }
51//! ```
52#![warn(missing_docs)]
53
54use std::alloc::{self, Layout};
55use std::marker::PhantomData;
56use std::mem::{self, ManuallyDrop};
57use std::ptr;
58use std::slice;
59
60use facet::{Def, Facet, Field, PtrConst, ShapeLayout};
61use facet_reflect::{Peek, Wip};
62
63macro_rules! sized_layout {
64    ($name:expr) => {{
65        let ShapeLayout::Sized(field_layout) = $name.layout else {
66            panic!("Can't handle unsized fields")
67        };
68        field_layout
69    }};
70}
71
72/// A `MultiArrayList` stores a list of a struct.
73///
74/// > Instead of storing a single list of items, `MultiArrayList` stores separate lists for each field of the struct.
75/// > This allows for memory savings if the struct has padding,
76/// > and also improves cache usage if only some fields are needed for a computation.
77///
78/// The primary API for accessing fields is the [`items(name)`][`MultiArrayList::items()`] function.
79#[derive(Debug)]
80pub struct MultiArrayList<T>
81where
82    T: Facet<'static>,
83{
84    elems: *mut *mut u8,
85    cap: usize,
86    len: usize,
87
88    _t: PhantomData<T>,
89}
90
91impl<T> MultiArrayList<T>
92where
93    T: Facet<'static>,
94{
95    /// Constructs a new, empty `MultiArrayList<T>`.
96    ///
97    /// # Examples
98    ///
99    /// ```rust
100    /// # use multi_array_list::MultiArrayList;
101    /// # use facet::Facet;
102    /// #[derive(Facet)]
103    /// struct Point {
104    ///     x: i32,
105    ///     y: i32
106    /// }
107    /// let mut list: MultiArrayList<Point> = MultiArrayList::new();
108    /// ```
109    pub fn new() -> MultiArrayList<T> {
110        let fields = Self::fields();
111
112        for field in fields {
113            _ = sized_layout!(field.shape());
114        }
115
116        // SAFETY:
117        // * We create the layout
118        // * We alloc that
119        unsafe {
120            let elemn = fields.len();
121            let layout = Layout::array::<*mut u8>(elemn).unwrap();
122            let elems = alloc::alloc(layout) as *mut *mut u8;
123
124            MultiArrayList {
125                elems,
126                cap: 0,
127                len: 0,
128                _t: PhantomData,
129            }
130        }
131    }
132
133    fn elem_ptrs(&self) -> &[*const u8] {
134        unsafe { slice::from_raw_parts(self.elems as *const *const u8, Self::fields().len()) }
135    }
136
137    fn elem_ptrs_mut(&self) -> &[*mut u8] {
138        unsafe { slice::from_raw_parts(self.elems as *const *mut u8, Self::fields().len()) }
139    }
140
141    fn fields() -> &'static [Field] {
142        let Def::Struct(sdef) = T::SHAPE.def else {
143            panic!("MultiArrayList only works for structs");
144        };
145
146        sdef.fields
147    }
148
149    /// Returns the total number of elements the vector can hold without reallocating.
150    pub fn capacity(&self) -> usize {
151        self.cap
152    }
153
154    /// Returns the number of elements in the `MultiArrayList`, also referred to as its ‘length’.
155    pub fn len(&self) -> usize {
156        self.len
157    }
158
159    /// Appends an element to the back of a collection.
160    pub fn push(&mut self, value: T) {
161        let len = self.len;
162
163        // We will consume this value, don't drop it.
164        let value = ManuallyDrop::new(value);
165
166        if len == self.capacity() {
167            self.grow_one();
168        }
169
170        unsafe {
171            let value = Peek::new(&*value);
172            let value_ptr = value.data().as_byte_ptr();
173
174            let elem_ptrs = self.elem_ptrs_mut();
175
176            for (field, elem) in Self::fields().iter().zip(elem_ptrs) {
177                let field_layout = sized_layout!(field.shape());
178
179                let offset = field.offset;
180                let size = field_layout.size();
181
182                let field_elem = elem.offset((size * self.len()) as isize);
183                let src = value_ptr.offset(offset as isize);
184                ptr::copy_nonoverlapping(src, field_elem, size);
185            }
186        }
187
188        self.len += 1;
189    }
190
191    /// Removes the last element from the `MultiArrayList` and returns it, or `None` if it is empty.
192    pub fn pop(&mut self) -> Option<T> {
193        if self.len() == 0 {
194            return None;
195        }
196        let idx = self.len() - 1;
197        let mut wip = Wip::alloc::<T>().unwrap();
198
199        unsafe {
200            let elem_ptrs = self.elem_ptrs_mut();
201
202            for (field, elem) in Self::fields().iter().zip(elem_ptrs) {
203                let shape = field.shape();
204                let field_layout = sized_layout!(field.shape());
205
206                let size = field_layout.size();
207                let field_elem = elem.offset((size * idx) as isize);
208                let ptr = PtrConst::new(field_elem);
209                wip = wip.put_shape(ptr, shape).unwrap();
210            }
211        }
212
213        let v = wip.build().unwrap();
214        let v = v.materialize().unwrap();
215        self.len -= 1;
216        Some(v)
217    }
218
219    /// Returns an iterator over the `MultiArrayList`.
220    ///
221    /// The iterator yields all items from start to end.
222    pub fn iter<'a>(&'a self) -> Iter<'a, T> {
223        Iter {
224            array: self,
225            pos: 0,
226        }
227    }
228
229    /// Get an iterator of values for a specified field.
230    ///
231    /// # Panics
232    ///
233    /// * Panics if the requested field does not exit.
234    /// * Panics if the requested type does not match the found field.
235    pub fn items<'a, V>(&'a self, name: &'static str) -> Slice<'a, V>
236    where
237        V: Facet<'static>,
238    {
239        let fields = Self::fields();
240        let elem_ptrs = self.elem_ptrs();
241        for (idx, field) in fields.iter().enumerate() {
242            if field.name == name {
243                let typeid = field.shape().id;
244                assert_eq!(V::SHAPE.id, typeid, "wrong type requested");
245
246                let field_layout = sized_layout!(field.shape());
247                let elem_size = field_layout.size();
248
249                unsafe {
250                    let ptr = elem_ptrs[idx];
251                    let end = ptr.offset((self.len * elem_size) as isize);
252
253                    return Slice {
254                        ptr,
255                        end,
256                        typ: PhantomData,
257                    };
258                }
259            }
260        }
261
262        panic!("unknown item name");
263    }
264
265    /// Get an iterator of mutable values for a specified field.
266    ///
267    /// # Panics
268    ///
269    /// * Panics if the requested field does not exit.
270    /// * Panics if the requested type does not match the found field.
271    pub fn items_mut<'a, V>(&'a mut self, name: &'static str) -> SliceMut<'a, V>
272    where
273        V: Facet<'static>,
274    {
275        let fields = Self::fields();
276        let elem_ptrs = self.elem_ptrs_mut();
277        for (idx, field) in fields.iter().enumerate() {
278            if field.name == name {
279                let typeid = field.shape().id;
280                assert_eq!(V::SHAPE.id, typeid, "wrong type requested");
281
282                let field_layout = sized_layout!(field.shape());
283                let elem_size = field_layout.size();
284
285                unsafe {
286                    let ptr = elem_ptrs[idx];
287                    let end = ptr.offset((self.len * elem_size) as isize);
288
289                    return SliceMut {
290                        ptr,
291                        end,
292                        typ: PhantomData,
293                    };
294                }
295            }
296        }
297
298        panic!("unknown item name");
299    }
300
301    fn grow_one(&mut self) {
302        let old_cap = self.capacity();
303        let new_cap = old_cap + 1;
304
305        let fields = Self::fields();
306
307        for (idx, field) in fields.iter().enumerate() {
308            let field_layout = sized_layout!(field.shape());
309
310            let element_size = field_layout.size();
311            let align = field_layout.align();
312            let old_array_size = element_size * old_cap;
313            let new_array_size = element_size * new_cap;
314
315            unsafe {
316                let ptr = self.elems.offset(idx as isize);
317                if old_cap == 0 {
318                    let layout = Layout::from_size_align_unchecked(new_array_size, align);
319                    *ptr = alloc::alloc(layout);
320                } else {
321                    let layout = Layout::from_size_align_unchecked(old_array_size, align);
322                    *ptr = alloc::realloc(*ptr, layout, new_array_size);
323                }
324            }
325        }
326
327        self.cap = new_cap;
328    }
329
330    /// Shrinks the capacity of the vector as much as possible.
331    pub fn shrink_to_fit(&mut self) {
332        let cur_cap = self.capacity();
333        let len = self.len();
334        if len == cur_cap {
335            return;
336        }
337
338        unsafe {
339            for (idx, field) in Self::fields().iter().enumerate() {
340                let field_layout = sized_layout!(field.shape());
341
342                let element_size = field_layout.size();
343                let align = field_layout.align();
344                let old_array_size = element_size * cur_cap;
345                let new_array_size = element_size * len;
346
347                let ptr = self.elems.offset(idx as isize);
348                let layout = Layout::from_size_align_unchecked(old_array_size, align);
349                if new_array_size == 0 {
350                    alloc::dealloc(*ptr, layout);
351                    *ptr = ptr::null_mut();
352                } else {
353                    *ptr = alloc::realloc(*ptr, layout, new_array_size);
354                }
355            }
356        }
357
358        self.cap = len;
359    }
360}
361
362impl<T> Drop for MultiArrayList<T>
363where
364    T: Facet<'static>,
365{
366    fn drop(&mut self) {
367        while let Some(elem) = self.pop() {
368            drop(elem);
369        }
370
371        let fields = Self::fields();
372
373        unsafe {
374            self.shrink_to_fit();
375
376            let elemn = fields.len();
377            let layout = Layout::array::<*mut u8>(elemn).unwrap();
378            alloc::dealloc(self.elems as *mut u8, layout);
379        }
380    }
381}
382
383/// An iterator over all values of a specified field.
384pub struct Slice<'a, V> {
385    ptr: *const u8,
386    end: *const u8,
387    typ: PhantomData<&'a V>,
388}
389
390impl<'a, V> Iterator for Slice<'a, V>
391where
392    V: Facet<'static>,
393{
394    type Item = &'a V;
395
396    fn next(&mut self) -> Option<Self::Item> {
397        if self.ptr >= self.end {
398            return None;
399        }
400        let layout = sized_layout!(V::SHAPE);
401        let elem_size = layout.size();
402
403        unsafe {
404            let slice: &[u8] = slice::from_raw_parts(self.ptr, elem_size);
405            let val = mem::transmute_copy::<&[u8], &V>(&slice);
406            self.ptr = self.ptr.offset(elem_size as isize);
407            return Some(val);
408        }
409    }
410}
411
412/// An iterator over mutable values of a specified field.
413pub struct SliceMut<'a, V> {
414    ptr: *mut u8,
415    end: *mut u8,
416    typ: PhantomData<&'a mut V>,
417}
418
419impl<'a, V> Iterator for SliceMut<'a, V>
420where
421    V: Facet<'static>,
422{
423    type Item = &'a mut V;
424
425    fn next(&mut self) -> Option<Self::Item> {
426        if self.ptr >= self.end {
427            return None;
428        }
429        let layout = sized_layout!(V::SHAPE);
430        let elem_size = layout.size();
431
432        unsafe {
433            let mut slice: &mut [u8] = slice::from_raw_parts_mut(self.ptr, elem_size);
434            let val = mem::transmute_copy::<&mut [u8], &mut V>(&mut slice);
435            self.ptr = self.ptr.offset(elem_size as isize);
436            return Some(val);
437        }
438    }
439}
440
441/// An iterator over all elements of the `MultiArrayList`.
442pub struct Iter<'a, T>
443where
444    T: Facet<'static>,
445{
446    array: &'a MultiArrayList<T>,
447    pos: usize,
448}
449
450impl<'a, T> Iterator for Iter<'a, T>
451where
452    T: Facet<'static>,
453{
454    type Item = IterRef<'a, T>;
455
456    fn next(&mut self) -> Option<Self::Item> {
457        if self.pos >= self.array.len() {
458            return None;
459        }
460
461        let Def::Struct(sdef) = T::SHAPE.def else {
462            panic!("MultiArrayList only works for structs");
463        };
464
465        let fields = sdef.fields;
466        unsafe {
467            let elems_slice =
468                slice::from_raw_parts(self.array.elems as *const *const u8, fields.len());
469
470            self.pos += 1;
471            Some(IterRef {
472                ptr: elems_slice,
473                idx: self.pos - 1,
474                elem: PhantomData,
475            })
476        }
477    }
478}
479
480/// A reference to a single element of `T`.
481///
482/// Use [`get(field)`][`IterRef::get`] to access a single field.
483pub struct IterRef<'a, T>
484where
485    T: Facet<'static>,
486{
487    ptr: &'a [*const u8],
488    idx: usize,
489    elem: PhantomData<&'a T>,
490}
491
492impl<'a, T> IterRef<'a, T>
493where
494    T: Facet<'static>,
495{
496    /// Get a specified field of `T`.
497    ///
498    /// # Panics
499    ///
500    /// * Panics if the requested field does not exit.
501    /// * Panics if the requested type does not match the found field.
502    pub fn get<V>(&self, name: &'static str) -> &V
503    where
504        V: Facet<'static>,
505    {
506        let Def::Struct(sdef) = T::SHAPE.def else {
507            panic!("MultiArrayList only works for structs");
508        };
509        let fields = sdef.fields;
510        for (field_idx, field) in fields.iter().enumerate() {
511            if field.name == name {
512                let typeid = field.shape().id;
513                assert_eq!(V::SHAPE.id, typeid, "wrong type requested");
514
515                let field_layout = sized_layout!(field.shape());
516                let elem_size = field_layout.size();
517
518                unsafe {
519                    let ptr = self.ptr[field_idx].offset((self.idx * elem_size) as isize);
520                    let slice: &[u8] = slice::from_raw_parts(ptr, elem_size);
521                    let val = mem::transmute_copy::<&[u8], &V>(&slice);
522                    return val;
523                }
524            }
525        }
526
527        panic!("unknown item name {}", name);
528    }
529}
530
531#[cfg(test)]
532mod test {
533    use super::*;
534    use std::mem;
535
536    #[derive(Facet)]
537    struct Point {
538        x: i32,
539        y: i32,
540    }
541
542    #[derive(Facet)]
543    #[repr(u8)]
544    #[allow(dead_code)]
545    enum Random {
546        Four,
547    }
548
549    #[test]
550    fn size() {
551        assert_eq!(24, mem::size_of::<MultiArrayList<Point>>());
552    }
553
554    #[test]
555    fn empty() {
556        let list = MultiArrayList::<Point>::new();
557        assert_eq!(0, list.len());
558        assert_eq!(0, list.capacity());
559    }
560
561    #[test]
562    #[should_panic = "structs"]
563    fn struct_only() {
564        let _m = MultiArrayList::<Random>::new();
565    }
566
567    #[test]
568    #[should_panic = "structs"]
569    fn not_empty() {
570        let _m = MultiArrayList::<()>::new();
571    }
572}