1#![allow(clippy::approx_constant)]
2
3pub use dtype_variant_derive::{DType, build_dtype_tokens};
4
5pub trait EnumVariantDowncast<VariantToken> {
6 type Target;
7
8 fn downcast_ref(&self) -> Option<&Self::Target>;
10 fn downcast_mut(&mut self) -> Option<&mut Self::Target>;
11 fn downcast(self) -> Option<Self::Target>;
12}
13
14pub trait EnumVariantConstraint<VariantToken> {
16 type Constraint: 'static;
17}
18
19#[cfg(test)]
20mod tests {
21 use super::*;
22
23 trait Constraint: 'static {}
24
25 impl Constraint for u16 {}
26 impl Constraint for u32 {}
27 impl Constraint for u64 {}
28
29 build_dtype_tokens!([U16, U32, U64]);
30
31 #[derive(Clone, Debug, Default, DType)]
32 #[dtype(matcher = "match_my_enum_variant", tokens = "self", constraint = "Constraint")]
33 pub enum MyEnumVariant {
34 U16,
35 U32,
36 #[default]
37 U64,
38 }
39
40 #[derive(Clone, Debug, DType, PartialEq, Eq)]
41 #[dtype(
42 matcher = "match_my_enum",
43 tokens = "self",
44 constraint = "Constraint",
45 container = "Vec"
46 )]
47 enum MyEnum {
48 U16(Vec<u16>),
49 U32(Vec<u32>),
50 U64(Vec<u64>),
51 }
52
53 impl MyEnum {
54 fn from_default_variant(kind: MyEnumVariant) -> Self {
55 match_my_enum_variant!(kind, MyEnumVariant<Variant>, MyEnum<Container, Constraint> => {
56 vec![Constraint::default()].into()
57 })
58 }
59 }
60
61 #[test]
62 fn test_simple_enum() {
63 let a = MyEnumVariant::U16;
64 let _b = MyEnumVariant::U32;
65 match_my_enum_variant!(a, MyEnumVariant<VariantToken> => {
66 });
67 }
68
69 #[test]
70 fn test_end_to_end() {
71 let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]);
72 let bit_size = match_my_enum!(&x, MyEnum<T, VariantToken>(inner) => { inner.len() * T::BITS as usize });
73 assert_eq!(bit_size, 80);
74 let x = x.downcast::<U16Variant>().unwrap();
75 assert_eq!(x[0], 1);
76 }
77
78 #[test]
79 fn test_constraint() {
80 let x = MyEnumVariant::U16;
81 let my_enum = MyEnum::from_default_variant(x);
82 assert_eq!(my_enum, MyEnum::U16(vec![0]));
83 }
84
85 #[test]
86 fn test_token_based_downcast() {
87 let x = MyEnum::from(vec![1_u16, 1, 2, 3, 5]);
88 let first_element = x.downcast_ref::<U16Variant>().unwrap()[0];
89 assert_eq!(first_element, 1_u16);
90 }
91
92 build_dtype_tokens!([I32, F32]);
93
94 #[derive(Clone, Debug, DType)]
95 #[dtype(matcher = "match_dyn_enum", tokens = "self")]
96 enum DynChunk {
97 I32(i32),
98 F32(f32),
99 }
100
101 #[test]
102 fn test_dyn_chunk() {
103 let x = DynChunk::from(42_i32);
104 if let DynChunk::I32(value) = x {
105 assert_eq!(value, 42);
106 } else {
107 panic!("Expected DynChunk::I32");
108 }
109
110 let mut y = DynChunk::from(3.14_f32);
111 if let DynChunk::F32(value) = y {
112 assert_eq!(value, 3.14);
113 } else {
114 panic!("Expected DynChunk::F32");
115 }
116
117 let downcasted: Option<&i32> = x.downcast_ref::<I32Variant>();
118 assert_eq!(*downcasted.unwrap(), 42);
119
120 let downcasted_mut: Option<&mut f32> = y.downcast_mut::<F32Variant>();
121 *downcasted_mut.unwrap() = 2.71;
122 if let DynChunk::F32(value) = y {
123 assert_eq!(value, 2.71);
124 }
125 }
126
127 #[test]
128 fn test_match_dyn_enum_usage() {
129 let x = DynChunk::from(42_i32);
130 match_dyn_enum!(x, DynChunk<T, Token>(value) => {
131 let str_repr = value.to_string();
132 assert_eq!(str_repr, "42");
133 });
134
135 let y = DynChunk::from(3.14_f32);
136 match_dyn_enum!(y, DynChunk<T, Token>(value) => {
137 let str_repr = value.to_string();
138 assert_eq!(str_repr, "3.14");
139 });
140 }
141}