codama_attributes/codama_directives/
discriminator_directive.rs1use crate::{
2 utils::{FromMeta, SetOnce},
3 Attribute, Attributes, CodamaAttribute, CodamaDirective, TryFromFilter,
4};
5use codama_errors::CodamaError;
6use codama_nodes::{
7 BytesEncoding, CamelCaseString, ConstantDiscriminatorNode, ConstantValueNode,
8 DiscriminatorNode, FieldDiscriminatorNode, SizeDiscriminatorNode,
9};
10use codama_syn_helpers::{extensions::*, Meta};
11
12#[derive(Debug, PartialEq)]
13pub struct DiscriminatorDirective {
14 pub discriminator: DiscriminatorNode,
15}
16
17impl DiscriminatorDirective {
18 pub fn parse(meta: &Meta) -> syn::Result<Self> {
19 let pl = meta.assert_directive("discriminator")?.as_path_list()?;
20
21 let kind = pl
22 .parse_metas()?
23 .iter()
24 .find_map(|m| match m.path_str().as_str() {
25 "bytes" => Some(DiscriminatorKind::Constant),
26 "field" => Some(DiscriminatorKind::Field),
27 "size" => Some(DiscriminatorKind::Size),
28 _ => None,
29 })
30 .ok_or_else(|| meta.error("discriminator must specify one of: bytes, field, size"))?;
31
32 let mut encoding_is_set: bool = false;
33 let mut bytes_is_array: bool = false;
34 let mut bytes = SetOnce::<BytesValue>::new("bytes");
35 let mut encoding =
36 SetOnce::<BytesEncoding>::new("encoding").initial_value(BytesEncoding::Base16);
37 let mut field = SetOnce::<CamelCaseString>::new("field");
38 let mut offset = SetOnce::<usize>::new("offset").initial_value(0);
39 let mut size = SetOnce::<usize>::new("size");
40 pl.each(|ref meta| match meta.path_str().as_str() {
41 "bytes" => {
42 if kind != DiscriminatorKind::Constant {
43 return Err(meta.error(format!("bytes cannot be used when {kind} is set")));
44 }
45 let value = BytesValue::from_meta(meta)?;
46 if let BytesValue::Array(_) = value {
47 bytes_is_array = true;
48 if encoding_is_set {
49 return Err(meta.error("bytes must be a string when encoding is set"));
50 }
51 };
52 bytes.set(value, meta)
53 }
54 "encoding" => {
55 if kind != DiscriminatorKind::Constant {
56 return Err(meta.error(format!("encoding cannot be used when {kind} is set")));
57 }
58 let value = BytesEncoding::from_meta(meta)?;
59 encoding_is_set = true;
60 if bytes_is_array {
61 return Err(meta.error("encoding cannot be set when bytes is an array"));
62 }
63 encoding.set(value, meta)
64 }
65 "field" => {
66 if kind != DiscriminatorKind::Field {
67 return Err(meta.error(format!("field cannot be used when {kind} is set")));
68 }
69 field.set(meta.as_value()?.as_expr()?.as_string()?.into(), meta)
70 }
71 "offset" => {
72 if kind == DiscriminatorKind::Size {
73 return Err(meta.error(format!("offset cannot be used when {kind} is set")));
74 }
75 offset.set(meta.as_value()?.as_expr()?.as_unsigned_integer()?, meta)
76 }
77 "size" => {
78 if kind != DiscriminatorKind::Size {
79 return Err(meta.error(format!("size cannot be used when {kind} is set")));
80 }
81 size.set(meta.as_value()?.as_expr()?.as_unsigned_integer()?, meta)
82 }
83 _ => Err(meta.error("unrecognized attribute")),
84 })?;
85
86 Ok(DiscriminatorDirective {
87 discriminator: match kind {
88 DiscriminatorKind::Constant => ConstantDiscriminatorNode::new(
89 ConstantValueNode::bytes(encoding.take(meta)?, bytes.take(meta)?),
90 offset.take(meta)?,
91 )
92 .into(),
93 DiscriminatorKind::Field => {
94 FieldDiscriminatorNode::new(field.take(meta)?, offset.take(meta)?).into()
95 }
96 DiscriminatorKind::Size => SizeDiscriminatorNode::new(size.take(meta)?).into(),
97 },
98 })
99 }
100}
101
102impl<'a> TryFrom<&'a CodamaAttribute<'a>> for &'a DiscriminatorDirective {
103 type Error = CodamaError;
104
105 fn try_from(attribute: &'a CodamaAttribute) -> Result<Self, Self::Error> {
106 match attribute.directive {
107 CodamaDirective::Discriminator(ref a) => Ok(a),
108 _ => Err(CodamaError::InvalidCodamaDirective {
109 expected: "discriminator".to_string(),
110 actual: attribute.directive.name().to_string(),
111 }),
112 }
113 }
114}
115
116impl<'a> TryFrom<&'a Attribute<'a>> for &'a DiscriminatorDirective {
117 type Error = CodamaError;
118
119 fn try_from(attribute: &'a Attribute) -> Result<Self, Self::Error> {
120 <&CodamaAttribute>::try_from(attribute)?.try_into()
121 }
122}
123
124impl From<&DiscriminatorDirective> for DiscriminatorNode {
125 fn from(directive: &DiscriminatorDirective) -> Self {
126 directive.discriminator.clone()
127 }
128}
129
130impl DiscriminatorDirective {
131 pub fn nodes(attributes: &Attributes) -> Vec<DiscriminatorNode> {
132 attributes
133 .iter()
134 .filter_map(DiscriminatorDirective::filter)
135 .map(Into::into)
136 .collect()
137 }
138}
139
140#[derive(PartialEq, Debug)]
141enum DiscriminatorKind {
142 Constant,
143 Field,
144 Size,
145}
146
147impl std::fmt::Display for DiscriminatorKind {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 DiscriminatorKind::Constant => write!(f, "bytes"),
151 DiscriminatorKind::Field => write!(f, "field"),
152 DiscriminatorKind::Size => write!(f, "size"),
153 }
154 }
155}
156
157enum BytesValue {
158 Array(Vec<u8>),
159 Encoded(String),
160}
161
162impl FromMeta for BytesValue {
163 fn from_meta(meta: &Meta) -> syn::Result<Self> {
164 let expr = match meta {
165 Meta::Expr(expr) => Ok(expr.clone()),
166 Meta::PathList(pl) => Ok(pl.as_expr_array()?.into()),
167 _ => meta.as_value()?.as_expr().cloned(),
168 }?;
169
170 if let Ok(s) = expr.as_string() {
171 return Ok(BytesValue::Encoded(s));
172 }
173 if let Ok(arr) = expr.as_u8_array() {
174 return Ok(BytesValue::Array(arr));
175 }
176 Err(expr.error("expected a string or a byte array"))
177 }
178}
179
180impl From<BytesValue> for String {
181 fn from(value: BytesValue) -> Self {
182 match value {
183 BytesValue::Array(bytes) => {
184 let mut s = String::with_capacity(bytes.len() * 2);
185 for byte in bytes {
186 s.push_str(&format!("{:02x}", byte));
187 }
188 s
189 }
190 BytesValue::Encoded(s) => s,
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn constant_discriminator() {
201 let meta: Meta = syn::parse_quote! { discriminator(bytes = "01020304") };
202 let directive = DiscriminatorDirective::parse(&meta).unwrap();
203 assert_eq!(
204 directive,
205 DiscriminatorDirective {
206 discriminator: ConstantDiscriminatorNode::new(
207 ConstantValueNode::bytes(BytesEncoding::Base16, "01020304"),
208 0
209 )
210 .into(),
211 }
212 );
213 }
214
215 #[test]
216 fn constant_discriminator_with_byte_array() {
217 let meta: Meta = syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4]) };
218 let directive = DiscriminatorDirective::parse(&meta).unwrap();
219 assert_eq!(
220 directive,
221 DiscriminatorDirective {
222 discriminator: ConstantDiscriminatorNode::new(
223 ConstantValueNode::bytes(BytesEncoding::Base16, "01020304"),
224 0
225 )
226 .into(),
227 }
228 );
229 }
230
231 #[test]
232 fn constant_discriminator_with_encoding() {
233 let meta: Meta = syn::parse_quote! { discriminator(bytes = "hello", encoding = "utf8") };
234 let directive = DiscriminatorDirective::parse(&meta).unwrap();
235 assert_eq!(
236 directive,
237 DiscriminatorDirective {
238 discriminator: ConstantDiscriminatorNode::new(
239 ConstantValueNode::bytes(BytesEncoding::Utf8, "hello"),
240 0
241 )
242 .into(),
243 }
244 );
245 }
246
247 #[test]
248 fn constant_discriminator_with_offset() {
249 let meta: Meta = syn::parse_quote! { discriminator(bytes = "ffff", offset = 42) };
250 let directive = DiscriminatorDirective::parse(&meta).unwrap();
251 assert_eq!(
252 directive,
253 DiscriminatorDirective {
254 discriminator: ConstantDiscriminatorNode::new(
255 ConstantValueNode::bytes(BytesEncoding::Base16, "ffff"),
256 42
257 )
258 .into(),
259 }
260 );
261 }
262
263 #[test]
264 fn constant_discriminator_with_byte_array_and_encoding() {
265 let meta: Meta =
266 syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4], encoding = "utf8") };
267 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
268 assert_eq!(
269 error.to_string(),
270 "encoding cannot be set when bytes is an array"
271 );
272 }
273
274 #[test]
275 fn constant_discriminator_with_too_many_bytes() {
276 let meta: Meta =
277 syn::parse_quote! { discriminator(bytes = [1, 2, 3, 4], bytes = "01020304") };
278 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
279 assert_eq!(error.to_string(), "bytes is already set");
280 }
281
282 #[test]
283 fn constant_discriminator_with_too_many_encoding() {
284 let meta: Meta = syn::parse_quote! { discriminator(bytes = "01020304", encoding = "utf8", encoding = "base64") };
285 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
286 assert_eq!(error.to_string(), "encoding is already set");
287 }
288
289 #[test]
290 fn constant_discriminator_with_too_many_offsets() {
291 let meta: Meta =
292 syn::parse_quote! { discriminator(bytes = "01020304", offset = 42, offset = 43) };
293 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
294 assert_eq!(error.to_string(), "offset is already set");
295 }
296
297 #[test]
298 fn constant_discriminator_with_encoding_and_byte_array() {
299 let meta: Meta =
300 syn::parse_quote! { discriminator(encoding = "utf8", bytes = [1, 2, 3, 4]) };
301 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
302 assert_eq!(
303 error.to_string(),
304 "bytes must be a string when encoding is set"
305 );
306 }
307
308 #[test]
309 fn constant_discriminator_with_another_discriminator_kind() {
310 let meta: Meta =
311 syn::parse_quote! { discriminator(bytes = "01020304", field = "account_type") };
312 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
313 assert_eq!(error.to_string(), "field cannot be used when bytes is set");
314 }
315
316 #[test]
317 fn field_discriminator() {
318 let meta: Meta = syn::parse_quote! { discriminator(field = "account_type") };
319 let directive = DiscriminatorDirective::parse(&meta).unwrap();
320 assert_eq!(
321 directive,
322 DiscriminatorDirective {
323 discriminator: FieldDiscriminatorNode::new("AccountType", 0).into(),
324 }
325 );
326 }
327
328 #[test]
329 fn field_discriminator_with_offset() {
330 let meta: Meta = syn::parse_quote! { discriminator(field = "account_type", offset = 42) };
331 let directive = DiscriminatorDirective::parse(&meta).unwrap();
332 assert_eq!(
333 directive,
334 DiscriminatorDirective {
335 discriminator: FieldDiscriminatorNode::new("AccountType", 42).into(),
336 }
337 );
338 }
339
340 #[test]
341 fn field_discriminator_with_too_many_field_names() {
342 let meta: Meta =
343 syn::parse_quote! { discriminator(field = "account_type", field = "user_type") };
344 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
345 assert_eq!(error.to_string(), "field is already set");
346 }
347
348 #[test]
349 fn field_discriminator_with_too_many_offsets() {
350 let meta: Meta =
351 syn::parse_quote! { discriminator(field = "account_type", offset = 42, offset = 43) };
352 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
353 assert_eq!(error.to_string(), "offset is already set");
354 }
355
356 #[test]
357 fn field_discriminator_with_another_discriminator_kind() {
358 let meta: Meta = syn::parse_quote! { discriminator(field = "account_type", size = 100) };
359 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
360 assert_eq!(error.to_string(), "size cannot be used when field is set");
361 }
362
363 #[test]
364 fn size_discriminator() {
365 let meta: Meta = syn::parse_quote! { discriminator(size = 100) };
366 let directive = DiscriminatorDirective::parse(&meta).unwrap();
367 assert_eq!(
368 directive,
369 DiscriminatorDirective {
370 discriminator: SizeDiscriminatorNode::new(100).into(),
371 }
372 );
373 }
374
375 #[test]
376 fn size_discriminator_with_too_many_sizes() {
377 let meta: Meta = syn::parse_quote! { discriminator(size = 100, size = 200) };
378 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
379 assert_eq!(error.to_string(), "size is already set");
380 }
381
382 #[test]
383 fn size_discriminator_with_offset() {
384 let meta: Meta = syn::parse_quote! { discriminator(size = 100, offset = 42) };
385 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
386 assert_eq!(error.to_string(), "offset cannot be used when size is set");
387 }
388
389 #[test]
390 fn size_discriminator_with_another_discriminator_kind() {
391 let meta: Meta = syn::parse_quote! { discriminator(size = 100, bytes = [1, 2, 3]) };
392 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
393 assert_eq!(error.to_string(), "bytes cannot be used when size is set");
394 }
395
396 #[test]
397 fn empty_discriminator() {
398 let meta: Meta = syn::parse_quote! { discriminator() };
399 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
400 assert_eq!(
401 error.to_string(),
402 "discriminator must specify one of: bytes, field, size"
403 );
404 }
405
406 #[test]
407 fn discriminator_with_no_kind() {
408 let meta: Meta = syn::parse_quote! { discriminator(encoding = "utf8", offset = 42) };
409 let error = DiscriminatorDirective::parse(&meta).unwrap_err();
410 assert_eq!(
411 error.to_string(),
412 "discriminator must specify one of: bytes, field, size"
413 );
414 }
415}