Skip to main content

facet_core/impls/core/
result.rs

1//! Facet implementation for Result<T, E>
2
3use core::cmp::Ordering;
4
5use crate::{
6    Def, Facet, HashProxy, OxPtrConst, OxPtrMut, OxRef, PtrConst, PtrMut, ResultDef, ResultVTable,
7    Shape, ShapeBuilder, Type, TypeOpsIndirect, TypeParam, UserType, VTableIndirect, Variance,
8    VarianceDep, VarianceDesc,
9};
10
11/// Extract the ResultDef from a shape, returns None if not a Result
12#[inline]
13const fn get_result_def(shape: &'static Shape) -> Option<&'static ResultDef> {
14    match shape.def {
15        Def::Result(ref def) => Some(def),
16        _ => None,
17    }
18}
19
20#[inline]
21unsafe fn result_get_ok_ptr(def: &ResultDef, ptr: PtrConst) -> Option<PtrConst> {
22    let raw = unsafe { (def.vtable.get_ok)(ptr) };
23    if raw.is_null() {
24        None
25    } else {
26        Some(PtrConst::new_sized(raw))
27    }
28}
29
30#[inline]
31unsafe fn result_get_err_ptr(def: &ResultDef, ptr: PtrConst) -> Option<PtrConst> {
32    let raw = unsafe { (def.vtable.get_err)(ptr) };
33    if raw.is_null() {
34        None
35    } else {
36        Some(PtrConst::new_sized(raw))
37    }
38}
39
40fn result_type_name(
41    shape: &'static Shape,
42    f: &mut core::fmt::Formatter<'_>,
43    opts: crate::TypeNameOpts,
44) -> core::fmt::Result {
45    write!(f, "Result")?;
46    if let Some(opts) = opts.for_children() {
47        write!(f, "<")?;
48        if let Some(t) = shape.type_params.first() {
49            t.shape.write_type_name(f, opts)?;
50        }
51        if let Some(e) = shape.type_params.get(1) {
52            write!(f, ", ")?;
53            e.shape.write_type_name(f, opts)?;
54        }
55        write!(f, ">")?;
56    } else {
57        write!(f, "<…>")?;
58    }
59    Ok(())
60}
61
62/// Debug for Result<T, E> - delegates to inner T/E's debug if available
63unsafe fn result_debug(
64    ox: OxPtrConst,
65    f: &mut core::fmt::Formatter<'_>,
66) -> Option<core::fmt::Result> {
67    let shape = ox.shape();
68    let def = get_result_def(shape)?;
69    let ptr = ox.ptr();
70
71    if unsafe { (def.vtable.is_ok)(ptr) } {
72        // SAFETY: is_ok returned true, so get_ok returns a valid pointer.
73        // The caller guarantees the OxPtrConst points to a valid Result.
74        let ok_ptr = unsafe { result_get_ok_ptr(def, ptr)? };
75        let ok_ox = unsafe { OxRef::new(ok_ptr, def.t) };
76        Some(f.debug_tuple("Ok").field(&ok_ox).finish())
77    } else {
78        // SAFETY: is_ok returned false, so get_err returns a valid pointer.
79        let err_ptr = unsafe { result_get_err_ptr(def, ptr)? };
80        let err_ox = unsafe { OxRef::new(err_ptr, def.e) };
81        Some(f.debug_tuple("Err").field(&err_ox).finish())
82    }
83}
84
85/// Hash for Result<T, E> - delegates to inner T/E's hash if available
86unsafe fn result_hash(ox: OxPtrConst, hasher: &mut HashProxy<'_>) -> Option<()> {
87    let shape = ox.shape();
88    let def = get_result_def(shape)?;
89    let ptr = ox.ptr();
90
91    use core::hash::Hash;
92    if unsafe { (def.vtable.is_ok)(ptr) } {
93        0u8.hash(hasher);
94        let ok_ptr = unsafe { result_get_ok_ptr(def, ptr)? };
95        unsafe { def.t.call_hash(ok_ptr, hasher)? };
96    } else {
97        1u8.hash(hasher);
98        let err_ptr = unsafe { result_get_err_ptr(def, ptr)? };
99        unsafe { def.e.call_hash(err_ptr, hasher)? };
100    }
101    Some(())
102}
103
104/// PartialEq for Result<T, E>
105unsafe fn result_partial_eq(a: OxPtrConst, b: OxPtrConst) -> Option<bool> {
106    let shape = a.shape();
107    let def = get_result_def(shape)?;
108
109    let a_ptr = a.ptr();
110    let b_ptr = b.ptr();
111    let a_is_ok = unsafe { (def.vtable.is_ok)(a_ptr) };
112    let b_is_ok = unsafe { (def.vtable.is_ok)(b_ptr) };
113
114    Some(match (a_is_ok, b_is_ok) {
115        (true, true) => {
116            let a_ok = unsafe { result_get_ok_ptr(def, a_ptr)? };
117            let b_ok = unsafe { result_get_ok_ptr(def, b_ptr)? };
118            unsafe { def.t.call_partial_eq(a_ok, b_ok)? }
119        }
120        (false, false) => {
121            let a_err = unsafe { result_get_err_ptr(def, a_ptr)? };
122            let b_err = unsafe { result_get_err_ptr(def, b_ptr)? };
123            unsafe { def.e.call_partial_eq(a_err, b_err)? }
124        }
125        _ => false,
126    })
127}
128
129/// PartialOrd for Result<T, E>
130unsafe fn result_partial_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Option<Ordering>> {
131    let shape = a.shape();
132    let def = get_result_def(shape)?;
133
134    let a_ptr = a.ptr();
135    let b_ptr = b.ptr();
136    let a_is_ok = unsafe { (def.vtable.is_ok)(a_ptr) };
137    let b_is_ok = unsafe { (def.vtable.is_ok)(b_ptr) };
138
139    Some(match (a_is_ok, b_is_ok) {
140        (true, true) => {
141            let a_ok = unsafe { result_get_ok_ptr(def, a_ptr)? };
142            let b_ok = unsafe { result_get_ok_ptr(def, b_ptr)? };
143            unsafe { def.t.call_partial_cmp(a_ok, b_ok)? }
144        }
145        (false, false) => {
146            let a_err = unsafe { result_get_err_ptr(def, a_ptr)? };
147            let b_err = unsafe { result_get_err_ptr(def, b_ptr)? };
148            unsafe { def.e.call_partial_cmp(a_err, b_err)? }
149        }
150        // Ok is greater than Err (following std::cmp::Ord for Result)
151        (true, false) => Some(Ordering::Greater),
152        (false, true) => Some(Ordering::Less),
153    })
154}
155
156/// Ord for Result<T, E>
157unsafe fn result_cmp(a: OxPtrConst, b: OxPtrConst) -> Option<Ordering> {
158    let shape = a.shape();
159    let def = get_result_def(shape)?;
160
161    let a_ptr = a.ptr();
162    let b_ptr = b.ptr();
163    let a_is_ok = unsafe { (def.vtable.is_ok)(a_ptr) };
164    let b_is_ok = unsafe { (def.vtable.is_ok)(b_ptr) };
165
166    Some(match (a_is_ok, b_is_ok) {
167        (true, true) => {
168            let a_ok = unsafe { result_get_ok_ptr(def, a_ptr)? };
169            let b_ok = unsafe { result_get_ok_ptr(def, b_ptr)? };
170            unsafe { def.t.call_cmp(a_ok, b_ok)? }
171        }
172        (false, false) => {
173            let a_err = unsafe { result_get_err_ptr(def, a_ptr)? };
174            let b_err = unsafe { result_get_err_ptr(def, b_ptr)? };
175            unsafe { def.e.call_cmp(a_err, b_err)? }
176        }
177        // Ok is greater than Err (following std::cmp::Ord for Result)
178        (true, false) => Ordering::Greater,
179        (false, true) => Ordering::Less,
180    })
181}
182
183/// Drop for Result<T, E>
184unsafe fn result_drop(ox: OxPtrMut) {
185    let shape = ox.shape();
186    let Some(def) = get_result_def(shape) else {
187        return;
188    };
189    let ptr = ox.ptr();
190
191    if unsafe { (def.vtable.is_ok)(ptr.as_const()) } {
192        let Some(ok_ptr) = (unsafe { result_get_ok_ptr(def, ptr.as_const()) }) else {
193            return;
194        };
195        let ok_ptr_mut = PtrMut::new(ok_ptr.as_byte_ptr() as *mut u8);
196        unsafe { def.t.call_drop_in_place(ok_ptr_mut) };
197    } else {
198        let Some(err_ptr) = (unsafe { result_get_err_ptr(def, ptr.as_const()) }) else {
199            return;
200        };
201        let err_ptr_mut = PtrMut::new(err_ptr.as_byte_ptr() as *mut u8);
202        unsafe { def.e.call_drop_in_place(err_ptr_mut) };
203    }
204}
205
206// Shared vtable for all Result<T, E>
207const RESULT_VTABLE: VTableIndirect = VTableIndirect {
208    display: None,
209    debug: Some(result_debug),
210    hash: Some(result_hash),
211    invariants: None,
212    parse: None,
213    parse_bytes: None,
214    try_from: None,
215    try_into_inner: None,
216    try_borrow_inner: None,
217    partial_eq: Some(result_partial_eq),
218    partial_cmp: Some(result_partial_cmp),
219    cmp: Some(result_cmp),
220};
221
222// Type operations for all Result<T, E>
223static RESULT_TYPE_OPS: TypeOpsIndirect = TypeOpsIndirect {
224    drop_in_place: result_drop,
225    default_in_place: None,
226    clone_into: None,
227    is_truthy: None,
228};
229
230/// Check if Result<T, E> is Ok
231unsafe extern "C" fn result_is_ok<T, E>(result: PtrConst) -> bool {
232    unsafe { result.get::<Result<T, E>>().is_ok() }
233}
234
235/// Get the Ok value from Result<T, E> if present
236unsafe extern "C" fn result_get_ok<T, E>(result: PtrConst) -> *const u8 {
237    unsafe {
238        result
239            .get::<Result<T, E>>()
240            .as_ref()
241            .ok()
242            .map_or(core::ptr::null(), |t| t as *const T as *const u8)
243    }
244}
245
246/// Get the Err value from Result<T, E> if present
247unsafe extern "C" fn result_get_err<T, E>(result: PtrConst) -> *const u8 {
248    unsafe {
249        result
250            .get::<Result<T, E>>()
251            .as_ref()
252            .err()
253            .map_or(core::ptr::null(), |e| e as *const E as *const u8)
254    }
255}
256
257/// Initialize Result<T, E> with Ok(value)
258unsafe extern "C" fn result_init_ok<T, E>(result: crate::PtrUninit, value: PtrMut) -> PtrMut {
259    unsafe { result.put(Result::<T, E>::Ok(value.read::<T>())) }
260}
261
262/// Initialize Result<T, E> with Err(value)
263unsafe extern "C" fn result_init_err<T, E>(result: crate::PtrUninit, value: PtrMut) -> PtrMut {
264    unsafe { result.put(Result::<T, E>::Err(value.read::<E>())) }
265}
266
267unsafe impl<'a, T: Facet<'a>, E: Facet<'a>> Facet<'a> for Result<T, E> {
268    const SHAPE: &'static Shape = &const {
269        const fn build_result_vtable<T, E>() -> ResultVTable {
270            ResultVTable::builder()
271                .is_ok(result_is_ok::<T, E>)
272                .get_ok(result_get_ok::<T, E>)
273                .get_err(result_get_err::<T, E>)
274                .init_ok(result_init_ok::<T, E>)
275                .init_err(result_init_err::<T, E>)
276                .build()
277        }
278
279        ShapeBuilder::for_sized::<Result<T, E>>("Result")
280            .module_path("core::result")
281            .type_name(result_type_name)
282            .ty(Type::User(UserType::Opaque))
283            .def(Def::Result(ResultDef::new(
284                &const { build_result_vtable::<T, E>() },
285                T::SHAPE,
286                E::SHAPE,
287            )))
288            .type_params(&[
289                TypeParam {
290                    name: "T",
291                    shape: T::SHAPE,
292                },
293                TypeParam {
294                    name: "E",
295                    shape: E::SHAPE,
296                },
297            ])
298            // Result<T, E> combines T and E variances
299            .variance(VarianceDesc {
300                base: Variance::Bivariant,
301                deps: &const {
302                    [
303                        VarianceDep::covariant(T::SHAPE),
304                        VarianceDep::covariant(E::SHAPE),
305                    ]
306                },
307            })
308            .vtable_indirect(&RESULT_VTABLE)
309            .type_ops_indirect(&RESULT_TYPE_OPS)
310            .build()
311    };
312}