codama_attributes/codama_directives/
discriminator_directive.rs

1use crate::{
2    utils::{FromMeta, SetOnce},
3    Attribute, Attributes, CodamaAttribute, CodamaDirective, TryFromFilter,
4};
5use codama_errors::CodamaError;
6use codama_nodes::{
7    BytesEncoding, CamelCaseString, ConstantDiscriminatorNode, ConstantValueNode,
8    DiscriminatorNode, FieldDiscriminatorNode, SizeDiscriminatorNode,
9};
10use codama_syn_helpers::{extensions::*, Meta};
11
12#[derive(Debug, PartialEq)]
13pub struct DiscriminatorDirective {
14    pub discriminator: DiscriminatorNode,
15}
16
17impl DiscriminatorDirective {
18    pub fn parse(meta: &Meta) -> syn::Result<Self> {
19        let pl = meta.assert_directive("discriminator")?.as_path_list()?;
20
21        let kind = pl
22            .parse_metas()?
23            .iter()
24            .find_map(|m| match m.path_str().as_str() {
25                "bytes" => Some(DiscriminatorKind::Constant),
26                "field" => Some(DiscriminatorKind::Field),
27                "size" => Some(DiscriminatorKind::Size),
28                _ => None,
29            })
30            .ok_or_else(|| meta.error("discriminator must specify one of: bytes, field, size"))?;
31
32        let mut encoding_is_set: bool = false;
33        let mut bytes_is_array: bool = false;
34        let mut bytes = SetOnce::<BytesValue>::new("bytes");
35        let mut encoding =
36            SetOnce::<BytesEncoding>::new("encoding").initial_value(BytesEncoding::Base16);
37        let mut field = SetOnce::<CamelCaseString>::new("field");
38        let mut offset = SetOnce::<usize>::new("offset").initial_value(0);
39        let mut size = SetOnce::<usize>::new("size");
40        pl.each(|ref meta| match meta.path_str().as_str() {
41            "bytes" => {
42                if kind != DiscriminatorKind::Constant {
43                    return Err(meta.error(format!("bytes cannot be used when {kind} is set")));
44                }
45                let value = BytesValue::from_meta(meta)?;
46                if let BytesValue::Array(_) = value {
47                    bytes_is_array = true;
48                    if encoding_is_set {
49                        return Err(meta.error("bytes must be a string when encoding is set"));
50                    }
51                };
52                bytes.set(value, meta)
53            }
54            "encoding" => {
55                if kind != DiscriminatorKind::Constant {
56                    return Err(meta.error(format!("encoding cannot be used when {kind} is set")));
57                }
58                let value = BytesEncoding::from_meta(meta)?;
59                encoding_is_set = true;
60                if bytes_is_array {
61                    return Err(meta.error("encoding cannot be set when bytes is an array"));
62                }
63                encoding.set(value, meta)
64            }
65            "field" => {
66                if kind != DiscriminatorKind::Field {
67                    return Err(meta.error(format!("field cannot be used when {kind} is set")));
68                }
69                field.set(meta.as_value()?.as_expr()?.as_string()?.into(), meta)
70            }
71            "offset" => {
72                if kind == DiscriminatorKind::Size {
73                    return Err(meta.error(format!("offset cannot be used when {kind} is set")));
74                }
75                offset.set(meta.as_value()?.as_expr()?.as_unsigned_integer()?, meta)
76            }
77            "size" => {
78                if kind != DiscriminatorKind::Size {
79                    return Err(meta.error(format!("size cannot be used when {kind} is set")));
80                }
81                size.set(meta.as_value()?.as_expr()?.as_unsigned_integer()?, meta)
82            }
83            _ => Err(meta.error("unrecognized attribute")),
84        })?;
85
86        Ok(DiscriminatorDirective {
87            discriminator: match kind {
88                DiscriminatorKind::Constant => ConstantDiscriminatorNode::new(
89                    ConstantValueNode::bytes(encoding.take(meta)?, bytes.take(meta)?),
90                    offset.take(meta)?,
91                )
92                .into(),
93                DiscriminatorKind::Field => {
94                    FieldDiscriminatorNode::new(field.take(meta)?, offset.take(meta)?).into()
95                }
96                DiscriminatorKind::Size => SizeDiscriminatorNode::new(size.take(meta)?).into(),
97            },
98        })
99    }
100}
101
102impl<'a> TryFrom<&'a CodamaAttribute<'a>> for &'a DiscriminatorDirective {
103    type Error = CodamaError;
104
105    fn try_from(attribute: &'a CodamaAttribute) -> Result<Self, Self::Error> {
106        match attribute.directive {
107            CodamaDirective::Discriminator(ref a) => Ok(a),
108            _ => Err(CodamaError::InvalidCodamaDirective {
109                expected: "discriminator".to_string(),
110                actual: attribute.directive.name().to_string(),
111            }),
112        }
113    }
114}
115
116impl<'a> TryFrom<&'a Attribute<'a>> for &'a DiscriminatorDirective {
117    type Error = CodamaError;
118
119    fn try_from(attribute: &'a Attribute) -> Result<Self, Self::Error> {
120        <&CodamaAttribute>::try_from(attribute)?.try_into()
121    }
122}
123
124impl From<&DiscriminatorDirective> for DiscriminatorNode {
125    fn from(directive: &DiscriminatorDirective) -> Self {
126        directive.discriminator.clone()
127    }
128}
129
130impl DiscriminatorDirective {
131    pub fn nodes(attributes: &Attributes) -> Vec<DiscriminatorNode> {
132        attributes
133            .iter()
134            .filter_map(DiscriminatorDirective::filter)
135            .map(Into::into)
136            .collect()
137    }
138}
139
140#[derive(PartialEq, Debug)]
141enum DiscriminatorKind {
142    Constant,
143    Field,
144    Size,
145}
146
147impl std::fmt::Display for DiscriminatorKind {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            DiscriminatorKind::Constant => write!(f, "bytes"),
151            DiscriminatorKind::Field => write!(f, "field"),
152            DiscriminatorKind::Size => write!(f, "size"),
153        }
154    }
155}
156
157enum BytesValue {
158    Array(Vec<u8>),
159    Encoded(String),
160}
161
162impl FromMeta for BytesValue {
163    fn from_meta(meta: &Meta) -> syn::Result<Self> {
164        let expr = match meta {
165            Meta::Expr(expr) => Ok(expr.clone()),
166            Meta::PathList(pl) => Ok(pl.as_expr_array()?.into()),
167            _ => meta.as_value()?.as_expr().cloned(),
168        }?;
169
170        if let Ok(s) = expr.as_string() {
171            return Ok(BytesValue::Encoded(s));
172        }
173        if let Ok(arr) = expr.as_u8_array() {
174            return Ok(BytesValue::Array(arr));
175        }
176        Err(expr.error("expected a string or a byte array"))
177    }
178}
179
180impl From<BytesValue> for String {
181    fn from(value: BytesValue) -> Self {
182        match value {
183            BytesValue::Array(bytes) => {
184                let mut s = String::with_capacity(bytes.len() * 2);
185                for byte in bytes {
186                    s.push_str(&format!("{:02x}", byte));
187                }
188                s
189            }
190            BytesValue::Encoded(s) => s,
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn constant_discriminator() {
201        let meta: Meta = syn::parse_quote! { discriminator(bytes = "01020304") };
202        let directive = DiscriminatorDirective::parse(&meta).unwrap();
203        assert_eq!(
204            directive,
205            DiscriminatorDirective {
206                discriminator: ConstantDiscriminatorNode::new(
207                    ConstantValueNode::bytes(BytesEncoding::Base16, "01020304"),
208                    0
209                )
210                .into(),
211            }
212        );
213    }
214
215    #[test]
216    fn constant_discriminator_with_byte_array() {
217        let meta: Meta = syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4]) };
218        let directive = DiscriminatorDirective::parse(&meta).unwrap();
219        assert_eq!(
220            directive,
221            DiscriminatorDirective {
222                discriminator: ConstantDiscriminatorNode::new(
223                    ConstantValueNode::bytes(BytesEncoding::Base16, "01020304"),
224                    0
225                )
226                .into(),
227            }
228        );
229    }
230
231    #[test]
232    fn constant_discriminator_with_encoding() {
233        let meta: Meta = syn::parse_quote! { discriminator(bytes = "hello", encoding = "utf8") };
234        let directive = DiscriminatorDirective::parse(&meta).unwrap();
235        assert_eq!(
236            directive,
237            DiscriminatorDirective {
238                discriminator: ConstantDiscriminatorNode::new(
239                    ConstantValueNode::bytes(BytesEncoding::Utf8, "hello"),
240                    0
241                )
242                .into(),
243            }
244        );
245    }
246
247    #[test]
248    fn constant_discriminator_with_offset() {
249        let meta: Meta = syn::parse_quote! { discriminator(bytes = "ffff", offset = 42) };
250        let directive = DiscriminatorDirective::parse(&meta).unwrap();
251        assert_eq!(
252            directive,
253            DiscriminatorDirective {
254                discriminator: ConstantDiscriminatorNode::new(
255                    ConstantValueNode::bytes(BytesEncoding::Base16, "ffff"),
256                    42
257                )
258                .into(),
259            }
260        );
261    }
262
263    #[test]
264    fn constant_discriminator_with_byte_array_and_encoding() {
265        let meta: Meta =
266            syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4], encoding = "utf8") };
267        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
268        assert_eq!(
269            error.to_string(),
270            "encoding cannot be set when bytes is an array"
271        );
272    }
273
274    #[test]
275    fn constant_discriminator_with_too_many_bytes() {
276        let meta: Meta =
277            syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4], bytes = "01020304") };
278        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
279        assert_eq!(error.to_string(), "bytes is already set");
280    }
281
282    #[test]
283    fn constant_discriminator_with_too_many_encoding() {
284        let meta: Meta = syn::parse_quote! { discriminator(bytes = "01020304", encoding = "utf8", encoding = "base64") };
285        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
286        assert_eq!(error.to_string(), "encoding is already set");
287    }
288
289    #[test]
290    fn constant_discriminator_with_too_many_offsets() {
291        let meta: Meta =
292            syn::parse_quote! { discriminator(bytes = "01020304", offset = 42, offset = 43) };
293        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
294        assert_eq!(error.to_string(), "offset is already set");
295    }
296
297    #[test]
298    fn constant_discriminator_with_encoding_and_byte_array() {
299        let meta: Meta =
300            syn::parse_quote! { discriminator(encoding = "utf8", bytes = [1, 2, 3, 4]) };
301        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
302        assert_eq!(
303            error.to_string(),
304            "bytes must be a string when encoding is set"
305        );
306    }
307
308    #[test]
309    fn constant_discriminator_with_another_discriminator_kind() {
310        let meta: Meta =
311            syn::parse_quote! { discriminator(bytes = "01020304", field = "account_type") };
312        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
313        assert_eq!(error.to_string(), "field cannot be used when bytes is set");
314    }
315
316    #[test]
317    fn field_discriminator() {
318        let meta: Meta = syn::parse_quote! { discriminator(field = "account_type") };
319        let directive = DiscriminatorDirective::parse(&meta).unwrap();
320        assert_eq!(
321            directive,
322            DiscriminatorDirective {
323                discriminator: FieldDiscriminatorNode::new("AccountType", 0).into(),
324            }
325        );
326    }
327
328    #[test]
329    fn field_discriminator_with_offset() {
330        let meta: Meta = syn::parse_quote! { discriminator(field = "account_type", offset = 42) };
331        let directive = DiscriminatorDirective::parse(&meta).unwrap();
332        assert_eq!(
333            directive,
334            DiscriminatorDirective {
335                discriminator: FieldDiscriminatorNode::new("AccountType", 42).into(),
336            }
337        );
338    }
339
340    #[test]
341    fn field_discriminator_with_too_many_field_names() {
342        let meta: Meta =
343            syn::parse_quote! { discriminator(field = "account_type", field = "user_type") };
344        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
345        assert_eq!(error.to_string(), "field is already set");
346    }
347
348    #[test]
349    fn field_discriminator_with_too_many_offsets() {
350        let meta: Meta =
351            syn::parse_quote! { discriminator(field = "account_type", offset = 42, offset = 43) };
352        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
353        assert_eq!(error.to_string(), "offset is already set");
354    }
355
356    #[test]
357    fn field_discriminator_with_another_discriminator_kind() {
358        let meta: Meta = syn::parse_quote! { discriminator(field = "account_type", size = 100) };
359        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
360        assert_eq!(error.to_string(), "size cannot be used when field is set");
361    }
362
363    #[test]
364    fn size_discriminator() {
365        let meta: Meta = syn::parse_quote! { discriminator(size = 100) };
366        let directive = DiscriminatorDirective::parse(&meta).unwrap();
367        assert_eq!(
368            directive,
369            DiscriminatorDirective {
370                discriminator: SizeDiscriminatorNode::new(100).into(),
371            }
372        );
373    }
374
375    #[test]
376    fn size_discriminator_with_too_many_sizes() {
377        let meta: Meta = syn::parse_quote! { discriminator(size = 100, size = 200) };
378        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
379        assert_eq!(error.to_string(), "size is already set");
380    }
381
382    #[test]
383    fn size_discriminator_with_offset() {
384        let meta: Meta = syn::parse_quote! { discriminator(size = 100, offset = 42) };
385        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
386        assert_eq!(error.to_string(), "offset cannot be used when size is set");
387    }
388
389    #[test]
390    fn size_discriminator_with_another_discriminator_kind() {
391        let meta: Meta = syn::parse_quote! { discriminator(size = 100, bytes = [1, 2, 3]) };
392        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
393        assert_eq!(error.to_string(), "bytes cannot be used when size is set");
394    }
395
396    #[test]
397    fn empty_discriminator() {
398        let meta: Meta = syn::parse_quote! { discriminator() };
399        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
400        assert_eq!(
401            error.to_string(),
402            "discriminator must specify one of: bytes, field, size"
403        );
404    }
405
406    #[test]
407    fn discriminator_with_no_kind() {
408        let meta: Meta = syn::parse_quote! { discriminator(encoding = "utf8", offset = 42) };
409        let error = DiscriminatorDirective::parse(&meta).unwrap_err();
410        assert_eq!(
411            error.to_string(),
412            "discriminator must specify one of: bytes, field, size"
413        );
414    }
415}