1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
//! This module implements the trait object used to check const generics.

use crate::{
    abi_stability::{
        check_layout_compatibility,
        extra_checks::{ExtraChecksError, TypeCheckerMut},
    },
    erased_types::{
        c_functions::{adapt_std_fmt, debug_impl, partial_eq_impl},
        FormattingMode,
    },
    marker_type::ErasedObject,
    prefix_type::WithMetadata,
    sabi_types::RRef,
    std_types::{RErr, ROk, RResult, RString},
    type_layout::TypeLayout,
    StableAbi,
};

use std::{
    cmp::{Eq, PartialEq},
    fmt::{self, Debug},
};

///////////////////////////////////////////////////////////////////////////////

/// A trait object used to check equality between const generic parameters.
#[repr(C)]
#[derive(Copy, Clone, StableAbi)]
pub struct ConstGeneric {
    ptr: RRef<'static, ErasedObject>,
    vtable: ConstGenericVTable_Ref,
}

unsafe impl Send for ConstGeneric {}
unsafe impl Sync for ConstGeneric {}

impl ConstGeneric {
    /// Constructs a ConstGeneric from a reference.
    pub const fn new<T>(this: &'static T) -> Self
    where
        T: StableAbi + Eq + PartialEq + Debug + Send + Sync + 'static,
    {
        Self {
            ptr: unsafe { RRef::from_raw(this as *const T as *const ErasedObject) },
            vtable: ConstGenericVTableFor::<T>::VTABLE,
        }
    }

    /// Compares this to another `ConstGeneric` for equality,
    /// returning an error if the type layout of `self` and `other` is not compatible.
    pub fn is_equal(
        &self,
        other: &Self,
        mut checker: TypeCheckerMut<'_>,
    ) -> Result<bool, ExtraChecksError> {
        match checker.check_compatibility(self.vtable.layout(), other.vtable.layout()) {
            ROk(_) => unsafe { Ok(self.vtable.partial_eq()(self.ptr, other.ptr)) },
            RErr(e) => Err(e),
        }
    }
}

impl Debug for ConstGeneric {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        unsafe { adapt_std_fmt::<ErasedObject>(self.ptr, self.vtable.debug(), f) }
    }
}

// Make sure that this isn't called within `check_layout_compatibility` itself,
// since it would cause infinite recursion.
impl PartialEq for ConstGeneric {
    fn eq(&self, other: &Self) -> bool {
        if check_layout_compatibility(self.vtable.layout(), other.vtable.layout()).is_err() {
            false
        } else {
            unsafe { self.vtable.partial_eq()(self.ptr, other.ptr) }
        }
    }
}

impl Eq for ConstGeneric {}

///////////////////////////////////////////////////////////////////////////////

/// The vtable of `ConstGeneric`
#[repr(C)]
#[derive(StableAbi)]
#[sabi(kind(Prefix))]
#[sabi(missing_field(panic))]
struct ConstGenericVTable {
    layout: &'static TypeLayout,
    partial_eq: unsafe extern "C" fn(RRef<'_, ErasedObject>, RRef<'_, ErasedObject>) -> bool,
    #[sabi(last_prefix_field)]
    debug: unsafe extern "C" fn(
        RRef<'_, ErasedObject>,
        FormattingMode,
        &mut RString,
    ) -> RResult<(), ()>,
}

/// A type that contains the vtable stored in the `ConstGeneric` constructed from a `T`.
/// This is used as a workaround for `const fn` not allowing trait bounds.
struct ConstGenericVTableFor<T>(T);

impl<T> ConstGenericVTableFor<T>
where
    T: StableAbi + Eq + PartialEq + Debug + Send + Sync + 'static,
{
    const _VTABLE_STATIC: &'static WithMetadata<ConstGenericVTable> = &{
        WithMetadata::new(ConstGenericVTable {
            layout: <T as StableAbi>::LAYOUT,
            partial_eq: partial_eq_impl::<T>,
            debug: debug_impl::<T>,
        })
    };

    /// Constructs a `ConstGenericVTableFor`
    const VTABLE: ConstGenericVTable_Ref =
        ConstGenericVTable_Ref(Self::_VTABLE_STATIC.static_as_prefix());
}