Skip to main content

packet_strata/
macros.rs

1// Macro to generate constants, helper enum, and trait implementations
2#[macro_export]
3macro_rules! protocol_constants {
4    // 1. Constructor Helper: Identity (for u8)
5    (@construct_u8 $ztype:ty, $val:expr) => { $val };
6
7    // 2. Constructor Helper: New (for U16 etc)
8    (@construct_new $ztype:ty, $val:expr) => { <$ztype>::new($val) };
9
10    // 3. Body Implementation
11    (@impl $(#[$outer:meta])*, $type_name:ident, $ztype:ty, $primitive:ty, $strategy:ident, $( $(#[$default:ident])? $const_name:ident = $val:expr; )+ ) => {
12        paste::paste! {
13            /// $type_name number.
14            ///
15            #[doc = concat!("A newtype wrapper around a ", stringify!($primitive), " representing an ", stringify!($type_name), " number.")]
16            /// This type provides named constants for well-known protocols and implements
17            /// `Display` to show human-readable protocol names.
18            $(#[$outer])*
19            #[derive(
20                Clone,
21                Copy,
22                PartialEq,
23                Eq,
24                Hash,
25                Debug,
26                FromBytes,
27                IntoBytes,
28                Immutable,
29                KnownLayout,
30            )]
31            pub struct $type_name(pub $ztype);
32
33            // Implementation of constants for the struct
34            impl $type_name {
35                $(
36                    pub const $const_name: $type_name = $type_name($crate::protocol_constants!(@$strategy $ztype, $val));
37                )+
38
39                pub fn is_valid(&self) -> bool {
40                    let p: $primitive = self.0.into();
41                    <[< $type_name Name >] as std::convert::TryFrom<$primitive>>::try_from(p).is_ok()
42                }
43            }
44
45            impl Default for $type_name {
46                fn default() -> Self {
47                    $( $(if stringify!($default) == "default" {
48                            return Self::$const_name;
49                        })?
50                    )+
51                    Self($crate::protocol_constants!(@$strategy $ztype, 0))
52                }
53            }
54
55            // Shadow Enum for Strum machinery
56            #[derive(Debug, PartialEq, strum::EnumString, strum::IntoStaticStr, Clone, Copy)]
57            #[strum(serialize_all = "kebab-case")]
58            #[allow(non_camel_case_types)]
59            enum [< $type_name Name >] {
60                $(
61                    $const_name,
62                )+
63            }
64
65            // Idiomatic conversion from Enum to Primitive
66            impl From<[< $type_name Name >]> for $primitive {
67                fn from(v: [< $type_name Name >]) -> Self {
68                    match v {
69                        $(
70                            [< $type_name Name >]::$const_name => $val,
71                        )+
72                    }
73                }
74            }
75
76            // Fast mapping from Primitive to Enum (used during Serialization)
77            impl TryFrom<$primitive> for [< $type_name Name >] {
78                type Error = ();
79                fn try_from(v: $primitive) -> Result<Self, Self::Error> {
80                    match v {
81                        $(
82                            $val => Ok([< $type_name Name >]::$const_name),
83                        )+
84                        _ => Err(()),
85                    }
86                }
87            }
88
89            // Conversion from Primitive to Struct
90            impl From<$primitive> for $type_name {
91                fn from(v: $primitive) -> Self {
92                    Self(v.into())
93                }
94            }
95
96            // Conversion from Struct to Primitive
97            impl From<$type_name> for $primitive {
98                fn from(v: $type_name) -> Self {
99                    v.0.into()
100                }
101            }
102
103            // Manual Serialize implementation (strata_protocol_names)
104            #[cfg(feature = "strata_protocol_names")]
105            impl serde::Serialize for $type_name {
106                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107                where
108                    S: serde::Serializer,
109                {
110                    let val: $primitive = self.0.into();
111                    if let Ok(proto_enum) = <[< $type_name Name >] as std::convert::TryFrom<$primitive>>::try_from(val) {
112                        let s: &'static str = proto_enum.into();
113                        serializer.serialize_str(s)
114                    } else {
115                        let hex_str = format!("0x{:x}", val);
116                        serializer.serialize_str(&hex_str)
117                    }
118                }
119            }
120
121            // Manual Deserialize implementation (strata_protocol_names)
122            #[cfg(feature = "strata_protocol_names")]
123            impl<'de> serde::Deserialize<'de> for $type_name {
124                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125                where
126                    D: serde::Deserializer<'de>,
127                {
128                    struct Visitor;
129
130                    impl<'de> serde::de::Visitor<'de> for Visitor {
131                        type Value = $type_name;
132
133                        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
134                            formatter.write_str("a protocol name or hex value")
135                        }
136
137                        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
138                        where
139                            E: serde::de::Error,
140                        {
141                            if let Ok(variant) = <[< $type_name Name >] as std::str::FromStr>::from_str(value) {
142                                let p: $primitive = variant.into();
143                                return Ok($type_name(p.into()));
144                            }
145
146                            if value.starts_with("0x") || value.starts_with("0X") {
147                                let no_prefix = &value[2..];
148                                let val = $primitive::from_str_radix(no_prefix, 16)
149                                    .map_err(|_| E::custom(format!("invalid hex: {}", value)))?;
150                                return Ok($type_name(val.into()));
151                            }
152
153                            Err(E::custom(format!("unknown {}Proto: {}", stringify!($type_name), value)))
154                        }
155                    }
156
157                    deserializer.deserialize_str(Visitor)
158                }
159            }
160
161            // Display implementation
162            impl std::fmt::Display for $type_name {
163                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164                    let val: $primitive = self.0.into();
165                    if let Ok(proto_enum) = <[< $type_name Name >] as std::convert::TryFrom<$primitive>>::try_from(val) {
166                        let s: &'static str = proto_enum.into();
167                        f.write_str(s)
168                    } else {
169                        write!(f, "0x{:x}", val)
170                    }
171                }
172            }
173
174            // Binary Serialize implementation
175            #[cfg(not(feature = "strata_protocol_names"))]
176            impl serde::Serialize for $type_name {
177                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
178                where
179                    S: serde::Serializer,
180                {
181                    let val: $primitive = self.0.into();
182                    val.serialize(serializer)
183                }
184            }
185
186            // Binary Deserialize implementation
187            #[cfg(not(feature = "strata_protocol_names"))]
188            impl<'de> serde::Deserialize<'de> for $type_name {
189                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
190                where
191                    D: serde::Deserializer<'de>,
192                {
193                    let val = $primitive::deserialize(deserializer)?;
194                    Ok($type_name(val.into()))
195                }
196            }
197        }
198    };
199
200    // 4. Entry Point: u8 specialization
201    (   $(#[$outer:meta])*
202        $type_name:ident,
203        u8,
204        $primitive:ty:
205        $( $(#[$default:ident])? $const_name:ident = $val:expr; )+
206    ) => {
207        $crate::protocol_constants!(@impl $(#[$outer])*, $type_name, u8, $primitive, construct_u8, $( $(#[$default])? $const_name = $val; )+ );
208    };
209
210    // 5. Entry Point: Generic (U16, etc)
211    (   $(#[$outer:meta])*
212        $type_name:ident,
213        $ztype:ty,
214        $primitive:ty:
215        $( $(#[$default:ident])? $const_name:ident = $val:expr; )+
216    ) => {
217        $crate::protocol_constants!(@impl $(#[$outer])*, $type_name, $ztype, $primitive, construct_new, $( $(#[$default])? $const_name = $val; )+ );
218    };
219}