codama_attributes/codama_directives/
discriminator_directive.rs

1use crate::{
2    utils::{get_expr_from_meta_with_path_lists_as_arrays, FromMeta, SetOnce},
3    Attribute, Attributes, CodamaAttribute, CodamaDirective, TryFromFilter,
4};
5use codama_errors::CodamaError;
6use codama_nodes::{
7    BytesEncoding, CamelCaseString, ConstantDiscriminatorNode, ConstantValueNode,
8    DiscriminatorNode, FieldDiscriminatorNode, SizeDiscriminatorNode,
9};
10use codama_syn_helpers::{extensions::*, Meta};
11
12#[derive(Debug, PartialEq)]
13pub struct DiscriminatorDirective {
14    pub discriminator: DiscriminatorNode,
15}
16
17impl DiscriminatorDirective {
18    pub fn parse(meta: &Meta) -> syn::Result<Self> {
19        let pl = meta.assert_directive("discriminator")?.as_path_list()?;
20
21        let kind = pl
22            .parse_metas()?
23            .iter()
24            .find_map(|m| match m.path_str().as_str() {
25                "bytes" => Some(DiscriminatorKind::Constant),
26                "field" => Some(DiscriminatorKind::Field),
27                "size" => Some(DiscriminatorKind::Size),
28                _ => None,
29            })
30            .ok_or_else(|| meta.error("discriminator must specify one of: bytes, field, size"))?;
31
32        let mut encoding_is_set: bool = false;
33        let mut bytes_is_array: bool = false;
34        let mut bytes = SetOnce::<BytesValue>::new("bytes");
35        let mut encoding =
36            SetOnce::<BytesEncoding>::new("encoding").initial_value(BytesEncoding::Base16);
37        let mut field = SetOnce::<CamelCaseString>::new("field");
38        let mut offset = SetOnce::<usize>::new("offset").initial_value(0);
39        let mut size = SetOnce::<usize>::new("size");
40        pl.each(|ref meta| match meta.path_str().as_str() {
41            "bytes" => {
42                if kind != DiscriminatorKind::Constant {
43                    return Err(meta.error(format!("bytes cannot be used when {kind} is set")));
44                }
45                let value = BytesValue::from_meta(meta)?;
46                if let BytesValue::Array(_) = value {
47                    bytes_is_array = true;
48                    if encoding_is_set {
49                        return Err(meta.error("bytes must be a string when encoding is set"));
50                    }
51                };
52                bytes.set(value, meta)
53            }
54            "encoding" => {
55                if kind != DiscriminatorKind::Constant {
56                    return Err(meta.error(format!("encoding cannot be used when {kind} is set")));
57                }
58                let value = BytesEncoding::from_meta(meta)?;
59                encoding_is_set = true;
60                if bytes_is_array {
61                    return Err(meta.error("encoding cannot be set when bytes is an array"));
62                }
63                encoding.set(value, meta)
64            }
65            "field" => {
66                if kind != DiscriminatorKind::Field {
67                    return Err(meta.error(format!("field cannot be used when {kind} is set")));
68                }
69                field.set(String::from_meta(meta)?.into(), meta)
70            }
71            "offset" => {
72                if kind == DiscriminatorKind::Size {
73                    return Err(meta.error(format!("offset cannot be used when {kind} is set")));
74                }
75                offset.set(usize::from_meta(meta)?, meta)
76            }
77            "size" => {
78                if kind != DiscriminatorKind::Size {
79                    return Err(meta.error(format!("size cannot be used when {kind} is set")));
80                }
81                size.set(usize::from_meta(meta)?, meta)
82            }
83            _ => Err(meta.error("unrecognized attribute")),
84        })?;
85
86        Ok(DiscriminatorDirective {
87            discriminator: match kind {
88                DiscriminatorKind::Constant => ConstantDiscriminatorNode::new(
89                    ConstantValueNode::bytes(encoding.take(meta)?, bytes.take(meta)?),
90                    offset.take(meta)?,
91                )
92                .into(),
93                DiscriminatorKind::Field => {
94                    FieldDiscriminatorNode::new(field.take(meta)?, offset.take(meta)?).into()
95                }
96                DiscriminatorKind::Size => SizeDiscriminatorNode::new(size.take(meta)?).into(),
97            },
98        })
99    }
100}
101
102impl<'a> TryFrom<&'a CodamaAttribute<'a>> for &'a DiscriminatorDirective {
103    type Error = CodamaError;
104
105    fn try_from(attribute: &'a CodamaAttribute) -> Result<Self, Self::Error> {
106        match attribute.directive {
107            CodamaDirective::Discriminator(ref a) => Ok(a),
108            _ => Err(CodamaError::InvalidCodamaDirective {
109                expected: "discriminator".to_string(),
110                actual: attribute.directive.name().to_string(),
111            }),
112        }
113    }
114}
115
116impl<'a> TryFrom<&'a Attribute<'a>> for &'a DiscriminatorDirective {
117    type Error = CodamaError;
118
119    fn try_from(attribute: &'a Attribute) -> Result<Self, Self::Error> {
120        <&CodamaAttribute>::try_from(attribute)?.try_into()
121    }
122}
123
124impl From<&DiscriminatorDirective> for DiscriminatorNode {
125    fn from(directive: &DiscriminatorDirective) -> Self {
126        directive.discriminator.clone()
127    }
128}
129
130impl DiscriminatorDirective {
131    pub fn nodes(attributes: &Attributes) -> Vec<DiscriminatorNode> {
132        attributes
133            .iter()
134            .filter_map(DiscriminatorDirective::filter)
135            .map(Into::into)
136            .collect()
137    }
138}
139
140#[derive(PartialEq, Debug)]
141enum DiscriminatorKind {
142    Constant,
143    Field,
144    Size,
145}
146
147impl std::fmt::Display for DiscriminatorKind {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            DiscriminatorKind::Constant => write!(f, "bytes"),
151            DiscriminatorKind::Field => write!(f, "field"),
152            DiscriminatorKind::Size => write!(f, "size"),
153        }
154    }
155}
156
157enum BytesValue {
158    Array(Vec<u8>),
159    Encoded(String),
160}
161
162impl FromMeta for BytesValue {
163    fn from_meta(meta: &Meta) -> syn::Result<Self> {
164        let expr = get_expr_from_meta_with_path_lists_as_arrays(meta)?;
165        if let Ok(s) = expr.as_string() {
166            return Ok(BytesValue::Encoded(s));
167        }
168        match expr.as_u8_array() {
169            Ok(arr) => Ok(BytesValue::Array(arr)),
170            Err(_) => Err(expr.error("expected a string or a byte array")),
171        }
172    }
173}
174
175impl From<BytesValue> for String {
176    fn from(value: BytesValue) -> Self {
177        match value {
178            BytesValue::Array(bytes) => {
179                let mut s = String::with_capacity(bytes.len() * 2);
180                for byte in bytes {
181                    s.push_str(&format!("{:02x}", byte));
182                }
183                s
184            }
185            BytesValue::Encoded(s) => s,
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn constant_discriminator() {
196        let meta: Meta = syn::parse_quote! { discriminator(bytes = "01020304") };
197        let directive = DiscriminatorDirective::parse(&meta).unwrap();
198        assert_eq!(
199            directive,
200            DiscriminatorDirective {
201                discriminator: ConstantDiscriminatorNode::new(
202                    ConstantValueNode::bytes(BytesEncoding::Base16, "01020304"),
203                    0
204                )
205                .into(),
206            }
207        );
208    }
209
210    #[test]
211    fn constant_discriminator_with_byte_array() {
212        let meta: Meta = syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4]) };
213        let directive = DiscriminatorDirective::parse(&meta).unwrap();
214        assert_eq!(
215            directive,
216            DiscriminatorDirective {
217                discriminator: ConstantDiscriminatorNode::new(
218                    ConstantValueNode::bytes(BytesEncoding::Base16, "01020304"),
219                    0
220                )
221                .into(),
222            }
223        );
224    }
225
226    #[test]
227    fn constant_discriminator_with_encoding() {
228        let meta: Meta = syn::parse_quote! { discriminator(bytes = "hello", encoding = "utf8") };
229        let directive = DiscriminatorDirective::parse(&meta).unwrap();
230        assert_eq!(
231            directive,
232            DiscriminatorDirective {
233                discriminator: ConstantDiscriminatorNode::new(
234                    ConstantValueNode::bytes(BytesEncoding::Utf8, "hello"),
235                    0
236                )
237                .into(),
238            }
239        );
240    }
241
242    #[test]
243    fn constant_discriminator_with_offset() {
244        let meta: Meta = syn::parse_quote! { discriminator(bytes = "ffff", offset = 42) };
245        let directive = DiscriminatorDirective::parse(&meta).unwrap();
246        assert_eq!(
247            directive,
248            DiscriminatorDirective {
249                discriminator: ConstantDiscriminatorNode::new(
250                    ConstantValueNode::bytes(BytesEncoding::Base16, "ffff"),
251                    42
252                )
253                .into(),
254            }
255        );
256    }
257
258    #[test]
259    fn constant_discriminator_with_byte_array_and_encoding() {
260        let meta: Meta =
261            syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4], encoding = "utf8") };
262        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
263        assert_eq!(
264            error.to_string(),
265            "encoding cannot be set when bytes is an array"
266        );
267    }
268
269    #[test]
270    fn constant_discriminator_with_too_many_bytes() {
271        let meta: Meta =
272            syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4], bytes = "01020304") };
273        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
274        assert_eq!(error.to_string(), "bytes is already set");
275    }
276
277    #[test]
278    fn constant_discriminator_with_too_many_encoding() {
279        let meta: Meta = syn::parse_quote! { discriminator(bytes = "01020304", encoding = "utf8", encoding = "base64") };
280        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
281        assert_eq!(error.to_string(), "encoding is already set");
282    }
283
284    #[test]
285    fn constant_discriminator_with_too_many_offsets() {
286        let meta: Meta =
287            syn::parse_quote! { discriminator(bytes = "01020304", offset = 42, offset = 43) };
288        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
289        assert_eq!(error.to_string(), "offset is already set");
290    }
291
292    #[test]
293    fn constant_discriminator_with_encoding_and_byte_array() {
294        let meta: Meta =
295            syn::parse_quote! { discriminator(encoding = "utf8", bytes = [1, 2, 3, 4]) };
296        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
297        assert_eq!(
298            error.to_string(),
299            "bytes must be a string when encoding is set"
300        );
301    }
302
303    #[test]
304    fn constant_discriminator_with_another_discriminator_kind() {
305        let meta: Meta =
306            syn::parse_quote! { discriminator(bytes = "01020304", field = "account_type") };
307        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
308        assert_eq!(error.to_string(), "field cannot be used when bytes is set");
309    }
310
311    #[test]
312    fn field_discriminator() {
313        let meta: Meta = syn::parse_quote! { discriminator(field = "account_type") };
314        let directive = DiscriminatorDirective::parse(&meta).unwrap();
315        assert_eq!(
316            directive,
317            DiscriminatorDirective {
318                discriminator: FieldDiscriminatorNode::new("AccountType", 0).into(),
319            }
320        );
321    }
322
323    #[test]
324    fn field_discriminator_with_offset() {
325        let meta: Meta = syn::parse_quote! { discriminator(field = "account_type", offset = 42) };
326        let directive = DiscriminatorDirective::parse(&meta).unwrap();
327        assert_eq!(
328            directive,
329            DiscriminatorDirective {
330                discriminator: FieldDiscriminatorNode::new("AccountType", 42).into(),
331            }
332        );
333    }
334
335    #[test]
336    fn field_discriminator_with_too_many_field_names() {
337        let meta: Meta =
338            syn::parse_quote! { discriminator(field = "account_type", field = "user_type") };
339        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
340        assert_eq!(error.to_string(), "field is already set");
341    }
342
343    #[test]
344    fn field_discriminator_with_too_many_offsets() {
345        let meta: Meta =
346            syn::parse_quote! { discriminator(field = "account_type", offset = 42, offset = 43) };
347        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
348        assert_eq!(error.to_string(), "offset is already set");
349    }
350
351    #[test]
352    fn field_discriminator_with_another_discriminator_kind() {
353        let meta: Meta = syn::parse_quote! { discriminator(field = "account_type", size = 100) };
354        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
355        assert_eq!(error.to_string(), "size cannot be used when field is set");
356    }
357
358    #[test]
359    fn size_discriminator() {
360        let meta: Meta = syn::parse_quote! { discriminator(size = 100) };
361        let directive = DiscriminatorDirective::parse(&meta).unwrap();
362        assert_eq!(
363            directive,
364            DiscriminatorDirective {
365                discriminator: SizeDiscriminatorNode::new(100).into(),
366            }
367        );
368    }
369
370    #[test]
371    fn size_discriminator_with_too_many_sizes() {
372        let meta: Meta = syn::parse_quote! { discriminator(size = 100, size = 200) };
373        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
374        assert_eq!(error.to_string(), "size is already set");
375    }
376
377    #[test]
378    fn size_discriminator_with_offset() {
379        let meta: Meta = syn::parse_quote! { discriminator(size = 100, offset = 42) };
380        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
381        assert_eq!(error.to_string(), "offset cannot be used when size is set");
382    }
383
384    #[test]
385    fn size_discriminator_with_another_discriminator_kind() {
386        let meta: Meta = syn::parse_quote! { discriminator(size = 100, bytes = [1, 2, 3]) };
387        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
388        assert_eq!(error.to_string(), "bytes cannot be used when size is set");
389    }
390
391    #[test]
392    fn empty_discriminator() {
393        let meta: Meta = syn::parse_quote! { discriminator() };
394        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
395        assert_eq!(
396            error.to_string(),
397            "discriminator must specify one of: bytes, field, size"
398        );
399    }
400
401    #[test]
402    fn discriminator_with_no_kind() {
403        let meta: Meta = syn::parse_quote! { discriminator(encoding = "utf8", offset = 42) };
404        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
405        assert_eq!(
406            error.to_string(),
407            "discriminator must specify one of: bytes, field, size"
408        );
409    }
410}