Skip to main content

vercode_macros/
lib.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
5
6/// Helper struct for field information during code generation
7#[derive(Clone)]
8struct FieldInfo<'a> {
9    index: usize,
10    version: u32,
11    ty: &'a syn::Type,
12    ident: Option<&'a syn::Ident>,
13}
14
15impl<'a> FieldInfo<'a> {
16    fn temp_var(&self) -> syn::Ident {
17        syn::Ident::new(
18            &format!("__field{}", self.index),
19            proc_macro2::Span::call_site(),
20        )
21    }
22}
23
24/// Batches of fields grouped by version number
25struct VersionBatch<'a> {
26    version: u32,
27    fields: Vec<FieldInfo<'a>>,
28}
29
30fn parse_version_attribute(attrs: &[syn::Attribute]) -> u32 {
31    for attr in attrs {
32        if let Meta::List(list) = &attr.meta
33            && list.path.is_ident("version")
34        {
35            let ts = list.tokens.to_string();
36            let digits: String = ts.chars().filter(|c| c.is_ascii_digit()).collect();
37            if let Ok(v) = digits.parse::<u32>() {
38                return v;
39            }
40        }
41    }
42    0
43}
44
45fn field_version(field: &syn::Field) -> u32 {
46    parse_version_attribute(&field.attrs)
47}
48
49fn variant_version(variant: &syn::Variant) -> u32 {
50    parse_version_attribute(&variant.attrs)
51}
52
53/// Extract and organize field information from Fields
54fn extract_field_info(fields: &Fields) -> Vec<FieldInfo<'_>> {
55    match fields {
56        Fields::Named(named) => named
57            .named
58            .iter()
59            .enumerate()
60            .map(|(i, f)| FieldInfo {
61                index: i,
62                version: field_version(f),
63                ty: &f.ty,
64                ident: f.ident.as_ref(),
65            })
66            .collect(),
67        Fields::Unnamed(unnamed) => unnamed
68            .unnamed
69            .iter()
70            .enumerate()
71            .map(|(i, f)| FieldInfo {
72                index: i,
73                version: field_version(f),
74                ty: &f.ty,
75                ident: None,
76            })
77            .collect(),
78        Fields::Unit => vec![],
79    }
80}
81
82/// Sort fields by version, then by original index, and batch by version
83fn create_version_batches(mut field_infos: Vec<FieldInfo>) -> Vec<VersionBatch> {
84    // Sort by version first, then by original index
85    field_infos.sort_by_key(|f| (f.version, f.index));
86
87    // Group into batches by version
88    let mut batches: Vec<VersionBatch> = Vec::new();
89    for field in field_infos {
90        if let Some(last_batch) = batches.last_mut()
91            && last_batch.version == field.version
92        {
93            last_batch.fields.push(field);
94            continue;
95        }
96        batches.push(VersionBatch {
97            version: field.version,
98            fields: vec![field],
99        });
100    }
101    batches
102}
103
104/// Generate write code for a batch of fields
105fn generate_field_writes(
106    batches: &[VersionBatch],
107    is_named: bool,
108) -> Vec<proc_macro2::TokenStream> {
109    let mut writes = Vec::new();
110    let mut last_version = 0u32;
111
112    for batch in batches {
113        if batch.version != last_version {
114            last_version = batch.version;
115            let v = batch.version;
116            writes.push(quote! { if version < #v { return offset; } });
117        }
118
119        for field in &batch.fields {
120            let write_stmt = if is_named {
121                let ident = field.ident.unwrap();
122                quote! { offset += ::vercode::VerCodable::write_version(&self.#ident, version, &mut buf[offset..]); }
123            } else {
124                let idx = syn::Index::from(field.index);
125                quote! { offset += ::vercode::VerCodable::write_version(&self.#idx, version, &mut buf[offset..]); }
126            };
127            writes.push(write_stmt);
128        }
129    }
130    writes
131}
132
133/// Generate size calculation code for a batch of fields
134fn generate_field_sizes(batches: &[VersionBatch], is_named: bool) -> Vec<proc_macro2::TokenStream> {
135    let mut sizes = Vec::new();
136    let mut last_version = 0u32;
137
138    for batch in batches {
139        if batch.version != last_version {
140            last_version = batch.version;
141            let v = batch.version;
142            sizes.push(quote! { if version < #v { return total; } });
143        }
144
145        for field in &batch.fields {
146            let size_stmt = if is_named {
147                let ident = field.ident.unwrap();
148                quote! { total += ::vercode::VerCodable::size_version(&self.#ident, version); }
149            } else {
150                let idx = syn::Index::from(field.index);
151                quote! { total += ::vercode::VerCodable::size_version(&self.#idx, version); }
152            };
153            sizes.push(size_stmt);
154        }
155    }
156    sizes
157}
158
159/// Generate read code for version batches
160fn generate_field_reads(batches: &[VersionBatch]) -> Vec<proc_macro2::TokenStream> {
161    let mut reads = Vec::new();
162
163    for batch in batches {
164        let temp_vars: Vec<_> = batch.fields.iter().map(|f| f.temp_var()).collect();
165        let mut read_stmts = Vec::new();
166        let mut default_stmts = Vec::new();
167
168        for field in &batch.fields {
169            let temp_var = field.temp_var();
170            let ty = field.ty;
171            read_stmts.push(quote! {
172                (#temp_var, __temp_size) = <#ty as ::vercode::VerCodable>::read_version(version, &buf[offset..])?;
173                offset += __temp_size;
174            });
175            default_stmts.push(quote! {
176                #temp_var = <#ty as ::std::default::Default>::default();
177            });
178        }
179
180        if batch.version == 0 {
181            // Version 0 fields always read
182            reads.push(quote! {
183                #(let mut #temp_vars;)*
184                let mut __temp_size;
185                #(#read_stmts)*
186            });
187        } else {
188            let v = batch.version;
189            reads.push(quote! {
190                #(let mut #temp_vars;)*
191                let mut __temp_size;
192                if offset < length && version >= #v {
193                    #(#read_stmts)*
194                } else {
195                    #(#default_stmts)*
196                }
197            });
198        }
199    }
200    reads
201}
202
203/// Generate struct construction from field info
204fn generate_struct_construction(
205    name: &syn::Ident,
206    fields: &Fields,
207    field_infos: &[FieldInfo],
208) -> proc_macro2::TokenStream {
209    match fields {
210        Fields::Named(_) => {
211            let field_inits: Vec<_> = field_infos
212                .iter()
213                .map(|f| {
214                    let ident = f.ident.unwrap();
215                    let temp_var = f.temp_var();
216                    quote! { #ident: #temp_var }
217                })
218                .collect();
219            quote! { #name { #(#field_inits),* } }
220        }
221        Fields::Unnamed(_) => {
222            let field_values: Vec<_> = field_infos.iter().map(|f| f.temp_var()).collect();
223            quote! { #name ( #(#field_values),* ) }
224        }
225        Fields::Unit => quote! { #name },
226    }
227}
228
229/// Calculate maximum version expression from field infos
230fn calculate_max_version_expr(field_infos: &[FieldInfo]) -> proc_macro2::TokenStream {
231    // Calculate max from field version attributes
232    let field_attr_max = field_infos.iter().map(|f| f.version).max().unwrap_or(0);
233
234    // Generate expressions for each field type's MAX_VERSION
235    let field_type_exprs: Vec<_> = field_infos
236        .iter()
237        .map(|f| {
238            let ty = f.ty;
239            quote! { <#ty as ::vercode::VerCodable>::MAX_VERSION }
240        })
241        .collect();
242
243    // Generate a const expression that computes the max of all versions
244    if field_type_exprs.is_empty() {
245        quote! { #field_attr_max }
246    } else {
247        quote! {
248            {
249                let mut max = #field_attr_max;
250                #(
251                    if #field_type_exprs > max {
252                        max = #field_type_exprs;
253                    }
254                )*
255                max
256            }
257        }
258    }
259}
260
261/// Variant information for enum processing
262struct VariantInfo<'a> {
263    index: usize,
264    variant: &'a syn::Variant,
265    field_infos: Vec<FieldInfo<'a>>,
266    batches: Vec<VersionBatch<'a>>,
267}
268
269impl<'a> VariantInfo<'a> {
270    fn new(index: usize, variant: &'a syn::Variant) -> Self {
271        let field_infos = extract_field_info(&variant.fields);
272        let batches = create_version_batches(field_infos.clone());
273        VariantInfo {
274            index,
275            variant,
276            field_infos,
277            batches,
278        }
279    }
280
281    fn max_version_expr(&self) -> proc_macro2::TokenStream {
282        let variant_ver = variant_version(self.variant);
283
284        // Get field attribute versions
285        let field_attr_max = self
286            .field_infos
287            .iter()
288            .map(|f| f.version)
289            .max()
290            .unwrap_or(0);
291
292        // Generate expressions for each field type's MAX_VERSION
293        let field_type_exprs: Vec<_> = self
294            .field_infos
295            .iter()
296            .map(|f| {
297                let ty = f.ty;
298                quote! { <#ty as ::vercode::VerCodable>::MAX_VERSION }
299            })
300            .collect();
301
302        // Generate a const expression that computes the max of all versions
303        if field_type_exprs.is_empty() {
304            let max = variant_ver.max(field_attr_max);
305            quote! { #max }
306        } else {
307            quote! {
308                {
309                    let mut max = #variant_ver;
310                    if #field_attr_max > max {
311                        max = #field_attr_max;
312                    }
313                    #(
314                        if #field_type_exprs > max {
315                            max = #field_type_exprs;
316                        }
317                    )*
318                    max
319                }
320            }
321        }
322    }
323
324    /// Generate pattern match binding for this variant
325    fn match_pattern(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
326        let var_name = &self.variant.ident;
327        match &self.variant.fields {
328            Fields::Named(_) => {
329                let actual_names: Vec<_> =
330                    self.field_infos.iter().map(|f| f.ident.unwrap()).collect();
331                let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
332                quote! { #enum_name::#var_name { #(#actual_names: #temp_vars),* } }
333            }
334            Fields::Unnamed(_) => {
335                let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
336                quote! { #enum_name::#var_name(#(#temp_vars),*) }
337            }
338            Fields::Unit => quote! { #enum_name::#var_name },
339        }
340    }
341
342    /// Generate variant construction from temp variables
343    fn construct_variant(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
344        let var_name = &self.variant.ident;
345        match &self.variant.fields {
346            Fields::Named(_) => {
347                let actual_names: Vec<_> =
348                    self.field_infos.iter().map(|f| f.ident.unwrap()).collect();
349                let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
350                quote! { #enum_name::#var_name { #(#actual_names: #temp_vars),* } }
351            }
352            Fields::Unnamed(_) => {
353                let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
354                quote! { #enum_name::#var_name(#(#temp_vars),*) }
355            }
356            Fields::Unit => quote! { #enum_name::#var_name },
357        }
358    }
359
360    /// Generate write arm for this variant
361    fn write_arm(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
362        let idx_u32 = self.index as u32;
363        let pattern = self.match_pattern(enum_name);
364        let field_writes = generate_variant_field_writes(&self.batches);
365
366        quote! {
367            #pattern => {
368                buf[offset..offset+2].copy_from_slice(&(#idx_u32 as u16).to_le_bytes());
369                offset += 2;
370                #(#field_writes)*
371            }
372        }
373    }
374
375    /// Generate size arm for this variant
376    fn size_arm(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
377        let pattern = self.match_pattern(enum_name);
378        let field_sizes = generate_variant_field_sizes(&self.batches);
379
380        quote! {
381            #pattern => {
382                #(#field_sizes)*
383            }
384        }
385    }
386
387    /// Generate read arm for this variant
388    fn read_arm(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
389        let idx_u32 = self.index as u32;
390        let reads = generate_field_reads(&self.batches);
391        let construction = self.construct_variant(enum_name);
392
393        quote! {
394            #idx_u32 => {
395                #(#reads)*
396                #construction
397            }
398        }
399    }
400}
401
402/// Generate write statements for variant fields (using temp vars)
403fn generate_variant_field_writes(batches: &[VersionBatch]) -> Vec<proc_macro2::TokenStream> {
404    let mut writes = Vec::new();
405    let mut last_version = 0u32;
406
407    for batch in batches {
408        if batch.version != last_version {
409            last_version = batch.version;
410            let v = batch.version;
411            writes.push(quote! { if version < #v { return offset; } });
412        }
413
414        for field in &batch.fields {
415            let temp_var = field.temp_var();
416            writes.push(quote! {
417                offset += ::vercode::VerCodable::write_version(#temp_var, version, &mut buf[offset..]);
418            });
419        }
420    }
421    writes
422}
423
424/// Generate size statements for variant fields (using temp vars)
425fn generate_variant_field_sizes(batches: &[VersionBatch]) -> Vec<proc_macro2::TokenStream> {
426    let mut sizes = Vec::new();
427    let mut last_version = 0u32;
428
429    for batch in batches {
430        if batch.version != last_version {
431            last_version = batch.version;
432            let v = batch.version;
433            sizes.push(quote! { if version < #v { return total; } });
434        }
435
436        for field in &batch.fields {
437            let temp_var = field.temp_var();
438            sizes.push(quote! {
439                total += ::vercode::VerCodable::size_version(#temp_var, version);
440            });
441        }
442    }
443    sizes
444}
445
446#[proc_macro_derive(VercodeTransparent)]
447pub fn derive_vercode_transparent(input: TokenStream) -> TokenStream {
448    let input = parse_macro_input!(input as DeriveInput);
449    let name = &input.ident;
450    let generics = &input.generics;
451    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
452
453    // Verify it's a newtype struct and get field accessor
454    let (inner_type, field_accessor, construction) = match &input.data {
455        Data::Struct(s) => match &s.fields {
456            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
457                let ty = &fields.unnamed.first().unwrap().ty;
458                (ty, quote! { 0 }, quote! { Self(inner) })
459            }
460            Fields::Named(fields) if fields.named.len() == 1 => {
461                let field = fields.named.first().unwrap();
462                let ty = &field.ty;
463                let field_name = field.ident.as_ref().unwrap();
464                (
465                    ty,
466                    quote! { #field_name },
467                    quote! { Self { #field_name: inner } },
468                )
469            }
470            _ => panic!(
471                "VercodeTransparent can only be used on newtype structs with exactly one field"
472            ),
473        },
474        _ => panic!("VercodeTransparent can only be used on structs"),
475    };
476
477    let expanded = quote! {
478        impl #impl_generics ::vercode::VerCodable for #name #ty_generics #where_clause {
479            const MAX_VERSION: u32 = <#inner_type as ::vercode::VerCodable>::MAX_VERSION;
480
481            #[inline(always)]
482            fn write_version(&self, version: u32, buf: &mut [u8]) -> usize {
483                ::vercode::VerCodable::write_version(&self.#field_accessor, version, buf)
484            }
485
486            #[inline(always)]
487            fn read_version(version: u32, buf: &[u8]) -> ::std::result::Result<(Self, usize), ::vercode::InvalidEncoding> {
488                let (inner, size) = <#inner_type as ::vercode::VerCodable>::read_version(version, buf)?;
489                Ok((#construction, size))
490            }
491
492            #[inline(always)]
493            fn size_version(&self, version: u32) -> usize {
494                ::vercode::VerCodable::size_version(&self.#field_accessor, version)
495            }
496
497            #[inline(always)]
498            fn write_option(this: Option<&Self>, version: u32, buf: &mut [u8]) -> usize {
499                ::vercode::VerCodable::write_option(
500                    this.map(|this| &this.#field_accessor),
501                    version,
502                    buf,
503                )
504            }
505
506            #[inline(always)]
507            fn read_option(version: u32, buf: &[u8]) -> Result<(Option<Self>, usize), ::vercode::InvalidEncoding> {
508                let (inner_option, size) = ::vercode::VerCodable::read_option(version, buf)?;
509                let result_option = inner_option.map(|inner| #construction);
510                Ok((result_option, size))
511            }
512
513            #[inline(always)]
514            fn size_option_version(this: &Option<Self>, version: u32) -> usize {
515                let inner_option = this.as_ref().map(|this| this.#field_accessor);
516                <#inner_type as ::vercode::VerCodable>::size_option_version(&inner_option, version)
517            }
518        }
519    };
520
521    TokenStream::from(expanded)
522}
523
524#[proc_macro_derive(Vercode, attributes(version))]
525pub fn derive_vercode(input: TokenStream) -> TokenStream {
526    let input = parse_macro_input!(input as DeriveInput);
527    let name = &input.ident;
528    let generics = &input.generics;
529    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
530
531    match &input.data {
532        Data::Struct(s) => {
533            derive_struct(name, &impl_generics, &ty_generics, &where_clause, &s.fields)
534        }
535        Data::Enum(e) => derive_enum(
536            name,
537            &impl_generics,
538            &ty_generics,
539            &where_clause,
540            &e.variants,
541        ),
542        _ => panic!("Vercode only supports structs and enums"),
543    }
544}
545
546fn derive_struct(
547    name: &syn::Ident,
548    impl_generics: &syn::ImplGenerics,
549    ty_generics: &syn::TypeGenerics,
550    where_clause: &Option<&syn::WhereClause>,
551    fields: &Fields,
552) -> TokenStream {
553    let field_infos = extract_field_info(fields);
554    let batches = create_version_batches(field_infos.clone());
555
556    let max_version_expr = calculate_max_version_expr(&field_infos);
557
558    let is_named = matches!(fields, Fields::Named(_));
559    let writes = generate_field_writes(&batches, is_named);
560    let sizes = generate_field_sizes(&batches, is_named);
561    let reads = generate_field_reads(&batches);
562    let construction = generate_struct_construction(name, fields, &field_infos);
563
564    let expanded = quote! {
565        impl #impl_generics ::vercode::VerCodable for #name #ty_generics #where_clause {
566            const MAX_VERSION: u32 = #max_version_expr;
567
568            #[inline(always)]
569            fn write_version(&self, version: u32, buf: &mut [u8]) -> usize {
570                let total_data = self.size_version(version);
571                buf[..4].copy_from_slice(&(total_data as u32).to_le_bytes());
572                let mut offset = 4usize;
573                #(#writes)*
574                offset
575            }
576
577            #[inline(always)]
578            fn read_version(version: u32, buf: &[u8]) -> ::std::result::Result<(Self, usize), ::vercode::InvalidEncoding> {
579                if buf.len() < 4 { return Err(::vercode::InvalidEncoding); }
580                let length = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
581                let mut offset = 4usize;
582                #(#reads)*
583                let result = #construction;
584                Ok((result, offset))
585            }
586
587            #[inline(always)]
588            fn size_version(&self, version: u32) -> usize {
589                let mut total = 4usize;
590                #(#sizes)*
591                total
592            }
593        }
594    };
595    TokenStream::from(expanded)
596}
597
598fn derive_enum(
599    name: &syn::Ident,
600    impl_generics: &syn::ImplGenerics,
601    ty_generics: &syn::TypeGenerics,
602    where_clause: &Option<&syn::WhereClause>,
603    variants: &syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>,
604) -> TokenStream {
605    // Process all variants once
606    let variant_infos: Vec<VariantInfo> = variants
607        .iter()
608        .enumerate()
609        .map(|(idx, variant)| VariantInfo::new(idx, variant))
610        .collect();
611
612    // Calculate max version expression
613    // Generate const expression that computes max across all variants
614    let variant_max_exprs: Vec<_> = variant_infos.iter().map(|v| v.max_version_expr()).collect();
615
616    let max_version_expr = if variant_max_exprs.is_empty() {
617        quote! { 0 }
618    } else {
619        quote! {
620            {
621                let mut max = 0;
622                #(
623                    {
624                        let variant_max = #variant_max_exprs;
625                        if variant_max > max {
626                            max = variant_max;
627                        }
628                    }
629                )*
630                max
631            }
632        }
633    };
634
635    // Generate match arms using the processed variant info
636    let write_arms: Vec<_> = variant_infos.iter().map(|v| v.write_arm(name)).collect();
637    let size_arms: Vec<_> = variant_infos.iter().map(|v| v.size_arm(name)).collect();
638    let read_arms: Vec<_> = variant_infos.iter().map(|v| v.read_arm(name)).collect();
639
640    let expanded = quote! {
641        impl #impl_generics ::vercode::VerCodable for #name #ty_generics #where_clause {
642            const MAX_VERSION: u32 = #max_version_expr;
643
644            #[inline(always)]
645            fn write_version(&self, version: u32, buf: &mut [u8]) -> usize {
646                let total_data = self.size_version(version);
647                buf[..4].copy_from_slice(&(total_data as u32).to_le_bytes());
648                let mut offset = 4usize;
649                match self {
650                    #(#write_arms)*
651                }
652                offset
653            }
654
655            #[inline(always)]
656            fn read_version(version: u32, buf: &[u8]) -> ::std::result::Result<(Self, usize), ::vercode::InvalidEncoding> {
657                if buf.len() < 6 { return Err(::vercode::InvalidEncoding); }
658                let length = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
659                let discriminant = u16::from_le_bytes(buf[4..6].try_into().unwrap()) as u32;
660                let mut offset = 6usize;
661
662                let result = match discriminant {
663                    #(#read_arms,)*
664                    _ => return Err(::vercode::InvalidEncoding),
665                };
666                Ok((result, offset))
667            }
668
669            #[inline(always)]
670            fn size_version(&self, version: u32) -> usize {
671                let mut total = 6usize; // length prefix (4 bytes) + discriminant (2 bytes)
672                match self {
673                    #(#size_arms)*
674                }
675                total
676            }
677        }
678    };
679
680    TokenStream::from(expanded)
681}