1use proc_macro::TokenStream;
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
5
6#[derive(Clone)]
8struct FieldInfo<'a> {
9 index: usize,
10 version: u32,
11 ty: &'a syn::Type,
12 ident: Option<&'a syn::Ident>,
13}
14
15impl<'a> FieldInfo<'a> {
16 fn temp_var(&self) -> syn::Ident {
17 syn::Ident::new(
18 &format!("__field{}", self.index),
19 proc_macro2::Span::call_site(),
20 )
21 }
22}
23
24struct VersionBatch<'a> {
26 version: u32,
27 fields: Vec<FieldInfo<'a>>,
28}
29
30fn parse_version_attribute(attrs: &[syn::Attribute]) -> u32 {
31 for attr in attrs {
32 if let Meta::List(list) = &attr.meta
33 && list.path.is_ident("version")
34 {
35 let ts = list.tokens.to_string();
36 let digits: String = ts.chars().filter(|c| c.is_ascii_digit()).collect();
37 if let Ok(v) = digits.parse::<u32>() {
38 return v;
39 }
40 }
41 }
42 0
43}
44
45fn field_version(field: &syn::Field) -> u32 {
46 parse_version_attribute(&field.attrs)
47}
48
49fn variant_version(variant: &syn::Variant) -> u32 {
50 parse_version_attribute(&variant.attrs)
51}
52
53fn extract_field_info(fields: &Fields) -> Vec<FieldInfo<'_>> {
55 match fields {
56 Fields::Named(named) => named
57 .named
58 .iter()
59 .enumerate()
60 .map(|(i, f)| FieldInfo {
61 index: i,
62 version: field_version(f),
63 ty: &f.ty,
64 ident: f.ident.as_ref(),
65 })
66 .collect(),
67 Fields::Unnamed(unnamed) => unnamed
68 .unnamed
69 .iter()
70 .enumerate()
71 .map(|(i, f)| FieldInfo {
72 index: i,
73 version: field_version(f),
74 ty: &f.ty,
75 ident: None,
76 })
77 .collect(),
78 Fields::Unit => vec![],
79 }
80}
81
82fn create_version_batches(mut field_infos: Vec<FieldInfo>) -> Vec<VersionBatch> {
84 field_infos.sort_by_key(|f| (f.version, f.index));
86
87 let mut batches: Vec<VersionBatch> = Vec::new();
89 for field in field_infos {
90 if let Some(last_batch) = batches.last_mut()
91 && last_batch.version == field.version
92 {
93 last_batch.fields.push(field);
94 continue;
95 }
96 batches.push(VersionBatch {
97 version: field.version,
98 fields: vec![field],
99 });
100 }
101 batches
102}
103
104fn generate_field_writes(
106 batches: &[VersionBatch],
107 is_named: bool,
108) -> Vec<proc_macro2::TokenStream> {
109 let mut writes = Vec::new();
110 let mut last_version = 0u32;
111
112 for batch in batches {
113 if batch.version != last_version {
114 last_version = batch.version;
115 let v = batch.version;
116 writes.push(quote! { if version < #v { return offset; } });
117 }
118
119 for field in &batch.fields {
120 let write_stmt = if is_named {
121 let ident = field.ident.unwrap();
122 quote! { offset += ::vercode::VerCodable::write_version(&self.#ident, version, &mut buf[offset..]); }
123 } else {
124 let idx = syn::Index::from(field.index);
125 quote! { offset += ::vercode::VerCodable::write_version(&self.#idx, version, &mut buf[offset..]); }
126 };
127 writes.push(write_stmt);
128 }
129 }
130 writes
131}
132
133fn generate_field_sizes(batches: &[VersionBatch], is_named: bool) -> Vec<proc_macro2::TokenStream> {
135 let mut sizes = Vec::new();
136 let mut last_version = 0u32;
137
138 for batch in batches {
139 if batch.version != last_version {
140 last_version = batch.version;
141 let v = batch.version;
142 sizes.push(quote! { if version < #v { return total; } });
143 }
144
145 for field in &batch.fields {
146 let size_stmt = if is_named {
147 let ident = field.ident.unwrap();
148 quote! { total += ::vercode::VerCodable::size_version(&self.#ident, version); }
149 } else {
150 let idx = syn::Index::from(field.index);
151 quote! { total += ::vercode::VerCodable::size_version(&self.#idx, version); }
152 };
153 sizes.push(size_stmt);
154 }
155 }
156 sizes
157}
158
159fn generate_field_reads(batches: &[VersionBatch]) -> Vec<proc_macro2::TokenStream> {
161 let mut reads = Vec::new();
162
163 for batch in batches {
164 let temp_vars: Vec<_> = batch.fields.iter().map(|f| f.temp_var()).collect();
165 let mut read_stmts = Vec::new();
166 let mut default_stmts = Vec::new();
167
168 for field in &batch.fields {
169 let temp_var = field.temp_var();
170 let ty = field.ty;
171 read_stmts.push(quote! {
172 (#temp_var, __temp_size) = <#ty as ::vercode::VerCodable>::read_version(version, &buf[offset..])?;
173 offset += __temp_size;
174 });
175 default_stmts.push(quote! {
176 #temp_var = <#ty as ::std::default::Default>::default();
177 });
178 }
179
180 if batch.version == 0 {
181 reads.push(quote! {
183 #(let mut #temp_vars;)*
184 let mut __temp_size;
185 #(#read_stmts)*
186 });
187 } else {
188 let v = batch.version;
189 reads.push(quote! {
190 #(let mut #temp_vars;)*
191 let mut __temp_size;
192 if offset < length && version >= #v {
193 #(#read_stmts)*
194 } else {
195 #(#default_stmts)*
196 }
197 });
198 }
199 }
200 reads
201}
202
203fn generate_struct_construction(
205 name: &syn::Ident,
206 fields: &Fields,
207 field_infos: &[FieldInfo],
208) -> proc_macro2::TokenStream {
209 match fields {
210 Fields::Named(_) => {
211 let field_inits: Vec<_> = field_infos
212 .iter()
213 .map(|f| {
214 let ident = f.ident.unwrap();
215 let temp_var = f.temp_var();
216 quote! { #ident: #temp_var }
217 })
218 .collect();
219 quote! { #name { #(#field_inits),* } }
220 }
221 Fields::Unnamed(_) => {
222 let field_values: Vec<_> = field_infos.iter().map(|f| f.temp_var()).collect();
223 quote! { #name ( #(#field_values),* ) }
224 }
225 Fields::Unit => quote! { #name },
226 }
227}
228
229fn calculate_max_version_expr(field_infos: &[FieldInfo]) -> proc_macro2::TokenStream {
231 let field_attr_max = field_infos.iter().map(|f| f.version).max().unwrap_or(0);
233
234 let field_type_exprs: Vec<_> = field_infos
236 .iter()
237 .map(|f| {
238 let ty = f.ty;
239 quote! { <#ty as ::vercode::VerCodable>::MAX_VERSION }
240 })
241 .collect();
242
243 if field_type_exprs.is_empty() {
245 quote! { #field_attr_max }
246 } else {
247 quote! {
248 {
249 let mut max = #field_attr_max;
250 #(
251 if #field_type_exprs > max {
252 max = #field_type_exprs;
253 }
254 )*
255 max
256 }
257 }
258 }
259}
260
261struct VariantInfo<'a> {
263 index: usize,
264 variant: &'a syn::Variant,
265 field_infos: Vec<FieldInfo<'a>>,
266 batches: Vec<VersionBatch<'a>>,
267}
268
269impl<'a> VariantInfo<'a> {
270 fn new(index: usize, variant: &'a syn::Variant) -> Self {
271 let field_infos = extract_field_info(&variant.fields);
272 let batches = create_version_batches(field_infos.clone());
273 VariantInfo {
274 index,
275 variant,
276 field_infos,
277 batches,
278 }
279 }
280
281 fn max_version_expr(&self) -> proc_macro2::TokenStream {
282 let variant_ver = variant_version(self.variant);
283
284 let field_attr_max = self
286 .field_infos
287 .iter()
288 .map(|f| f.version)
289 .max()
290 .unwrap_or(0);
291
292 let field_type_exprs: Vec<_> = self
294 .field_infos
295 .iter()
296 .map(|f| {
297 let ty = f.ty;
298 quote! { <#ty as ::vercode::VerCodable>::MAX_VERSION }
299 })
300 .collect();
301
302 if field_type_exprs.is_empty() {
304 let max = variant_ver.max(field_attr_max);
305 quote! { #max }
306 } else {
307 quote! {
308 {
309 let mut max = #variant_ver;
310 if #field_attr_max > max {
311 max = #field_attr_max;
312 }
313 #(
314 if #field_type_exprs > max {
315 max = #field_type_exprs;
316 }
317 )*
318 max
319 }
320 }
321 }
322 }
323
324 fn match_pattern(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
326 let var_name = &self.variant.ident;
327 match &self.variant.fields {
328 Fields::Named(_) => {
329 let actual_names: Vec<_> =
330 self.field_infos.iter().map(|f| f.ident.unwrap()).collect();
331 let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
332 quote! { #enum_name::#var_name { #(#actual_names: #temp_vars),* } }
333 }
334 Fields::Unnamed(_) => {
335 let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
336 quote! { #enum_name::#var_name(#(#temp_vars),*) }
337 }
338 Fields::Unit => quote! { #enum_name::#var_name },
339 }
340 }
341
342 fn construct_variant(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
344 let var_name = &self.variant.ident;
345 match &self.variant.fields {
346 Fields::Named(_) => {
347 let actual_names: Vec<_> =
348 self.field_infos.iter().map(|f| f.ident.unwrap()).collect();
349 let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
350 quote! { #enum_name::#var_name { #(#actual_names: #temp_vars),* } }
351 }
352 Fields::Unnamed(_) => {
353 let temp_vars: Vec<_> = self.field_infos.iter().map(|f| f.temp_var()).collect();
354 quote! { #enum_name::#var_name(#(#temp_vars),*) }
355 }
356 Fields::Unit => quote! { #enum_name::#var_name },
357 }
358 }
359
360 fn write_arm(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
362 let idx_u32 = self.index as u32;
363 let pattern = self.match_pattern(enum_name);
364 let field_writes = generate_variant_field_writes(&self.batches);
365
366 quote! {
367 #pattern => {
368 buf[offset..offset+2].copy_from_slice(&(#idx_u32 as u16).to_le_bytes());
369 offset += 2;
370 #(#field_writes)*
371 }
372 }
373 }
374
375 fn size_arm(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
377 let pattern = self.match_pattern(enum_name);
378 let field_sizes = generate_variant_field_sizes(&self.batches);
379
380 quote! {
381 #pattern => {
382 #(#field_sizes)*
383 }
384 }
385 }
386
387 fn read_arm(&self, enum_name: &syn::Ident) -> proc_macro2::TokenStream {
389 let idx_u32 = self.index as u32;
390 let reads = generate_field_reads(&self.batches);
391 let construction = self.construct_variant(enum_name);
392
393 quote! {
394 #idx_u32 => {
395 #(#reads)*
396 #construction
397 }
398 }
399 }
400}
401
402fn generate_variant_field_writes(batches: &[VersionBatch]) -> Vec<proc_macro2::TokenStream> {
404 let mut writes = Vec::new();
405 let mut last_version = 0u32;
406
407 for batch in batches {
408 if batch.version != last_version {
409 last_version = batch.version;
410 let v = batch.version;
411 writes.push(quote! { if version < #v { return offset; } });
412 }
413
414 for field in &batch.fields {
415 let temp_var = field.temp_var();
416 writes.push(quote! {
417 offset += ::vercode::VerCodable::write_version(#temp_var, version, &mut buf[offset..]);
418 });
419 }
420 }
421 writes
422}
423
424fn generate_variant_field_sizes(batches: &[VersionBatch]) -> Vec<proc_macro2::TokenStream> {
426 let mut sizes = Vec::new();
427 let mut last_version = 0u32;
428
429 for batch in batches {
430 if batch.version != last_version {
431 last_version = batch.version;
432 let v = batch.version;
433 sizes.push(quote! { if version < #v { return total; } });
434 }
435
436 for field in &batch.fields {
437 let temp_var = field.temp_var();
438 sizes.push(quote! {
439 total += ::vercode::VerCodable::size_version(#temp_var, version);
440 });
441 }
442 }
443 sizes
444}
445
446#[proc_macro_derive(VercodeTransparent)]
447pub fn derive_vercode_transparent(input: TokenStream) -> TokenStream {
448 let input = parse_macro_input!(input as DeriveInput);
449 let name = &input.ident;
450 let generics = &input.generics;
451 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
452
453 let (inner_type, field_accessor, construction) = match &input.data {
455 Data::Struct(s) => match &s.fields {
456 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
457 let ty = &fields.unnamed.first().unwrap().ty;
458 (ty, quote! { 0 }, quote! { Self(inner) })
459 }
460 Fields::Named(fields) if fields.named.len() == 1 => {
461 let field = fields.named.first().unwrap();
462 let ty = &field.ty;
463 let field_name = field.ident.as_ref().unwrap();
464 (
465 ty,
466 quote! { #field_name },
467 quote! { Self { #field_name: inner } },
468 )
469 }
470 _ => panic!(
471 "VercodeTransparent can only be used on newtype structs with exactly one field"
472 ),
473 },
474 _ => panic!("VercodeTransparent can only be used on structs"),
475 };
476
477 let expanded = quote! {
478 impl #impl_generics ::vercode::VerCodable for #name #ty_generics #where_clause {
479 const MAX_VERSION: u32 = <#inner_type as ::vercode::VerCodable>::MAX_VERSION;
480
481 #[inline(always)]
482 fn write_version(&self, version: u32, buf: &mut [u8]) -> usize {
483 ::vercode::VerCodable::write_version(&self.#field_accessor, version, buf)
484 }
485
486 #[inline(always)]
487 fn read_version(version: u32, buf: &[u8]) -> ::std::result::Result<(Self, usize), ::vercode::InvalidEncoding> {
488 let (inner, size) = <#inner_type as ::vercode::VerCodable>::read_version(version, buf)?;
489 Ok((#construction, size))
490 }
491
492 #[inline(always)]
493 fn size_version(&self, version: u32) -> usize {
494 ::vercode::VerCodable::size_version(&self.#field_accessor, version)
495 }
496
497 #[inline(always)]
498 fn write_option(this: Option<&Self>, version: u32, buf: &mut [u8]) -> usize {
499 ::vercode::VerCodable::write_option(
500 this.map(|this| &this.#field_accessor),
501 version,
502 buf,
503 )
504 }
505
506 #[inline(always)]
507 fn read_option(version: u32, buf: &[u8]) -> Result<(Option<Self>, usize), ::vercode::InvalidEncoding> {
508 let (inner_option, size) = ::vercode::VerCodable::read_option(version, buf)?;
509 let result_option = inner_option.map(|inner| #construction);
510 Ok((result_option, size))
511 }
512
513 #[inline(always)]
514 fn size_option_version(this: &Option<Self>, version: u32) -> usize {
515 let inner_option = this.as_ref().map(|this| this.#field_accessor);
516 <#inner_type as ::vercode::VerCodable>::size_option_version(&inner_option, version)
517 }
518 }
519 };
520
521 TokenStream::from(expanded)
522}
523
524#[proc_macro_derive(Vercode, attributes(version))]
525pub fn derive_vercode(input: TokenStream) -> TokenStream {
526 let input = parse_macro_input!(input as DeriveInput);
527 let name = &input.ident;
528 let generics = &input.generics;
529 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
530
531 match &input.data {
532 Data::Struct(s) => {
533 derive_struct(name, &impl_generics, &ty_generics, &where_clause, &s.fields)
534 }
535 Data::Enum(e) => derive_enum(
536 name,
537 &impl_generics,
538 &ty_generics,
539 &where_clause,
540 &e.variants,
541 ),
542 _ => panic!("Vercode only supports structs and enums"),
543 }
544}
545
546fn derive_struct(
547 name: &syn::Ident,
548 impl_generics: &syn::ImplGenerics,
549 ty_generics: &syn::TypeGenerics,
550 where_clause: &Option<&syn::WhereClause>,
551 fields: &Fields,
552) -> TokenStream {
553 let field_infos = extract_field_info(fields);
554 let batches = create_version_batches(field_infos.clone());
555
556 let max_version_expr = calculate_max_version_expr(&field_infos);
557
558 let is_named = matches!(fields, Fields::Named(_));
559 let writes = generate_field_writes(&batches, is_named);
560 let sizes = generate_field_sizes(&batches, is_named);
561 let reads = generate_field_reads(&batches);
562 let construction = generate_struct_construction(name, fields, &field_infos);
563
564 let expanded = quote! {
565 impl #impl_generics ::vercode::VerCodable for #name #ty_generics #where_clause {
566 const MAX_VERSION: u32 = #max_version_expr;
567
568 #[inline(always)]
569 fn write_version(&self, version: u32, buf: &mut [u8]) -> usize {
570 let total_data = self.size_version(version);
571 buf[..4].copy_from_slice(&(total_data as u32).to_le_bytes());
572 let mut offset = 4usize;
573 #(#writes)*
574 offset
575 }
576
577 #[inline(always)]
578 fn read_version(version: u32, buf: &[u8]) -> ::std::result::Result<(Self, usize), ::vercode::InvalidEncoding> {
579 if buf.len() < 4 { return Err(::vercode::InvalidEncoding); }
580 let length = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
581 let mut offset = 4usize;
582 #(#reads)*
583 let result = #construction;
584 Ok((result, offset))
585 }
586
587 #[inline(always)]
588 fn size_version(&self, version: u32) -> usize {
589 let mut total = 4usize;
590 #(#sizes)*
591 total
592 }
593 }
594 };
595 TokenStream::from(expanded)
596}
597
598fn derive_enum(
599 name: &syn::Ident,
600 impl_generics: &syn::ImplGenerics,
601 ty_generics: &syn::TypeGenerics,
602 where_clause: &Option<&syn::WhereClause>,
603 variants: &syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>,
604) -> TokenStream {
605 let variant_infos: Vec<VariantInfo> = variants
607 .iter()
608 .enumerate()
609 .map(|(idx, variant)| VariantInfo::new(idx, variant))
610 .collect();
611
612 let variant_max_exprs: Vec<_> = variant_infos.iter().map(|v| v.max_version_expr()).collect();
615
616 let max_version_expr = if variant_max_exprs.is_empty() {
617 quote! { 0 }
618 } else {
619 quote! {
620 {
621 let mut max = 0;
622 #(
623 {
624 let variant_max = #variant_max_exprs;
625 if variant_max > max {
626 max = variant_max;
627 }
628 }
629 )*
630 max
631 }
632 }
633 };
634
635 let write_arms: Vec<_> = variant_infos.iter().map(|v| v.write_arm(name)).collect();
637 let size_arms: Vec<_> = variant_infos.iter().map(|v| v.size_arm(name)).collect();
638 let read_arms: Vec<_> = variant_infos.iter().map(|v| v.read_arm(name)).collect();
639
640 let expanded = quote! {
641 impl #impl_generics ::vercode::VerCodable for #name #ty_generics #where_clause {
642 const MAX_VERSION: u32 = #max_version_expr;
643
644 #[inline(always)]
645 fn write_version(&self, version: u32, buf: &mut [u8]) -> usize {
646 let total_data = self.size_version(version);
647 buf[..4].copy_from_slice(&(total_data as u32).to_le_bytes());
648 let mut offset = 4usize;
649 match self {
650 #(#write_arms)*
651 }
652 offset
653 }
654
655 #[inline(always)]
656 fn read_version(version: u32, buf: &[u8]) -> ::std::result::Result<(Self, usize), ::vercode::InvalidEncoding> {
657 if buf.len() < 6 { return Err(::vercode::InvalidEncoding); }
658 let length = u32::from_le_bytes(buf[..4].try_into().unwrap()) as usize;
659 let discriminant = u16::from_le_bytes(buf[4..6].try_into().unwrap()) as u32;
660 let mut offset = 6usize;
661
662 let result = match discriminant {
663 #(#read_arms,)*
664 _ => return Err(::vercode::InvalidEncoding),
665 };
666 Ok((result, offset))
667 }
668
669 #[inline(always)]
670 fn size_version(&self, version: u32) -> usize {
671 let mut total = 6usize; match self {
673 #(#size_arms)*
674 }
675 total
676 }
677 }
678 };
679
680 TokenStream::from(expanded)
681}