Skip to main content

cbor2_derive/
lib.rs

1//! Derive support for protocol-shaped CBOR with `cbor2`.
2//!
3//! This crate provides the implementation behind `#[derive(cbor2::Cbor)]`.
4//! Users normally enable it through the `derive` feature of the `cbor2`
5//! crate:
6//!
7//! ```toml
8//! [dependencies]
9//! cbor2 = { version = "1", features = ["derive"] }
10//! serde_bytes = "0.11" # only needed for binary fields like the example below
11//! ```
12//!
13//! The derive generates `serde::Serialize` and `serde::Deserialize` impls
14//! for CBOR protocols that need integer map keys, field-order arrays and
15//! semantic tags, such as COSE (RFC 9052). Map-shaped structs can also use
16//! `#[serde(flatten)]` for extension fields beside the registered integer-key
17//! subset. It implements `cbor2::Cbor`, exposing the declared keys, tag and
18//! array shape as runtime metadata. The original Rust field names stay intact
19//! for JSON and other serde formats.
20//!
21//! ```ignore
22//! use cbor2::Cbor;
23//!
24//! #[derive(Debug, PartialEq, Cbor)]
25//! #[cbor(tag = 18)]
26//! struct CoseHeader {
27//!     #[cbor(key = 1)]
28//!     alg: i8,
29//!     #[cbor(key = 4)]
30//!     #[serde(with = "serde_bytes")]
31//!     kid: Vec<u8>,
32//! }
33//!
34//! assert_eq!(CoseHeader::KEYS, &[("alg", 1), ("kid", 4)]);
35//! assert_eq!(CoseHeader::TAG, Some(18));
36//! ```
37
38use core::fmt::Write as _;
39
40use proc_macro2::TokenStream;
41use quote::{format_ident, quote};
42use syn::parse::{Parse, ParseStream};
43use syn::spanned::Spanned as _;
44
45// The marker prefix recognized by the `cbor2` serializers. Keep in sync
46// with `cbor2::ser::STRUCT_MARKER`; the integration tests of the `cbor2`
47// crate pin the resulting wire bytes.
48const MARKER: &str = "@@CBOR@@";
49
50/// Derives `serde::Serialize` and `serde::Deserialize` with CBOR protocol
51/// details: integer map keys (`#[cbor(key = <integer>)]` on fields),
52/// field-order array structs (`#[cbor(array)]` on the container) and a
53/// CBOR tag (`#[cbor(tag = <integer>)]` on the container). The tag is
54/// written on encode and transparent on decode, so input is accepted with
55/// or without it. The declared details are also exposed through an
56/// implementation of the `cbor2::Cbor` trait, so the generated code
57/// requires the `cbor2` crate under that name.
58///
59/// Do not also derive serde's `Serialize`/`Deserialize`: this macro
60/// generates both impls (the implementations would conflict).
61#[proc_macro_derive(Cbor, attributes(cbor, serde))]
62pub fn derive_cbor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
63    expand(item.into())
64        .unwrap_or_else(|err| err.to_compile_error())
65        .into()
66}
67
68fn expand(item: TokenStream) -> syn::Result<TokenStream> {
69    let input: syn::DeriveInput = syn::parse2(item)?;
70
71    if let Some(lifetime) = input
72        .generics
73        .lifetimes()
74        .find(|def| def.lifetime.ident == "de")
75    {
76        return Err(syn::Error::new(
77            lifetime.lifetime.span(),
78            "#[derive(Cbor)] cannot support a lifetime named 'de because serde's \
79             Deserialize derive reserves that name; rename the lifetime",
80        ));
81    }
82
83    let container = container_attrs(&input.attrs)?;
84    let serde = scan_serde(&input.attrs);
85    if let Some(span) = serde.rename.map(|(_, span)| span).or(serde.split_rename) {
86        return Err(syn::Error::new(
87            span,
88            "#[derive(Cbor)] does not support a container-level #[serde(rename = ...)]; \
89             rename the type itself",
90        ));
91    }
92
93    let mut entries = Vec::new();
94    let mut flatten = false;
95    match &input.data {
96        syn::Data::Struct(data) => {
97            for entry in field_entries(&data.fields)? {
98                merge_entry(&mut entries, entry)?;
99            }
100            if let Some(span) = fields_have_flatten(&data.fields) {
101                flatten = true;
102                if !matches!(data.fields, syn::Fields::Named(..)) {
103                    return Err(syn::Error::new(
104                        span,
105                        "#[serde(flatten)] with #[derive(Cbor)] requires a struct with named fields",
106                    ));
107                }
108                if let Some(array) = container.array {
109                    return Err(syn::Error::new(
110                        array,
111                        "#[serde(flatten)] cannot be used with #[cbor(array)]",
112                    ));
113                }
114            }
115
116            if let Some(span) = container.array {
117                if !matches!(data.fields, syn::Fields::Named(..)) {
118                    return Err(syn::Error::new(
119                        span,
120                        "#[cbor(array)] requires a struct with named fields",
121                    ));
122                }
123                if let Some(entry) = entries.first() {
124                    return Err(syn::Error::new(
125                        entry.span,
126                        "#[cbor(key = ...)] cannot be used with #[cbor(array)]",
127                    ));
128                }
129            }
130
131            if !entries.is_empty() {
132                if let Some(span) = serde.rename_all {
133                    return Err(syn::Error::new(
134                        span,
135                        "#[serde(rename_all = ...)] is not supported with \
136                         #[cbor(key = ...)]; rename the fields explicitly",
137                    ));
138                }
139            }
140        }
141
142        syn::Data::Enum(data) => {
143            if let Some(tag) = &container.tag {
144                return Err(syn::Error::new(
145                    tag.span,
146                    "`tag = ...` is not supported on enums",
147                ));
148            }
149            if let Some(span) = container.array {
150                return Err(syn::Error::new(span, "`array` is not supported on enums"));
151            }
152
153            for variant in &data.variants {
154                if let Some(attr) = variant.attrs.iter().find(|a| a.path().is_ident("cbor")) {
155                    return Err(syn::Error::new(
156                        attr.span(),
157                        "#[cbor(...)] is not supported on enum variants",
158                    ));
159                }
160
161                let keyed = field_entries(&variant.fields)?;
162                if let Some(span) = fields_have_flatten(&variant.fields) {
163                    return Err(syn::Error::new(
164                        span,
165                        "#[serde(flatten)] with #[derive(Cbor)] is supported only on structs",
166                    ));
167                }
168                if !keyed.is_empty() {
169                    if let Some(span) = scan_serde(&variant.attrs).rename_all {
170                        return Err(syn::Error::new(
171                            span,
172                            "#[serde(rename_all = ...)] is not supported with \
173                             #[cbor(key = ...)]; rename the fields explicitly",
174                        ));
175                    }
176                }
177                for entry in keyed {
178                    merge_entry(&mut entries, entry)?;
179                }
180            }
181
182            if !entries.is_empty() {
183                if let Some(span) = serde.rename_all_fields {
184                    return Err(syn::Error::new(
185                        span,
186                        "#[serde(rename_all_fields = ...)] is not supported with \
187                         #[cbor(key = ...)]; rename the fields explicitly",
188                    ));
189                }
190                if let Some(span) = serde.enum_repr {
191                    return Err(syn::Error::new(
192                        span,
193                        "only externally tagged enums support #[cbor(key = ...)]",
194                    ));
195                }
196            }
197        }
198
199        syn::Data::Union(data) => {
200            return Err(syn::Error::new(
201                data.union_token.span(),
202                "Cbor supports structs and enums",
203            ));
204        }
205    }
206
207    Ok(generate(
208        &input,
209        container.tag.as_ref().map(|tag| tag.value),
210        container.array.is_some(),
211        flatten,
212        &entries,
213    ))
214}
215
216// Generates the serde impls: a hidden *shadow* of the item carrying the
217// marker rename plus `#[serde(remote = ...)]`, and two impls delegating
218// to the shadow's generated functions. The shadow accesses the real
219// type's fields directly, so nothing is copied at runtime, and the real
220// type's name and field names stay exactly as written.
221fn generate(
222    input: &syn::DeriveInput,
223    tag: Option<u64>,
224    array: bool,
225    flatten: bool,
226    entries: &[Entry],
227) -> TokenStream {
228    let ident = &input.ident;
229    let shadow_ident = format_ident!("__CborShadow");
230
231    let mut shadow = input.clone();
232    shadow.ident = shadow_ident.clone();
233    shadow.attrs = copied_attrs(&input.attrs);
234    match &mut shadow.data {
235        syn::Data::Struct(data) => {
236            for field in data.fields.iter_mut() {
237                field.attrs = copied_attrs(&field.attrs);
238            }
239        }
240        syn::Data::Enum(data) => {
241            for variant in data.variants.iter_mut() {
242                variant.attrs = copied_attrs(&variant.attrs);
243                for field in variant.fields.iter_mut() {
244                    field.attrs = copied_attrs(&field.attrs);
245                }
246            }
247        }
248        syn::Data::Union(..) => unreachable!("rejected above"),
249    }
250
251    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
252
253    // The remote path: the real type, as seen from inside the const
254    // block. serde applies the shadow's own generics to it, so the path
255    // itself must not carry generic arguments.
256    let remote = ident.to_string();
257
258    let mut head = vec![
259        syn::parse_quote!(#[derive(::serde::Serialize, ::serde::Deserialize)]),
260        syn::parse_quote!(#[serde(remote = #remote)]),
261        syn::parse_quote!(#[automatically_derived]),
262    ];
263    if let Some(marker) = marker(tag, array, entries, ident) {
264        head.push(syn::parse_quote!(#[serde(rename = #marker)]));
265    }
266    head.append(&mut shadow.attrs);
267    shadow.attrs = head;
268
269    // `T: Serialize` / `T: Deserialize<'de>` bounds, like serde's derive.
270    let mut ser_generics = input.generics.clone();
271    for param in ser_generics.type_params_mut() {
272        param.bounds.push(syn::parse_quote!(::serde::Serialize));
273    }
274    let (ser_impl_generics, ..) = ser_generics.split_for_impl();
275
276    let de_lifetime = fresh_de_lifetime(&input.generics);
277    let mut de_generics = input.generics.clone();
278    for param in de_generics.type_params_mut() {
279        param
280            .bounds
281            .push(syn::parse_quote!(::serde::Deserialize<#de_lifetime>));
282    }
283    let mut de_lifetime_param = syn::LifetimeParam::new(de_lifetime.clone());
284    de_lifetime_param
285        .bounds
286        .extend(input.generics.lifetimes().map(|def| def.lifetime.clone()));
287    de_generics
288        .params
289        .insert(0, syn::GenericParam::Lifetime(de_lifetime_param));
290    let (de_impl_generics, ..) = de_generics.split_for_impl();
291
292    let serde_impls = if flatten {
293        let cbor_lifetime = fresh_lifetime(&input.generics, "__cbor");
294
295        let mut shadow_ref_generics = input.generics.clone();
296        shadow_ref_generics.params.insert(
297            0,
298            syn::GenericParam::Lifetime(syn::LifetimeParam::new(cbor_lifetime.clone())),
299        );
300        let (shadow_ref_impl_generics, _shadow_ref_ty_generics, shadow_ref_where_clause) =
301            shadow_ref_generics.split_for_impl();
302
303        let mut ser_ref_generics = ser_generics.clone();
304        ser_ref_generics.params.insert(
305            0,
306            syn::GenericParam::Lifetime(syn::LifetimeParam::new(cbor_lifetime.clone())),
307        );
308        let (ser_ref_impl_generics, ser_ref_ty_generics, ser_ref_where_clause) =
309            ser_ref_generics.split_for_impl();
310
311        quote! {
312            struct __CborShadowRef #shadow_ref_impl_generics #shadow_ref_where_clause {
313                value: &#cbor_lifetime #ident #ty_generics,
314            }
315
316            impl #ser_ref_impl_generics ::serde::Serialize for __CborShadowRef #ser_ref_ty_generics #ser_ref_where_clause {
317                fn serialize<__S>(&self, serializer: __S) -> ::core::result::Result<__S::Ok, __S::Error>
318                where
319                    __S: ::serde::Serializer,
320                {
321                    #shadow_ident::serialize(self.value, serializer)
322                }
323            }
324
325            struct __CborShadowOwned #impl_generics (#ident #ty_generics) #where_clause;
326
327            impl #de_impl_generics ::serde::Deserialize<#de_lifetime> for __CborShadowOwned #ty_generics #where_clause {
328                fn deserialize<__D>(deserializer: __D) -> ::core::result::Result<Self, __D::Error>
329                where
330                    __D: ::serde::Deserializer<#de_lifetime>,
331                {
332                    #shadow_ident::deserialize(deserializer).map(Self)
333                }
334            }
335
336            #[automatically_derived]
337            impl #ser_impl_generics ::serde::Serialize for #ident #ty_generics #where_clause {
338                fn serialize<__S>(&self, serializer: __S) -> ::core::result::Result<__S::Ok, __S::Error>
339                where
340                    __S: ::serde::Serializer,
341                {
342                    if serializer.is_human_readable() {
343                        return #shadow_ident::serialize(self, serializer);
344                    }
345
346                    let __value = ::cbor2::Value::serialized(&__CborShadowRef { value: self })
347                        .map_err(::serde::ser::Error::custom)?;
348                    let __value = ::cbor2::__private::__cbor2_flatten_serialize(
349                        __value,
350                        <#ident #ty_generics as ::cbor2::Cbor>::TAG,
351                        <#ident #ty_generics as ::cbor2::Cbor>::KEYS,
352                    )
353                    .map_err(::serde::ser::Error::custom)?;
354                    ::serde::Serialize::serialize(&__value, serializer)
355                }
356            }
357
358            #[automatically_derived]
359            impl #de_impl_generics ::serde::Deserialize<#de_lifetime> for #ident #ty_generics #where_clause {
360                fn deserialize<__D>(deserializer: __D) -> ::core::result::Result<Self, __D::Error>
361                where
362                    __D: ::serde::Deserializer<#de_lifetime>,
363                {
364                    if deserializer.is_human_readable() {
365                        return #shadow_ident::deserialize(deserializer);
366                    }
367
368                    let __value: ::cbor2::Value =
369                        ::serde::Deserialize::deserialize(deserializer)?;
370                    let __value = ::cbor2::__private::__cbor2_flatten_deserialize(
371                        __value,
372                        <#ident #ty_generics as ::cbor2::Cbor>::KEYS,
373                    )
374                    .map_err(::serde::de::Error::custom)?;
375                    let __value: __CborShadowOwned #ty_generics =
376                        ::cbor2::__private::__cbor2_flatten_deserialize_value(&__value)
377                        .map_err(::serde::de::Error::custom)?;
378                    ::core::result::Result::Ok(__value.0)
379                }
380            }
381        }
382    } else {
383        quote! {
384            #[automatically_derived]
385            impl #ser_impl_generics ::serde::Serialize for #ident #ty_generics #where_clause {
386                fn serialize<__S>(&self, serializer: __S) -> ::core::result::Result<__S::Ok, __S::Error>
387                where
388                    __S: ::serde::Serializer,
389                {
390                    #shadow_ident::serialize(self, serializer)
391                }
392            }
393
394            #[automatically_derived]
395            impl #de_impl_generics ::serde::Deserialize<#de_lifetime> for #ident #ty_generics #where_clause {
396                fn deserialize<__D>(deserializer: __D) -> ::core::result::Result<Self, __D::Error>
397                where
398                    __D: ::serde::Deserializer<#de_lifetime>,
399                {
400                    #shadow_ident::deserialize(deserializer)
401                }
402            }
403        }
404    };
405
406    // The `cbor2::Cbor` trait exposes the declared protocol details.
407    let key_pairs = entries.iter().map(|entry| {
408        let name = &entry.name;
409        let key = entry.key;
410        quote!((#name, #key))
411    });
412    let tag_const = match tag {
413        Some(tag) => quote!(::core::option::Option::Some(#tag)),
414        None => quote!(::core::option::Option::None),
415    };
416    let array_const = array;
417
418    quote! {
419        #[doc(hidden)]
420        const _: () = {
421            #shadow
422
423            #serde_impls
424
425            #[automatically_derived]
426            impl #impl_generics ::cbor2::Cbor for #ident #ty_generics #where_clause {
427                const KEYS: &'static [(&'static str, i128)] = &[#(#key_pairs),*];
428                const TAG: ::core::option::Option<u64> = #tag_const;
429                const ARRAY: bool = #array_const;
430            }
431        };
432    }
433}
434
435// Picks an internal deserializer lifetime that cannot collide with the
436// user's generics. User code may legitimately name a lifetime `'de`.
437fn fresh_de_lifetime(generics: &syn::Generics) -> syn::Lifetime {
438    fresh_lifetime(generics, "__de")
439}
440
441fn fresh_lifetime(generics: &syn::Generics, base: &str) -> syn::Lifetime {
442    let mut name = String::from(base);
443    while generics.lifetimes().any(|def| def.lifetime.ident == name) {
444        name.push('_');
445    }
446
447    syn::Lifetime::new(&format!("'{name}"), proc_macro2::Span::call_site())
448}
449
450// The attributes that carry over to the shadow: serde configuration and
451// conditional compilation. Everything else — docs, derives, `#[cbor]` —
452// stays behind.
453fn copied_attrs(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
454    attrs
455        .iter()
456        .filter(|attr| {
457            let path = attr.path();
458            path.is_ident("serde") || path.is_ident("cfg") || path.is_ident("cfg_attr")
459        })
460        .cloned()
461        .collect()
462}
463
464// The `@@CBOR@@<tag>@@<keys>@@<name>` container marker, when the item
465// declares a tag, array shape or integer keys.
466fn marker(tag: Option<u64>, array: bool, entries: &[Entry], ident: &syn::Ident) -> Option<String> {
467    if tag.is_none() && entries.is_empty() && !array {
468        return None;
469    }
470
471    let mut marker = String::from(MARKER);
472    if let Some(tag) = tag {
473        let _ = write!(&mut marker, "{tag}");
474    }
475    marker.push_str("@@");
476    for (i, entry) in entries.iter().enumerate() {
477        if i > 0 {
478            marker.push(';');
479        }
480        let _ = write!(&mut marker, "{}={}", entry.name, entry.key);
481    }
482    marker.push_str("@@");
483    if array {
484        marker.push_str("array@@");
485    }
486    let name = ident.to_string();
487    marker.push_str(name.strip_prefix("r#").unwrap_or(&name));
488
489    Some(marker)
490}
491
492// `tag = <integer>` inside the container's `#[cbor(...)]`.
493struct TagArg {
494    value: u64,
495    span: proc_macro2::Span,
496}
497
498// `key = <integer>` inside a field's `#[cbor(...)]`.
499struct KeyArg {
500    value: i128,
501    span: proc_macro2::Span,
502}
503
504impl Parse for KeyArg {
505    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
506        let name: syn::Ident = input.parse()?;
507        if name != "key" {
508            return Err(syn::Error::new(name.span(), "expected `key = <integer>`"));
509        }
510        input.parse::<syn::Token![=]>()?;
511
512        let negative = input.peek(syn::Token![-]);
513        if negative {
514            input.parse::<syn::Token![-]>()?;
515        }
516        let literal: syn::LitInt = input.parse()?;
517
518        let magnitude: i128 = literal.base10_parse()?;
519        let value = if negative { -magnitude } else { magnitude };
520
521        Ok(KeyArg {
522            value,
523            span: literal.span(),
524        })
525    }
526}
527
528struct ContainerAttrs {
529    tag: Option<TagArg>,
530    array: Option<proc_macro2::Span>,
531}
532
533// Reads the container-level `#[cbor(tag = ..., array)]` attribute.
534fn container_attrs(attrs: &[syn::Attribute]) -> syn::Result<ContainerAttrs> {
535    let mut out = ContainerAttrs {
536        tag: None,
537        array: None,
538    };
539
540    for attr in attrs {
541        if !attr.path().is_ident("cbor") {
542            continue;
543        }
544
545        attr.parse_nested_meta(|meta| {
546            if meta.path.is_ident("tag") {
547                const RANGE: &str = "tag must fit a CBOR tag (0 ..= 2^64 - 1)";
548                let value = meta.value()?;
549                if value.peek(syn::Token![-]) {
550                    return Err(syn::Error::new(value.span(), RANGE));
551                }
552                let literal: syn::LitInt = value.parse()?;
553                let tag = TagArg {
554                    value: literal
555                        .base10_parse()
556                        .map_err(|_| syn::Error::new(literal.span(), RANGE))?,
557                    span: literal.span(),
558                };
559                if out.tag.replace(tag).is_some() {
560                    return Err(syn::Error::new(
561                        meta.path.span(),
562                        "duplicate #[cbor(tag = ...)] attribute",
563                    ));
564                }
565                Ok(())
566            } else if meta.path.is_ident("array") {
567                if meta.input.peek(syn::Token![=]) {
568                    return Err(syn::Error::new(
569                        meta.path.span(),
570                        "expected `array` without a value",
571                    ));
572                }
573                if out.array.replace(meta.path.span()).is_some() {
574                    return Err(syn::Error::new(
575                        meta.path.span(),
576                        "duplicate #[cbor(array)] attribute",
577                    ));
578                }
579                Ok(())
580            } else {
581                Err(syn::Error::new(
582                    meta.path.span(),
583                    "expected `tag = <integer>` or `array`",
584                ))
585            }
586        })?;
587    }
588
589    Ok(out)
590}
591
592// One `<name>=<key>` entry of the marker's key table.
593struct Entry {
594    name: String,
595    key: i128,
596    span: proc_macro2::Span,
597}
598
599// Adds an entry, rejecting ambiguous mappings. Identical mappings merge,
600// so enum variants may share a field.
601fn merge_entry(entries: &mut Vec<Entry>, entry: Entry) -> syn::Result<()> {
602    match entries
603        .iter()
604        .find(|e| e.name == entry.name || e.key == entry.key)
605    {
606        Some(e) if e.name == entry.name && e.key == entry.key => Ok(()),
607        Some(e) if e.name == entry.name => Err(syn::Error::new(
608            entry.span,
609            format!(
610                "field `{}` maps to conflicting keys {} and {}",
611                entry.name, e.key, entry.key
612            ),
613        )),
614        Some(e) => Err(syn::Error::new(
615            entry.span,
616            format!("key {} is already mapped to field `{}`", entry.key, e.name),
617        )),
618        None => {
619            entries.push(entry);
620            Ok(())
621        }
622    }
623}
624
625// Reads the `#[cbor(key = ...)]` field attributes into key table entries
626// under the fields' serde names.
627fn field_entries(fields: &syn::Fields) -> syn::Result<Vec<Entry>> {
628    let mut entries = Vec::new();
629
630    for field in fields {
631        let mut key: Option<KeyArg> = None;
632        for attr in &field.attrs {
633            if !attr.path().is_ident("cbor") {
634                continue;
635            }
636
637            let arg: KeyArg = attr.parse_args()?;
638            if key.replace(arg).is_some() {
639                return Err(syn::Error::new(
640                    attr.span(),
641                    "duplicate #[cbor(key = ...)] attribute",
642                ));
643            }
644        }
645        let serde = scan_serde(&field.attrs);
646        if let (Some(..), Some(span)) = (&key, serde.flatten) {
647            return Err(syn::Error::new(
648                span,
649                "#[serde(flatten)] cannot be combined with #[cbor(key = ...)]",
650            ));
651        }
652
653        let Some(key) = key else { continue };
654
655        if field.ident.is_none() {
656            return Err(syn::Error::new(
657                key.span,
658                "#[cbor(key = ...)] requires a named field",
659            ));
660        }
661
662        // CBOR integer keys span major types 0 and 1.
663        if key.value > u64::MAX as i128 || key.value < -(u64::MAX as i128) - 1 {
664            return Err(syn::Error::new(
665                key.span,
666                "#[cbor(key = ...)] must fit a CBOR integer (-2^64 ..= 2^64 - 1)",
667            ));
668        }
669
670        if let Some(span) = serde.split_rename {
671            return Err(syn::Error::new(
672                span,
673                "split serialize/deserialize renames are not supported with \
674                 #[cbor(key = ...)]",
675            ));
676        }
677
678        // The key table is consulted with the field's *serde* name, so an
679        // explicit rename carries over.
680        let name = match serde.rename {
681            Some((name, _)) => name,
682            None => {
683                let ident = field.ident.as_ref().expect("checked above").to_string();
684                ident.strip_prefix("r#").unwrap_or(&ident).to_string()
685            }
686        };
687
688        if name.is_empty() || name.contains(['@', ';', '=']) {
689            return Err(syn::Error::new(
690                key.span,
691                "the serde name of a keyed field may not be empty or contain '@', ';' or '='",
692            ));
693        }
694
695        entries.push(Entry {
696            name,
697            key: key.value,
698            span: key.span,
699        });
700    }
701
702    Ok(entries)
703}
704
705fn fields_have_flatten(fields: &syn::Fields) -> Option<proc_macro2::Span> {
706    fields
707        .iter()
708        .find_map(|field| scan_serde(&field.attrs).flatten)
709}
710
711// The serde attribute metas the marker must coordinate with.
712#[derive(Default)]
713struct SerdeAttrs {
714    rename: Option<(String, proc_macro2::Span)>,
715    split_rename: Option<proc_macro2::Span>,
716    rename_all: Option<proc_macro2::Span>,
717    rename_all_fields: Option<proc_macro2::Span>,
718    enum_repr: Option<proc_macro2::Span>,
719    flatten: Option<proc_macro2::Span>,
720}
721
722// Scans `#[serde(...)]` attributes, tolerating any meta shapes we do not
723// understand — the serde derive validates them later anyway.
724fn scan_serde(attrs: &[syn::Attribute]) -> SerdeAttrs {
725    let mut out = SerdeAttrs::default();
726
727    for attr in attrs {
728        if !attr.path().is_ident("serde") {
729            continue;
730        }
731
732        let _ = attr.parse_nested_meta(|meta| {
733            if meta.path.is_ident("rename") {
734                if meta.input.peek(syn::Token![=]) {
735                    let expr: syn::Expr = meta.value()?.parse()?;
736                    if let syn::Expr::Lit(syn::ExprLit {
737                        lit: syn::Lit::Str(s),
738                        ..
739                    }) = expr
740                    {
741                        out.rename = Some((s.value(), meta.path.span()));
742                    }
743                    return Ok(());
744                }
745                out.split_rename = Some(meta.path.span());
746            } else if meta.path.is_ident("rename_all") {
747                out.rename_all = Some(meta.path.span());
748            } else if meta.path.is_ident("rename_all_fields") {
749                out.rename_all_fields = Some(meta.path.span());
750            } else if meta.path.is_ident("flatten") {
751                out.flatten = Some(meta.path.span());
752            } else if meta.path.is_ident("tag")
753                || meta.path.is_ident("untagged")
754                || meta.path.is_ident("content")
755            {
756                out.enum_repr = Some(meta.path.span());
757            }
758
759            if meta.input.peek(syn::token::Paren) {
760                let content;
761                syn::parenthesized!(content in meta.input);
762                let _: TokenStream = content.parse()?;
763            } else if !meta.input.is_empty() && !meta.input.peek(syn::Token![,]) {
764                let _: syn::Expr = meta.value()?.parse()?;
765            }
766
767            Ok(())
768        });
769    }
770
771    out
772}
773
774#[cfg(test)]
775mod tests {
776    use quote::quote;
777
778    use super::*;
779
780    fn expanded(item: TokenStream) -> String {
781        expand(item).unwrap().to_string()
782    }
783
784    fn error(item: TokenStream) -> String {
785        expand(item).unwrap_err().to_string()
786    }
787
788    #[test]
789    fn generates_a_marked_remote_shadow() {
790        let out = expanded(quote! {
791            #[cbor(tag = 123)]
792            struct ProtectedHeader {
793                #[cbor(key = 1)]
794                alg: i8,
795                #[cbor(key = 4)]
796                #[serde(with = "serde_bytes")]
797                kid: Vec<u8>,
798                plain: bool,
799            }
800        });
801
802        assert!(
803            out.contains(r#"rename = "@@CBOR@@123@@alg=1;kid=4@@ProtectedHeader""#),
804            "{out}"
805        );
806        assert!(out.contains(r#"remote = "ProtectedHeader""#), "{out}");
807        assert!(out.contains(r#"with = "serde_bytes""#), "{out}");
808        assert!(
809            out.contains("impl :: serde :: Serialize for ProtectedHeader"),
810            "{out}"
811        );
812        assert!(
813            out.contains("impl < '__de > :: serde :: Deserialize < '__de > for ProtectedHeader"),
814            "{out}"
815        );
816        // The #[cbor(...)] attributes stay off the shadow.
817        assert!(!out.contains("# [cbor"), "{out}");
818
819        // The declared details surface through the cbor2::Cbor trait.
820        assert!(
821            out.contains("impl :: cbor2 :: Cbor for ProtectedHeader"),
822            "{out}"
823        );
824        assert!(
825            out.contains(r#"const KEYS : & 'static [(& 'static str , i128)] = & [("alg" , 1i128) , ("kid" , 4i128)] ;"#),
826            "{out}"
827        );
828        assert!(
829            out.contains(":: core :: option :: Option :: Some (123u64)"),
830            "{out}"
831        );
832    }
833
834    #[test]
835    fn generates_plain_serde_impls_without_cbor_attributes() {
836        let out = expanded(quote! {
837            struct Plain {
838                a: u8,
839            }
840        });
841
842        assert!(!out.contains("@@CBOR@@"), "{out}");
843        assert!(out.contains(r#"remote = "Plain""#), "{out}");
844        assert!(
845            out.contains("impl :: serde :: Serialize for Plain"),
846            "{out}"
847        );
848
849        // The trait impl is still generated, with an empty table.
850        assert!(
851            out.contains(r#"const KEYS : & 'static [(& 'static str , i128)] = & [] ;"#),
852            "{out}"
853        );
854        assert!(out.contains(":: core :: option :: Option :: None"), "{out}");
855    }
856
857    #[test]
858    fn uses_the_serde_rename_as_the_key_table_name() {
859        let out = expanded(quote! {
860            struct S {
861                #[cbor(key = 1)]
862                #[serde(rename = "alg", default)]
863                algorithm: i8,
864            }
865        });
866
867        assert!(out.contains(r#"rename = "@@CBOR@@@@alg=1@@S""#), "{out}");
868        assert!(out.contains(r#"rename = "alg""#), "{out}");
869        assert!(out.contains("default"), "{out}");
870    }
871
872    #[test]
873    fn supports_field_order_array_structs() {
874        let out = expanded(quote! {
875            #[cbor(tag = 18, array)]
876            struct Sign1 {
877                protected: Vec<u8>,
878                unprotected: u8,
879                payload: Vec<u8>,
880                signature: Vec<u8>,
881            }
882        });
883
884        assert!(
885            out.contains(r#"rename = "@@CBOR@@18@@@@array@@Sign1""#),
886            "{out}"
887        );
888        assert!(out.contains("const ARRAY : bool = true"), "{out}");
889        assert!(out.contains("impl :: cbor2 :: Cbor for Sign1"), "{out}");
890    }
891
892    #[test]
893    fn supports_flattened_map_structs() {
894        let out = expanded(quote! {
895            #[cbor(tag = 61)]
896            struct Claims {
897                #[cbor(key = 1)]
898                #[serde(rename = "iss")]
899                issuer: String,
900                #[serde(flatten)]
901                extra: BTreeMap<String, cbor2::Value>,
902            }
903        });
904
905        assert!(out.contains("__cbor2_flatten_serialize"), "{out}");
906        assert!(out.contains("__cbor2_flatten_deserialize"), "{out}");
907        assert!(
908            out.contains(r#"rename = "@@CBOR@@61@@iss=1@@Claims""#),
909            "{out}"
910        );
911        assert!(
912            out.contains(
913                r#"const KEYS : & 'static [(& 'static str , i128)] = & [("iss" , 1i128)] ;"#
914            ),
915            "{out}"
916        );
917    }
918
919    #[test]
920    fn strips_raw_identifier_prefixes() {
921        let out = expanded(quote! {
922            struct S {
923                #[cbor(key = 1)]
924                r#type: u8,
925            }
926        });
927
928        assert!(out.contains(r#"rename = "@@CBOR@@@@type=1@@S""#), "{out}");
929    }
930
931    #[test]
932    fn merges_enum_variant_fields() {
933        let out = expanded(quote! {
934            enum Message {
935                Signed {
936                    #[cbor(key = 1)]
937                    payload: u8,
938                },
939                Verified {
940                    #[cbor(key = 1)]
941                    payload: u8,
942                    #[cbor(key = 2)]
943                    peer: u8,
944                },
945                Unit,
946            }
947        });
948
949        assert!(
950            out.contains(r#"rename = "@@CBOR@@@@payload=1;peer=2@@Message""#),
951            "{out}"
952        );
953    }
954
955    #[test]
956    fn keeps_generics_and_their_bounds() {
957        let out = expanded(quote! {
958            #[cbor(tag = 7)]
959            struct Wrap<T: Clone> {
960                #[cbor(key = 1)]
961                inner: T,
962            }
963        });
964
965        assert!(out.contains(r#"remote = "Wrap""#), "{out}");
966        assert!(
967            out.contains(
968                "impl < T : Clone + :: serde :: Serialize > :: serde :: Serialize for Wrap < T >"
969            ),
970            "{out}"
971        );
972        assert!(
973            out.contains("impl < '__de , T : Clone + :: serde :: Deserialize < '__de > >"),
974            "{out}"
975        );
976        // The trait impl carries the original generics, without serde bounds.
977        assert!(
978            out.contains("impl < T : Clone > :: cbor2 :: Cbor for Wrap < T >"),
979            "{out}"
980        );
981    }
982
983    #[test]
984    fn avoids_deserialize_lifetime_collisions() {
985        let out = expanded(quote! {
986            struct Borrowed<'a, '__de> {
987                #[cbor(key = 1)]
988                value: &'a str,
989                other: &'__de str,
990            }
991        });
992
993        assert!(
994            out.contains(
995                "impl < '__de_ : 'a + '__de , 'a , '__de > :: serde :: Deserialize < '__de_ > for Borrowed < 'a , '__de >"
996            ),
997            "{out}"
998        );
999    }
1000
1001    #[test]
1002    fn rejects_user_lifetime_named_de() {
1003        let msg = error(quote! {
1004            struct Borrowed<'de> {
1005                value: &'de str,
1006            }
1007        });
1008
1009        assert!(msg.contains("lifetime named 'de"), "{msg}");
1010    }
1011
1012    #[test]
1013    fn accepts_the_full_integer_ranges() {
1014        let out = expanded(quote! {
1015            #[cbor(tag = 18446744073709551615)]
1016            struct Edges {
1017                #[cbor(key = 0)]
1018                zero: u8,
1019                #[cbor(key = 18446744073709551615)]
1020                hi: u8,
1021                #[cbor(key = -18446744073709551616)]
1022                lo: u8,
1023            }
1024        });
1025
1026        assert!(
1027            out.contains(
1028                r#"rename = "@@CBOR@@18446744073709551615@@zero=0;hi=18446744073709551615;lo=-18446744073709551616@@Edges""#
1029            ),
1030            "{out}"
1031        );
1032    }
1033
1034    #[test]
1035    fn rejects_invalid_uses() {
1036        let msg = error(quote! {
1037            struct S {
1038                #[cbor(key = 18446744073709551616)]
1039                a: u8,
1040            }
1041        });
1042        assert!(msg.contains("must fit a CBOR integer"), "{msg}");
1043
1044        let msg = error(quote! {
1045            #[cbor(tag = 18446744073709551616)]
1046            struct S;
1047        });
1048        assert!(msg.contains("must fit a CBOR tag"), "{msg}");
1049
1050        let msg = error(quote! {
1051            #[cbor(tag = -1)]
1052            struct S;
1053        });
1054        assert!(msg.contains("must fit a CBOR tag"), "{msg}");
1055
1056        let msg = error(quote! {
1057            #[cbor(tag = 1)]
1058            #[cbor(tag = 2)]
1059            struct S;
1060        });
1061        assert!(msg.contains("duplicate #[cbor(tag = ...)]"), "{msg}");
1062
1063        let msg = error(quote! {
1064            #[cbor(tag = 1)]
1065            enum E { A }
1066        });
1067        assert!(msg.contains("not supported on enums"), "{msg}");
1068
1069        let msg = error(quote! {
1070            #[cbor(array)]
1071            enum E { A }
1072        });
1073        assert!(msg.contains("`array` is not supported on enums"), "{msg}");
1074
1075        let msg = error(quote! {
1076            #[cbor(array)]
1077            struct S(u8);
1078        });
1079        assert!(msg.contains("requires a struct with named fields"), "{msg}");
1080
1081        let msg = error(quote! {
1082            #[cbor(array)]
1083            struct S {
1084                #[cbor(key = 1)]
1085                a: u8,
1086            }
1087        });
1088        assert!(msg.contains("cannot be used with #[cbor(array)]"), "{msg}");
1089
1090        let msg = error(quote! {
1091            #[cbor(array)]
1092            struct S {
1093                a: u8,
1094                #[serde(flatten)]
1095                extra: BTreeMap<String, u8>,
1096            }
1097        });
1098        assert!(msg.contains("cannot be used with #[cbor(array)]"), "{msg}");
1099
1100        let msg = error(quote! {
1101            struct S {
1102                #[cbor(key = 1)]
1103                #[cbor(key = 2)]
1104                a: u8,
1105            }
1106        });
1107        assert!(msg.contains("duplicate #[cbor(key = ...)]"), "{msg}");
1108
1109        let msg = error(quote! {
1110            struct S(#[cbor(key = 1)] u8);
1111        });
1112        assert!(msg.contains("named field"), "{msg}");
1113
1114        let msg = error(quote! {
1115            struct S {
1116                #[cbor(name = 1)]
1117                a: u8,
1118            }
1119        });
1120        assert!(msg.contains("expected `key = <integer>`"), "{msg}");
1121
1122        let msg = error(quote! {
1123            struct S {
1124                #[cbor(key = 1)]
1125                a: u8,
1126                #[cbor(key = 9)]
1127                #[serde(flatten)]
1128                extra: BTreeMap<String, u8>,
1129            }
1130        });
1131        assert!(msg.contains("cannot be combined with #[cbor(key"), "{msg}");
1132
1133        let msg = error(quote! {
1134            enum E {
1135                A {
1136                    #[serde(flatten)]
1137                    extra: BTreeMap<String, u8>,
1138                },
1139            }
1140        });
1141        assert!(msg.contains("supported only on structs"), "{msg}");
1142
1143        let msg = error(quote! {
1144            #[cbor(key = 1)]
1145            struct S {
1146                a: u8,
1147            }
1148        });
1149        assert!(msg.contains("expected `tag = <integer>`"), "{msg}");
1150
1151        let msg = error(quote! {
1152            union U { a: u8 }
1153        });
1154        assert!(msg.contains("supports structs and enums"), "{msg}");
1155
1156        let msg = error(quote! {
1157            #[cbor(tag = 1, foo)]
1158            struct S {
1159                a: u8,
1160            }
1161        });
1162        assert!(
1163            msg.contains("expected `tag = <integer>` or `array`"),
1164            "{msg}"
1165        );
1166    }
1167
1168    #[test]
1169    fn rejects_serde_conflicts() {
1170        let msg = error(quote! {
1171            #[serde(rename = "Other")]
1172            struct S {
1173                #[cbor(key = 1)]
1174                a: u8,
1175            }
1176        });
1177        assert!(msg.contains("container-level #[serde(rename"), "{msg}");
1178
1179        let msg = error(quote! {
1180            #[serde(rename_all = "camelCase")]
1181            struct S {
1182                #[cbor(key = 1)]
1183                a_b: u8,
1184            }
1185        });
1186        assert!(msg.contains("rename_all"), "{msg}");
1187
1188        let msg = error(quote! {
1189            struct S {
1190                #[cbor(key = 1)]
1191                #[serde(rename(serialize = "x", deserialize = "y"))]
1192                a: u8,
1193            }
1194        });
1195        assert!(msg.contains("split serialize/deserialize renames"), "{msg}");
1196
1197        let msg = error(quote! {
1198            #[serde(tag = "type")]
1199            enum E {
1200                A {
1201                    #[cbor(key = 1)]
1202                    a: u8,
1203                },
1204            }
1205        });
1206        assert!(msg.contains("externally tagged"), "{msg}");
1207
1208        let msg = error(quote! {
1209            struct S {
1210                #[cbor(key = 1)]
1211                a: u8,
1212                #[cbor(key = 1)]
1213                b: u8,
1214            }
1215        });
1216        assert!(msg.contains("already mapped"), "{msg}");
1217
1218        let msg = error(quote! {
1219            enum E {
1220                A {
1221                    #[cbor(key = 1)]
1222                    x: u8,
1223                },
1224                B {
1225                    #[cbor(key = 2)]
1226                    x: u8,
1227                },
1228            }
1229        });
1230        assert!(msg.contains("conflicting keys"), "{msg}");
1231
1232        let msg = error(quote! {
1233            enum E {
1234                #[cbor(tag = 1)]
1235                A,
1236            }
1237        });
1238        assert!(msg.contains("not supported on enum variants"), "{msg}");
1239
1240        // A rename whose value would corrupt the marker grammar.
1241        let msg = error(quote! {
1242            struct S {
1243                #[cbor(key = 1)]
1244                #[serde(rename = "a=b")]
1245                a: u8,
1246            }
1247        });
1248        assert!(msg.contains("may not be empty or contain"), "{msg}");
1249    }
1250}