dtype_dispatch/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(unreachable_patterns)]
3
4/// Produces two macros: an enum definer and an enum matcher.
5///
6/// See the crate-level documentation for more info.
7#[macro_export]
8macro_rules! build_dtype_macros {
9  (
10    $(#[$definer_attrs: meta])*
11    $definer: ident,
12    $(#[$matcher_attrs: meta])*
13    $matcher: ident,
14    $constraint: path,
15    {$($variant: ident => $t: ty,)+}$(,)?
16  ) => {
17    $(#[$definer_attrs])*
18    macro_rules! $definer {
19      (#[$enum_attrs: meta] $vis: vis $name: ident) => {
20        #[$enum_attrs]
21        #[non_exhaustive]
22        $vis enum $name {
23          $($variant,)+
24        }
25
26        impl $name {
27          #[inline]
28          pub fn new<T: $constraint>() -> Option<Self> {
29            let type_id = std::any::TypeId::of::<T>();
30            $(
31              if type_id == std::any::TypeId::of::<$t>() {
32                return Some($name::$variant);
33              }
34            )+
35            None
36          }
37        }
38      };
39      (#[$enum_attrs: meta] #[repr($desc_t: ty)] $vis: vis $name: ident = $desc_val: ident) => {
40        #[$enum_attrs]
41        #[repr($desc_t)]
42        #[non_exhaustive]
43        $vis enum $name {
44          $($variant = <$t>::$desc_val,)+
45        }
46
47        impl $name {
48          #[inline]
49          pub fn new<T: $constraint>() -> Option<Self> {
50            let type_id = std::any::TypeId::of::<T>();
51            $(
52              if type_id == std::any::TypeId::of::<$t>() {
53                return Some($name::$variant);
54              }
55            )+
56            None
57          }
58
59          pub fn from_descriminant(desc: $desc_t) -> Option<Self> {
60            match desc {
61              $(<$t>::$desc_val => Some(Self::$variant),)+
62              _ => None
63            }
64          }
65        }
66      };
67      (#[$enum_attrs: meta] $vis: vis $name: ident($container: ident)) => {
68        #[$enum_attrs]
69        #[non_exhaustive]
70        $vis enum $name {
71          $($variant($container<$t>),)+
72        }
73
74        impl $name {
75          #[inline]
76          pub fn new<S: $constraint>(inner: $container<S>) -> Option<Self> {
77            let type_id = std::any::TypeId::of::<S>();
78            $(
79              if type_id == std::any::TypeId::of::<$t>() {
80                // Transmute doesn't work for containers whose size depends on T,
81                // so we use a hack from
82                // https://users.rust-lang.org/t/transmuting-a-generic-array/45645/6
83                let ptr = &inner as *const $container<S> as *const $container<$t>;
84                let typed = unsafe { ptr.read() };
85                std::mem::forget(inner);
86                return Some($name::$variant(typed));
87              }
88            )+
89            None
90          }
91
92          pub fn downcast<T: $constraint>(self) -> Option<$container<T>> {
93            match self {
94              $(
95                Self::$variant(inner) => {
96                  if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
97                    // same hack from `new`
98                    let ptr = &inner as *const $container<$t> as *const $container<T>;
99                    let typed = unsafe { ptr.read() };
100                    std::mem::forget(inner);
101                    Some(typed)
102                  } else {
103                    None
104                  }
105                }
106              )+
107            }
108          }
109
110          pub fn downcast_ref<T: $constraint>(&self) -> Option<&$container<T>> {
111            match self {
112              $(
113                Self::$variant(inner) => {
114                  if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
115                    unsafe {
116                      Some(std::mem::transmute::<_, &$container<T>>(inner))
117                    }
118                  } else {
119                    None
120                  }
121                }
122              )+
123            }
124          }
125
126          pub fn downcast_mut<T: $constraint>(&mut self) -> Option<&mut $container<T>> {
127            match self {
128              $(
129                Self::$variant(inner) => {
130                  if std::any::TypeId::of::<T>() == std::any::TypeId::of::<$t>() {
131                    unsafe {
132                      Some(std::mem::transmute::<_, &mut $container<T>>(inner))
133                    }
134                  } else {
135                    None
136                  }
137                }
138              )+
139            }
140          }
141        }
142      };
143    }
144
145    $(#[$matcher_attrs])*
146    macro_rules! $matcher {
147      ($value: expr, $enum_: ident<$generic: ident> => $block: block) => {
148        match $value {
149          $($enum_::$variant => {
150            type $generic = $t;
151            $block
152          })+
153          _ => unreachable!()
154        }
155      };
156      ($value: expr, $enum_: ident<$generic: ident>($inner: ident) => $block: block) => {
157        match $value {
158          $($enum_::$variant($inner) => {
159            type $generic = $t;
160            $block
161          })+
162          _ => unreachable!()
163        }
164      };
165    }
166  };
167}
168
169#[allow(dead_code)]
170#[cfg(test)]
171mod tests {
172  use std::collections::HashMap;
173
174  trait Constraint: 'static {}
175
176  impl Constraint for u16 {}
177  impl Constraint for u32 {}
178  impl Constraint for u64 {}
179
180  build_dtype_macros!(
181    define_enum,
182    match_enum,
183    crate::tests::Constraint,
184    {
185      U16 => u16,
186      U32 => u32,
187      U64 => u64,
188    }
189  );
190
191  define_enum!(
192    #[derive(Clone, Debug)]
193    MyEnum(Vec)
194  );
195
196  type Counter<T> = HashMap<T, usize>;
197
198  define_enum!(
199    #[derive(Clone, Debug)]
200    AnotherContainerEnumInSameScope(Counter)
201  );
202
203  // we use this helper just to prove that we can handle generic types, not
204  // just concrete types
205  fn generic_new<T: Constraint>(inner: Vec<T>) -> MyEnum {
206    MyEnum::new(inner).unwrap()
207  }
208
209  #[test]
210  fn test_end_to_end() {
211    let x = generic_new(vec![1_u16, 1, 2, 3, 5]);
212    let bit_size = match_enum!(&x, MyEnum<L>(inner) => { inner.len() * L::BITS as usize });
213    assert_eq!(bit_size, 80);
214    let x = x.downcast::<u16>().unwrap();
215    assert_eq!(x[0], 1);
216  }
217
218  #[test]
219  fn test_multiple_enums_defined_in_same_scope() {
220    // This was really tested during compilation, but I'm just using the new
221    // enum here to ensure the code doesn't die.
222    AnotherContainerEnumInSameScope::new(HashMap::<u16, usize>::new()).unwrap();
223  }
224}