Skip to main content

facet_core/impls/core/
result.rs

1//! Facet implementation for Result<T, E>
2
3use core::cmp::Ordering;
4
5use crate::{
6    Def, Facet, HashProxy, OxPtrConst, OxPtrMut, OxRef, PtrConst, PtrMut, ResultDef, ResultVTable,
7    Shape, ShapeBuilder, Type, TypeOpsIndirect, TypeParam, UserType, VTableIndirect, Variance,
8    VarianceDep, VarianceDesc,
9};
10
11/// Extract the ResultDef from a shape, returns None if not a Result
12#[inline]
13const fn get_result_def(shape: &'static Shape) -> Option<&'static ResultDef> {
14    match shape.def {
15        Def::Result(ref def) => Some(def),
16        _ => None,
17    }
18}
19
20/// Debug for Result<T, E> - delegates to inner T/E's debug if available
21unsafe fn result_debug(
22    ox: OxPtrConst,
23    f: &mut core::fmt::Formatter<'_>,
24) -> Option<core::fmt::Result> {
25    let shape = ox.shape();
26    let def = get_result_def(shape)?;
27    let ptr = ox.ptr();
28
29    if unsafe { (def.vtable.is_ok)(ptr) } {
30        // SAFETY: is_ok returned true, so get_ok returns a valid pointer.
31        // The caller guarantees the OxPtrConst points to a valid Result.
32        let ok_ptr = unsafe { (def.vtable.get_ok)(ptr)? };
33        let ok_ox = unsafe { OxRef::new(ok_ptr, def.t) };
34        Some(f.debug_tuple("Ok").field(&ok_ox).finish())
35    } else {
36        // SAFETY: is_ok returned false, so get_err returns a valid pointer.
37        let err_ptr = unsafe { (def.vtable.get_err)(ptr)? };
38        let err_ox = unsafe { OxRef::new(err_ptr, def.e) };
39        Some(f.debug_tuple("Err").field(&err_ox).finish())
40    }
41}
42
43/// Hash for Result<T, E> - delegates to inner T/E's hash if available
44unsafe fn result_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
45    let shape = ox.shape();
46    let def = get_result_def(shape)?;
47    let ptr = ox.ptr();
48
49    use core::hash::Hash;
50    if unsafe { (def.vtable.is_ok)(ptr) } {
51        0u8.hash(hasher);
52        let ok_ptr = unsafe { (def.vtable.get_ok)(ptr)? };
53        unsafe { def.t.call_hash(ok_ptr, hasher)? };
54    } else {
55        1u8.hash(hasher);
56        let err_ptr = unsafe { (def.vtable.get_err)(ptr)? };
57        unsafe { def.e.call_hash(err_ptr, hasher)? };
58    }
59    Some(())
60}
61
62/// PartialEq for Result<T, E>
63unsafe fn result_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
64    let shape = a.shape();
65    let def = get_result_def(shape)?;
66
67    let a_ptr = a.ptr();
68    let b_ptr = b.ptr();
69    let a_is_ok = unsafe { (def.vtable.is_ok)(a_ptr) };
70    let b_is_ok = unsafe { (def.vtable.is_ok)(b_ptr) };
71
72    Some(match (a_is_ok, b_is_ok) {
73        (true, true) => {
74            let a_ok = unsafe { (def.vtable.get_ok)(a_ptr)? };
75            let b_ok = unsafe { (def.vtable.get_ok)(b_ptr)? };
76            unsafe { def.t.call_partial_eq(a_ok, b_ok)? }
77        }
78        (false, false) => {
79            let a_err = unsafe { (def.vtable.get_err)(a_ptr)? };
80            let b_err = unsafe { (def.vtable.get_err)(b_ptr)? };
81            unsafe { def.e.call_partial_eq(a_err, b_err)? }
82        }
83        _ => false,
84    })
85}
86
87/// PartialOrd for Result<T, E>
88unsafe fn result_partial_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Option<Ordering>> {
89    let shape = a.shape();
90    let def = get_result_def(shape)?;
91
92    let a_ptr = a.ptr();
93    let b_ptr = b.ptr();
94    let a_is_ok = unsafe { (def.vtable.is_ok)(a_ptr) };
95    let b_is_ok = unsafe { (def.vtable.is_ok)(b_ptr) };
96
97    Some(match (a_is_ok, b_is_ok) {
98        (true, true) => {
99            let a_ok = unsafe { (def.vtable.get_ok)(a_ptr)? };
100            let b_ok = unsafe { (def.vtable.get_ok)(b_ptr)? };
101            unsafe { def.t.call_partial_cmp(a_ok, b_ok)? }
102        }
103        (false, false) => {
104            let a_err = unsafe { (def.vtable.get_err)(a_ptr)? };
105            let b_err = unsafe { (def.vtable.get_err)(b_ptr)? };
106            unsafe { def.e.call_partial_cmp(a_err, b_err)? }
107        }
108        // Ok is greater than Err (following std::cmp::Ord for Result)
109        (true, false) => Some(Ordering::Greater),
110        (false, true) => Some(Ordering::Less),
111    })
112}
113
114/// Ord for Result<T, E>
115unsafe fn result_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Ordering> {
116    let shape = a.shape();
117    let def = get_result_def(shape)?;
118
119    let a_ptr = a.ptr();
120    let b_ptr = b.ptr();
121    let a_is_ok = unsafe { (def.vtable.is_ok)(a_ptr) };
122    let b_is_ok = unsafe { (def.vtable.is_ok)(b_ptr) };
123
124    Some(match (a_is_ok, b_is_ok) {
125        (true, true) => {
126            let a_ok = unsafe { (def.vtable.get_ok)(a_ptr)? };
127            let b_ok = unsafe { (def.vtable.get_ok)(b_ptr)? };
128            unsafe { def.t.call_cmp(a_ok, b_ok)? }
129        }
130        (false, false) => {
131            let a_err = unsafe { (def.vtable.get_err)(a_ptr)? };
132            let b_err = unsafe { (def.vtable.get_err)(b_ptr)? };
133            unsafe { def.e.call_cmp(a_err, b_err)? }
134        }
135        // Ok is greater than Err (following std::cmp::Ord for Result)
136        (true, false) => Ordering::Greater,
137        (false, true) => Ordering::Less,
138    })
139}
140
141/// Drop for Result<T, E>
142unsafe fn result_drop(ox: OxPtrMut) {
143    let shape = ox.shape();
144    let Some(def) = get_result_def(shape) else {
145        return;
146    };
147    let ptr = ox.ptr();
148
149    if unsafe { (def.vtable.is_ok)(ptr.as_const()) } {
150        let Some(ok_ptr) = (unsafe { (def.vtable.get_ok)(ptr.as_const()) }) else {
151            return;
152        };
153        let ok_ptr_mut = PtrMut::new(ok_ptr.as_byte_ptr() as *mut u8);
154        unsafe { def.t.call_drop_in_place(ok_ptr_mut) };
155    } else {
156        let Some(err_ptr) = (unsafe { (def.vtable.get_err)(ptr.as_const()) }) else {
157            return;
158        };
159        let err_ptr_mut = PtrMut::new(err_ptr.as_byte_ptr() as *mut u8);
160        unsafe { def.e.call_drop_in_place(err_ptr_mut) };
161    }
162}
163
164// Shared vtable for all Result<T, E>
165const RESULT_VTABLE: VTableIndirect = VTableIndirect {
166    display: None,
167    debug: Some(result_debug),
168    hash: Some(result_hash),
169    invariants: None,
170    parse: None,
171    parse_bytes: None,
172    try_from: None,
173    try_into_inner: None,
174    try_borrow_inner: None,
175    partial_eq: Some(result_partial_eq),
176    partial_cmp: Some(result_partial_cmp),
177    cmp: Some(result_cmp),
178};
179
180// Type operations for all Result<T, E>
181static RESULT_TYPE_OPS: TypeOpsIndirect = TypeOpsIndirect {
182    drop_in_place: result_drop,
183    default_in_place: None,
184    clone_into: None,
185    is_truthy: None,
186};
187
188/// Check if Result<T, E> is Ok
189unsafe fn result_is_ok<T, E>(result: PtrConst) -> bool {
190    unsafe { result.get::<Result<T, E>>().is_ok() }
191}
192
193/// Get the Ok value from Result<T, E> if present
194unsafe fn result_get_ok<T, E>(result: PtrConst) -> Option<PtrConst> {
195    unsafe {
196        result
197            .get::<Result<T, E>>()
198            .as_ref()
199            .ok()
200            .map(|t| PtrConst::new(t as *const T))
201    }
202}
203
204/// Get the Err value from Result<T, E> if present
205unsafe fn result_get_err<T, E>(result: PtrConst) -> Option<PtrConst> {
206    unsafe {
207        result
208            .get::<Result<T, E>>()
209            .as_ref()
210            .err()
211            .map(|e| PtrConst::new(e as *const E))
212    }
213}
214
215/// Initialize Result<T, E> with Ok(value)
216unsafe fn result_init_ok<T, E>(result: crate::PtrUninit, value: PtrConst) -> PtrMut {
217    unsafe { result.put(Result::<T, E>::Ok(value.read::<T>())) }
218}
219
220/// Initialize Result<T, E> with Err(value)
221unsafe fn result_init_err<T, E>(result: crate::PtrUninit, value: PtrConst) -> PtrMut {
222    unsafe { result.put(Result::<T, E>::Err(value.read::<E>())) }
223}
224
225unsafe impl<'a, T: Facet<'a>, E: Facet<'a>> Facet<'a> for Result<T, E> {
226    const SHAPE: &'static Shape = &const {
227        const fn build_result_vtable<T, E>() -> ResultVTable {
228            ResultVTable::builder()
229                .is_ok(result_is_ok::<T, E>)
230                .get_ok(result_get_ok::<T, E>)
231                .get_err(result_get_err::<T, E>)
232                .init_ok(result_init_ok::<T, E>)
233                .init_err(result_init_err::<T, E>)
234                .build()
235        }
236
237        ShapeBuilder::for_sized::<Result<T, E>>("Result")
238            .module_path("core::result")
239            .ty(Type::User(UserType::Opaque))
240            .def(Def::Result(ResultDef::new(
241                &const { build_result_vtable::<T, E>() },
242                T::SHAPE,
243                E::SHAPE,
244            )))
245            .type_params(&[
246                TypeParam {
247                    name: "T",
248                    shape: T::SHAPE,
249                },
250                TypeParam {
251                    name: "E",
252                    shape: E::SHAPE,
253                },
254            ])
255            // Result<T, E> combines T and E variances
256            .variance(VarianceDesc {
257                base: Variance::Bivariant,
258                deps: &const {
259                    [
260                        VarianceDep::covariant(T::SHAPE),
261                        VarianceDep::covariant(E::SHAPE),
262                    ]
263                },
264            })
265            .vtable_indirect(&RESULT_VTABLE)
266            .type_ops_indirect(&RESULT_TYPE_OPS)
267            .build()
268    };
269}