Skip to main content

facet_core/impls/core/
array.rs

1//! Facet implementation for [T; N] arrays
2
3use core::{cmp::Ordering, fmt};
4
5use crate::{
6    ArrayDef, ArrayVTable, Def, Facet, HashProxy, OxPtrConst, OxPtrMut, OxRef, PtrConst, PtrMut,
7    Shape, ShapeBuilder, Type, TypeNameOpts, TypeOpsIndirect, TypeParam, VTableIndirect, Variance,
8    VarianceDep, VarianceDesc,
9};
10
11/// Extract the ArrayDef from a shape, returns None if not an array
12#[inline]
13const fn get_array_def(shape: &'static Shape) -> Option<&'static ArrayDef> {
14    match shape.def {
15        Def::Array(ref def) => Some(def),
16        _ => None,
17    }
18}
19
20/// Type-erased type_name for arrays - reads T and N from the shape
21fn array_type_name(
22    shape: &'static Shape,
23    f: &mut fmt::Formatter<'_>,
24    opts: TypeNameOpts,
25) -> fmt::Result {
26    let def = match &shape.def {
27        Def::Array(def) => def,
28        _ => return write!(f, "[?; ?]"),
29    };
30
31    if let Some(opts) = opts.for_children() {
32        write!(f, "[")?;
33        def.t.write_type_name(f, opts)?;
34        write!(f, "; {}]", def.n)
35    } else {
36        write!(f, "[…; {}]", def.n)
37    }
38}
39
40/// Debug for [T; N] - formats as array literal
41unsafe fn array_debug(
42    ox: OxPtrConst,
43    f: &mut core::fmt::Formatter<'_>,
44) -> Option<core::fmt::Result> {
45    let shape = ox.shape();
46    let def = get_array_def(shape)?;
47    let ptr = ox.ptr();
48
49    let mut list = f.debug_list();
50    let slice_ptr = unsafe { (def.vtable.as_ptr)(ptr) };
51    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
52
53    for i in 0..def.n {
54        // SAFETY: We're iterating within bounds of the array, and the caller
55        // guarantees the OxPtrConst points to a valid array.
56        let elem_ptr = unsafe { PtrConst::new((slice_ptr.as_byte_ptr()).add(i * stride)) };
57        let elem_ox = unsafe { OxRef::new(elem_ptr, def.t) };
58        list.entry(&elem_ox);
59    }
60    Some(list.finish())
61}
62
63/// Hash for [T; N] - hashes each element
64unsafe fn array_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
65    let shape = ox.shape();
66    let def = get_array_def(shape)?;
67    let ptr = ox.ptr();
68
69    let slice_ptr = unsafe { (def.vtable.as_ptr)(ptr) };
70    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
71
72    for i in 0..def.n {
73        let elem_ptr = unsafe { PtrConst::new((slice_ptr.as_byte_ptr()).add(i * stride)) };
74        unsafe { def.t.call_hash(elem_ptr, hasher)? };
75    }
76    Some(())
77}
78
79/// PartialEq for [T; N]
80unsafe fn array_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
81    let shape = a.shape();
82    let def = get_array_def(shape)?;
83
84    let a_ptr = unsafe { (def.vtable.as_ptr)(a.ptr()) };
85    let b_ptr = unsafe { (def.vtable.as_ptr)(b.ptr()) };
86    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
87
88    for i in 0..def.n {
89        let a_elem = unsafe { PtrConst::new((a_ptr.as_byte_ptr()).add(i * stride)) };
90        let b_elem = unsafe { PtrConst::new((b_ptr.as_byte_ptr()).add(i * stride)) };
91        if !unsafe { def.t.call_partial_eq(a_elem, b_elem)? } {
92            return Some(false);
93        }
94    }
95    Some(true)
96}
97
98/// PartialOrd for [T; N]
99unsafe fn array_partial_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Option<Ordering>> {
100    let shape = a.shape();
101    let def = get_array_def(shape)?;
102
103    let a_ptr = unsafe { (def.vtable.as_ptr)(a.ptr()) };
104    let b_ptr = unsafe { (def.vtable.as_ptr)(b.ptr()) };
105    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
106
107    for i in 0..def.n {
108        let a_elem = unsafe { PtrConst::new((a_ptr.as_byte_ptr()).add(i * stride)) };
109        let b_elem = unsafe { PtrConst::new((b_ptr.as_byte_ptr()).add(i * stride)) };
110        match unsafe { def.t.call_partial_cmp(a_elem, b_elem)? } {
111            Some(Ordering::Equal) => continue,
112            other => return Some(other),
113        }
114    }
115    Some(Some(Ordering::Equal))
116}
117
118/// Ord for [T; N]
119unsafe fn array_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Ordering> {
120    let shape = a.shape();
121    let def = get_array_def(shape)?;
122
123    let a_ptr = unsafe { (def.vtable.as_ptr)(a.ptr()) };
124    let b_ptr = unsafe { (def.vtable.as_ptr)(b.ptr()) };
125    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
126
127    for i in 0..def.n {
128        let a_elem = unsafe { PtrConst::new((a_ptr.as_byte_ptr()).add(i * stride)) };
129        let b_elem = unsafe { PtrConst::new((b_ptr.as_byte_ptr()).add(i * stride)) };
130        match unsafe { def.t.call_cmp(a_elem, b_elem)? } {
131            Ordering::Equal => continue,
132            other => return Some(other),
133        }
134    }
135    Some(Ordering::Equal)
136}
137
138/// Drop for [T; N]
139unsafe fn array_drop(ox: OxPtrMut) {
140    let shape = ox.shape();
141    let Some(def) = get_array_def(shape) else {
142        return;
143    };
144    let ptr = ox.ptr();
145
146    let slice_ptr = unsafe { (def.vtable.as_mut_ptr)(ptr) };
147    let Some(stride) = def
148        .t
149        .layout
150        .sized_layout()
151        .ok()
152        .map(|l| l.pad_to_align().size())
153    else {
154        return;
155    };
156
157    for i in 0..def.n {
158        let elem_ptr = unsafe { PtrMut::new((slice_ptr.as_byte_ptr() as *mut u8).add(i * stride)) };
159        unsafe { def.t.call_drop_in_place(elem_ptr) };
160    }
161}
162
163/// Default for [T; N] - default-initializes each element
164unsafe fn array_default(ox: OxPtrMut) {
165    let shape = ox.shape();
166    let Some(def) = get_array_def(shape) else {
167        return;
168    };
169    let ptr = ox.ptr();
170
171    let slice_ptr = unsafe { (def.vtable.as_mut_ptr)(ptr) };
172    let Some(stride) = def
173        .t
174        .layout
175        .sized_layout()
176        .ok()
177        .map(|l| l.pad_to_align().size())
178    else {
179        return;
180    };
181
182    for i in 0..def.n {
183        let elem_ptr = unsafe { PtrMut::new((slice_ptr.as_byte_ptr() as *mut u8).add(i * stride)) };
184        if unsafe { def.t.call_default_in_place(elem_ptr) }.is_none() {
185            return;
186        }
187    }
188}
189
190/// Clone for [T; N] - clones each element
191unsafe fn array_clone(src: OxPtrConst, dst: OxPtrMut) {
192    let shape = src.shape();
193    let Some(def) = get_array_def(shape) else {
194        return;
195    };
196
197    let src_ptr = unsafe { (def.vtable.as_ptr)(src.ptr()) };
198    let dst_ptr = unsafe { (def.vtable.as_mut_ptr)(dst.ptr()) };
199    let Some(stride) = def
200        .t
201        .layout
202        .sized_layout()
203        .ok()
204        .map(|l| l.pad_to_align().size())
205    else {
206        return;
207    };
208
209    for i in 0..def.n {
210        let src_elem = unsafe { PtrConst::new((src_ptr.as_byte_ptr()).add(i * stride)) };
211        let dst_elem = unsafe { PtrMut::new((dst_ptr.as_byte_ptr() as *mut u8).add(i * stride)) };
212        if unsafe { def.t.call_clone_into(src_elem, dst_elem) }.is_none() {
213            return;
214        }
215    }
216}
217
218// Shared vtable for all [T; N]
219const ARRAY_VTABLE: VTableIndirect = VTableIndirect {
220    display: None,
221    debug: Some(array_debug),
222    hash: Some(array_hash),
223    invariants: None,
224    parse: None,
225    parse_bytes: None,
226    try_from: None,
227    try_into_inner: None,
228    try_borrow_inner: None,
229    partial_eq: Some(array_partial_eq),
230    partial_cmp: Some(array_partial_cmp),
231    cmp: Some(array_cmp),
232};
233
234/// Get pointer to array data buffer
235unsafe fn array_as_ptr<T, const N: usize>(ptr: PtrConst) -> PtrConst {
236    let array = unsafe { ptr.get::<[T; N]>() };
237    PtrConst::new(array.as_ptr() as *const u8)
238}
239
240/// Get mutable pointer to array data buffer
241unsafe fn array_as_mut_ptr<T, const N: usize>(ptr: PtrMut) -> PtrMut {
242    let array = unsafe { ptr.as_mut::<[T; N]>() };
243    PtrMut::new(array.as_mut_ptr() as *mut u8)
244}
245
246unsafe impl<'a, T, const N: usize> Facet<'a> for [T; N]
247where
248    T: Facet<'a>,
249{
250    const SHAPE: &'static Shape = &const {
251        const fn build_array_vtable<T, const N: usize>() -> ArrayVTable {
252            ArrayVTable::builder()
253                .as_ptr(array_as_ptr::<T, N>)
254                .as_mut_ptr(array_as_mut_ptr::<T, N>)
255                .build()
256        }
257
258        ShapeBuilder::for_sized::<[T; N]>("[T; N]")
259            .decl_id(crate::DeclId::new(crate::decl_id_hash("#array#[T; N]")))
260            .type_name(array_type_name)
261            .ty(Type::Sequence(crate::SequenceType::Array(
262                crate::ArrayType { t: T::SHAPE, n: N },
263            )))
264            .def(Def::Array(ArrayDef::new(
265                &const { build_array_vtable::<T, N>() },
266                T::SHAPE,
267                N,
268            )))
269            .type_params(&[TypeParam {
270                name: "T",
271                shape: T::SHAPE,
272            }])
273            .inner(T::SHAPE)
274            // [T; N] propagates T's variance
275            .variance(VarianceDesc {
276                base: Variance::Bivariant,
277                deps: &const { [VarianceDep::covariant(T::SHAPE)] },
278            })
279            .vtable_indirect(&ARRAY_VTABLE)
280            .type_ops_indirect(
281                &const {
282                    unsafe fn truthy<const N: usize>(_: PtrConst) -> bool {
283                        N != 0
284                    }
285
286                    TypeOpsIndirect {
287                        drop_in_place: array_drop,
288                        default_in_place: Some(array_default),
289                        clone_into: Some(array_clone),
290                        is_truthy: Some(truthy::<N>),
291                    }
292                },
293            )
294            .build()
295    };
296}