dtype_variant/
lib.rs

1#![allow(clippy::approx_constant)]
2
3pub use dtype_variant_derive::{DType, build_dtype_tokens};
4
5pub trait EnumVariantDowncast<VariantToken> {
6    type Target;
7
8    /// Returns a reference to the target field if the enum is the target variant
9    fn downcast_ref(&self) -> Option<&Self::Target>;
10    fn downcast_mut(&mut self) -> Option<&mut Self::Target>;
11    fn downcast(self) -> Option<Self::Target>;
12}
13
14// Define the EnumVariantConstraint trait with Constraint parameter
15pub trait EnumVariantConstraint<VariantToken> {
16    type Constraint: 'static;
17}
18
19#[cfg(test)]
20mod tests {
21    use super::*;
22
23    trait Constraint: 'static {}
24
25    impl Constraint for u16 {}
26    impl Constraint for u32 {}
27    impl Constraint for u64 {}
28
29    build_dtype_tokens!([U16, U32, U64]);
30
31    #[derive(Clone, Debug, Default, DType)]
32    #[dtype(matcher = "match_my_enum_variant", tokens = "self", constraint = "Constraint")]
33    pub enum MyEnumVariant {
34        U16,
35        U32,
36        #[default]
37        U64,
38    }
39
40    #[derive(Clone, Debug, DType, PartialEq, Eq)]
41    #[dtype(
42        matcher = "match_my_enum",
43        tokens = "self",
44        constraint = "Constraint",
45        container = "Vec"
46    )]
47    enum MyEnum {
48        U16(Vec<u16>),
49        U32(Vec<u32>),
50        U64(Vec<u64>),
51    }
52
53    impl MyEnum {
54        fn from_default_variant(kind: MyEnumVariant) -> Self {
55            match_my_enum_variant!(kind, MyEnumVariant<Variant>, MyEnum<Container, Constraint> => {
56                vec![Constraint::default()].into()
57            })
58        }
59    }
60
61    #[test]
62    fn test_simple_enum() {
63        let a = MyEnumVariant::U16;
64        let _b = MyEnumVariant::U32;
65        match_my_enum_variant!(a, MyEnumVariant<VariantToken> => {
66        });
67    }
68
69    #[test]
70    fn test_end_to_end() {
71        let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]);
72        let bit_size = match_my_enum!(&x, MyEnum<T, VariantToken>(inner) => { inner.len() * T::BITS as usize });
73        assert_eq!(bit_size, 80);
74        let x = x.downcast::<U16Variant>().unwrap();
75        assert_eq!(x[0], 1);
76    }
77
78    #[test]
79    fn test_constraint() {
80        let x = MyEnumVariant::U16;
81        let my_enum = MyEnum::from_default_variant(x);
82        assert_eq!(my_enum, MyEnum::U16(vec![0]));
83    }
84
85    #[test]
86    fn test_token_based_downcast() {
87        let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]);
88        let first_element = x.downcast_ref::<U16Variant>().unwrap()[0];
89        assert_eq!(first_element, 1_u16);
90    }
91
92    build_dtype_tokens!([I32, F32]);
93
94    #[derive(Clone, Debug, DType)]
95    #[dtype(matcher = "match_dyn_enum", tokens = "self")]
96    enum DynChunk {
97        I32(i32),
98        F32(f32),
99    }
100
101    #[test]
102    fn test_dyn_chunk() {
103        let x = DynChunk::from(42_i32);
104        if let DynChunk::I32(value) = x {
105            assert_eq!(value, 42);
106        } else {
107            panic!("Expected DynChunk::I32");
108        }
109
110        let mut y = DynChunk::from(3.14_f32);
111        if let DynChunk::F32(value) = y {
112            assert_eq!(value, 3.14);
113        } else {
114            panic!("Expected DynChunk::F32");
115        }
116
117        let downcasted: Option<&i32> = x.downcast_ref::<I32Variant>();
118        assert_eq!(*downcasted.unwrap(), 42);
119
120        let downcasted_mut: Option<&mut f32> = y.downcast_mut::<F32Variant>();
121        *downcasted_mut.unwrap() = 2.71;
122        if let DynChunk::F32(value) = y {
123            assert_eq!(value, 2.71);
124        }
125    }
126
127    #[test]
128    fn test_match_dyn_enum_usage() {
129        let x = DynChunk::from(42_i32);
130        match_dyn_enum!(x, DynChunk<T, Token>(value) => {
131            let str_repr = value.to_string();
132            assert_eq!(str_repr, "42");
133        });
134
135        let y = DynChunk::from(3.14_f32);
136        match_dyn_enum!(y, DynChunk<T, Token>(value) => {
137            let str_repr = value.to_string();
138            assert_eq!(str_repr, "3.14");
139        });
140    }
141}