yikes_intenum/
lib.rs

1//! A macro for mapping integers to Rust `enum`s with integer-numbered variants, plus
2//! a catch-all `Unknown` variant.
3
4/// A macro that implements useful functionality on integer-based `enum`s.
5/// ```rust
6/// yikes_intenum::yikes_intenum! {
7///     /// IP datagram encapsulated protocol.
8///     pub enum Protocol(u8) {
9///         HopByHop  = 0x00,
10///         Icmp      = 0x01,
11///         Igmp      = 0x02,
12///         Tcp       = 0x06,
13///         Udp       = 0x11,
14///         Ipv6Route = 0x2b,
15///         Ipv6Frag  = 0x2c,
16///         IpSecEsp  = 0x32,
17///         IpSecAh   = 0x33,
18///         Icmpv6    = 0x3a,
19///         Ipv6NoNxt = 0x3b,
20///         Ipv6Opts  = 0x3c
21///     }
22/// }
23/// ```
24#[macro_export]
25macro_rules! yikes_intenum {
26    (
27        $( #[$enum_attr:meta] )*
28        pub enum $name:ident($ty:ty) {
29            $(
30              $( #[$variant_attr:meta] )*
31              $variant:ident = $value:expr
32            ),+ $(,)?
33        }
34    ) => {
35        paste::paste! {
36            mod [< _ $name:snake _private >] {
37                #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
38                pub struct Sealed;
39            }
40
41            // #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
42            #[derive(Eq, Clone, Copy)]
43            // #[cfg_attr(feature = "defmt", derive(defmt::Format))]
44            $( #[$enum_attr] )*
45            #[repr($ty)]
46            pub enum $name {
47                $(
48                $( #[$variant_attr] )*
49                $variant
50                ),*,
51                Unknown {
52                    value: $ty,
53                    _private: [< _ $name:snake _private >]::Sealed
54                }
55            }
56
57            // Debug
58            impl ::core::fmt::Debug for $name {
59                #[inline]
60                fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
61                    match self {
62                        $( $name::$variant => ::core::fmt::Formatter::write_str(f, stringify!($variant)) ),*,
63                        $name::Unknown{value: other, ..} => {
64                            write!(f, "Unknown({})", other)
65                        }
66                    }
67                }
68            }
69
70            // PartialEq (Eq is derived automatically)
71            // impl ::core::marker::StructuralPartialEq for $name {}
72            impl ::core::cmp::PartialEq for $name {
73                #[inline]
74                fn eq(&self, other: &$name) -> bool {
75                    $ty::from(self).eq(&$ty::from(other))
76                }
77            }
78
79            // PartialOrd, Ord
80            impl ::core::cmp::PartialOrd for $name {
81                #[inline]
82                fn partial_cmp(&self, other: &$name) -> ::core::option::Option<::core::cmp::Ordering> {
83                    Some(self.cmp(other))
84                }
85            }
86
87            impl ::core::cmp::Ord for $name {
88                #[inline]
89                fn cmp(&self, other: &$name) -> ::core::cmp::Ordering {
90                    $ty::from(self).cmp(&$ty::from(other))
91                }
92            }
93
94            // Hash
95            impl ::core::hash::Hash for $name {
96                #[inline]
97                fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {
98                    $ty::from(self).hash(state)
99                }
100            }
101
102            impl ::core::convert::From<$ty> for $name {
103                fn from(value: $ty) -> Self {
104                    match value {
105                        $( $value => $name::$variant ),*,
106                        other => $name::Unknown{value: other, _private: [< _ $name:snake _private >]::Sealed}
107                    }
108                }
109            }
110
111            impl ::core::convert::From<&$name> for $ty {
112                fn from(value: &$name) -> Self {
113                    match value {
114                        $( &$name::$variant => $value ),*,
115                        &$name::Unknown{value: other, ..} => other
116                    }
117                }
118            }
119
120            impl ::core::convert::From<$name> for $ty {
121                fn from(value: $name) -> Self {
122                    (&value).into()
123                }
124            }
125        } // paste::paste!
126    }
127} // macro_rules! yikes_intenum
128
129// pub[(crate)] use yikes_intenum; // if not using `#[macro_export]`
130
131#[cfg(test)]
132mod tests {
133    use core::cmp::Ordering;
134    use core::hash::{BuildHasher, Hash, Hasher};
135
136    use super::*;
137
138    yikes_intenum! {
139        pub enum TestIpProtocol(u8) {
140            Icmp = 0x01_u8,
141            Tcp = 0x06_u8,
142        }
143    }
144
145    #[test]
146    fn test_ipprotocol_roundtrip() {
147        for i in 0..=u8::MAX {
148            let a: TestIpProtocol = i.into();
149            let b: u8 = a.into();
150            let c: TestIpProtocol = b.into();
151            let d: u8 = c.into();
152            assert_eq!(a, c);
153            assert_eq!(c, a);
154            assert_eq!(b, d);
155            assert_eq!(d, b);
156        }
157    }
158
159    #[test]
160    fn test_ipprotocol_debug() {
161        for i in 0..=u8::MAX {
162            let a: TestIpProtocol = i.into();
163            match &a {
164                TestIpProtocol::Icmp => {
165                    assert_eq!("Icmp", format!("{a:?}"));
166                }
167                TestIpProtocol::Tcp => {
168                    assert_eq!("Tcp", format!("{a:?}"));
169                }
170                TestIpProtocol::Unknown { value, .. } => {
171                    assert_eq!(format!("Unknown({value})"), format!("{a:?}"));
172                }
173            }
174        }
175    }
176
177    #[test]
178    fn test_ipprotocol_eq() {
179        for i in 0..=u8::MAX {
180            let a: TestIpProtocol = i.into();
181            let b = TestIpProtocol::Unknown {
182                value: i,
183                _private: _test_ip_protocol_private::Sealed,
184            };
185            assert!(
186                a.eq(&b),
187                "a {a:?} ({}) != b {b:?} ({})",
188                u8::from(&a),
189                u8::from(&b)
190            );
191            assert!(
192                b.eq(&a),
193                "b {b:?} ({}) != a {a:?} ({})",
194                u8::from(&b),
195                u8::from(&a)
196            );
197        }
198    }
199
200    #[test]
201    fn test_ipprotocol_cmp() {
202        #![allow(clippy::similar_names)]
203
204        for i in 0..u8::MAX {
205            for a in [
206                i.into(),
207                TestIpProtocol::Unknown {
208                    value: i,
209                    _private: _test_ip_protocol_private::Sealed,
210                },
211            ] {
212                for j in (i + 1)..=u8::MAX {
213                    for b in [
214                        j.into(),
215                        TestIpProtocol::Unknown {
216                            value: j,
217                            _private: _test_ip_protocol_private::Sealed,
218                        },
219                    ] {
220                        let a_cmp_b = a.cmp(&b);
221                        let a_pcmp_b = a.partial_cmp(&b);
222                        let b_cmp_a = b.cmp(&a);
223                        let b_pcmp_a = b.partial_cmp(&a);
224                        let a_int = u8::from(&a);
225                        let b_int = u8::from(&b);
226
227                        assert_eq!(
228                            a_cmp_b,
229                            Ordering::Less,
230                            "[cmp] a {a:?} ({a_int}) !< b {b:?} ({b_int})"
231                        );
232                        assert_eq!(
233                            a_pcmp_b,
234                            Some(Ordering::Less),
235                            "[pcmp] a {a:?} ({a_int}) !< b {b:?} ({b_int})"
236                        );
237                        assert_eq!(
238                            b_cmp_a,
239                            Ordering::Greater,
240                            "[cmp] b {b:?} ({b_int}) !> a {a:?} ({a_int})"
241                        );
242                        assert_eq!(
243                            b_pcmp_a,
244                            Some(Ordering::Greater),
245                            "[pcmp] b {b:?} ({b_int}) !> a {a:?} ({a_int})"
246                        );
247
248                        // extra checks specifically for invariants that should hold
249                        // for correct implementations of `Ord` and `PartialOrd`.
250                        // (just in case the above is ever modified carelessly).
251
252                        // check that `cmp` and `partial_cmp` agree.
253                        assert_eq!(a_pcmp_b, Some(a_cmp_b), "a cmp b != a pcmp b");
254                        assert_eq!(b_pcmp_a, Some(b_cmp_a), "b cmp a != b pcmp a!");
255
256                        // check reversing args also reverses the `cmp` result.
257                        assert_eq!(a_cmp_b, b_cmp_a.reverse());
258                        assert_eq!(b_cmp_a, a_cmp_b.reverse());
259                    }
260                }
261            }
262        }
263    }
264
265    #[test]
266    fn test_ipprotocol_hash() {
267        for i in 0..u8::MAX {
268            let a: TestIpProtocol = i.into();
269            let b = TestIpProtocol::Unknown {
270                value: i,
271                _private: _test_ip_protocol_private::Sealed,
272            };
273
274            #[allow(unused_qualifications, clippy::type_complexity)]
275            let hashers: Vec<(
276                &str,
277                Box<dyn core::hash::Hasher>,
278                Box<dyn core::hash::Hasher>,
279            )> = vec![
280                (
281                    "std::collections::hash_map::DefaultHasher",
282                    Box::new(std::collections::hash_map::DefaultHasher::new()) as _,
283                    Box::new(std::collections::hash_map::DefaultHasher::new()) as _,
284                ),
285                (
286                    "FnvHasher",
287                    Box::new(fnv::FnvBuildHasher::default().build_hasher()) as _,
288                    Box::new(fnv::FnvBuildHasher::default().build_hasher()) as _,
289                ),
290            ];
291
292            for (hasher_kind, mut hasher_a, mut hasher_b) in hashers {
293                #[allow(clippy::unreadable_literal)]
294                let noises: &[Option<u64>] = &[
295                    None,
296                    Some(18223650421099562965_u64),
297                    Some(579348513557276885_u64),
298                    Some(6018745257231369041_u64),
299                    Some(4974397919804797078_u64),
300                    Some(6574880736321336287_u64),
301                    Some(8334883869055102477_u64),
302                    Some(8077341428061256032_u64),
303                    Some(8702568753483328048_u64),
304                ];
305                for noise in noises {
306                    let assert_hashers_match =
307                        |hasher_a: &mut dyn Hasher, hasher_b: &mut dyn Hasher| {
308                            let ha = &mut hasher_a.finish();
309                            let hb = &mut hasher_b.finish();
310                            assert_eq!(
311                                ha,
312                                hb,
313                                "[kind={hasher_kind}, noise={noise:?}] hash{{a {a:?} ({})}} {ha} != hash{{b {b:?} ({})}} {hb}",
314                                u8::from(&a),
315                                u8::from(&b),
316                            );
317                        };
318
319                    if let Some(noise) = noise {
320                        noise.hash(&mut hasher_a);
321                        noise.hash(&mut hasher_b);
322                        assert_hashers_match(&mut hasher_a, &mut hasher_b);
323                    }
324                    a.hash(&mut hasher_a);
325                    b.hash(&mut hasher_b);
326                    assert_hashers_match(&mut hasher_a, &mut hasher_b);
327                }
328                assert_eq!(
329                    a,
330                    b,
331                    "[kind={hasher_kind}] hash{{a}} == hash{{b}} must imply a == b, but: a {a:?} ({}) != b {b:?} ({})",
332                    u8::from(&a),
333                    u8::from(&b)
334                );
335            }
336        }
337    }
338
339    #[test]
340    fn test_ipprotocol_hash_different() {
341        for i in 0..u8::MAX {
342            for a in [
343                i.into(),
344                TestIpProtocol::Unknown {
345                    value: i,
346                    _private: _test_ip_protocol_private::Sealed,
347                },
348            ] {
349                for j in 0..=u8::MAX {
350                    if i == j {
351                        continue;
352                    }
353                    for b in [
354                        j.into(),
355                        TestIpProtocol::Unknown {
356                            value: j,
357                            _private: _test_ip_protocol_private::Sealed,
358                        },
359                    ] {
360                        let mut hasher_a =
361                            Box::new(std::collections::hash_map::DefaultHasher::new());
362                        let mut hasher_b =
363                            Box::new(std::collections::hash_map::DefaultHasher::new());
364                        a.hash(&mut hasher_a);
365                        b.hash(&mut hasher_b);
366                        let ha = hasher_a.finish();
367                        let hb = hasher_b.finish();
368                        // a != b almost surely implies hash{{a}} != hash{{b}}.
369                        assert_ne!(ha, hb, "Different values yielded same hashes: hash{{a {a:?} ({int_a})}} ({ha}) == hash{{b {b:?} ({int_b})}} ({hb})", int_a=u8::from(&a), ha=ha, int_b=u8::from(&b), hb=hb);
370                    }
371                }
372            }
373        }
374    }
375}