Skip to main content

facet_core/impls/core/
array.rs

1//! Facet implementation for [T; N] arrays
2
3use core::{cmp::Ordering, fmt};
4
5use crate::{
6    ArrayDef, ArrayVTable, ConstParam, ConstParamKind, Def, Facet, HashProxy, OxPtrConst, OxPtrMut,
7    OxPtrUninit, OxRef, PtrConst, PtrMut, PtrUninit, Shape, ShapeBuilder, Type, TypeNameOpts,
8    TypeOpsIndirect, TypeParam, VTableIndirect, Variance, VarianceDep, VarianceDesc,
9};
10
11/// Extract the ArrayDef from a shape, returns None if not an array
12#[inline]
13const fn get_array_def(shape: &'static Shape) -> Option<&'static ArrayDef> {
14    match shape.def {
15        Def::Array(ref def) => Some(def),
16        _ => None,
17    }
18}
19
20/// Type-erased type_name for arrays - reads T and N from the shape
21fn array_type_name(
22    shape: &'static Shape,
23    f: &mut fmt::Formatter<'_>,
24    opts: TypeNameOpts,
25) -> fmt::Result {
26    let def = match &shape.def {
27        Def::Array(def) => def,
28        _ => return write!(f, "[?; ?]"),
29    };
30
31    if let Some(opts) = opts.for_children() {
32        write!(f, "[")?;
33        def.t.write_type_name(f, opts)?;
34        write!(f, "; {}]", def.n)
35    } else {
36        write!(f, "[…; {}]", def.n)
37    }
38}
39
40/// Debug for [T; N] - formats as array literal
41unsafe fn array_debug(
42    ox: OxPtrConst,
43    f: &mut core::fmt::Formatter<'_>,
44) -> Option<core::fmt::Result> {
45    let shape = ox.shape();
46    let def = get_array_def(shape)?;
47    let ptr = ox.ptr();
48
49    let mut list = f.debug_list();
50    let slice_ptr = unsafe { (def.vtable.as_ptr)(ptr) };
51    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
52
53    for i in 0..def.n {
54        // SAFETY: We're iterating within bounds of the array, and the caller
55        // guarantees the OxPtrConst points to a valid array.
56        let elem_ptr = unsafe { PtrConst::new((slice_ptr.as_byte_ptr()).add(i * stride)) };
57        let elem_ox = unsafe { OxRef::new(elem_ptr, def.t) };
58        list.entry(&elem_ox);
59    }
60    Some(list.finish())
61}
62
63/// Hash for [T; N] - hashes each element
64unsafe fn array_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
65    let shape = ox.shape();
66    let def = get_array_def(shape)?;
67    let ptr = ox.ptr();
68
69    let slice_ptr = unsafe { (def.vtable.as_ptr)(ptr) };
70    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
71
72    for i in 0..def.n {
73        let elem_ptr = unsafe { PtrConst::new((slice_ptr.as_byte_ptr()).add(i * stride)) };
74        unsafe { def.t.call_hash(elem_ptr, hasher)? };
75    }
76    Some(())
77}
78
79/// PartialEq for [T; N]
80unsafe fn array_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
81    let shape = a.shape();
82    let def = get_array_def(shape)?;
83
84    let a_ptr = unsafe { (def.vtable.as_ptr)(a.ptr()) };
85    let b_ptr = unsafe { (def.vtable.as_ptr)(b.ptr()) };
86    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
87
88    for i in 0..def.n {
89        let a_elem = unsafe { PtrConst::new((a_ptr.as_byte_ptr()).add(i * stride)) };
90        let b_elem = unsafe { PtrConst::new((b_ptr.as_byte_ptr()).add(i * stride)) };
91        if !unsafe { def.t.call_partial_eq(a_elem, b_elem)? } {
92            return Some(false);
93        }
94    }
95    Some(true)
96}
97
98/// PartialOrd for [T; N]
99unsafe fn array_partial_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Option<Ordering>> {
100    let shape = a.shape();
101    let def = get_array_def(shape)?;
102
103    let a_ptr = unsafe { (def.vtable.as_ptr)(a.ptr()) };
104    let b_ptr = unsafe { (def.vtable.as_ptr)(b.ptr()) };
105    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
106
107    for i in 0..def.n {
108        let a_elem = unsafe { PtrConst::new((a_ptr.as_byte_ptr()).add(i * stride)) };
109        let b_elem = unsafe { PtrConst::new((b_ptr.as_byte_ptr()).add(i * stride)) };
110        match unsafe { def.t.call_partial_cmp(a_elem, b_elem)? } {
111            Some(Ordering::Equal) => continue,
112            other => return Some(other),
113        }
114    }
115    Some(Some(Ordering::Equal))
116}
117
118/// Ord for [T; N]
119unsafe fn array_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Ordering> {
120    let shape = a.shape();
121    let def = get_array_def(shape)?;
122
123    let a_ptr = unsafe { (def.vtable.as_ptr)(a.ptr()) };
124    let b_ptr = unsafe { (def.vtable.as_ptr)(b.ptr()) };
125    let stride = def.t.layout.sized_layout().ok()?.pad_to_align().size();
126
127    for i in 0..def.n {
128        let a_elem = unsafe { PtrConst::new((a_ptr.as_byte_ptr()).add(i * stride)) };
129        let b_elem = unsafe { PtrConst::new((b_ptr.as_byte_ptr()).add(i * stride)) };
130        match unsafe { def.t.call_cmp(a_elem, b_elem)? } {
131            Ordering::Equal => continue,
132            other => return Some(other),
133        }
134    }
135    Some(Ordering::Equal)
136}
137
138/// Drop for [T; N]
139unsafe fn array_drop(ox: OxPtrMut) {
140    let shape = ox.shape();
141    let Some(def) = get_array_def(shape) else {
142        return;
143    };
144    let ptr = ox.ptr();
145
146    let slice_ptr = unsafe { (def.vtable.as_mut_ptr)(ptr) };
147    let Some(stride) = def
148        .t
149        .layout
150        .sized_layout()
151        .ok()
152        .map(|l| l.pad_to_align().size())
153    else {
154        return;
155    };
156
157    for i in 0..def.n {
158        let elem_ptr = unsafe { PtrMut::new((slice_ptr.as_byte_ptr() as *mut u8).add(i * stride)) };
159        unsafe { def.t.call_drop_in_place(elem_ptr) };
160    }
161}
162
163/// Default for [T; N] - default-initializes each element
164unsafe fn array_default(ox: OxPtrUninit) -> bool {
165    let shape = ox.shape();
166    let Some(def) = get_array_def(shape) else {
167        return false;
168    };
169    let ptr = ox.ptr();
170
171    // Arrays are laid out contiguously starting at offset 0, so we can use
172    // the pointer directly without going through the vtable (which requires
173    // initialized memory).
174    let base_ptr = ptr.as_mut_byte_ptr();
175    let Some(stride) = def
176        .t
177        .layout
178        .sized_layout()
179        .ok()
180        .map(|l| l.pad_to_align().size())
181    else {
182        return false;
183    };
184
185    for i in 0..def.n {
186        let elem_ptr = unsafe { PtrUninit::new(base_ptr.add(i * stride)) };
187        if unsafe { def.t.call_default_in_place(elem_ptr) }.is_none() {
188            // Drop already-initialized elements before returning
189            for j in 0..i {
190                let drop_ptr = unsafe { PtrMut::new(base_ptr.add(j * stride)) };
191                unsafe { def.t.call_drop_in_place(drop_ptr) };
192            }
193            return false;
194        }
195    }
196    true
197}
198
199/// Clone for [T; N] - clones each element
200unsafe fn array_clone(src: OxPtrConst, dst: OxPtrMut) {
201    let shape = src.shape();
202    let Some(def) = get_array_def(shape) else {
203        return;
204    };
205
206    // Source is initialized, so we can use the vtable
207    let src_ptr = unsafe { (def.vtable.as_ptr)(src.ptr()) };
208    // Destination is uninitialized (clone_into writes to uninit memory),
209    // so use the pointer directly instead of the vtable
210    let dst_base = dst.ptr().as_mut_byte_ptr();
211    let Some(stride) = def
212        .t
213        .layout
214        .sized_layout()
215        .ok()
216        .map(|l| l.pad_to_align().size())
217    else {
218        return;
219    };
220
221    for i in 0..def.n {
222        let src_elem = unsafe { PtrConst::new((src_ptr.as_byte_ptr()).add(i * stride)) };
223        let dst_elem = unsafe { PtrMut::new(dst_base.add(i * stride)) };
224        if unsafe { def.t.call_clone_into(src_elem, dst_elem) }.is_none() {
225            return;
226        }
227    }
228}
229
230// Shared vtable for all [T; N]
231const ARRAY_VTABLE: VTableIndirect = VTableIndirect {
232    display: None,
233    debug: Some(array_debug),
234    hash: Some(array_hash),
235    invariants: None,
236    parse: None,
237    parse_bytes: None,
238    try_from: None,
239    try_into_inner: None,
240    try_borrow_inner: None,
241    partial_eq: Some(array_partial_eq),
242    partial_cmp: Some(array_partial_cmp),
243    cmp: Some(array_cmp),
244};
245
246/// Get pointer to array data buffer
247unsafe extern "C" fn array_as_ptr<T, const N: usize>(ptr: PtrConst) -> PtrConst {
248    let array = unsafe { ptr.get::<[T; N]>() };
249    PtrConst::new(array.as_ptr() as *const u8)
250}
251
252/// Get mutable pointer to array data buffer
253unsafe extern "C" fn array_as_mut_ptr<T, const N: usize>(ptr: PtrMut) -> PtrMut {
254    let array = unsafe { ptr.as_mut::<[T; N]>() };
255    PtrMut::new(array.as_mut_ptr() as *mut u8)
256}
257
258unsafe impl<'a, T, const N: usize> Facet<'a> for [T; N]
259where
260    T: Facet<'a>,
261{
262    const SHAPE: &'static Shape = &const {
263        const fn build_array_vtable<T, const N: usize>() -> ArrayVTable {
264            ArrayVTable::builder()
265                .as_ptr(array_as_ptr::<T, N>)
266                .as_mut_ptr(array_as_mut_ptr::<T, N>)
267                .build()
268        }
269
270        ShapeBuilder::for_sized::<[T; N]>("[T; N]")
271            .decl_id(crate::DeclId::new(crate::decl_id_hash("#array#[T; N]")))
272            .type_name(array_type_name)
273            .ty(Type::Sequence(crate::SequenceType::Array(
274                crate::ArrayType { t: T::SHAPE, n: N },
275            )))
276            .def(Def::Array(ArrayDef::new(
277                &const { build_array_vtable::<T, N>() },
278                T::SHAPE,
279                N,
280            )))
281            .type_params(&[TypeParam {
282                name: "T",
283                shape: T::SHAPE,
284            }])
285            .const_params(&[ConstParam {
286                name: "N",
287                value: N as u64,
288                kind: ConstParamKind::Usize,
289            }])
290            .inner(T::SHAPE)
291            // [T; N] propagates T's variance
292            .variance(VarianceDesc {
293                base: Variance::Bivariant,
294                deps: &const { [VarianceDep::covariant(T::SHAPE)] },
295            })
296            .vtable_indirect(&ARRAY_VTABLE)
297            .type_ops_indirect(
298                &const {
299                    unsafe fn truthy<const N: usize>(_: PtrConst) -> bool {
300                        N != 0
301                    }
302
303                    TypeOpsIndirect {
304                        drop_in_place: array_drop,
305                        default_in_place: Some(array_default),
306                        clone_into: Some(array_clone),
307                        is_truthy: Some(truthy::<N>),
308                    }
309                },
310            )
311            .build()
312    };
313}