Skip to main content

afastdata_macro/
lib.rs

1//! # afastdata-macro
2//!
3//! afastdata 序列化框架的 derive 宏,为结构体和枚举自动生成
4//! [`AFastSerialize`] 和 [`AFastDeserialize`] trait 的实现。
5//!
6//! Derive macros for the afastdata serialization framework, automatically generating
7//! implementations of [`AFastSerialize`]` and [`AFastDeserialize`] traits for
8//! structs and enums.
9//!
10//! ## 支持的类型 / Supported Types
11//!
12//! - **命名字段结构体 (Named-field struct)**:逐字段序列化/反序列化
13//!   / Field-by-field serialization/deserialization
14//! - **元组结构体 (Tuple struct)**:按索引逐字段序列化/反序列化
15//!   / Index-based field serialization/deserialization
16//! - **单元结构体 (Unit struct)**:生成空实现(不产生任何字节)
17//!   / Generates an empty implementation (no bytes produced)
18//! - **枚举 (Enum)**:写入 `u32` 变体索引 + 变体字段数据
19//!   / Writes a `u32` variant index + variant field data
20//!
21//! ## 编码格式 / Encoding Format
22//!
23//! ### 结构体 / Struct
24//!
25//! 所有字段按声明顺序依次调用 `to_bytes()` / `from_bytes()`,无额外前缀。
26//!
27//! All fields call `to_bytes()` / `from_bytes()` in declaration order, with no
28//! additional prefix.
29//!
30//! ### 枚举 / Enum
31//!
32//! | 编码内容 / Content | 类型 / Type | 说明 / Description |
33//! |---|---|---|
34//! | 变体索引 / Variant index | `u32` little-endian | 从 0 开始递增 / Starts from 0, incrementing |
35//! | 变体字段 / Variant fields | 逐字段编码 / Field-wise encoding | 仅非 unit 变体 / Only for non-unit variants |
36//!
37//! ## 泛型支持 / Generic Support
38//!
39//! 泛型参数会自动添加 `AFastSerialize` 和(对于反序列化)`AFastDeserialize` trait 约束。
40//! 如果泛型参数仅用于某些字段,生成的约束可能过于严格,但这保证了实现的正确性。
41//!
42//! Generic parameters automatically receive `AFastSerialize` and (for deserialization)
43//! `AFastDeserialize` trait bounds. If a generic parameter is only used in certain fields,
44//! the generated bounds may be overly strict, but this ensures correctness.
45//!
46//! ## 示例 / Example
47//!
48//! ```
49//! extern crate afastdata;
50//! use afastdata::{AFastSerialize, AFastDeserialize};
51//!
52//! #[derive(AFastSerialize, AFastDeserialize, Debug, PartialEq)]
53//! struct Point {
54//!     x: i32,
55//!     y: i32,
56//! }
57//!
58//! #[derive(AFastSerialize, AFastDeserialize, Debug, PartialEq)]
59//! enum Shape {
60//!     Circle(f64),
61//!     Rectangle { width: f64, height: f64 },
62//!     Empty,
63//! }
64//!
65//! // 序列化 / Serialize
66//! let point = Point { x: 10, y: 20 };
67//! let bytes = point.to_bytes();
68//!
69//! // 反序列化 / Deserialize
70//! let (decoded, _) = Point::from_bytes(&bytes).unwrap();
71//! assert_eq!(point, decoded);
72//! ```
73
74use proc_macro::TokenStream;
75use quote::quote;
76use syn::{
77    Attribute, Data, DeriveInput, Fields, Index, Lit, LitInt, LitStr, Meta, Path, Token, Type,
78    TypePath,
79    parse::{Parse, ParseStream},
80    parse_macro_input,
81    punctuated::Punctuated,
82};
83
84/// 返回枚举变体标签的类型和字节大小,基于编译时 feature 配置。
85///
86/// Returns the enum variant tag type and byte size based on compile-time feature config.
87fn tag_type() -> (proc_macro2::TokenStream, usize) {
88    if cfg!(feature = "tag-u16") {
89        (quote! { u16 }, 2)
90    } else if cfg!(feature = "tag-u32") {
91        (quote! { u32 }, 4)
92    } else {
93        (quote! { u8 }, 1)
94    }
95}
96
97/// 为结构体或枚举生成 [`AFastSerialize`] trait 实现。
98///
99/// Generates an [`AFastSerialize`] trait implementation for a struct or enum.
100///
101/// # 生成的代码 / Generated Code
102///
103/// ## 结构体 / Struct
104///
105/// 为每个字段调用 `to_bytes()`,并将结果依次追加到字节缓冲区中。
106///
107/// Calls `to_bytes()` on each field, appending the results to a byte buffer
108/// sequentially.
109///
110/// ```text
111/// // 以下为生成代码的示意(非实际代码)
112/// // The following is an illustration of generated code (not actual code)
113/// impl AFastSerialize for MyStruct {
114///     fn to_bytes(&self) -> Vec<u8> {
115///         let mut bytes = Vec::new();
116///         bytes.extend(AFastSerialize::to_bytes(&self.field1));
117///         bytes.extend(AFastSerialize::to_bytes(&self.field2));
118///         // ... 依次处理每个字段 / remaining fields processed similarly
119///         bytes
120///     }
121/// }
122/// ```
123///
124/// ## 枚举 / Enum
125///
126/// 先写入 `u8` 变体索引(从 0 开始),再写入变体的字段数据。
127/// Unit 变体只写入索引。可通过 feature 切换为 `u16` 或 `u32`。
128///
129/// Writes a `u8` variant index (starting from 0), then the variant's field data.
130/// Unit variants only write the index. Switchable to `u16` or `u32` via features.
131///
132/// ```text
133/// // 以下为生成代码的示意(非实际代码)
134/// // The following is an illustration of generated code (not actual code)
135/// impl AFastSerialize for MyEnum {
136///     fn to_bytes(&self) -> Vec<u8> {
137///         let mut bytes = Vec::new();
138///         match self {
139///             MyEnum::Variant1 => {
140///                 bytes.extend(0u8.to_le_bytes());
141///             }
142///             MyEnum::Variant2(field) => {
143///                 bytes.extend(1u8.to_le_bytes());
144///                 bytes.extend(AFastSerialize::to_bytes(field));
145///             }
146///             // ...
147///         }
148///         bytes
149///     }
150/// }
151/// ```
152///
153/// # 泛型 / Generics
154///
155/// 如果目标类型包含泛型参数,生成的 `impl` 会自动为这些参数添加
156/// `AFastSerialize` 约束。
157///
158/// If the target type contains generic parameters, the generated `impl` automatically
159/// adds `AFastSerialize` bounds to those parameters.
160///
161/// # Panics
162///
163/// 对 union 类型使用此宏会触发编译 panic。
164///
165/// Using this macro on a union type will trigger a compile-time panic.
166#[proc_macro_derive(AFastSerialize, attributes(afast))]
167pub fn derive_serialize(input: TokenStream) -> TokenStream {
168    let input = parse_macro_input!(input as DeriveInput);
169    let name = &input.ident;
170    let generics = &input.generics;
171
172    // 为泛型参数添加 AFastSerialize trait 约束
173    // Add AFastSerialize trait bounds to generic parameters
174    let mut generics_with_bounds = generics.clone();
175    for param in &mut generics_with_bounds.params {
176        if let syn::GenericParam::Type(ref mut ty) = *param {
177            ty.bounds
178                .push(syn::parse_quote!(::afastdata::AFastSerialize));
179        }
180    }
181    let (impl_generics, _, _) = generics_with_bounds.split_for_impl();
182    let (_, ty_generics, _) = generics.split_for_impl();
183
184    let expanded = match &input.data {
185        Data::Struct(data) => {
186            let serialize_body = generate_serialize_fields(&data.fields, quote!(self));
187            quote! {
188                impl #impl_generics ::afastdata::AFastSerialize for #name #ty_generics {
189                    fn to_bytes(&self) -> Vec<u8> {
190                        let mut bytes = Vec::new();
191                        #(#serialize_body)*
192                        bytes
193                    }
194                }
195            }
196        }
197        Data::Enum(data) => {
198            let (tag_ty, _) = tag_type();
199            let mut arms = Vec::new();
200            for (i, variant) in data.variants.iter().enumerate() {
201                let variant_name = &variant.ident;
202
203                match &variant.fields {
204                    Fields::Unit => {
205                        arms.push(quote! {
206                            #name::#variant_name => {
207                                bytes.extend((#i as #tag_ty).to_le_bytes());
208                            }
209                        });
210                    }
211                    Fields::Unnamed(fields) => {
212                        let field_names: Vec<_> = (0..fields.unnamed.len())
213                            .map(|i| {
214                                syn::Ident::new(&format!("__f{}", i), variant_name.span())
215                            })
216                            .collect();
217                        let field_patterns = &field_names;
218                        let mut serialize_fields = Vec::new();
219                        for (i, f) in fields.unnamed.iter().enumerate() {
220                            if !has_skip_attr(&f.attrs).0 {
221                                let fname = &field_names[i];
222                                serialize_fields.push(quote! {
223                                    bytes.extend(::afastdata::AFastSerialize::to_bytes(#fname));
224                                });
225                            }
226                        }
227                        arms.push(quote! {
228                            #name::#variant_name(#(#field_patterns),*) => {
229                                bytes.extend((#i as #tag_ty).to_le_bytes());
230                                #(#serialize_fields)*
231                            }
232                        });
233                    }
234                    Fields::Named(fields) => {
235                        let all_field_names: Vec<_> = fields
236                            .named
237                            .iter()
238                            .map(|f| f.ident.as_ref().unwrap())
239                            .collect();
240                        let non_skip_names: Vec<_> = fields
241                            .named
242                            .iter()
243                            .filter(|f| !has_skip_attr(&f.attrs).0)
244                            .map(|f| f.ident.as_ref().unwrap())
245                            .collect();
246                        let has_skip = non_skip_names.len() < all_field_names.len();
247                        let mut serialize_fields = Vec::new();
248                        for fname in &non_skip_names {
249                            serialize_fields.push(quote! {
250                                bytes.extend(::afastdata::AFastSerialize::to_bytes(#fname));
251                            });
252                        }
253                        if has_skip {
254                            arms.push(quote! {
255                                #name::#variant_name { #(#non_skip_names),*, .. } => {
256                                    bytes.extend((#i as #tag_ty).to_le_bytes());
257                                    #(#serialize_fields)*
258                                }
259                            });
260                        } else {
261                            arms.push(quote! {
262                                #name::#variant_name { #(#non_skip_names),* } => {
263                                    bytes.extend((#i as #tag_ty).to_le_bytes());
264                                    #(#serialize_fields)*
265                                }
266                            });
267                        }
268                    }
269                }
270            }
271
272            quote! {
273                impl #impl_generics ::afastdata::AFastSerialize for #name #ty_generics {
274                    fn to_bytes(&self) -> Vec<u8> {
275                        let mut bytes = Vec::new();
276                        match self {
277                            #(#arms)*
278                        }
279                        bytes
280                    }
281                }
282            }
283        }
284        Data::Union(_) => panic!("AFastSerialize does not support unions"),
285    };
286
287    TokenStream::from(expanded)
288}
289
290/// 为结构体或枚举生成 [`AFastDeserialize`] trait 实现。
291///
292/// Generates an [`AFastDeserialize`] trait implementation for a struct or enum.
293///
294/// # 生成的代码 / Generated Code
295///
296/// ## 结构体 / Struct
297///
298/// 为每个字段依次调用 `from_bytes()`,并使用偏移量追踪已消耗的字节数,
299/// 最后构造结构体实例。
300///
301/// Calls `from_bytes()` on each field sequentially, using an offset to track
302/// consumed bytes, then constructs the struct instance.
303///
304/// ```text
305/// // 以下为生成代码的示意(非实际代码)
306/// // The following is an illustration of generated code (not actual code)
307/// impl AFastDeserialize for MyStruct {
308///     fn from_bytes(data: &[u8]) -> Result<(Self, usize), Error> {
309///         let mut offset: usize = 0;
310///         let (__val, __new_offset) = AFastDeserialize::from_bytes(&data[offset..])?;
311///         let field1 = __val;
312///         offset += __new_offset;
313///         // ... 更多字段 / more fields ...
314///         Ok((MyStruct { field1, ... }, offset))
315///     }
316/// }
317/// ```
318///
319/// ## 枚举 / Enum
320///
321/// 先读取 `u8` 变体索引,根据索引匹配对应变体,再反序列化该变体的字段。
322/// 可通过 feature 切换为 `u16` 或 `u32`。
323///
324/// First reads the `u8` variant index, matches the corresponding variant by index,
325/// then deserializes the variant's fields.
326/// Switchable to `u16` or `u32` via features.
327///
328/// ```text
329/// // 以下为生成代码的示意(非实际代码)
330/// // The following is an illustration of generated code (not actual code)
331/// impl AFastDeserialize for MyEnum {
332///     fn from_bytes(data: &[u8]) -> Result<(Self, usize), Error> {
333///         let mut offset: usize = 0;
334///         let (__tag, __new_offset) = <u8 as AFastDeserialize>::from_bytes(&data[offset..])?;
335///         offset += __new_offset;
336///         match __tag as usize {
337///             0 => Ok((MyEnum::Variant1, offset)),
338///             1 => {
339///                 let (__val, __new_offset) = AFastDeserialize::from_bytes(&data[offset..])?;
340///                 offset += __new_offset;
341///                 Ok((MyEnum::Variant2(__val), offset))
342///             }
343///             v => Err(format!("Unknown variant tag: {} for MyEnum", v)),
344///         }
345///     }
346/// }
347/// ```
348///
349/// # 泛型 / Generics
350///
351/// 如果目标类型包含泛型参数,生成的 `impl` 会同时添加 `AFastSerialize` 和
352/// `AFastDeserialize` 约束。双重约束确保泛型类型在序列化和反序列化两个方向
353/// 上都可用。
354///
355/// If the target type contains generic parameters, the generated `impl` adds both
356/// `AFastSerialize` and `AFastDeserialize` bounds. The dual bounds ensure the generic
357/// type is available in both serialization and deserialization directions.
358///
359/// # Panics
360///
361/// 对 union 类型使用此宏会触发编译 panic。
362///
363/// Using this macro on a union type will trigger a compile-time panic.
364#[proc_macro_derive(AFastDeserialize, attributes(afast))]
365pub fn derive_deserialize(input: TokenStream) -> TokenStream {
366    let input = parse_macro_input!(input as DeriveInput);
367    let name = &input.ident;
368    let generics = &input.generics;
369
370    // 为泛型参数添加 AFastSerialize + AFastDeserialize trait 约束
371    // Add AFastSerialize + AFastDeserialize trait bounds to generic parameters
372    let mut generics_with_bounds = generics.clone();
373    for param in &mut generics_with_bounds.params {
374        if let syn::GenericParam::Type(ref mut ty) = *param {
375            ty.bounds
376                .push(syn::parse_quote!(::afastdata::AFastSerialize));
377            ty.bounds
378                .push(syn::parse_quote!(::afastdata::AFastDeserialize));
379        }
380    }
381    let (impl_generics, _, _) = generics_with_bounds.split_for_impl();
382    let (_, ty_generics, _) = generics.split_for_impl();
383
384    let expanded = match &input.data {
385        Data::Struct(data) => {
386            let (construct, field_desers) =
387                generate_deserialize_fields(&data.fields, name, &ty_generics);
388            quote! {
389                impl #impl_generics ::afastdata::AFastDeserialize for #name #ty_generics {
390                    fn from_bytes(data: &[u8]) -> Result<(Self, usize), ::afastdata::Error> {
391                        let mut offset: usize = 0;
392                        #(#field_desers)*
393                        Ok((#construct, offset))
394                    }
395                }
396            }
397        }
398        Data::Enum(data) => {
399            let (tag_ty, _) = tag_type();
400            let mut arms = Vec::new();
401            for (i, variant) in data.variants.iter().enumerate() {
402                let variant_name = &variant.ident;
403
404                match &variant.fields {
405                    Fields::Unit => {
406                        arms.push(quote! {
407                            #i => {
408                                Ok((#name::#variant_name, offset))
409                            }
410                        });
411                    }
412                    Fields::Unnamed(fields) => {
413                        let mut field_desers = Vec::new();
414                        let mut field_names = Vec::new();
415                        for (i, f) in fields.unnamed.iter().enumerate() {
416                            let fname = syn::Ident::new(
417                                &format!("__f{}", i),
418                                variant_name.span(),
419                            );
420                            let ftype = &f.ty;
421                            let (skip, default_fn) = has_skip_attr(&f.attrs);
422                            if skip {
423                                if let Some(func_name) = default_fn {
424                                    match syn::parse_str::<syn::Ident>(&func_name) {
425                                        Ok(ident) => {
426                                            field_desers.push(quote! {
427                                                let #fname: #ftype = #ident();
428                                            });
429                                        }
430                                        Err(_) => {
431                                            field_desers.push(quote! {
432                                                compile_error!(concat!("invalid function name in skip: ", #func_name));
433                                            });
434                                        }
435                                    }
436                                } else {
437                                    field_desers.push(quote! {
438                                        let #fname: #ftype = <#ftype as ::std::default::Default>::default();
439                                    });
440                                }
441                            } else {
442                                let validates = parse_validations(&fname, ftype, &f.attrs);
443                                field_desers.push(quote! {
444                                    let (__val, __new_offset) = ::afastdata::AFastDeserialize::from_bytes(&data[offset..])?;
445                                    let #fname: #ftype = __val;
446                                    #(#validates)*
447                                    offset += __new_offset;
448                                });
449                            }
450                            field_names.push(fname);
451                        }
452                        arms.push(quote! {
453                            #i => {
454                                #(#field_desers)*
455                                Ok((#name::#variant_name(#(#field_names),*), offset))
456                            }
457                        });
458                    }
459                    Fields::Named(fields) => {
460                        let mut field_desers = Vec::new();
461                        let mut field_names = Vec::new();
462                        for f in &fields.named {
463                            let fname = f.ident.as_ref().unwrap();
464                            let ftype = &f.ty;
465                            let (skip, default_fn) = has_skip_attr(&f.attrs);
466                            if skip {
467                                if let Some(func_name) = default_fn {
468                                    match syn::parse_str::<syn::Ident>(&func_name) {
469                                        Ok(ident) => {
470                                            field_desers.push(quote! {
471                                                let #fname: #ftype = #ident();
472                                            });
473                                        }
474                                        Err(_) => {
475                                            field_desers.push(quote! {
476                                                compile_error!(concat!("invalid function name in skip: ", #func_name));
477                                            });
478                                        }
479                                    }
480                                } else {
481                                    field_desers.push(quote! {
482                                        let #fname: #ftype = <#ftype as ::std::default::Default>::default();
483                                    });
484                                }
485                            } else {
486                                let validates = parse_validations(fname, ftype, &f.attrs);
487                                field_desers.push(quote! {
488                                    let (__val, __new_offset) = ::afastdata::AFastDeserialize::from_bytes(&data[offset..])?;
489                                    let #fname: #ftype = __val;
490                                    #(#validates)*
491                                    offset += __new_offset;
492                                });
493                            }
494                            field_names.push(fname);
495                        }
496                        arms.push(quote! {
497                            #i => {
498                                #(#field_desers)*
499                                Ok((#name::#variant_name { #(#field_names),* }, offset))
500                            }
501                        });
502                    }
503                }
504            }
505
506            quote! {
507                impl #impl_generics ::afastdata::AFastDeserialize for #name #ty_generics {
508                    fn from_bytes(data: &[u8]) -> Result<(Self, usize), ::afastdata::Error> {
509                        let mut offset: usize = 0;
510                        let (__tag_bytes, __new_offset) = <#tag_ty as ::afastdata::AFastDeserialize>::from_bytes(&data[offset..])?;
511                        offset += __new_offset;
512                        match __tag_bytes as usize {
513                            #(#arms)*
514                            v => Err(::afastdata::Error::deserialize(format!("Unknown variant tag: {} for {}", v, ::std::stringify!(#name)))),
515                        }
516                    }
517                }
518            }
519        }
520        Data::Union(_) => panic!("AFastDeserialize does not support unions"),
521    };
522
523    TokenStream::from(expanded)
524}
525
526/// 为结构体的字段生成序列化代码。内部辅助函数。
527///
528/// Generates serialization code for struct fields. Internal helper.
529///
530/// # 参数 / Parameters
531///
532/// - `fields`:结构体的字段定义 / The struct's field definitions
533/// - `self_prefix`:访问字段时使用的前缀(如 `self` 或变体解构变量)
534///   / The prefix used to access fields (e.g., `self` or a variant destructure variable)
535///
536/// # 返回值 / Returns
537///
538/// 返回一个 `TokenStream` 列表,每个元素对应一个字段的序列化语句。
539///
540/// Returns a list of `TokenStream`s, each corresponding to a serialization statement
541/// for one field.
542///
543/// # 生成格式 / Generated Format
544///
545/// - **命名字段 (Named)**:`bytes.extend(AFastSerialize::to_bytes(&self.field_name));`
546/// - **元组字段 (Unnamed)**:`bytes.extend(AFastSerialize::to_bytes(&self.0));`
547/// - **单元字段 (Unit)**:不生成任何代码 / Generates no code
548fn generate_serialize_fields(
549    fields: &Fields,
550    self_prefix: proc_macro2::TokenStream,
551) -> Vec<proc_macro2::TokenStream> {
552    match fields {
553        Fields::Named(named) => named
554            .named
555            .iter()
556            .filter(|f| !has_skip_attr(&f.attrs).0)
557            .map(|f| {
558                let fname = f.ident.as_ref().unwrap();
559                quote! {
560                    bytes.extend(::afastdata::AFastSerialize::to_bytes(&#self_prefix.#fname));
561                }
562            })
563            .collect(),
564        Fields::Unnamed(unnamed) => unnamed
565            .unnamed
566            .iter()
567            .enumerate()
568            .filter(|(_, f)| !has_skip_attr(&f.attrs).0)
569            .map(|(i, _)| {
570                let idx = Index::from(i);
571                quote! {
572                    bytes.extend(::afastdata::AFastSerialize::to_bytes(&#self_prefix.#idx));
573                }
574            })
575            .collect(),
576        Fields::Unit => vec![],
577    }
578}
579
580fn has_skip_attr(attrs: &[Attribute]) -> (bool, Option<String>) {
581    for attr in attrs {
582        if attr.path().is_ident("afast")
583            && let Ok(nested) =
584                attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
585        {
586            for meta in nested {
587                match meta {
588                    Meta::Path(path) if path.is_ident("skip") => {
589                        return (true, None);
590                    }
591                    Meta::List(meta_list) if meta_list.path.is_ident("skip") => {
592                        if let Ok(lit_str) = syn::parse2::<LitStr>(meta_list.tokens.clone()) {
593                            return (true, Some(lit_str.value()));
594                        } else {
595                            return (true, None);
596                        }
597                    }
598                    _ => {}
599                }
600            }
601        }
602    }
603    (false, None)
604}
605
606struct Range {
607    int: LitInt,
608    _comma1: Token![,],
609    code: LitInt,
610    _comma2: Token![,],
611    msg: LitStr,
612}
613
614impl Parse for Range {
615    fn parse(input: ParseStream) -> syn::Result<Self> {
616        Ok(Range {
617            int: input.parse()?,
618            _comma1: input.parse()?,
619            code: input.parse()?,
620            _comma2: input.parse()?,
621            msg: input.parse()?,
622        })
623    }
624}
625
626struct Length {
627    min: LitInt,
628    _comma1: Token![,],
629    max: LitInt,
630    _comma2: Token![,],
631    code: LitInt,
632    _comma3: Token![,],
633    msg: LitStr,
634}
635
636impl Parse for Length {
637    fn parse(input: ParseStream) -> syn::Result<Self> {
638        Ok(Length {
639            min: input.parse()?,
640            _comma1: input.parse()?,
641            max: input.parse()?,
642            _comma2: input.parse()?,
643            code: input.parse()?,
644            _comma3: input.parse()?,
645            msg: input.parse()?,
646        })
647    }
648}
649
650#[derive(Clone)]
651enum ValidateValue {
652    Int(i64),
653    Float(f64),
654    Bool(bool),
655    Str(String),
656}
657
658impl ValidateValue {
659    /// 将值转换为代码生成中使用的 TokenStream
660    ///
661    /// 例如:
662    /// ValidateValue::Int(42) → quote! { 42 }
663    /// ValidateValue::Str("hello") → quote! { "hello" }
664    fn to_token_stream(&self) -> proc_macro2::TokenStream {
665        match self {
666            ValidateValue::Int(v) => quote! { #v },
667
668            ValidateValue::Float(v) => {
669                // 浮点数通过字符串解析来保持精度
670                let v_str = v.to_string();
671                v_str.parse().unwrap_or_else(|_| {
672                    // 如果字符串解析失败,使用直接值
673                    quote! { #v }
674                })
675            }
676
677            ValidateValue::Bool(v) => quote! { #v },
678            ValidateValue::Str(v) => quote! { #v },
679        }
680    }
681}
682
683struct OfValidator {
684    allowed_values: Vec<ValidateValue>,
685    code: syn::LitInt,
686    msg: syn::LitStr,
687}
688
689impl Parse for OfValidator {
690    fn parse(input: ParseStream) -> syn::Result<Self> {
691        let content;
692        syn::bracketed!(content in input);
693
694        let mut values = Vec::new();
695
696        if !content.is_empty() {
697            loop {
698                let lit = content.parse::<Lit>()?;
699
700                let value = match lit {
701                    Lit::Int(lit_int) => {
702                        let int_value: i64 = lit_int.base10_parse()?;
703                        ValidateValue::Int(int_value)
704                    }
705
706                    Lit::Float(lit_float) => {
707                        let float_value: f64 = lit_float.base10_parse()?;
708                        ValidateValue::Float(float_value)
709                    }
710
711                    Lit::Bool(lit_bool) => ValidateValue::Bool(lit_bool.value),
712
713                    Lit::Str(lit_str) => ValidateValue::Str(lit_str.value()),
714
715                    _ => {
716                        return Err(syn::Error::new_spanned(
717                            &lit,
718                            "unsupported literal type in 'of' validator; \
719                             only int, float, bool, and str literals are supported",
720                        ));
721                    }
722                };
723
724                values.push(value);
725
726                if !content.peek(Token![,]) {
727                    break;
728                }
729
730                content.parse::<Token![,]>()?;
731
732                if content.is_empty() {
733                    break;
734                }
735            }
736        }
737
738        input.parse::<Token![,]>()?;
739
740        let code = input.parse::<syn::LitInt>()?;
741
742        input.parse::<Token![,]>()?;
743
744        let msg = input.parse::<syn::LitStr>()?;
745
746        Ok(OfValidator {
747            allowed_values: values,
748            code,
749            msg,
750        })
751    }
752}
753
754/// 解析字段上的 `#[afast(...)]` 校验属性,生成校验代码。
755///
756/// Parses `#[afast(...)]` validation attributes on a field and generates
757/// validation code blocks.
758///
759/// # 参数 / Parameters
760///
761/// - `field_name`:字段名 / Field name
762/// - `field_type`:字段类型 / Field type
763/// - `attrs`:字段的属性列表 / Field attributes
764///
765/// # 返回值 / Returns
766///
767/// 返回校验语句的 `TokenStream` 列表。
768///
769/// Returns a list of validation `TokenStream` blocks.
770fn parse_validations(
771    field_name: &syn::Ident,
772    field_type: &Type,
773    attrs: &[Attribute],
774) -> Vec<proc_macro2::TokenStream> {
775    let mut validates = Vec::new();
776    for attr in attrs {
777        if attr.path().is_ident("afast") {
778            let nested = match attr
779                .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
780            {
781                Ok(n) => n,
782                Err(e) => {
783                    validates.push(e.to_compile_error());
784                    continue;
785                }
786            };
787            for meta in nested {
788                if let Meta::List(meta) = meta {
789                        if meta.path.is_ident("gt") {
790                            let inner = match meta.parse_args::<Range>() {
791                                Ok(v) => v,
792                                Err(e) => {
793                                    validates.push(e.to_compile_error());
794                                    continue;
795                                }
796                            };
797                            let gt_value = match inner.int.base10_parse::<i64>() {
798                                Ok(v) => v,
799                                Err(e) => {
800                                    validates.push(
801                                        syn::Error::new_spanned(&inner.int, format!("invalid integer value: {}", e))
802                                            .to_compile_error(),
803                                    );
804                                    continue;
805                                }
806                            };
807                            let code = match inner.code.base10_parse::<i64>() {
808                                Ok(v) => v,
809                                Err(e) => {
810                                    validates.push(
811                                        syn::Error::new_spanned(&inner.code, format!("invalid error code: {}", e))
812                                            .to_compile_error(),
813                                    );
814                                    continue;
815                                }
816                            };
817                            let err_msg = inner
818                                .msg
819                                .value()
820                                .replace("${field}", &field_name.to_string());
821                            validates.push(quote! {
822                                if #field_name <= #gt_value {
823                                    return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
824                                }
825                            });
826                        } else if meta.path.is_ident("gte") {
827                            let inner = match meta.parse_args::<Range>() {
828                                Ok(v) => v,
829                                Err(e) => {
830                                    validates.push(e.to_compile_error());
831                                    continue;
832                                }
833                            };
834                            let gt_value = match inner.int.base10_parse::<i64>() {
835                                Ok(v) => v,
836                                Err(e) => {
837                                    validates.push(
838                                        syn::Error::new_spanned(&inner.int, format!("invalid integer value: {}", e))
839                                            .to_compile_error(),
840                                    );
841                                    continue;
842                                }
843                            };
844                            let code = match inner.code.base10_parse::<i64>() {
845                                Ok(v) => v,
846                                Err(e) => {
847                                    validates.push(
848                                        syn::Error::new_spanned(&inner.code, format!("invalid error code: {}", e))
849                                            .to_compile_error(),
850                                    );
851                                    continue;
852                                }
853                            };
854                            let err_msg = inner
855                                .msg
856                                .value()
857                                .replace("${field}", &field_name.to_string());
858                            validates.push(quote! {
859                                if #field_name < #gt_value {
860                                    return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
861                                }
862                            });
863                        } else if meta.path.is_ident("lt") {
864                            let inner = match meta.parse_args::<Range>() {
865                                Ok(v) => v,
866                                Err(e) => {
867                                    validates.push(e.to_compile_error());
868                                    continue;
869                                }
870                            };
871                            let lt_value = match inner.int.base10_parse::<i64>() {
872                                Ok(v) => v,
873                                Err(e) => {
874                                    validates.push(
875                                        syn::Error::new_spanned(&inner.int, format!("invalid integer value: {}", e))
876                                            .to_compile_error(),
877                                    );
878                                    continue;
879                                }
880                            };
881                            let code = match inner.code.base10_parse::<i64>() {
882                                Ok(v) => v,
883                                Err(e) => {
884                                    validates.push(
885                                        syn::Error::new_spanned(&inner.code, format!("invalid error code: {}", e))
886                                            .to_compile_error(),
887                                    );
888                                    continue;
889                                }
890                            };
891                            let err_msg = inner
892                                .msg
893                                .value()
894                                .replace("${field}", &field_name.to_string());
895                            validates.push(quote! {
896                                if #field_name >= #lt_value {
897                                    return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
898                                }
899                            });
900                        } else if meta.path.is_ident("lte") {
901                            let inner = match meta.parse_args::<Range>() {
902                                Ok(v) => v,
903                                Err(e) => {
904                                    validates.push(e.to_compile_error());
905                                    continue;
906                                }
907                            };
908                            let lt_value = match inner.int.base10_parse::<i64>() {
909                                Ok(v) => v,
910                                Err(e) => {
911                                    validates.push(
912                                        syn::Error::new_spanned(&inner.int, format!("invalid integer value: {}", e))
913                                            .to_compile_error(),
914                                    );
915                                    continue;
916                                }
917                            };
918                            let code = match inner.code.base10_parse::<i64>() {
919                                Ok(v) => v,
920                                Err(e) => {
921                                    validates.push(
922                                        syn::Error::new_spanned(&inner.code, format!("invalid error code: {}", e))
923                                            .to_compile_error(),
924                                    );
925                                    continue;
926                                }
927                            };
928                            let err_msg = inner
929                                .msg
930                                .value()
931                                .replace("${field}", &field_name.to_string());
932                            validates.push(quote! {
933                                if #field_name > #lt_value {
934                                    return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
935                                }
936                            });
937                        } else if meta.path.is_ident("len") {
938                            let field_is_option = match field_type {
939                                Type::Path(TypePath {
940                                    path: Path { segments, .. },
941                                    ..
942                                }) => {
943                                    segments.len() == 1 && segments[0].ident == "Option"
944                                }
945                                _ => false,
946                            };
947
948                            let inner = match meta.parse_args::<Length>() {
949                                Ok(v) => v,
950                                Err(e) => {
951                                    validates.push(e.to_compile_error());
952                                    continue;
953                                }
954                            };
955                            let min_value = match inner.min.base10_parse::<i64>() {
956                                Ok(v) => v,
957                                Err(e) => {
958                                    validates.push(
959                                        syn::Error::new_spanned(&inner.min, format!("invalid min value: {}", e))
960                                            .to_compile_error(),
961                                    );
962                                    continue;
963                                }
964                            };
965                            let max_value = match inner.max.base10_parse::<i64>() {
966                                Ok(v) => v,
967                                Err(e) => {
968                                    validates.push(
969                                        syn::Error::new_spanned(&inner.max, format!("invalid max value: {}", e))
970                                            .to_compile_error(),
971                                    );
972                                    continue;
973                                }
974                            };
975                            let code = match inner.code.base10_parse::<i64>() {
976                                Ok(v) => v,
977                                Err(e) => {
978                                    validates.push(
979                                        syn::Error::new_spanned(&inner.code, format!("invalid error code: {}", e))
980                                            .to_compile_error(),
981                                    );
982                                    continue;
983                                }
984                            };
985                            let err_msg = inner
986                                .msg
987                                .value()
988                                .replace("${field}", &field_name.to_string());
989                            if min_value > max_value {
990                                validates.push(
991                                    syn::Error::new_spanned(
992                                        &meta.path,
993                                        format!(
994                                            "invalid len validation: min ({}) > max ({}) for field `{}`",
995                                            min_value, max_value, field_name
996                                        ),
997                                    )
998                                    .to_compile_error(),
999                                );
1000                                continue;
1001                            }
1002                            if min_value < 0 && max_value < 0 {
1003                                validates.push(
1004                                    syn::Error::new_spanned(
1005                                        &meta.path,
1006                                        format!(
1007                                            "invalid len validation: both min and max are negative for field `{}`",
1008                                            field_name
1009                                        ),
1010                                    )
1011                                    .to_compile_error(),
1012                                );
1013                                continue;
1014                            } else if min_value < 0 {
1015                                let max: usize = match max_value.try_into() {
1016                                    Ok(v) => v,
1017                                    Err(_) => {
1018                                        validates.push(
1019                                            syn::Error::new_spanned(&inner.max, "value too large for usize")
1020                                                .to_compile_error(),
1021                                        );
1022                                        continue;
1023                                    }
1024                                };
1025                                validates.push(quote! {
1026                                    if #field_name.len() > #max {
1027                                        return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
1028                                    }
1029                                });
1030                            } else if max_value < 0 {
1031                                let min: usize = match min_value.try_into() {
1032                                    Ok(v) => v,
1033                                    Err(_) => {
1034                                        validates.push(
1035                                            syn::Error::new_spanned(&inner.min, "value too large for usize")
1036                                                .to_compile_error(),
1037                                        );
1038                                        continue;
1039                                    }
1040                                };
1041                                validates.push(quote! {
1042                                    if #field_name.len() < #min {
1043                                        return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
1044                                    }
1045                                });
1046                            } else {
1047                                let min: usize = match min_value.try_into() {
1048                                    Ok(v) => v,
1049                                    Err(_) => {
1050                                        validates.push(
1051                                            syn::Error::new_spanned(&inner.min, "value too large for usize")
1052                                                .to_compile_error(),
1053                                        );
1054                                        continue;
1055                                    }
1056                                };
1057                                let max: usize = match max_value.try_into() {
1058                                    Ok(v) => v,
1059                                    Err(_) => {
1060                                        validates.push(
1061                                            syn::Error::new_spanned(&inner.max, "value too large for usize")
1062                                                .to_compile_error(),
1063                                        );
1064                                        continue;
1065                                    }
1066                                };
1067                                if field_is_option {
1068                                    validates.push(quote! {
1069                                        let length = match &#field_name {
1070                                            Some(s) => {
1071                                                let __length = s.len();
1072                                                if __length < #min || __length > #max {
1073                                                    return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
1074                                                }
1075                                            },
1076                                            None => {},
1077                                        };
1078                                    });
1079                                } else {
1080                                    validates.push(quote! {
1081                                        if #field_name.len() < #min || #field_name.len() > #max {
1082                                            return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
1083                                        }
1084                                    });
1085                                }
1086                            }
1087                        } else if meta.path.is_ident("of") {
1088                            let inner = match meta.parse_args::<OfValidator>() {
1089                                Ok(v) => v,
1090                                Err(e) => {
1091                                    validates.push(e.to_compile_error());
1092                                    continue;
1093                                }
1094                            };
1095                            let allowed_values = inner.allowed_values.clone();
1096                            let code = match inner.code.base10_parse::<i64>() {
1097                                Ok(v) => v,
1098                                Err(e) => {
1099                                    validates.push(
1100                                        syn::Error::new_spanned(&inner.code, format!("invalid error code: {}", e))
1101                                            .to_compile_error(),
1102                                    );
1103                                    continue;
1104                                }
1105                            };
1106                            let err_msg = inner
1107                                .msg
1108                                .value()
1109                                .replace("${field}", &field_name.to_string());
1110                            let values_tokens: Vec<_> = allowed_values
1111                                .iter()
1112                                .map(|v| v.to_token_stream())
1113                                .collect();
1114                            validates.push(quote! {
1115                                if !matches!(#field_name, #(#values_tokens)|*) {
1116                                    return Err(::afastdata::Error::validate(#code, #err_msg.to_string()));
1117                                }
1118                            });
1119                        } else if meta.path.is_ident("func") {
1120                            let inner = match meta.parse_args::<LitStr>() {
1121                                Ok(v) => v,
1122                                Err(e) => {
1123                                    validates.push(e.to_compile_error());
1124                                    continue;
1125                                }
1126                            };
1127                            let ident = match syn::parse_str::<syn::Ident>(&inner.value()) {
1128                                Ok(v) => v,
1129                                Err(e) => {
1130                                    validates.push(
1131                                        syn::Error::new_spanned(
1132                                            &inner,
1133                                            format!("invalid function name `{}`: {}", inner.value(), e),
1134                                        )
1135                                        .to_compile_error(),
1136                                    );
1137                                    continue;
1138                                }
1139                            };
1140                            let field = field_name.to_string();
1141                            validates.push(quote! {
1142                                match #ident(&#field_name, #field) {
1143                                    Ok(()) => {},
1144                                    Err(e) => return Err(e.to_afastdata_error()),
1145                                }
1146                            });
1147                        }
1148                }
1149            }
1150        }
1151    }
1152    validates
1153}
1154
1155/// 为结构体的字段生成反序列化代码以及构造表达式。内部辅助函数。
1156///
1157/// Generates deserialization code for struct fields along with the construction
1158/// expression. Internal helper.
1159///
1160/// # 参数 / Parameters
1161///
1162/// - `fields`:结构体的字段定义 / The struct's field definitions
1163/// - `name`:结构体类型的标识符 / The struct type's identifier
1164/// - `ty_generics`:类型的泛型参数(用于构造时的 turbofish 语法)
1165///   / The type's generic parameters (used for turbofish syntax during construction)
1166///
1167/// # 返回值 / Returns
1168///
1169/// 返回 `(构造表达式, 反序列化语句列表)`:
1170/// - 构造表达式:用于创建结构体实例的 `TokenStream`
1171/// - 反序列化语句:每个字段的 `from_bytes()` 调用和偏移量更新
1172///
1173/// Returns `(construction_expression, deserialization_statements)`:
1174/// - Construction expression: A `TokenStream` for creating the struct instance
1175/// - Deserialization statements: `from_bytes()` calls and offset updates for each field
1176///
1177/// # 泛型构造 / Generic Construction
1178///
1179/// 使用 `as_turbofish()` 生成正确的泛型语法。例如 `MyStruct::<T>` 而非
1180/// `MyStruct <T>`(后者会被解析为比较操作)。
1181///
1182/// Uses `as_turbofish()` to generate correct generic syntax. For example,
1183/// `MyStruct::<T>` instead of `MyStruct <T>` (which would be parsed as a
1184/// comparison operation).
1185fn generate_deserialize_fields(
1186    fields: &Fields,
1187    name: &syn::Ident,
1188    ty_generics: &syn::TypeGenerics,
1189) -> (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>) {
1190    // 在表达式上下文中使用 turbofish 语法:Name::<T>
1191    // In expression context, use turbofish syntax: Name::<T>
1192    let ty_params = ty_generics.as_turbofish();
1193    match fields {
1194        Fields::Named(named) => {
1195            let mut desers = Vec::new();
1196            let mut field_names = Vec::new();
1197            for f in &named.named {
1198                let fname = f.ident.as_ref().unwrap();
1199                let ftype = &f.ty;
1200                field_names.push(fname.clone());
1201
1202                let validates = parse_validations(fname, ftype, &f.attrs);
1203                let (skip, default) = has_skip_attr(&f.attrs);
1204                if skip {
1205                    if let Some(default) = default {
1206                        match syn::parse_str::<syn::Ident>(&default) {
1207                            Ok(ident) => {
1208                                desers.push(quote! {
1209                                    let #fname: #ftype = #ident();
1210                                });
1211                            }
1212                            Err(_) => {
1213                                desers.push(quote! {
1214                                    compile_error!(concat!("invalid function name in skip: ", #default));
1215                                });
1216                            }
1217                        }
1218                    } else {
1219                        desers.push(quote! {
1220                            let #fname: #ftype = #ftype::default();
1221                        });
1222                    }
1223                } else {
1224                    desers.push(quote! {
1225                        let (__val, __new_offset) = ::afastdata::AFastDeserialize::from_bytes(&data[offset..])?;
1226                        let #fname: #ftype = __val;
1227                        #(#validates)*
1228                        offset += __new_offset;
1229                    });
1230                }
1231            }
1232            let construct = quote! {
1233                #name #ty_params { #(#field_names),* }
1234            };
1235            (construct, desers)
1236        }
1237        Fields::Unnamed(unnamed) => {
1238            let mut desers = Vec::new();
1239            let mut field_names = Vec::new();
1240            for (i, f) in unnamed.unnamed.iter().enumerate() {
1241                let fname = syn::Ident::new(&format!("__f{}", i), name.span());
1242                let ftype = &f.ty;
1243                let (skip, default_fn) = has_skip_attr(&f.attrs);
1244                if skip {
1245                    if let Some(func_name) = default_fn {
1246                        match syn::parse_str::<syn::Ident>(&func_name) {
1247                            Ok(ident) => {
1248                                desers.push(quote! {
1249                                    let #fname: #ftype = #ident();
1250                                });
1251                            }
1252                            Err(_) => {
1253                                desers.push(quote! {
1254                                    compile_error!(concat!("invalid function name in skip: ", #func_name));
1255                                });
1256                            }
1257                        }
1258                    } else {
1259                        desers.push(quote! {
1260                            let #fname: #ftype = <#ftype as ::std::default::Default>::default();
1261                        });
1262                    }
1263                } else {
1264                    let validates = parse_validations(&fname, ftype, &f.attrs);
1265                    desers.push(quote! {
1266                        let (__val, __new_offset) = ::afastdata::AFastDeserialize::from_bytes(&data[offset..])?;
1267                        let #fname: #ftype = __val;
1268                        #(#validates)*
1269                        offset += __new_offset;
1270                    });
1271                }
1272                field_names.push(fname);
1273            }
1274            let construct = quote! {
1275                #name #ty_params ( #(#field_names),* )
1276            };
1277            (construct, desers)
1278        }
1279        Fields::Unit => {
1280            let construct = quote! { #name #ty_params };
1281            (construct, vec![])
1282        }
1283    }
1284}