Skip to main content

darkbio_crypto_cbor_derive/
lib.rs

1// crypto-rs: cryptography primitives and wrappers
2// Copyright 2025 Dark Bio AG. All rights reserved.
3//
4// Use of this source code is governed by a BSD-style
5// license that can be found in the LICENSE file.
6
7//! Procedural macros for darkbio-crypto.
8//!
9//! Provides the `Cbor` derive macro for structs that generates both `Encode`
10//! and `Decode` implementations.
11//!
12//! # Struct encoding modes
13//!
14//! By default, structs encode as CBOR maps with integer keys specified via
15//! `#[cbor(key = N)]`. Use `#[cbor(array)]` to encode as a CBOR array instead.
16//!
17//! # Examples
18//!
19//! Map encoding (default):
20//! ```ignore
21//! #[derive(Cbor)]
22//! struct Data {
23//!     #[cbor(key = 1)]
24//!     name: String,
25//!     #[cbor(key = -1)]
26//!     value: u64,
27//! }
28//! // Encodes as: {1: name, -1: value} (sorted by bytewise key encoding)
29//! ```
30//!
31//! Array encoding:
32//! ```ignore
33//! #[derive(Cbor)]
34//! #[cbor(array)]
35//! struct Point {
36//!     x: u64,
37//!     y: u64,
38//! }
39//! // Encodes as: [x, y]
40//! ```
41
42mod cbor;
43
44use cbor::cbor_key_bytes;
45use proc_macro::TokenStream;
46use proc_macro2::TokenStream as TokenStream2;
47use quote::quote;
48use std::collections::BTreeSet;
49use syn::{Data, DeriveInput, Expr, Fields, Lit, parse_macro_input};
50
51/// Derives the CBOR encoder and decoder for structs tagged with #[derive(Cbor)]
52/// and internally fields tagged with #[cbor(...)].
53#[proc_macro_derive(Cbor, attributes(cbor))]
54pub fn derive_cbor(input: TokenStream) -> TokenStream {
55    let input = parse_macro_input!(input as DeriveInput);
56    let encode = derive_encode(&input);
57    let decode = derive_decode(&input);
58    match (encode, decode) {
59        (Ok(enc), Ok(dec)) => quote! { #enc #dec }.into(),
60        (Err(e), _) | (_, Err(e)) => e.to_compile_error().into(),
61    }
62}
63
64/// Generates the `Encode` trait implementation for a struct.
65fn derive_encode(input: &DeriveInput) -> syn::Result<TokenStream2> {
66    let fields = parse_fields(input)?;
67    if want_array(input) {
68        derive_encode_array(&input.ident, &fields)
69    } else {
70        derive_encode_map(&input.ident, &fields)
71    }
72}
73
74/// Generates array-mode `Encode` impl: fields encoded in declaration order.
75fn derive_encode_array(name: &syn::Ident, fields: &[FieldInfo]) -> syn::Result<TokenStream2> {
76    let cbor_crate = quote! { darkbio_crypto::cbor };
77    for field in fields {
78        if field.embed {
79            return Err(syn::Error::new_spanned(
80                &field.ident,
81                "#[cbor(embed)] is not supported on #[cbor(array)] structs",
82            ));
83        }
84    }
85    let len = fields.len();
86
87    // Generate code to encode each field in declaration order
88    let encode_fields: Vec<_> = fields
89        .iter()
90        .map(|f| {
91            let ident = &f.ident;
92            quote! { self.#ident.encode_cbor_to(buf)?; }
93        })
94        .collect();
95
96    Ok(quote! {
97        impl #cbor_crate::Encode for #name {
98            fn encode_cbor_to(&self, buf: &mut Vec<u8>) -> Result<(), #cbor_crate::Error> {
99                #cbor_crate::encode_array_header_to(buf, #len);
100                #(#encode_fields)*
101                Ok(())
102            }
103        }
104    })
105}
106
107/// Generates map-mode `Encode` impl: fields encoded as key-value pairs, sorted by key bytes.
108/// Option<T> fields are omitted when None (the key-value pair is not encoded).
109/// Fields with #[cbor(embed)] are flattened: their CBOR map entries are merged into the parent.
110fn derive_encode_map(name: &syn::Ident, fields: &[FieldInfo]) -> syn::Result<TokenStream2> {
111    let cbor_crate = quote! { darkbio_crypto::cbor };
112
113    // Validate field attributes: mutual exclusivity, duplicate direct keys,
114    // and reject embed on optional/nullable types.
115    let embed_count = fields.iter().filter(|f| f.embed).count();
116    let mut direct_keys = BTreeSet::new();
117    for field in fields {
118        if field.embed && field.key.is_some() {
119            return Err(syn::Error::new_spanned(
120                &field.ident,
121                "#[cbor(embed)] and #[cbor(key)] are mutually exclusive",
122            ));
123        }
124        if field.embed && extract_nullable_inner(&field.kind).is_some() {
125            return Err(syn::Error::new_spanned(
126                &field.ident,
127                "#[cbor(embed)] cannot be nullable (Option<Option<T>>)",
128            ));
129        }
130        if !field.embed && field.key.is_none() {
131            return Err(syn::Error::new_spanned(
132                &field.ident,
133                "map struct fields require #[cbor(key = N)], or use #[cbor(array)]",
134            ));
135        }
136        if !field.embed && !direct_keys.insert(field.key.unwrap()) {
137            return Err(syn::Error::new_spanned(
138                &field.ident,
139                format!("duplicate CBOR key {}", field.key.unwrap()),
140            ));
141        }
142    }
143    if embed_count > 0 {
144        let direct: Vec<_> = fields.iter().filter(|f| !f.embed).collect();
145        let embeds: Vec<_> = fields.iter().filter(|f| f.embed).collect();
146
147        let direct_key_lits: Vec<i64> = direct.iter().map(|f| f.key.unwrap()).collect();
148        let direct_field_count = direct.len();
149
150        let schema_eval: Vec<_> = embeds
151            .iter()
152            .map(|f| {
153                let embed_ty = extract_option_inner(&f.kind).unwrap_or(&f.kind);
154                quote! {
155                    {
156                        let embed_keys = <#embed_ty as #cbor_crate::MapDecode>::cbor_map_keys();
157                        estimated_entries += embed_keys.len();
158                        for k in embed_keys.iter().copied() {
159                            if dk.contains(&k) {
160                                return Err(k);
161                            }
162                            if sek.contains(&k) {
163                                return Err(k);
164                            }
165                            sek.push(k);
166                        }
167                    }
168                }
169            })
170            .collect();
171
172        let direct_entries: Vec<_> = direct
173            .iter()
174            .map(|f| {
175                let ident = &f.ident;
176                let key = f.key.unwrap();
177                if extract_option_inner(&f.kind).is_some() {
178                    quote! {
179                        enc.push_optional(#key, &self.#ident)?;
180                    }
181                } else {
182                    quote! {
183                        enc.push(#key, &self.#ident)?;
184                    }
185                }
186            })
187            .collect();
188
189        let embed_entries: Vec<_> = embeds
190            .iter()
191            .map(|f| {
192                let ident = &f.ident;
193                if let Some(inner_ty) = extract_option_inner(&f.kind) {
194                    quote! {
195                        if let Some(ref v) = self.#ident {
196                            <#inner_ty as #cbor_crate::MapEncode>::encode_map(v, enc)?;
197                        }
198                    }
199                } else {
200                    let ty = &f.kind;
201                    quote! {
202                        <#ty as #cbor_crate::MapEncode>::encode_map(&self.#ident, enc)?;
203                    }
204                }
205            })
206            .collect();
207
208        return Ok(quote! {
209            impl #cbor_crate::Encode for #name {
210                fn encode_cbor_to(&self, buf: &mut Vec<u8>) -> Result<(), #cbor_crate::Error> {
211                    static SCHEMA: std::sync::OnceLock<Result<usize, i64>> = std::sync::OnceLock::new();
212                    let estimated_entries = match SCHEMA.get_or_init(|| {
213                        let dk: &[i64] = &[#(#direct_key_lits),*];
214                        let mut sek: Vec<i64> = Vec::new();
215                        let mut estimated_entries: usize = #direct_field_count;
216                        #(#schema_eval)*
217                        Ok(estimated_entries)
218                    }) {
219                        Ok(v) => *v,
220                        Err(k) => return Err(#cbor_crate::Error::DuplicateMapKey(*k)),
221                    };
222
223                    let mut enc = #cbor_crate::MapEncodeBuffer::new(estimated_entries);
224                    <Self as #cbor_crate::MapEncode>::encode_map(self, &mut enc)?;
225                    enc.finish_to(buf)
226                }
227            }
228
229            impl #cbor_crate::MapEncode for #name {
230                fn encode_map(&self, enc: &mut #cbor_crate::MapEncodeBuffer) -> Result<(), #cbor_crate::Error> {
231                    #(#direct_entries)*
232                    #(#embed_entries)*
233                    Ok(())
234                }
235            }
236        });
237    }
238    // No embed fields — use direct encode (existing path)
239
240    // Sort fields by CBOR-encoded key bytes for deterministic encoding
241    let mut sorted: Vec<_> = fields.iter().collect();
242    sorted.sort_by(|a, b| {
243        let ka = cbor_key_bytes(a.key.unwrap());
244        let kb = cbor_key_bytes(b.key.unwrap());
245        ka.cmp(&kb)
246    });
247
248    // Check if any field is optional (affects whether map size is static or dynamic)
249    let has_optional = sorted
250        .iter()
251        .any(|f| extract_option_inner(&f.kind).is_some());
252
253    // Generate code to count and encode fields. Optional fields are only included
254    // when they are Some, so the map header count must be computed at runtime.
255    let count_fields: Vec<_> = sorted
256        .iter()
257        .map(|f| {
258            let ident = &f.ident;
259            if extract_option_inner(&f.kind).is_some() {
260                quote! { if self.#ident.is_some() { count += 1; } }
261            } else {
262                quote! { count += 1; }
263            }
264        })
265        .collect();
266
267    let encode_fields: Vec<_> = sorted
268        .iter()
269        .map(|f| {
270            let ident = &f.ident;
271            let key = f.key.unwrap();
272            if extract_option_inner(&f.kind).is_some() {
273                quote! {
274                    if let Some(ref v) = self.#ident {
275                        #cbor_crate::encode_int_to(buf, #key);
276                        v.encode_cbor_to(buf)?;
277                    }
278                }
279            } else {
280                quote! {
281                    #cbor_crate::encode_int_to(buf, #key);
282                    self.#ident.encode_cbor_to(buf)?;
283                }
284            }
285        })
286        .collect();
287
288    let map_encode_fields: Vec<_> = sorted
289        .iter()
290        .map(|f| {
291            let ident = &f.ident;
292            let key = f.key.unwrap();
293            if extract_option_inner(&f.kind).is_some() {
294                quote! {
295                    enc.push_optional(#key, &self.#ident)?;
296                }
297            } else {
298                quote! {
299                    enc.push(#key, &self.#ident)?;
300                }
301            }
302        })
303        .collect();
304
305    // If there are no optional fields, use a static map header size to avoid
306    // generating unnecessary runtime counting code.
307    if has_optional {
308        Ok(quote! {
309            impl #cbor_crate::Encode for #name {
310                fn encode_cbor_to(&self, buf: &mut Vec<u8>) -> Result<(), #cbor_crate::Error> {
311                    let mut count: usize = 0;
312                    #(#count_fields)*
313                    #cbor_crate::encode_map_header_to(buf, count);
314                    #(#encode_fields)*
315                    Ok(())
316                }
317            }
318
319            impl #cbor_crate::MapEncode for #name {
320                fn encode_map(&self, enc: &mut #cbor_crate::MapEncodeBuffer) -> Result<(), #cbor_crate::Error> {
321                    #(#map_encode_fields)*
322                    Ok(())
323                }
324            }
325        })
326    } else {
327        let len = sorted.len();
328        Ok(quote! {
329            impl #cbor_crate::Encode for #name {
330                fn encode_cbor_to(&self, buf: &mut Vec<u8>) -> Result<(), #cbor_crate::Error> {
331                    #cbor_crate::encode_map_header_to(buf, #len);
332                    #(#encode_fields)*
333                    Ok(())
334                }
335            }
336
337            impl #cbor_crate::MapEncode for #name {
338                fn encode_map(&self, enc: &mut #cbor_crate::MapEncodeBuffer) -> Result<(), #cbor_crate::Error> {
339                    #(#map_encode_fields)*
340                    Ok(())
341                }
342            }
343        })
344    }
345}
346
347/// Generates the `Decode` trait implementation for a struct.
348fn derive_decode(input: &DeriveInput) -> syn::Result<TokenStream2> {
349    let fields = parse_fields(input)?;
350    if want_array(input) {
351        derive_decode_array(&input.ident, &fields)
352    } else {
353        derive_decode_map(&input.ident, &fields)
354    }
355}
356
357/// Generates array-mode `Decode` impl: fields decoded in declaration order.
358fn derive_decode_array(name: &syn::Ident, fields: &[FieldInfo]) -> syn::Result<TokenStream2> {
359    let cbor_crate = quote! { darkbio_crypto::cbor };
360    for field in fields {
361        if field.embed {
362            return Err(syn::Error::new_spanned(
363                &field.ident,
364                "#[cbor(embed)] is not supported on #[cbor(array)] structs",
365            ));
366        }
367    }
368    let len = fields.len();
369    let field_idents: Vec<_> = fields.iter().map(|f| &f.ident).collect();
370
371    // Generate a decode statement for each field
372    let decode_fields: Vec<_> = fields
373        .iter()
374        .map(|f| {
375            let ident = &f.ident;
376            let ty = &f.kind;
377            quote! {
378                let #ident = <#ty as #cbor_crate::Decode>::decode_cbor_notrail(dec)?;
379            }
380        })
381        .collect();
382
383    Ok(quote! {
384        impl #cbor_crate::Decode for #name {
385            fn decode_cbor(data: &[u8]) -> Result<Self, #cbor_crate::Error> {
386                let mut dec = #cbor_crate::Decoder::new(data);
387                let result = Self::decode_cbor_notrail(&mut dec)?;
388                dec.finish()?; // Ensure no trailing data
389                Ok(result)
390            }
391
392            fn decode_cbor_notrail(dec: &mut #cbor_crate::Decoder) -> Result<Self, #cbor_crate::Error> {
393                let len = dec.decode_array_header()?;
394                if len != #len as u64 {
395                    return Err(#cbor_crate::Error::UnexpectedItemCount(len, #len));
396                }
397                #(#decode_fields)*
398                Ok(Self { #(#field_idents),* })
399            }
400        }
401    })
402}
403
404/// Generates map-mode `Decode` impl: fields decoded as key-value pairs, validating key order.
405/// Option<T> fields tolerate missing keys by defaulting to None.
406/// Fields with #[cbor(embed)] consume remaining map entries after direct fields are extracted.
407fn derive_decode_map(name: &syn::Ident, fields: &[FieldInfo]) -> syn::Result<TokenStream2> {
408    let cbor_crate = quote! { darkbio_crypto::cbor };
409
410    let mut direct_keys = BTreeSet::new();
411    for field in fields {
412        if field.embed && field.key.is_some() {
413            return Err(syn::Error::new_spanned(
414                &field.ident,
415                "#[cbor(embed)] and #[cbor(key)] are mutually exclusive",
416            ));
417        }
418        if field.embed && extract_nullable_inner(&field.kind).is_some() {
419            return Err(syn::Error::new_spanned(
420                &field.ident,
421                "#[cbor(embed)] cannot be nullable (Option<Option<T>>)",
422            ));
423        }
424        if !field.embed && field.key.is_none() {
425            return Err(syn::Error::new_spanned(
426                &field.ident,
427                "map struct fields require #[cbor(key = N)], or use #[cbor(array)]",
428            ));
429        }
430        if !field.embed && !direct_keys.insert(field.key.unwrap()) {
431            return Err(syn::Error::new_spanned(
432                &field.ident,
433                format!("duplicate CBOR key {}", field.key.unwrap()),
434            ));
435        }
436    }
437
438    let direct: Vec<_> = fields.iter().filter(|f| !f.embed).collect();
439    let embeds: Vec<_> = fields.iter().filter(|f| f.embed).collect();
440    let direct_key_lits: Vec<i64> = direct.iter().map(|f| f.key.unwrap()).collect();
441
442    let map_key_pushes: Vec<_> = fields
443        .iter()
444        .map(|f| {
445            if f.embed {
446                let embed_ty = extract_option_inner(&f.kind).unwrap_or(&f.kind);
447                quote! { keys.extend_from_slice(<#embed_ty as #cbor_crate::MapDecode>::cbor_map_keys()); }
448            } else {
449                let key = f.key.unwrap();
450                quote! { keys.push(#key); }
451            }
452        })
453        .collect();
454
455    let extract_direct: Vec<_> = direct
456        .iter()
457        .map(|f| {
458            let ident = &f.ident;
459            let ty = &f.kind;
460            let key = f.key.unwrap();
461            if let Some(inner_ty) = extract_option_inner(ty) {
462                quote! {
463                    let #ident: #ty = if let Some(raw) = entries.take(#key) {
464                        Some(<#inner_ty as #cbor_crate::Decode>::decode_cbor(raw)?)
465                    } else {
466                        None
467                    };
468                }
469            } else if let Some(inner_ty) = extract_nullable_inner(ty) {
470                quote! {
471                    let raw = entries.take(#key).ok_or(#cbor_crate::Error::DecodeFailed(
472                        format!("missing required key {}", #key)
473                    ))?;
474                    let #ident: #ty = Some(<#inner_ty as #cbor_crate::Decode>::decode_cbor(raw)?);
475                }
476            } else {
477                quote! {
478                    let raw = entries.take(#key).ok_or(#cbor_crate::Error::DecodeFailed(
479                        format!("missing required key {}", #key)
480                    ))?;
481                    let #ident: #ty = <#ty as #cbor_crate::Decode>::decode_cbor(raw)?;
482                }
483            }
484        })
485        .collect();
486
487    // Schema validation checks — each embed's keys are tested for overlap
488    // with direct keys and with every earlier embed's keys. These are static
489    // properties of the type, so they are evaluated once via OnceLock.
490    let schema_checks: Vec<_> = embeds
491        .iter()
492        .map(|f| {
493            let ident = &f.ident;
494            let embed_ty = extract_option_inner(&f.kind).unwrap_or(&f.kind);
495            quote! {
496                {
497                    let ek: &[i64] = <#embed_ty as #cbor_crate::MapDecode>::cbor_map_keys();
498                    if ek.is_empty() {
499                        return Some(#cbor_crate::Error::DecodeFailed(format!(
500                            "embedded field `{}` has no CBOR map keys", stringify!(#ident)
501                        )));
502                    }
503                    for k in ek.iter().copied() {
504                        if dk.contains(&k) || seen.contains(&k) {
505                            return Some(#cbor_crate::Error::DuplicateMapKey(k));
506                        }
507                        seen.push(k);
508                    }
509                }
510            }
511        })
512        .collect();
513
514    let embed_schema_validation = if embeds.is_empty() {
515        quote! {}
516    } else {
517        quote! {
518            static __EMBED_SCHEMA: std::sync::OnceLock<Option<#cbor_crate::Error>> = std::sync::OnceLock::new();
519            if let Some(err) = __EMBED_SCHEMA.get_or_init(|| {
520                let dk: &[i64] = &[#(#direct_key_lits),*];
521                let mut seen: Vec<i64> = Vec::new();
522                #(#schema_checks)*
523                None
524            }) {
525                return Err(err.clone());
526            }
527        }
528    };
529
530    // Per-call decode blocks — overlap was already validated above.
531    let extract_embeds: Vec<_> = embeds
532        .iter()
533        .map(|f| {
534            let ident = &f.ident;
535            let ty = &f.kind;
536            let is_optional = extract_option_inner(ty).is_some();
537            let embed_ty = extract_option_inner(ty).unwrap_or(ty);
538
539            if is_optional {
540                // Optional embed: check if any of the embed's keys are present.
541                // If none → None. If any → decode and wrap in Some (decode_map
542                // validates that all required keys within the embed are present,
543                // giving us all-or-none semantics).
544                quote! {
545                    let embed_keys: &[i64] = <#embed_ty as #cbor_crate::MapDecode>::cbor_map_keys();
546                    let #ident: #ty = {
547                        let mut found = false;
548                        for k in embed_keys.iter().copied() {
549                            if entries.contains(k) {
550                                found = true;
551                                break;
552                            }
553                        }
554                        if found {
555                            let mut subset = #cbor_crate::MapEntriesScoped::new(entries, embed_keys);
556                            let value = <#embed_ty as #cbor_crate::MapDecode>::decode_map(&mut subset)?;
557                            if !#cbor_crate::MapEntryAccess::is_empty(&subset) {
558                                let unknown: Vec<i64> = #cbor_crate::MapEntryAccess::remaining_keys(&subset);
559                                return Err(#cbor_crate::Error::DecodeFailed(
560                                    format!("unknown CBOR map keys: {:?}", unknown)
561                                ));
562                            }
563                            Some(value)
564                        } else {
565                            None
566                        }
567                    };
568                }
569            } else {
570                // Mandatory embed: all fields must be present.
571                quote! {
572                    let embed_keys: &[i64] = <#embed_ty as #cbor_crate::MapDecode>::cbor_map_keys();
573                    let mut subset = #cbor_crate::MapEntriesScoped::new(entries, embed_keys);
574                    let #ident = <#embed_ty as #cbor_crate::MapDecode>::decode_map(&mut subset)?;
575                    if !#cbor_crate::MapEntryAccess::is_empty(&subset) {
576                        let unknown: Vec<i64> = #cbor_crate::MapEntryAccess::remaining_keys(&subset);
577                        return Err(#cbor_crate::Error::DecodeFailed(
578                            format!("unknown CBOR map keys: {:?}", unknown)
579                        ));
580                    }
581                }
582            }
583        })
584        .collect();
585
586    let construct_fields: Vec<_> = fields.iter().map(|f| &f.ident).collect();
587
588    let map_decode_impl = quote! {
589        impl #cbor_crate::MapDecode for #name {
590            fn cbor_map_keys() -> &'static [i64] {
591                static KEYS: std::sync::OnceLock<Vec<i64>> = std::sync::OnceLock::new();
592                KEYS.get_or_init(|| {
593                    let mut keys = Vec::new();
594                    #(#map_key_pushes)*
595                    keys.sort_by(|a, b| #cbor_crate::cbor_key_cmp(*a, *b));
596                    keys
597                }).as_slice()
598            }
599
600            fn decode_map<'a, E: #cbor_crate::MapEntryAccess<'a>>(entries: &mut E) -> Result<Self, #cbor_crate::Error> {
601                #embed_schema_validation
602
603                #(#extract_direct)*
604                #(#extract_embeds)*
605                Ok(Self { #(#construct_fields),* })
606            }
607        }
608    };
609
610    if embeds.is_empty() {
611        let mut sorted: Vec<_> = fields.iter().collect();
612        sorted.sort_by(|a, b| {
613            let ka = cbor_key_bytes(a.key.unwrap());
614            let kb = cbor_key_bytes(b.key.unwrap());
615            ka.cmp(&kb)
616        });
617
618        let len = sorted.len();
619        let has_optional = sorted
620            .iter()
621            .any(|f| extract_option_inner(&f.kind).is_some());
622
623        if !has_optional {
624            let decode_fields: Vec<_> = sorted
625                .iter()
626                .map(|f| {
627                    let ident = &f.ident;
628                    let ty = &f.kind;
629                    let key = f.key.unwrap();
630                    let decode_expr = if let Some(inner_ty) = extract_nullable_inner(ty) {
631                        quote! { Some(<#inner_ty as #cbor_crate::Decode>::decode_cbor_notrail(dec)?) }
632                    } else {
633                        quote! { <#ty as #cbor_crate::Decode>::decode_cbor_notrail(dec)? }
634                    };
635                    quote! {
636                        let key = dec.decode_int()?;
637                        if key != #key {
638                            return Err(#cbor_crate::Error::InvalidMapKeyOrder(key, #key));
639                        }
640                        let #ident = #decode_expr;
641                    }
642                })
643                .collect();
644
645            return Ok(quote! {
646                impl #cbor_crate::Decode for #name {
647                    fn decode_cbor(data: &[u8]) -> Result<Self, #cbor_crate::Error> {
648                        let mut dec = #cbor_crate::Decoder::new(data);
649                        let result = Self::decode_cbor_notrail(&mut dec)?;
650                        dec.finish()?;
651                        Ok(result)
652                    }
653
654                    fn decode_cbor_notrail(dec: &mut #cbor_crate::Decoder) -> Result<Self, #cbor_crate::Error> {
655                        let len = dec.decode_map_header()?;
656                        if len != #len as u64 {
657                            return Err(#cbor_crate::Error::UnexpectedItemCount(len, #len));
658                        }
659                        #(#decode_fields)*
660                        Ok(Self { #(#construct_fields),* })
661                    }
662                }
663
664                #map_decode_impl
665            });
666        }
667
668        let decode_fields: Vec<_> = sorted
669            .iter()
670            .map(|f| {
671                let ident = &f.ident;
672                let key = f.key.unwrap();
673                if let Some(inner_ty) = extract_option_inner(&f.kind) {
674                    quote! {
675                        let #ident = if remaining > 0 && dec.peek_int()? == #key {
676                            dec.decode_int()?;
677                            remaining -= 1;
678                            Some(<#inner_ty as #cbor_crate::Decode>::decode_cbor_notrail(dec)?)
679                        } else {
680                            None
681                        };
682                    }
683                } else {
684                    let ty = &f.kind;
685                    let decode_expr = if let Some(inner_ty) = extract_nullable_inner(ty) {
686                        quote! { Some(<#inner_ty as #cbor_crate::Decode>::decode_cbor_notrail(dec)?) }
687                    } else {
688                        quote! { <#ty as #cbor_crate::Decode>::decode_cbor_notrail(dec)? }
689                    };
690                    quote! {
691                        if remaining == 0 {
692                            return Err(#cbor_crate::Error::InvalidMapKeyOrder(0, #key));
693                        }
694                        let key = dec.decode_int()?;
695                        if key != #key {
696                            return Err(#cbor_crate::Error::InvalidMapKeyOrder(key, #key));
697                        }
698                        remaining -= 1;
699                        let #ident = #decode_expr;
700                    }
701                }
702            })
703            .collect();
704
705        return Ok(quote! {
706            impl #cbor_crate::Decode for #name {
707                fn decode_cbor(data: &[u8]) -> Result<Self, #cbor_crate::Error> {
708                    let mut dec = #cbor_crate::Decoder::new(data);
709                    let result = Self::decode_cbor_notrail(&mut dec)?;
710                    dec.finish()?;
711                    Ok(result)
712                }
713
714                fn decode_cbor_notrail(dec: &mut #cbor_crate::Decoder) -> Result<Self, #cbor_crate::Error> {
715                    let map_len = dec.decode_map_header()?;
716                    if map_len > #len as u64 {
717                        return Err(#cbor_crate::Error::UnexpectedItemCount(map_len, #len));
718                    }
719                    let mut remaining = map_len;
720                    #(#decode_fields)*
721                    if remaining != 0 {
722                        return Err(#cbor_crate::Error::UnexpectedItemCount(map_len, (map_len - remaining) as usize));
723                    }
724                    Ok(Self { #(#construct_fields),* })
725                }
726            }
727
728            #map_decode_impl
729        });
730    }
731
732    Ok(quote! {
733        impl #cbor_crate::Decode for #name {
734            fn decode_cbor(data: &[u8]) -> Result<Self, #cbor_crate::Error> {
735                let mut dec = #cbor_crate::Decoder::new(data);
736                let result = Self::decode_cbor_notrail(&mut dec)?;
737                dec.finish()?;
738                Ok(result)
739            }
740
741            fn decode_cbor_notrail(dec: &mut #cbor_crate::Decoder) -> Result<Self, #cbor_crate::Error> {
742                let entries = #cbor_crate::decode_map_entries_slices_notrail(dec)?;
743                let mut remaining = #cbor_crate::MapEntries::new(entries);
744
745                let value = <Self as #cbor_crate::MapDecode>::decode_map(&mut remaining)?;
746
747                if !remaining.is_empty() {
748                    let unknown: Vec<i64> = remaining.remaining_keys();
749                    return Err(#cbor_crate::Error::DecodeFailed(
750                        format!("unknown CBOR map keys: {:?}", unknown)
751                    ));
752                }
753                Ok(value)
754            }
755        }
756
757        #map_decode_impl
758    })
759}
760
761/// Returns true if the struct has #[cbor(array)] attribute.
762fn want_array(input: &DeriveInput) -> bool {
763    for attr in &input.attrs {
764        if attr.path().is_ident("cbor") {
765            let mut is_array = false;
766            let _ = attr.parse_nested_meta(|meta| {
767                if meta.path.is_ident("array") {
768                    is_array = true;
769                }
770                Ok(())
771            });
772            if is_array {
773                return true;
774            }
775        }
776    }
777    false
778}
779
780/// Parsed information about a struct field.
781struct FieldInfo {
782    ident: syn::Ident,
783    kind: syn::Type,
784    key: Option<i64>, // CBOR map key from #[cbor(key = N)], None for array-mode
785    embed: bool,      // #[cbor(embed)] — flatten this field's map into the parent
786}
787
788/// Extracts field metadata from a struct, including names, types, and CBOR keys.
789fn parse_fields(input: &DeriveInput) -> syn::Result<Vec<FieldInfo>> {
790    // Ensure we're only tagging structs with plain fields
791    let fields = match &input.data {
792        Data::Struct(data) => match &data.fields {
793            Fields::Named(fields) => &fields.named,
794            _ => {
795                return Err(syn::Error::new_spanned(
796                    input,
797                    "only named fields supported",
798                ));
799            }
800        },
801        _ => return Err(syn::Error::new_spanned(input, "only structs supported")),
802    };
803    // Collect all the keys from all the fields
804    let mut result = Vec::new();
805
806    for field in fields {
807        let ident = field.ident.clone().unwrap();
808        let kind = field.ty.clone();
809        let mut key = None;
810        let mut embed = false;
811
812        for attr in &field.attrs {
813            if attr.path().is_ident("cbor") {
814                attr.parse_nested_meta(|meta| {
815                    if meta.path.is_ident("key") {
816                        let value: Expr = meta.value()?.parse()?;
817                        key = Some(parse_key(&value)?);
818                    }
819                    if meta.path.is_ident("embed") {
820                        embed = true;
821                    }
822                    Ok(())
823                })?;
824            }
825        }
826        result.push(FieldInfo {
827            ident,
828            kind,
829            key,
830            embed,
831        });
832    }
833    Ok(result)
834}
835
836/// Checks if a type is `Option<T>` (but not `Option<Option<T>>`) and returns
837/// the inner type `T`. Used to identify omittable map fields.
838///
839/// Returns `None` for `Option<Option<T>>`: nested options represent a
840/// non-omittable nullable field (always present, value is null or T).
841fn extract_option_inner(ty: &syn::Type) -> Option<&syn::Type> {
842    if let syn::Type::Path(type_path) = ty {
843        let segment = type_path.path.segments.last()?;
844        if segment.ident == "Option"
845            && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
846            && args.args.len() == 1
847            && let syn::GenericArgument::Type(inner) = args.args.first()?
848        {
849            // Option<Option<T>> is NOT omittable — it's a non-omittable
850            // nullable field. See extract_nullable_inner.
851            if is_option_type(inner) {
852                return None;
853            }
854            return Some(inner);
855        }
856    }
857    None
858}
859
860/// Checks if a type is `Option<Option<T>>` and returns the inner `Option<T>`.
861/// Used to identify non-omittable nullable map fields. During decoding, the
862/// derive generates `Some(<Option<T>>::decode(...))` so that null on the wire
863/// becomes `Some(None)` rather than `None`.
864fn extract_nullable_inner(ty: &syn::Type) -> Option<&syn::Type> {
865    if let syn::Type::Path(type_path) = ty {
866        let segment = type_path.path.segments.last()?;
867        if segment.ident == "Option"
868            && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
869            && args.args.len() == 1
870            && let syn::GenericArgument::Type(inner) = args.args.first()?
871            && is_option_type(inner)
872        {
873            return Some(inner);
874        }
875    }
876    None
877}
878
879/// Returns true if the type's last path segment is `Option`.
880fn is_option_type(ty: &syn::Type) -> bool {
881    if let syn::Type::Path(type_path) = ty
882        && let Some(segment) = type_path.path.segments.last()
883    {
884        return segment.ident == "Option";
885    }
886    false
887}
888
889/// Parses an integer expression, handling both positive literals and negation.
890fn parse_key(expr: &Expr) -> syn::Result<i64> {
891    match expr {
892        // Parse positive integers
893        Expr::Lit(lit) => match &lit.lit {
894            Lit::Int(i) => i.base10_parse(),
895            _ => Err(syn::Error::new_spanned(expr, "expected integer literal")),
896        },
897        // Parse negative integers
898        Expr::Unary(unary) => {
899            if let syn::UnOp::Neg(_) = unary.op
900                && let Expr::Lit(lit) = &*unary.expr
901                && let Lit::Int(i) = &lit.lit
902            {
903                let val: i64 = i.base10_parse()?;
904                return Ok(-val);
905            }
906            Err(syn::Error::new_spanned(expr, "expected integer literal"))
907        }
908        _ => Err(syn::Error::new_spanned(expr, "expected integer literal")),
909    }
910}