binary_mirror_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, LitStr};
4
5#[proc_macro_derive(BinaryMirror, attributes(bm))]
6pub fn binary_mirror_derive(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    impl_binary_mirror(&input)
9}
10
11#[proc_macro_derive(BinaryEnum, attributes(bv))]
12pub fn binary_enum_derive(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14    impl_binary_enum(&input)
15}
16
17#[derive(Debug, Clone)]
18struct FieldAttrs {
19    type_name: String,
20    alias: Option<String>,
21    format: Option<String>,
22    datetime_with: Option<String>,
23    skip: bool,
24    skip_native: bool,
25    enum_type: Option<String>,
26    default_byte: Option<u8>,
27    ignore_warn: bool,
28    default_func: Option<String>,
29}
30
31#[derive(Debug, Clone)]
32struct OriginField {
33    name: syn::Ident,
34    size: usize,
35    attrs: Option<FieldAttrs>,
36}
37
38#[derive(Debug, Clone)]
39struct NativeField {
40    name: syn::Ident,
41    ty: proc_macro2::TokenStream,
42    type_name: String,
43    pure_ty: proc_macro2::TokenStream,
44    origin_fields: Vec<OriginField>,
45    is_combined_datetime: bool,
46    default_func: Option<String>,
47    skip_native: bool,
48}
49
50#[derive(Debug)]
51struct NativeField2OriginFieldMap {
52    origin_field: OriginField,
53    native_field: Option<NativeField>,
54}
55
56#[derive(Debug, Clone)]
57struct StructAttrs {
58    derives: Vec<syn::Path>,
59}
60
61fn get_struct_attrs(input: &DeriveInput) -> StructAttrs {
62    let attrs = &input.attrs;
63    let mut struct_attrs = StructAttrs { derives: vec![] };
64    for attr in attrs {
65        if attr.path().is_ident("bm") {
66            let _ = attr.parse_nested_meta(|meta| {
67                if meta.path.is_ident("derive") {
68                    let content;
69                    syn::parenthesized!(content in meta.input);
70                    let derives: syn::punctuated::Punctuated<syn::Path, syn::Token![,]> = content
71                        .parse_terminated(syn::parse::Parse::parse, syn::Token![,])
72                        .expect("derive");
73                    struct_attrs.derives = derives.into_iter().collect();
74                }
75                Ok(())
76            });
77        }
78    }
79    struct_attrs
80}
81
82fn get_field_attrs(attrs: &[syn::Attribute]) -> Option<FieldAttrs> {
83    for attr in attrs {
84        if attr.path().is_ident("bm") {
85            let mut field_attrs = FieldAttrs {
86                type_name: String::new(),
87                alias: None,
88                format: None,
89                datetime_with: None,
90                skip: false,
91                skip_native: false,
92                enum_type: None,
93                default_byte: None,
94                ignore_warn: false,
95                default_func: None,
96            };
97
98            let _ = attr.parse_nested_meta(|meta| {
99                if meta.path.is_ident("type") {
100                    let lit = meta.value()?.parse::<LitStr>()?;
101                    field_attrs.type_name = lit.value();
102                } else if meta.path.is_ident("alias") {
103                    let lit = meta.value()?.parse::<LitStr>()?;
104                    field_attrs.alias = Some(lit.value());
105                } else if meta.path.is_ident("format") {
106                    let lit = meta.value()?.parse::<LitStr>()?;
107                    field_attrs.format = Some(lit.value());
108                } else if meta.path.is_ident("datetime_with") {
109                    let lit = meta.value()?.parse::<LitStr>()?;
110                    field_attrs.datetime_with = Some(lit.value());
111                } else if meta.path.is_ident("skip") {
112                    field_attrs.skip = meta.value()?.parse::<syn::LitBool>()?.value();
113                } else if meta.path.is_ident("skip_native") {
114                    field_attrs.skip_native = meta.value()?.parse::<syn::LitBool>()?.value();
115                } else if meta.path.is_ident("enum_type") {
116                    let lit = meta.value()?.parse::<LitStr>()?;
117                    field_attrs.enum_type = Some(lit.value());
118                } else if meta.path.is_ident("default_byte") {
119                    let lit = meta.value()?.parse::<syn::LitByte>()?;
120                    field_attrs.default_byte = Some(lit.value());
121                } else if meta.path.is_ident("ignore_warn") {
122                    field_attrs.ignore_warn = meta.value()?.parse::<syn::LitBool>()?.value();
123                } else if meta.path.is_ident("default_func") {
124                    let lit = meta.value()?.parse::<syn::LitStr>()?;
125                    field_attrs.default_func = Some(lit.value());
126                }
127                Ok(())
128            });
129
130            if !field_attrs.type_name.is_empty() {
131                return Some(field_attrs);
132            }
133        }
134    }
135    None
136}
137
138fn get_origin_fields(input: &DeriveInput) -> Vec<OriginField> {
139    let fields = match &input.data {
140        Data::Struct(data) => match &data.fields {
141            Fields::Named(fields) => &fields.named,
142            _ => panic!("Only named fields are supported"),
143        },
144        _ => panic!("Only structs are supported"),
145    };
146
147    fields
148        .iter()
149        .map(|field| {
150            let name = field.ident.clone().unwrap();
151
152            // Check if field is [u8] array and get size
153            let size = if let syn::Type::Array(array) = &field.ty {
154                if let syn::Expr::Lit(syn::ExprLit {
155                    lit: syn::Lit::Int(ref lit_int),
156                    ..
157                }) = array.len
158                {
159                    lit_int
160                        .base10_parse::<usize>()
161                        .expect("Could not parse array length")
162                } else {
163                    panic!("Field {} array length must be a literal integer", name);
164                }
165            } else {
166                panic!("Field {} must be a [u8] array", name);
167            };
168
169            OriginField {
170                name,
171                size,
172                attrs: get_field_attrs(&field.attrs),
173            }
174        })
175        .collect()
176}
177
178fn get_native_fields_and_map(origin_fields: &[OriginField]) -> (Vec<NativeField>, Vec<NativeField2OriginFieldMap>) {
179    let mut native_fields = Vec::new();
180    let mut native_field_map = Vec::new();
181    let mut processed = std::collections::HashSet::new();
182
183    for field in origin_fields {
184        if let Some(attrs) = &field.attrs {
185            // Skip if this field has already been processed
186            if processed.contains(&field.name.to_string()) {
187                continue;
188            }
189
190            let field_name = if let Some(alias) = &attrs.alias {
191                quote::format_ident!("{}", alias)
192            } else {
193                field.name.clone()
194            };
195
196            match attrs.type_name.as_str() {
197                "date" | "time" if attrs.datetime_with.is_some() => {
198                    let other_field_name = attrs.datetime_with.as_ref().unwrap();
199                    let other_field = origin_fields
200                        .iter()
201                        .find(|f| f.name == quote::format_ident!("{}", other_field_name))
202                        .expect("Could not find datetime pair field");
203
204                    // Mark both fields as processed
205                    processed.insert(field.name.to_string());
206                    processed.insert(other_field.name.to_string());
207
208                    // Determine which is date and which is time
209                    let (date_field, time_field) = if attrs.type_name == "date" {
210                        (field, other_field)
211                    } else {
212                        (other_field, field)
213                    };
214
215                    let native_field = NativeField {
216                        name: field_name,
217                        ty: quote!(Option<chrono::NaiveDateTime>),
218                        type_name: "datetime".to_string(),
219                        pure_ty: quote!(chrono::NaiveDateTime),
220                        origin_fields: vec![date_field.clone(), time_field.clone()],
221                        is_combined_datetime: true,
222                        default_func: attrs.default_func.clone(),
223                        skip_native: attrs.skip_native,
224                    };
225
226                    native_fields.push(native_field.clone());
227                    native_field_map.push(NativeField2OriginFieldMap {
228                        origin_field: field.clone(),
229                        native_field: Some(native_field.clone()),
230                    });
231                    native_field_map.push(NativeField2OriginFieldMap {
232                        origin_field: other_field.clone(),
233                        native_field: Some(native_field),
234                    });
235                }
236                _ => {
237                    let (ty, pure_ty) = match attrs.type_name.as_str() {
238                        "str" => (quote!(Option<String>), quote!(String)),
239                        "compact_str" => (
240                            quote!(Option<compact_str::CompactString>), 
241                            quote!(compact_str::CompactString)
242                        ),
243                        // "hipstr" => (
244                        //     quote!(hipstr::HipStr<'borrow>),
245                        //     quote!(hipstr::HipStr<'borrow>)
246                        // ),
247                        "bytes" => {
248                            let size = field.size;
249                            (quote!([u8; #size]), quote!([u8; #size]))
250                        }
251                        "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" => {
252                            let type_ident = quote::format_ident!("{}", attrs.type_name);
253                            (quote!(Option<#type_ident>), quote!(#type_ident))
254                        }
255                        "decimal" => (
256                            quote!(Option<rust_decimal::Decimal>),
257                            quote!(rust_decimal::Decimal),
258                        ),
259                        "datetime" => (
260                            quote!(Option<chrono::NaiveDateTime>),
261                            quote!(chrono::NaiveDateTime),
262                        ),
263                        "date" => (quote!(Option<chrono::NaiveDate>), quote!(chrono::NaiveDate)),
264                        "time" => (quote!(Option<chrono::NaiveTime>), quote!(chrono::NaiveTime)),
265                        "enum" => {
266                            let enum_type = attrs.enum_type.as_ref();
267                            match enum_type {
268                                Some(enum_type) => {
269                                    let enum_ident = quote::format_ident!("{}", enum_type);
270                                    (quote!(Option<#enum_ident>), quote!(#enum_ident))
271                                }
272                                None => panic!("enum_type is required for enum field"),
273                            }
274                        }
275                        _ => continue,
276                    };
277                    let native_field = NativeField {
278                        name: field_name,
279                        ty,
280                        type_name: attrs.type_name.clone(),
281                        pure_ty,
282                        origin_fields: vec![field.clone()],
283                        is_combined_datetime: false,
284                        default_func: attrs.default_func.clone(),
285                        skip_native: attrs.skip_native,
286                    };
287                    if !attrs.skip {
288                        native_fields.push(native_field.clone());
289                        native_field_map.push(NativeField2OriginFieldMap {
290                            origin_field: field.clone(),
291                            native_field: Some(native_field),
292                        });
293                    } else {
294                        native_field_map.push(NativeField2OriginFieldMap {
295                            origin_field: field.clone(),
296                            native_field: None,
297                        });
298                    }
299                }
300            }
301        } else {
302            native_field_map.push(NativeField2OriginFieldMap {
303                origin_field: field.clone(),
304                native_field: None,
305            });
306        }
307    }
308
309    (native_fields, native_field_map)
310}
311
312fn get_debug_fields(origin_fields: &[OriginField]) -> Vec<proc_macro2::TokenStream> {
313    origin_fields
314        .iter()
315        .map(|field| {
316            let field_name = &field.name;
317            quote! {
318                .field(
319                    stringify!(#field_name),
320                    &format_args!("hex: [{}], bytes: \"{}\"",
321                        binary_mirror::to_hex_repr(&self.#field_name),
322                        binary_mirror::to_bytes_repr(&self.#field_name)
323                    )
324                )
325            }
326        })
327        .collect()
328}
329
330fn get_methods(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
331    native_fields
332        .iter()
333        .map(|field| {
334            let name = &field.name;
335            let origin_field = &field.origin_fields[0].name;
336
337            let method_with_warn_name = quote::format_ident!("{}_with_warn", name);
338
339            let debug_bytes = quote! {
340                tracing::warn!("Failed to parse {} in {:?}", stringify!(#name), self);
341            };
342
343            if field.is_combined_datetime {
344                let date_field = &field.origin_fields[0].name;
345                let time_field = &field.origin_fields[1].name;
346                let date_format = field.origin_fields[0]
347                    .attrs
348                    .as_ref()
349                    .and_then(|attrs| attrs.format.as_ref())
350                    .map(String::as_str)
351                    .unwrap_or("%Y%m%d");
352                let time_format = field.origin_fields[1]
353                    .attrs
354                    .as_ref()
355                    .and_then(|attrs| attrs.format.as_ref())
356                    .map(String::as_str)
357                    .unwrap_or("%H%M%S");
358
359                quote! {
360                    pub fn #name(&self) -> Option<chrono::NaiveDateTime> {
361                        let date = chrono::NaiveDate::parse_from_str(
362                            std::str::from_utf8(&self.#date_field.trim_ascii()).ok()?,
363                            #date_format
364                        ).ok()?;
365                        let time = chrono::NaiveTime::parse_from_str(
366                            std::str::from_utf8(&self.#time_field.trim_ascii()).ok()?,
367                            #time_format
368                        ).ok()?;
369                        Some(chrono::NaiveDateTime::new(date, time))
370                    }
371
372                    pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveDateTime> {
373                        match self.#name() {
374                            Some(dt) => Some(dt),
375                            None => {
376                                #debug_bytes
377                                return None;
378                            }
379                        }
380                    }
381                }
382            } else {
383                let attrs = field.origin_fields[0].attrs.as_ref().unwrap();
384                // let expect_str = format!("Failed to convert {} to string", name);
385                // let expect_lit = syn::LitStr::new(&expect_str, proc_macro2::Span::call_site());
386                // TODO string also return Option<String>
387                match attrs.type_name.as_str() {
388                    "str" => quote! {
389                        pub fn #name(&self) -> Option<String> {
390                            std::str::from_utf8(&self.#origin_field.trim_ascii()).ok().map(|s| s.to_string())
391                        }
392
393                        pub fn #method_with_warn_name(&self) -> Option<String> {
394                            match self.#name() {
395                                Some(s) => Some(s),
396                                None => {
397                                    #debug_bytes
398                                    return None;
399                                }
400                            }
401                        }
402                    },
403                    "compact_str" => {
404                        quote! {
405                            pub fn #name(&self) -> Option<compact_str::CompactString> {
406                                compact_str::CompactString::from_utf8(&self.#origin_field.trim_ascii()).ok()
407                            }
408
409                            pub fn #method_with_warn_name(&self) -> Option<compact_str::CompactString> {
410                                match self.#name() {
411                                    Some(s) => Some(s),
412                                    None => {
413                                        #debug_bytes
414                                        return None;
415                                    }
416                                }
417                            }
418                        }
419                    },
420                    // "hipstr" => {
421                    //     quote! {
422                    //         pub fn #name(&self) -> hipstr::HipStr {
423                    //             hipstr::HipStr::from_utf8_lossy(hipstr::HipByt::borrowed(&self.#origin_field.trim_ascii()))
424                    //         }
425
426                    //         pub fn #method_with_warn_name(&self) -> hipstr::HipStr {
427                    //             hipstr::HipStr::from_utf8_lossy(hipstr::HipByt::borrowed(&self.#origin_field.trim_ascii()))
428                    //         }
429                    //     }
430                    // },
431                    "bytes" => {
432                        let size = field.origin_fields[0].size;
433                        quote! {
434                            pub fn #name(&self) -> [u8; #size] {
435                                self.#origin_field
436                            }
437
438                            pub fn #method_with_warn_name(&self) -> [u8; #size] {
439                                self.#origin_field
440                            }
441                        }
442                    }
443                    "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" => {
444                        let type_ident = quote::format_ident!("{}", attrs.type_name);
445                        quote! {
446                            pub fn #name(&self) -> Option<#type_ident> {
447                                std::str::from_utf8(&self.#origin_field.trim_ascii())
448                                    .ok()?
449                                    .parse::<#type_ident>()
450                                    .ok()
451                            }
452
453                            pub fn #method_with_warn_name(&self) -> Option<#type_ident> {
454                                match self.#name() {
455                                    Some(val) => Some(val),
456                                    None => {
457                                        #debug_bytes
458                                        None
459                                    }
460                                }
461                            }
462                        }
463                    }
464                    "decimal" => quote! {
465                        pub fn #name(&self) -> Option<rust_decimal::Decimal> {
466                            std::str::from_utf8(&self.#origin_field.trim_ascii())
467                                .ok()?
468                                .parse::<rust_decimal::Decimal>()
469                                .ok()
470                                .map(|d| d.normalize())
471                        }
472                        pub fn #method_with_warn_name(&self) -> Option<rust_decimal::Decimal> {
473                            match self.#name() {
474                                Some(d) => Some(d),
475                                None => {
476                                    #debug_bytes
477                                    None
478                                }
479                            }
480                        }
481
482                    },
483                    "datetime" => {
484                        let format = attrs
485                            .format
486                            .as_ref()
487                            .map(String::as_str)
488                            .unwrap_or("%Y%m%d%H%M%S");
489                        quote! {
490                            pub fn #name(&self) -> Option<chrono::NaiveDateTime> {
491                                chrono::NaiveDateTime::parse_from_str(
492                                    std::str::from_utf8(&self.#origin_field.trim_ascii()).ok()?,
493                                    #format
494                                ).ok()
495                            }
496
497                            pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveDateTime> {
498                                match self.#name() {
499                                    Some(dt) => Some(dt),
500                                    None => {
501                                        #debug_bytes
502                                        None
503                                    }
504                                }
505                            }
506
507
508                        }
509                    }
510                    "date" => {
511                        let format = attrs
512                            .format
513                            .as_ref()
514                            .map(String::as_str)
515                            .unwrap_or("%Y%m%d");
516                        quote! {
517                            pub fn #name(&self) -> Option<chrono::NaiveDate> {
518                                chrono::NaiveDate::parse_from_str(
519                                    std::str::from_utf8(&self.#origin_field.trim_ascii()).ok()?,
520                                    #format
521                                )
522                                .ok()
523                            }
524                            pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveDate> {
525                                match self.#name() {
526                                    Some(d) => Some(d),
527                                    None => {
528                                        #debug_bytes
529                                        None
530                                    }
531                                }
532                            }
533                        }
534                    }
535                    "time" => {
536                        let format = attrs
537                            .format
538                            .as_ref()
539                            .map(String::as_str)
540                            .unwrap_or("%H%M%S");
541                        quote! {
542                            pub fn #name(&self) -> Option<chrono::NaiveTime> {
543                                chrono::NaiveTime::parse_from_str(
544                                    std::str::from_utf8(&self.#origin_field.trim_ascii()).ok()?,
545                                    #format
546                                )
547                                .ok()
548                            }
549                            pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveTime> {
550                                match self.#name() {
551                                    Some(t) => Some(t),
552                                    None => {
553                                        #debug_bytes
554                                        None
555                                    }
556                                }
557                            }
558                        }
559                    }
560                    "enum" => {
561                        let enum_type = attrs.enum_type.as_ref().unwrap();
562                        let enum_ident = quote::format_ident!("{}", enum_type);
563                        quote! {
564                            pub fn #name(&self) -> Option<#enum_ident> {
565                                #enum_ident::from_bytes(&self.#origin_field)
566                            }
567
568                            pub fn #method_with_warn_name(&self) -> Option<#enum_ident> {
569                                match self.#name() {
570                                    Some(v) => Some(v),
571                                    None => {
572                                        #debug_bytes
573                                        None
574                                    }
575                                }
576                            }
577
578                        }
579                    }
580                    _ => panic!("Unsupported type: {}", attrs.type_name),
581                }
582            }
583        })
584        .collect()
585}
586
587fn get_display_fields(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
588    native_fields
589        .iter()
590        .filter_map(|field| {
591            let name = &field.name;
592            let method_name = &field.name;
593            let attrs = &field.origin_fields[0].attrs.as_ref()?;
594            let origin_field = &field.origin_fields[0].name;
595
596            // Skip if marked with skip (except datetime fields)
597            if attrs.skip && !field.is_combined_datetime {
598                return None;
599            }
600
601            Some(match attrs.type_name.as_str() {
602                // | "hipstr" 
603                // "str" | "compact_str" => quote! {
604                //     write!(f, "{}: {}", stringify!(#name), self.#method_name())?;
605                // },
606                "str" | "compact_str" | "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "decimal"
607                | "datetime" | "date" | "time" => quote! {
608                    match self.#method_name() {
609                        Some(val) => write!(f, "{}: {}", stringify!(#name), val)?,
610                        None => write!(f, "{}: Error<bytes: \"{}\">",
611                            stringify!(#name),
612                            binary_mirror::to_bytes_repr(&self.#origin_field)
613                        )?,
614                    }
615                },
616                "enum" => quote! {
617                    match self.#method_name() {
618                        Some(val) => write!(f, "{}: {:?}", stringify!(#name), val)?,
619                        None => write!(f, "{}: Error<bytes: \"{}\">",
620                            stringify!(#name),
621                            binary_mirror::to_bytes_repr(&self.#origin_field)
622                        )?,
623                    }
624                },
625                _ => quote! {},
626            })
627        })
628        .collect()
629}
630
631fn get_native_fields_token(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
632    native_fields
633        .iter()
634        .filter(|field| !field.skip_native)
635        .map(|field| {
636            let name = &field.name;
637            let ty = &field.ty;
638
639            quote! {
640                pub #name: #ty
641            }
642        })
643        .collect()
644}
645
646fn get_to_native_fields(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
647    native_fields
648        .iter()
649        .filter(|field| !field.skip_native)
650        .map(|field| {
651            let name = &field.name;
652            let ignore_warn = field.origin_fields[0]
653                .attrs
654                .as_ref()
655                .map(|attrs| attrs.ignore_warn)
656                .unwrap_or(false);
657
658            if ignore_warn {
659                quote! { #name: self.#name() }
660            } else {
661                let method_name = quote::format_ident!("{}_with_warn", name);
662                quote! { #name: self.#method_name() }
663            }
664        })
665        .collect()
666}
667
668fn get_from_native_fields(
669    native_field_map: &[NativeField2OriginFieldMap],
670) -> Vec<proc_macro2::TokenStream> {
671    native_field_map.iter().map(|mapping| {
672        let field_name = &mapping.origin_field.name;
673        let size = mapping.origin_field.size;
674        let default_byte = mapping.origin_field.attrs
675            .as_ref()
676            .and_then(|attrs| attrs.default_byte)
677            .unwrap_or(b' ');
678        
679
680        if let Some(native_field) = &mapping.native_field {
681            let native_name = &native_field.name;
682            let attrs = mapping.origin_field.attrs.as_ref().unwrap();
683            let format = attrs.format.as_ref().map(String::as_str);
684            let skip_native = native_field.skip_native;
685            if skip_native {
686                return quote! {
687                    #field_name: [#default_byte; #size]
688                };
689            }
690            match attrs.type_name.as_str() {
691                // | "hipstr" 
692                "str" | "compact_str" => quote! {
693                    #field_name: {
694                        let mut bytes = [#default_byte; #size];  // Use default_byte here
695                        if let Some(s) = &native.#native_name {
696                            let s = s.as_bytes();
697                            bytes[..s.len().min(#size)].copy_from_slice(&s[..s.len().min(#size)]);
698                        }
699                        bytes
700                    }
701                },
702                "enum" => quote! {
703                    #field_name: {
704                        let mut bytes = [#default_byte; #size];
705                        if let Some(enum_val) = &native.#native_name {
706                            let s = enum_val.as_bytes();
707                            bytes[..s.len().min(#size)].copy_from_slice(&s[..s.len().min(#size)]);
708                        }
709                        bytes
710                    }
711                },
712                "datetime" => {
713                    let format = attrs.format.as_ref()
714                        .map(String::as_str)
715                        .unwrap_or("%Y-%m-%d %H:%M:%S");
716                    quote! {
717                        #field_name: {
718                            let mut bytes = [#default_byte; #size];
719                            if let Some(dt) = native.#native_name {
720                                let s = dt.format(#format).to_string();
721                                let b = s.as_bytes();
722                                bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
723                            }
724                            bytes
725                        }
726                    }
727                }
728                "date" => {
729                    let format = attrs.format.as_ref()
730                        .map(String::as_str)
731                        .unwrap_or("%Y-%m-%d");
732                    quote! {
733                        #field_name: {
734                            let mut bytes = [#default_byte; #size];
735                            if let Some(dt) = native.#native_name {
736                                let s = dt.format(#format).to_string();
737                                let b = s.as_bytes();
738                                bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
739                            }
740                            bytes
741                        }
742                    }
743                },
744                "time" => {
745                    let format = attrs.format.as_ref()
746                        .map(String::as_str)
747                        .unwrap_or("%H%M%S");
748                    quote! {
749                        #field_name: {
750                            let mut bytes = [#default_byte; #size];
751                            if let Some(dt) = native.#native_name {
752                                let s = dt.format(#format).to_string();
753                                let b = s.as_bytes();
754                                bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
755                            }
756                            bytes
757                        }
758                    }
759                },
760                "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "decimal" => {
761                    if let Some(fmt) = format {
762                        quote! {
763                            #field_name: {
764                                let mut bytes = [#default_byte; #size];
765                                if let Some(val) = &native.#native_name {
766                                    let s = format!(#fmt, val);
767                                    let b = s.as_bytes();
768                                    bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
769                                }
770                                bytes
771                            }
772                        }
773                    } else {
774                        quote! {
775                            #field_name: {
776                                let mut bytes = [#default_byte; #size];
777                                if let Some(val) = &native.#native_name {
778                                    let s = val.to_string();
779                                    let b = s.as_bytes();
780                                    bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
781                                }
782                                bytes
783                            }
784                        }
785                    }
786                },
787                "bytes" => quote! {
788                    #field_name: native.#native_name
789                },
790                _ => quote! {
791                    #field_name: {
792                        let mut bytes = [#default_byte; #size];
793                        if let Some(val) = &native.#native_name {
794                            let s = val.to_string();
795                            let b = s.as_bytes();
796                            bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
797                        }
798                        bytes
799                    }
800                }
801            }
802        } else {
803            // Field without attributes, use default byte
804            quote! {
805                #field_name: [#default_byte; #size]
806            }
807        }
808    }).collect()
809}
810
811fn get_native_methods(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
812    native_fields
813        .iter()
814        .filter(|field| !field.skip_native)
815        .map(|field| {
816            let name = &field.name;
817            let method_name = quote::format_ident!("with_{}", name);
818            let ty = &field.pure_ty;
819            let type_name = &field.type_name;
820
821            match type_name.as_str() {
822                "str" => quote! {
823                    pub fn #method_name(mut self, value: impl Into<String>) -> Self {
824                        self.#name = Some(value.into());
825                        self
826                    }
827                },
828                "compact_str" => quote! {
829                    pub fn #method_name(mut self, value: impl Into<compact_str::CompactString>) -> Self {
830                        self.#name = Some(value.into());
831                        self
832                    }
833                },
834                // "hipstr" => quote! {
835                //     pub fn #method_name<'borrow>(mut self, value: impl Into<hipstr::HipStr<'borrow>>) -> Self {
836                //         self.#name = value.into();
837                //         self
838                //     }
839                // },
840                "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "decimal"
841                | "datetime" | "date" | "time" | "enum" => {
842                    quote! {
843                        pub fn #method_name(mut self, value: #ty) -> Self {
844                            self.#name = Some(value);
845                            self
846                        }
847                    }
848                }
849                _ => quote! {
850                    pub fn #method_name(mut self, value: #ty) -> Self {
851                        self.#name = value;
852                        self
853                    }
854                },
855            }
856        })
857        .collect()
858}
859
860fn get_field_spec_methods(origin_fields: &[OriginField]) -> proc_macro2::TokenStream {
861    let mut cumulative_size = 0;
862    let size_methods = origin_fields.iter().map(|field| {
863        let field_name = &field.name;
864        let field_size = field.size;
865        let offset = cumulative_size;
866        let limit = offset + field_size;
867        cumulative_size = limit;
868        let method_name = quote::format_ident!("{}_spec", field_name);
869
870        quote! {
871            pub fn #method_name() -> binary_mirror::FieldSpec {
872                binary_mirror::FieldSpec {
873                    offset: #offset,
874                    limit: #limit,
875                    size: #field_size,
876                }
877            }
878        }
879    });
880
881    quote! {
882        #(#size_methods)*
883    }
884}
885
886fn get_native_default_impl(
887    native_fields: &[NativeField],
888    native_name: &proc_macro2::Ident,
889) -> proc_macro2::TokenStream {
890    let default_fields = native_fields.iter().filter(|field| !field.skip_native).map(|field| {
891        let name = &field.name;
892        if let Some(default) = &field.default_func {
893            // let default_quote = quote! { #default };
894            let default_quote = quote::format_ident!("{}", default.as_str());
895            match field.type_name.as_str() {
896                // "hipstr" => quote! {
897                //     #name: Some(#default_quote())
898                // },
899                "str"| "compact_str" | "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "datetime"
900                | "date" | "time" | "enum" | "decimal" => {
901                    quote! {
902                        #name: Some(#default_quote())
903                    }
904                }
905                _ => quote! {
906                    #name: Default::default()
907                },
908            }
909        } else {
910            quote! {
911                #name: Default::default()
912            }
913        }
914    });
915
916    quote! {
917        impl Default for #native_name {
918            fn default() -> Self {
919                Self {
920                    #(#default_fields,)*
921                }
922            }
923        }
924    }
925}
926
927fn get_native_to_raw_impl(
928    name: &syn::Ident,
929    native_name: &proc_macro2::Ident,
930) -> proc_macro2::TokenStream {
931    quote! {
932        impl #native_name {
933            pub fn to_raw(&self) -> #name {
934                #name::from_native(self)
935            }
936        }
937    }
938}
939
940fn get_native_struct_code(
941    name: &syn::Ident,
942    native_fields: &[NativeField],
943) -> proc_macro2::TokenStream {
944    let native_name = quote::format_ident!("{}Native", name);
945    let fields_code = native_fields
946        .iter()
947        .filter(|field| !field.skip_native)
948        .map(|field| {
949            let name = &field.name;
950            let ty = &field.ty;
951            // Convert TokenStream to string and normalize whitespace
952            let ty_str = ty
953                .to_string()
954                .replace(" :: ", "::")
955                .replace(" < ", "<")
956                .replace(" > ", ">")
957                .replace(" >", ">");
958            format!("    pub {}: {},", name, ty_str)
959        })
960        .collect::<Vec<_>>()
961        .join("\n");
962
963    quote! {
964        impl binary_mirror::NativeStructCode for #name {
965            fn native_struct_code() -> String {
966                format!(
967                    "pub struct {} {{\n{}\n}}",
968                    stringify!(#native_name),
969                    #fields_code
970                )
971            }
972        }
973    }
974}
975
976fn get_native_derives(struct_attrs: &StructAttrs) -> proc_macro2::TokenStream {
977    if struct_attrs.derives.is_empty() {
978        quote!(Debug, PartialEq, Serialize, Deserialize)
979    } else {
980        let native_derives = struct_attrs
981            .derives
982            .iter()
983            .map(|derive| quote!(#derive))
984            .collect::<Vec<_>>();
985        quote!(#(#native_derives),*)
986    }
987}
988
989fn impl_binary_mirror(input: &DeriveInput) -> TokenStream {
990    let name = &input.ident;
991    let native_name = quote::format_ident!("{}Native", name);
992    let struct_attrs = get_struct_attrs(input);
993
994    let origin_fields = get_origin_fields(input);
995    let (native_fields, native_field_map) = get_native_fields_and_map(&origin_fields);
996    let debug_fields_token = get_debug_fields(&origin_fields);
997    let display_fields_token = get_display_fields(&native_fields);
998    let methods = get_methods(&native_fields);
999    let native_fields_token = get_native_fields_token(&native_fields);
1000    let to_native_fields_token = get_to_native_fields(&native_fields);
1001    let from_native_fields_token = get_from_native_fields(&native_field_map);
1002    let native_methods = get_native_methods(&native_fields);
1003    let field_spec_methods = get_field_spec_methods(&origin_fields);
1004    let native_default_impl = get_native_default_impl(&native_fields, &native_name);
1005    let native_to_raw_impl = get_native_to_raw_impl(name, &native_name);
1006    let native_derives = get_native_derives(&struct_attrs);
1007    let native_struct_code = get_native_struct_code(name, &native_fields);
1008
1009    let gen = quote! {
1010        impl #name {
1011            #(#methods)*
1012            /// Get the size of the struct in bytes
1013            pub const fn size() -> usize {
1014                std::mem::size_of::<Self>()
1015            }
1016            #field_spec_methods
1017        }
1018
1019        #[derive(#native_derives)]
1020        pub struct #native_name {
1021            #(#native_fields_token,)*
1022        }
1023
1024        impl #native_name {
1025            #(#native_methods)*
1026        }
1027
1028        impl std::fmt::Debug for #name {
1029            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1030                f.debug_struct(stringify!(#name))
1031                    #(#debug_fields_token)*
1032                    .finish()
1033            }
1034        }
1035
1036        impl std::fmt::Display for #name {
1037            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1038                write!(f, "{} {{ ", stringify!(#name))?;
1039                let mut first = true;
1040                #(
1041                    if first {
1042                        first = false;
1043                    } else {
1044                        write!(f, ", ")?;
1045                    }
1046                    #display_fields_token
1047                )*
1048                write!(f, " }}")
1049            }
1050        }
1051
1052        #native_default_impl
1053        #native_to_raw_impl
1054        #native_struct_code
1055
1056        impl binary_mirror::FromBytes for #name {
1057            const SIZE: usize = std::mem::size_of::<Self>();
1058
1059            fn from_bytes(bytes: &[u8]) -> Result<&Self, binary_mirror::BytesSizeError> {
1060                let expected = Self::SIZE;
1061                let actual = bytes.len();
1062                if actual != expected {
1063                    return Err(binary_mirror::BytesSizeError::new(
1064                        expected,
1065                        actual,
1066                        bytes.iter()
1067                        .map(|&b| {
1068                            match b {
1069                                0x0A => "\\n".to_string(),
1070                                0x0D => "\\r".to_string(),
1071                                0x09 => "\\t".to_string(),
1072                                0x20..=0x7E => (b as char).to_string(),
1073                                _ => format!("\\x{:02x}", b),
1074                            }
1075                        })
1076                        .collect::<Vec<String>>()
1077                        .join("")
1078                    ));
1079                }
1080                // Safety:
1081                // 1. We've verified the size matches
1082                // 2. The struct is #[repr(C)]
1083                // 3. The alignment is handled by the compiler
1084                Ok(unsafe { &*(bytes.as_ptr() as *const Self) })
1085            }
1086
1087        }
1088
1089        impl binary_mirror::ToBytes for #name {
1090            fn to_bytes(&self) -> &[u8] {
1091                // Safety:
1092                // 1. The struct is #[repr(C)]
1093                // 2. We're reading the exact size of the struct
1094                // 3. All fields are byte arrays
1095                // 4. The returned slice lifetime is tied to self
1096                unsafe {
1097                    std::slice::from_raw_parts(
1098                        (self as *const Self) as *const u8,
1099                        Self::size()
1100                    )
1101                }
1102            }
1103
1104            fn to_bytes_owned(&self) -> Vec<u8> {
1105                self.to_bytes().to_vec()
1106            }
1107        }
1108
1109        impl binary_mirror::ToNative for #name {
1110            type Native = #native_name;
1111
1112            fn to_native(&self) -> Self::Native {
1113                #native_name {
1114                    #(#to_native_fields_token,)*
1115                }
1116            }
1117        }
1118
1119        impl binary_mirror::FromNative<#native_name> for #name {
1120            fn from_native(native: &#native_name) -> Self {
1121                Self {
1122                    #(#from_native_fields_token,)*
1123                }
1124            }
1125        }
1126    };
1127
1128    gen.into()
1129}
1130
1131fn get_variant_value(attrs: &[syn::Attribute]) -> Option<Vec<u8>> {
1132    for attr in attrs {
1133        if attr.path().is_ident("bv") {
1134            let mut byte_value = None;
1135            let _ = attr.parse_nested_meta(|meta| {
1136                if meta.path.is_ident("value") {
1137                    let lit = meta.value()?.parse::<syn::LitByteStr>()?;
1138                    byte_value = Some(lit.value().to_vec());
1139                }
1140                Ok(())
1141            });
1142            return byte_value;
1143        }
1144    }
1145    None
1146}
1147
1148fn impl_binary_enum(input: &DeriveInput) -> TokenStream {
1149    let name = &input.ident;
1150
1151    let variants = match &input.data {
1152        Data::Enum(data) => &data.variants,
1153        _ => panic!("BinaryEnum can only be derived for enums"),
1154    };
1155
1156    let match_arms_from = variants.iter().map(|variant| {
1157        let variant_ident = &variant.ident;
1158        let byte_value = get_variant_value(&variant.attrs).unwrap_or_else(|| {
1159            let variant_str = variant_ident.to_string().to_uppercase();
1160            vec![variant_str.chars().next().unwrap() as u8]
1161        });
1162        let byte_len = byte_value.len();
1163
1164        quote! {
1165            if bytes.len() >= #byte_len && &bytes[..#byte_len] == &[#(#byte_value),*] {
1166                Some(Self::#variant_ident)
1167            } else
1168        }
1169    });
1170
1171    let match_arms_to = variants.iter().map(|variant| {
1172        let variant_ident = &variant.ident;
1173        let byte_value = get_variant_value(&variant.attrs).unwrap_or_else(|| {
1174            let variant_str = variant_ident.to_string().to_uppercase();
1175            vec![variant_str.chars().next().unwrap() as u8]
1176        });
1177
1178        quote! {
1179            Self::#variant_ident => &[#(#byte_value),*],
1180        }
1181    });
1182
1183    let gen = quote! {
1184        impl #name {
1185            pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
1186                #(#match_arms_from)* {
1187                    None
1188                }
1189            }
1190
1191            pub fn as_bytes(&self) -> &'static [u8] {
1192                match self {
1193                    #(#match_arms_to)*
1194                }
1195            }
1196        }
1197    };
1198
1199    gen.into()
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204    #[test]
1205    fn test_basic_derive() {
1206        // Tests will go here
1207    }
1208}