1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, LitStr};
4
5#[proc_macro_derive(BinaryMirror, attributes(bm))]
6pub fn binary_mirror_derive(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 impl_binary_mirror(&input)
9}
10
11#[proc_macro_derive(BinaryEnum, attributes(bv))]
12pub fn binary_enum_derive(input: TokenStream) -> TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14 impl_binary_enum(&input)
15}
16
17#[derive(Debug, Clone)]
18struct FieldAttrs {
19 type_name: String,
20 alias: Option<String>,
21 format: Option<String>,
22 datetime_with: Option<String>,
23 skip: bool,
24 skip_native: bool,
25 enum_type: Option<String>,
26 default_byte: Option<u8>,
27 ignore_warn: bool,
28 default_func: Option<String>,
29}
30
31#[derive(Debug, Clone)]
32struct OriginField {
33 name: syn::Ident,
34 size: usize,
35 attrs: Option<FieldAttrs>,
36}
37
38#[derive(Debug, Clone)]
39struct NativeField {
40 name: syn::Ident,
41 ty: proc_macro2::TokenStream,
42 type_name: String,
43 pure_ty: proc_macro2::TokenStream,
44 origin_fields: Vec<OriginField>,
45 is_combined_datetime: bool,
46 default_func: Option<String>,
47 skip_native: bool,
48}
49
50#[derive(Debug)]
51struct NativeField2OriginFieldMap {
52 origin_field: OriginField,
53 native_field: Option<NativeField>,
54}
55
56#[derive(Debug, Clone)]
57struct StructAttrs {
58 derives: Vec<syn::Path>,
59}
60
61fn get_struct_attrs(input: &DeriveInput) -> StructAttrs {
62 let attrs = &input.attrs;
63 let mut struct_attrs = StructAttrs { derives: vec![] };
64 for attr in attrs {
65 if attr.path().is_ident("bm") {
66 let _ = attr.parse_nested_meta(|meta| {
67 if meta.path.is_ident("derive") {
68 let content;
69 syn::parenthesized!(content in meta.input);
70 let derives: syn::punctuated::Punctuated<syn::Path, syn::Token![,]> = content
71 .parse_terminated(syn::parse::Parse::parse, syn::Token![,])
72 .expect("derive");
73 struct_attrs.derives = derives.into_iter().collect();
74 }
75 Ok(())
76 });
77 }
78 }
79 struct_attrs
80}
81
82fn get_field_attrs(attrs: &[syn::Attribute]) -> Option<FieldAttrs> {
83 for attr in attrs {
84 if attr.path().is_ident("bm") {
85 let mut field_attrs = FieldAttrs {
86 type_name: String::new(),
87 alias: None,
88 format: None,
89 datetime_with: None,
90 skip: false,
91 skip_native: false,
92 enum_type: None,
93 default_byte: None,
94 ignore_warn: false,
95 default_func: None,
96 };
97
98 let _ = attr.parse_nested_meta(|meta| {
99 if meta.path.is_ident("type") {
100 let lit = meta.value()?.parse::<LitStr>()?;
101 field_attrs.type_name = lit.value();
102 } else if meta.path.is_ident("alias") {
103 let lit = meta.value()?.parse::<LitStr>()?;
104 field_attrs.alias = Some(lit.value());
105 } else if meta.path.is_ident("format") {
106 let lit = meta.value()?.parse::<LitStr>()?;
107 field_attrs.format = Some(lit.value());
108 } else if meta.path.is_ident("datetime_with") {
109 let lit = meta.value()?.parse::<LitStr>()?;
110 field_attrs.datetime_with = Some(lit.value());
111 } else if meta.path.is_ident("skip") {
112 field_attrs.skip = meta.value()?.parse::<syn::LitBool>()?.value();
113 } else if meta.path.is_ident("skip_native") {
114 field_attrs.skip_native = meta.value()?.parse::<syn::LitBool>()?.value();
115 } else if meta.path.is_ident("enum_type") {
116 let lit = meta.value()?.parse::<LitStr>()?;
117 field_attrs.enum_type = Some(lit.value());
118 } else if meta.path.is_ident("default_byte") {
119 let lit = meta.value()?.parse::<syn::LitByte>()?;
120 field_attrs.default_byte = Some(lit.value());
121 } else if meta.path.is_ident("ignore_warn") {
122 field_attrs.ignore_warn = meta.value()?.parse::<syn::LitBool>()?.value();
123 } else if meta.path.is_ident("default_func") {
124 let lit = meta.value()?.parse::<syn::LitStr>()?;
125 field_attrs.default_func = Some(lit.value());
126 }
127 Ok(())
128 });
129
130 if !field_attrs.type_name.is_empty() {
131 return Some(field_attrs);
132 }
133 }
134 }
135 None
136}
137
138fn get_origin_fields(input: &DeriveInput) -> Vec<OriginField> {
139 let fields = match &input.data {
140 Data::Struct(data) => match &data.fields {
141 Fields::Named(fields) => &fields.named,
142 _ => panic!("Only named fields are supported"),
143 },
144 _ => panic!("Only structs are supported"),
145 };
146
147 fields
148 .iter()
149 .map(|field| {
150 let name = field.ident.clone().unwrap();
151
152 let size = if let syn::Type::Array(array) = &field.ty {
154 if let syn::Expr::Lit(syn::ExprLit {
155 lit: syn::Lit::Int(ref lit_int),
156 ..
157 }) = array.len
158 {
159 lit_int
160 .base10_parse::<usize>()
161 .expect("Could not parse array length")
162 } else {
163 panic!("Field {} array length must be a literal integer", name);
164 }
165 } else {
166 panic!("Field {} must be a [u8] array", name);
167 };
168
169 OriginField {
170 name,
171 size,
172 attrs: get_field_attrs(&field.attrs),
173 }
174 })
175 .collect()
176}
177
178fn get_native_fields_and_map(origin_fields: &[OriginField]) -> (Vec<NativeField>, Vec<NativeField2OriginFieldMap>) {
179 let mut native_fields = Vec::new();
180 let mut native_field_map = Vec::new();
181 let mut processed = std::collections::HashSet::new();
182
183 for field in origin_fields {
184 if let Some(attrs) = &field.attrs {
185 if processed.contains(&field.name.to_string()) {
187 continue;
188 }
189
190 let field_name = if let Some(alias) = &attrs.alias {
191 quote::format_ident!("{}", alias)
192 } else {
193 field.name.clone()
194 };
195
196 match attrs.type_name.as_str() {
197 "date" | "time" if attrs.datetime_with.is_some() => {
198 let other_field_name = attrs.datetime_with.as_ref().unwrap();
199 let other_field = origin_fields
200 .iter()
201 .find(|f| f.name == quote::format_ident!("{}", other_field_name))
202 .expect("Could not find datetime pair field");
203
204 processed.insert(field.name.to_string());
206 processed.insert(other_field.name.to_string());
207
208 let (date_field, time_field) = if attrs.type_name == "date" {
210 (field, other_field)
211 } else {
212 (other_field, field)
213 };
214
215 let native_field = NativeField {
216 name: field_name,
217 ty: quote!(Option<chrono::NaiveDateTime>),
218 type_name: "datetime".to_string(),
219 pure_ty: quote!(chrono::NaiveDateTime),
220 origin_fields: vec![date_field.clone(), time_field.clone()],
221 is_combined_datetime: true,
222 default_func: attrs.default_func.clone(),
223 skip_native: attrs.skip_native,
224 };
225
226 native_fields.push(native_field.clone());
227 native_field_map.push(NativeField2OriginFieldMap {
228 origin_field: field.clone(),
229 native_field: Some(native_field.clone()),
230 });
231 native_field_map.push(NativeField2OriginFieldMap {
232 origin_field: other_field.clone(),
233 native_field: Some(native_field),
234 });
235 }
236 _ => {
237 let (ty, pure_ty) = match attrs.type_name.as_str() {
238 "str" => (quote!(Option<String>), quote!(String)),
239 "compact_str" => (
240 quote!(Option<compact_str::CompactString>),
241 quote!(compact_str::CompactString)
242 ),
243 "bytes" => {
248 let size = field.size;
249 (quote!([u8; #size]), quote!([u8; #size]))
250 }
251 "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" => {
252 let type_ident = quote::format_ident!("{}", attrs.type_name);
253 (quote!(Option<#type_ident>), quote!(#type_ident))
254 }
255 "decimal" => (
256 quote!(Option<rust_decimal::Decimal>),
257 quote!(rust_decimal::Decimal),
258 ),
259 "datetime" => (
260 quote!(Option<chrono::NaiveDateTime>),
261 quote!(chrono::NaiveDateTime),
262 ),
263 "date" => (quote!(Option<chrono::NaiveDate>), quote!(chrono::NaiveDate)),
264 "time" => (quote!(Option<chrono::NaiveTime>), quote!(chrono::NaiveTime)),
265 "enum" => {
266 let enum_type = attrs.enum_type.as_ref();
267 match enum_type {
268 Some(enum_type) => {
269 let enum_ident = quote::format_ident!("{}", enum_type);
270 (quote!(Option<#enum_ident>), quote!(#enum_ident))
271 }
272 None => panic!("enum_type is required for enum field"),
273 }
274 }
275 _ => continue,
276 };
277 let native_field = NativeField {
278 name: field_name,
279 ty,
280 type_name: attrs.type_name.clone(),
281 pure_ty,
282 origin_fields: vec![field.clone()],
283 is_combined_datetime: false,
284 default_func: attrs.default_func.clone(),
285 skip_native: attrs.skip_native,
286 };
287 if !attrs.skip {
288 native_fields.push(native_field.clone());
289 native_field_map.push(NativeField2OriginFieldMap {
290 origin_field: field.clone(),
291 native_field: Some(native_field),
292 });
293 } else {
294 native_field_map.push(NativeField2OriginFieldMap {
295 origin_field: field.clone(),
296 native_field: None,
297 });
298 }
299 }
300 }
301 } else {
302 native_field_map.push(NativeField2OriginFieldMap {
303 origin_field: field.clone(),
304 native_field: None,
305 });
306 }
307 }
308
309 (native_fields, native_field_map)
310}
311
312fn get_debug_fields(origin_fields: &[OriginField]) -> Vec<proc_macro2::TokenStream> {
313 origin_fields
314 .iter()
315 .map(|field| {
316 let field_name = &field.name;
317 quote! {
318 .field(
319 stringify!(#field_name),
320 &format_args!("hex: [{}], bytes: \"{}\"",
321 binary_mirror::to_hex_repr(&self.#field_name),
322 binary_mirror::to_bytes_repr(&self.#field_name)
323 )
324 )
325 }
326 })
327 .collect()
328}
329
330fn get_methods(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
331 native_fields
332 .iter()
333 .map(|field| {
334 let name = &field.name;
335 let origin_field = &field.origin_fields[0].name;
336
337 let method_with_warn_name = quote::format_ident!("{}_with_warn", name);
338
339 let debug_bytes = quote! {
340 tracing::warn!("Failed to parse {} in {:?}", stringify!(#name), self);
341 };
342
343 if field.is_combined_datetime {
344 let date_field = &field.origin_fields[0].name;
345 let time_field = &field.origin_fields[1].name;
346 let date_format = field.origin_fields[0]
347 .attrs
348 .as_ref()
349 .and_then(|attrs| attrs.format.as_ref())
350 .map(String::as_str)
351 .unwrap_or("%Y%m%d");
352 let time_format = field.origin_fields[1]
353 .attrs
354 .as_ref()
355 .and_then(|attrs| attrs.format.as_ref())
356 .map(String::as_str)
357 .unwrap_or("%H%M%S");
358
359 quote! {
360 pub fn #name(&self) -> Option<chrono::NaiveDateTime> {
361 let date = chrono::NaiveDate::parse_from_str(
362 std::str::from_utf8(&self.#date_field.trim_ascii()).ok()?,
363 #date_format
364 ).ok()?;
365 let time = chrono::NaiveTime::parse_from_str(
366 std::str::from_utf8(&self.#time_field.trim_ascii()).ok()?,
367 #time_format
368 ).ok()?;
369 Some(chrono::NaiveDateTime::new(date, time))
370 }
371
372 pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveDateTime> {
373 match self.#name() {
374 Some(dt) => Some(dt),
375 None => {
376 #debug_bytes
377 return None;
378 }
379 }
380 }
381 }
382 } else {
383 let attrs = field.origin_fields[0].attrs.as_ref().unwrap();
384 match attrs.type_name.as_str() {
388 "str" => quote! {
389 pub fn #name(&self) -> Option<String> {
390 std::str::from_utf8(&self.#origin_field.trim_ascii()).ok().map(|s| s.to_string())
391 }
392
393 pub fn #method_with_warn_name(&self) -> Option<String> {
394 match self.#name() {
395 Some(s) => Some(s),
396 None => {
397 #debug_bytes
398 return None;
399 }
400 }
401 }
402 },
403 "compact_str" => {
404 quote! {
405 pub fn #name(&self) -> Option<compact_str::CompactString> {
406 compact_str::CompactString::from_utf8(&self.#origin_field.trim_ascii()).ok()
407 }
408
409 pub fn #method_with_warn_name(&self) -> Option<compact_str::CompactString> {
410 match self.#name() {
411 Some(s) => Some(s),
412 None => {
413 #debug_bytes
414 return None;
415 }
416 }
417 }
418 }
419 },
420 "bytes" => {
432 let size = field.origin_fields[0].size;
433 quote! {
434 pub fn #name(&self) -> [u8; #size] {
435 self.#origin_field
436 }
437
438 pub fn #method_with_warn_name(&self) -> [u8; #size] {
439 self.#origin_field
440 }
441 }
442 }
443 "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" => {
444 let type_ident = quote::format_ident!("{}", attrs.type_name);
445 quote! {
446 pub fn #name(&self) -> Option<#type_ident> {
447 std::str::from_utf8(&self.#origin_field.trim_ascii())
448 .ok()?
449 .parse::<#type_ident>()
450 .ok()
451 }
452
453 pub fn #method_with_warn_name(&self) -> Option<#type_ident> {
454 match self.#name() {
455 Some(val) => Some(val),
456 None => {
457 #debug_bytes
458 None
459 }
460 }
461 }
462 }
463 }
464 "decimal" => quote! {
465 pub fn #name(&self) -> Option<rust_decimal::Decimal> {
466 std::str::from_utf8(&self.#origin_field.trim_ascii())
467 .ok()?
468 .parse::<rust_decimal::Decimal>()
469 .ok()
470 .map(|d| d.normalize())
471 }
472 pub fn #method_with_warn_name(&self) -> Option<rust_decimal::Decimal> {
473 match self.#name() {
474 Some(d) => Some(d),
475 None => {
476 #debug_bytes
477 None
478 }
479 }
480 }
481
482 },
483 "datetime" => {
484 let format = attrs
485 .format
486 .as_ref()
487 .map(String::as_str)
488 .unwrap_or("%Y%m%d%H%M%S");
489 quote! {
490 pub fn #name(&self) -> Option<chrono::NaiveDateTime> {
491 chrono::NaiveDateTime::parse_from_str(
492 std::str::from_utf8(&self.#origin_field.trim_ascii()).ok()?,
493 #format
494 ).ok()
495 }
496
497 pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveDateTime> {
498 match self.#name() {
499 Some(dt) => Some(dt),
500 None => {
501 #debug_bytes
502 None
503 }
504 }
505 }
506
507
508 }
509 }
510 "date" => {
511 let format = attrs
512 .format
513 .as_ref()
514 .map(String::as_str)
515 .unwrap_or("%Y%m%d");
516 quote! {
517 pub fn #name(&self) -> Option<chrono::NaiveDate> {
518 chrono::NaiveDate::parse_from_str(
519 std::str::from_utf8(&self.#origin_field.trim_ascii()).ok()?,
520 #format
521 )
522 .ok()
523 }
524 pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveDate> {
525 match self.#name() {
526 Some(d) => Some(d),
527 None => {
528 #debug_bytes
529 None
530 }
531 }
532 }
533 }
534 }
535 "time" => {
536 let format = attrs
537 .format
538 .as_ref()
539 .map(String::as_str)
540 .unwrap_or("%H%M%S");
541 quote! {
542 pub fn #name(&self) -> Option<chrono::NaiveTime> {
543 chrono::NaiveTime::parse_from_str(
544 std::str::from_utf8(&self.#origin_field.trim_ascii()).ok()?,
545 #format
546 )
547 .ok()
548 }
549 pub fn #method_with_warn_name(&self) -> Option<chrono::NaiveTime> {
550 match self.#name() {
551 Some(t) => Some(t),
552 None => {
553 #debug_bytes
554 None
555 }
556 }
557 }
558 }
559 }
560 "enum" => {
561 let enum_type = attrs.enum_type.as_ref().unwrap();
562 let enum_ident = quote::format_ident!("{}", enum_type);
563 quote! {
564 pub fn #name(&self) -> Option<#enum_ident> {
565 #enum_ident::from_bytes(&self.#origin_field)
566 }
567
568 pub fn #method_with_warn_name(&self) -> Option<#enum_ident> {
569 match self.#name() {
570 Some(v) => Some(v),
571 None => {
572 #debug_bytes
573 None
574 }
575 }
576 }
577
578 }
579 }
580 _ => panic!("Unsupported type: {}", attrs.type_name),
581 }
582 }
583 })
584 .collect()
585}
586
587fn get_display_fields(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
588 native_fields
589 .iter()
590 .filter_map(|field| {
591 let name = &field.name;
592 let method_name = &field.name;
593 let attrs = &field.origin_fields[0].attrs.as_ref()?;
594 let origin_field = &field.origin_fields[0].name;
595
596 if attrs.skip && !field.is_combined_datetime {
598 return None;
599 }
600
601 Some(match attrs.type_name.as_str() {
602 "str" | "compact_str" | "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "decimal"
607 | "datetime" | "date" | "time" => quote! {
608 match self.#method_name() {
609 Some(val) => write!(f, "{}: {}", stringify!(#name), val)?,
610 None => write!(f, "{}: Error<bytes: \"{}\">",
611 stringify!(#name),
612 binary_mirror::to_bytes_repr(&self.#origin_field)
613 )?,
614 }
615 },
616 "enum" => quote! {
617 match self.#method_name() {
618 Some(val) => write!(f, "{}: {:?}", stringify!(#name), val)?,
619 None => write!(f, "{}: Error<bytes: \"{}\">",
620 stringify!(#name),
621 binary_mirror::to_bytes_repr(&self.#origin_field)
622 )?,
623 }
624 },
625 _ => quote! {},
626 })
627 })
628 .collect()
629}
630
631fn get_native_fields_token(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
632 native_fields
633 .iter()
634 .filter(|field| !field.skip_native)
635 .map(|field| {
636 let name = &field.name;
637 let ty = &field.ty;
638
639 quote! {
640 pub #name: #ty
641 }
642 })
643 .collect()
644}
645
646fn get_to_native_fields(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
647 native_fields
648 .iter()
649 .filter(|field| !field.skip_native)
650 .map(|field| {
651 let name = &field.name;
652 let ignore_warn = field.origin_fields[0]
653 .attrs
654 .as_ref()
655 .map(|attrs| attrs.ignore_warn)
656 .unwrap_or(false);
657
658 if ignore_warn {
659 quote! { #name: self.#name() }
660 } else {
661 let method_name = quote::format_ident!("{}_with_warn", name);
662 quote! { #name: self.#method_name() }
663 }
664 })
665 .collect()
666}
667
668fn get_from_native_fields(
669 native_field_map: &[NativeField2OriginFieldMap],
670) -> Vec<proc_macro2::TokenStream> {
671 native_field_map.iter().map(|mapping| {
672 let field_name = &mapping.origin_field.name;
673 let size = mapping.origin_field.size;
674 let default_byte = mapping.origin_field.attrs
675 .as_ref()
676 .and_then(|attrs| attrs.default_byte)
677 .unwrap_or(b' ');
678
679
680 if let Some(native_field) = &mapping.native_field {
681 let native_name = &native_field.name;
682 let attrs = mapping.origin_field.attrs.as_ref().unwrap();
683 let format = attrs.format.as_ref().map(String::as_str);
684 let skip_native = native_field.skip_native;
685 if skip_native {
686 return quote! {
687 #field_name: [#default_byte; #size]
688 };
689 }
690 match attrs.type_name.as_str() {
691 "str" | "compact_str" => quote! {
693 #field_name: {
694 let mut bytes = [#default_byte; #size]; if let Some(s) = &native.#native_name {
696 let s = s.as_bytes();
697 bytes[..s.len().min(#size)].copy_from_slice(&s[..s.len().min(#size)]);
698 }
699 bytes
700 }
701 },
702 "enum" => quote! {
703 #field_name: {
704 let mut bytes = [#default_byte; #size];
705 if let Some(enum_val) = &native.#native_name {
706 let s = enum_val.as_bytes();
707 bytes[..s.len().min(#size)].copy_from_slice(&s[..s.len().min(#size)]);
708 }
709 bytes
710 }
711 },
712 "datetime" => {
713 let format = attrs.format.as_ref()
714 .map(String::as_str)
715 .unwrap_or("%Y-%m-%d %H:%M:%S");
716 quote! {
717 #field_name: {
718 let mut bytes = [#default_byte; #size];
719 if let Some(dt) = native.#native_name {
720 let s = dt.format(#format).to_string();
721 let b = s.as_bytes();
722 bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
723 }
724 bytes
725 }
726 }
727 }
728 "date" => {
729 let format = attrs.format.as_ref()
730 .map(String::as_str)
731 .unwrap_or("%Y-%m-%d");
732 quote! {
733 #field_name: {
734 let mut bytes = [#default_byte; #size];
735 if let Some(dt) = native.#native_name {
736 let s = dt.format(#format).to_string();
737 let b = s.as_bytes();
738 bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
739 }
740 bytes
741 }
742 }
743 },
744 "time" => {
745 let format = attrs.format.as_ref()
746 .map(String::as_str)
747 .unwrap_or("%H%M%S");
748 quote! {
749 #field_name: {
750 let mut bytes = [#default_byte; #size];
751 if let Some(dt) = native.#native_name {
752 let s = dt.format(#format).to_string();
753 let b = s.as_bytes();
754 bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
755 }
756 bytes
757 }
758 }
759 },
760 "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "decimal" => {
761 if let Some(fmt) = format {
762 quote! {
763 #field_name: {
764 let mut bytes = [#default_byte; #size];
765 if let Some(val) = &native.#native_name {
766 let s = format!(#fmt, val);
767 let b = s.as_bytes();
768 bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
769 }
770 bytes
771 }
772 }
773 } else {
774 quote! {
775 #field_name: {
776 let mut bytes = [#default_byte; #size];
777 if let Some(val) = &native.#native_name {
778 let s = val.to_string();
779 let b = s.as_bytes();
780 bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
781 }
782 bytes
783 }
784 }
785 }
786 },
787 "bytes" => quote! {
788 #field_name: native.#native_name
789 },
790 _ => quote! {
791 #field_name: {
792 let mut bytes = [#default_byte; #size];
793 if let Some(val) = &native.#native_name {
794 let s = val.to_string();
795 let b = s.as_bytes();
796 bytes[..b.len().min(#size)].copy_from_slice(&b[..b.len().min(#size)]);
797 }
798 bytes
799 }
800 }
801 }
802 } else {
803 quote! {
805 #field_name: [#default_byte; #size]
806 }
807 }
808 }).collect()
809}
810
811fn get_native_methods(native_fields: &[NativeField]) -> Vec<proc_macro2::TokenStream> {
812 native_fields
813 .iter()
814 .filter(|field| !field.skip_native)
815 .map(|field| {
816 let name = &field.name;
817 let method_name = quote::format_ident!("with_{}", name);
818 let ty = &field.pure_ty;
819 let type_name = &field.type_name;
820
821 match type_name.as_str() {
822 "str" => quote! {
823 pub fn #method_name(mut self, value: impl Into<String>) -> Self {
824 self.#name = Some(value.into());
825 self
826 }
827 },
828 "compact_str" => quote! {
829 pub fn #method_name(mut self, value: impl Into<compact_str::CompactString>) -> Self {
830 self.#name = Some(value.into());
831 self
832 }
833 },
834 "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "decimal"
841 | "datetime" | "date" | "time" | "enum" => {
842 quote! {
843 pub fn #method_name(mut self, value: #ty) -> Self {
844 self.#name = Some(value);
845 self
846 }
847 }
848 }
849 _ => quote! {
850 pub fn #method_name(mut self, value: #ty) -> Self {
851 self.#name = value;
852 self
853 }
854 },
855 }
856 })
857 .collect()
858}
859
860fn get_field_spec_methods(origin_fields: &[OriginField]) -> proc_macro2::TokenStream {
861 let mut cumulative_size = 0;
862 let size_methods = origin_fields.iter().map(|field| {
863 let field_name = &field.name;
864 let field_size = field.size;
865 let offset = cumulative_size;
866 let limit = offset + field_size;
867 cumulative_size = limit;
868 let method_name = quote::format_ident!("{}_spec", field_name);
869
870 quote! {
871 pub fn #method_name() -> binary_mirror::FieldSpec {
872 binary_mirror::FieldSpec {
873 offset: #offset,
874 limit: #limit,
875 size: #field_size,
876 }
877 }
878 }
879 });
880
881 quote! {
882 #(#size_methods)*
883 }
884}
885
886fn get_native_default_impl(
887 native_fields: &[NativeField],
888 native_name: &proc_macro2::Ident,
889) -> proc_macro2::TokenStream {
890 let default_fields = native_fields.iter().filter(|field| !field.skip_native).map(|field| {
891 let name = &field.name;
892 if let Some(default) = &field.default_func {
893 let default_quote = quote::format_ident!("{}", default.as_str());
895 match field.type_name.as_str() {
896 "str"| "compact_str" | "i16" | "i32" | "i64" | "u16" | "u32" | "u64" | "f32" | "f64" | "datetime"
900 | "date" | "time" | "enum" | "decimal" => {
901 quote! {
902 #name: Some(#default_quote())
903 }
904 }
905 _ => quote! {
906 #name: Default::default()
907 },
908 }
909 } else {
910 quote! {
911 #name: Default::default()
912 }
913 }
914 });
915
916 quote! {
917 impl Default for #native_name {
918 fn default() -> Self {
919 Self {
920 #(#default_fields,)*
921 }
922 }
923 }
924 }
925}
926
927fn get_native_to_raw_impl(
928 name: &syn::Ident,
929 native_name: &proc_macro2::Ident,
930) -> proc_macro2::TokenStream {
931 quote! {
932 impl #native_name {
933 pub fn to_raw(&self) -> #name {
934 #name::from_native(self)
935 }
936 }
937 }
938}
939
940fn get_native_struct_code(
941 name: &syn::Ident,
942 native_fields: &[NativeField],
943) -> proc_macro2::TokenStream {
944 let native_name = quote::format_ident!("{}Native", name);
945 let fields_code = native_fields
946 .iter()
947 .filter(|field| !field.skip_native)
948 .map(|field| {
949 let name = &field.name;
950 let ty = &field.ty;
951 let ty_str = ty
953 .to_string()
954 .replace(" :: ", "::")
955 .replace(" < ", "<")
956 .replace(" > ", ">")
957 .replace(" >", ">");
958 format!(" pub {}: {},", name, ty_str)
959 })
960 .collect::<Vec<_>>()
961 .join("\n");
962
963 quote! {
964 impl binary_mirror::NativeStructCode for #name {
965 fn native_struct_code() -> String {
966 format!(
967 "pub struct {} {{\n{}\n}}",
968 stringify!(#native_name),
969 #fields_code
970 )
971 }
972 }
973 }
974}
975
976fn get_native_derives(struct_attrs: &StructAttrs) -> proc_macro2::TokenStream {
977 if struct_attrs.derives.is_empty() {
978 quote!(Debug, PartialEq, Serialize, Deserialize)
979 } else {
980 let native_derives = struct_attrs
981 .derives
982 .iter()
983 .map(|derive| quote!(#derive))
984 .collect::<Vec<_>>();
985 quote!(#(#native_derives),*)
986 }
987}
988
989fn impl_binary_mirror(input: &DeriveInput) -> TokenStream {
990 let name = &input.ident;
991 let native_name = quote::format_ident!("{}Native", name);
992 let struct_attrs = get_struct_attrs(input);
993
994 let origin_fields = get_origin_fields(input);
995 let (native_fields, native_field_map) = get_native_fields_and_map(&origin_fields);
996 let debug_fields_token = get_debug_fields(&origin_fields);
997 let display_fields_token = get_display_fields(&native_fields);
998 let methods = get_methods(&native_fields);
999 let native_fields_token = get_native_fields_token(&native_fields);
1000 let to_native_fields_token = get_to_native_fields(&native_fields);
1001 let from_native_fields_token = get_from_native_fields(&native_field_map);
1002 let native_methods = get_native_methods(&native_fields);
1003 let field_spec_methods = get_field_spec_methods(&origin_fields);
1004 let native_default_impl = get_native_default_impl(&native_fields, &native_name);
1005 let native_to_raw_impl = get_native_to_raw_impl(name, &native_name);
1006 let native_derives = get_native_derives(&struct_attrs);
1007 let native_struct_code = get_native_struct_code(name, &native_fields);
1008
1009 let gen = quote! {
1010 impl #name {
1011 #(#methods)*
1012 pub const fn size() -> usize {
1014 std::mem::size_of::<Self>()
1015 }
1016 #field_spec_methods
1017 }
1018
1019 #[derive(#native_derives)]
1020 pub struct #native_name {
1021 #(#native_fields_token,)*
1022 }
1023
1024 impl #native_name {
1025 #(#native_methods)*
1026 }
1027
1028 impl std::fmt::Debug for #name {
1029 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1030 f.debug_struct(stringify!(#name))
1031 #(#debug_fields_token)*
1032 .finish()
1033 }
1034 }
1035
1036 impl std::fmt::Display for #name {
1037 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1038 write!(f, "{} {{ ", stringify!(#name))?;
1039 let mut first = true;
1040 #(
1041 if first {
1042 first = false;
1043 } else {
1044 write!(f, ", ")?;
1045 }
1046 #display_fields_token
1047 )*
1048 write!(f, " }}")
1049 }
1050 }
1051
1052 #native_default_impl
1053 #native_to_raw_impl
1054 #native_struct_code
1055
1056 impl binary_mirror::FromBytes for #name {
1057 const SIZE: usize = std::mem::size_of::<Self>();
1058
1059 fn from_bytes(bytes: &[u8]) -> Result<&Self, binary_mirror::BytesSizeError> {
1060 let expected = Self::SIZE;
1061 let actual = bytes.len();
1062 if actual != expected {
1063 return Err(binary_mirror::BytesSizeError::new(
1064 expected,
1065 actual,
1066 bytes.iter()
1067 .map(|&b| {
1068 match b {
1069 0x0A => "\\n".to_string(),
1070 0x0D => "\\r".to_string(),
1071 0x09 => "\\t".to_string(),
1072 0x20..=0x7E => (b as char).to_string(),
1073 _ => format!("\\x{:02x}", b),
1074 }
1075 })
1076 .collect::<Vec<String>>()
1077 .join("")
1078 ));
1079 }
1080 Ok(unsafe { &*(bytes.as_ptr() as *const Self) })
1085 }
1086
1087 }
1088
1089 impl binary_mirror::ToBytes for #name {
1090 fn to_bytes(&self) -> &[u8] {
1091 unsafe {
1097 std::slice::from_raw_parts(
1098 (self as *const Self) as *const u8,
1099 Self::size()
1100 )
1101 }
1102 }
1103
1104 fn to_bytes_owned(&self) -> Vec<u8> {
1105 self.to_bytes().to_vec()
1106 }
1107 }
1108
1109 impl binary_mirror::ToNative for #name {
1110 type Native = #native_name;
1111
1112 fn to_native(&self) -> Self::Native {
1113 #native_name {
1114 #(#to_native_fields_token,)*
1115 }
1116 }
1117 }
1118
1119 impl binary_mirror::FromNative<#native_name> for #name {
1120 fn from_native(native: &#native_name) -> Self {
1121 Self {
1122 #(#from_native_fields_token,)*
1123 }
1124 }
1125 }
1126 };
1127
1128 gen.into()
1129}
1130
1131fn get_variant_value(attrs: &[syn::Attribute]) -> Option<Vec<u8>> {
1132 for attr in attrs {
1133 if attr.path().is_ident("bv") {
1134 let mut byte_value = None;
1135 let _ = attr.parse_nested_meta(|meta| {
1136 if meta.path.is_ident("value") {
1137 let lit = meta.value()?.parse::<syn::LitByteStr>()?;
1138 byte_value = Some(lit.value().to_vec());
1139 }
1140 Ok(())
1141 });
1142 return byte_value;
1143 }
1144 }
1145 None
1146}
1147
1148fn impl_binary_enum(input: &DeriveInput) -> TokenStream {
1149 let name = &input.ident;
1150
1151 let variants = match &input.data {
1152 Data::Enum(data) => &data.variants,
1153 _ => panic!("BinaryEnum can only be derived for enums"),
1154 };
1155
1156 let match_arms_from = variants.iter().map(|variant| {
1157 let variant_ident = &variant.ident;
1158 let byte_value = get_variant_value(&variant.attrs).unwrap_or_else(|| {
1159 let variant_str = variant_ident.to_string().to_uppercase();
1160 vec![variant_str.chars().next().unwrap() as u8]
1161 });
1162 let byte_len = byte_value.len();
1163
1164 quote! {
1165 if bytes.len() >= #byte_len && &bytes[..#byte_len] == &[#(#byte_value),*] {
1166 Some(Self::#variant_ident)
1167 } else
1168 }
1169 });
1170
1171 let match_arms_to = variants.iter().map(|variant| {
1172 let variant_ident = &variant.ident;
1173 let byte_value = get_variant_value(&variant.attrs).unwrap_or_else(|| {
1174 let variant_str = variant_ident.to_string().to_uppercase();
1175 vec![variant_str.chars().next().unwrap() as u8]
1176 });
1177
1178 quote! {
1179 Self::#variant_ident => &[#(#byte_value),*],
1180 }
1181 });
1182
1183 let gen = quote! {
1184 impl #name {
1185 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
1186 #(#match_arms_from)* {
1187 None
1188 }
1189 }
1190
1191 pub fn as_bytes(&self) -> &'static [u8] {
1192 match self {
1193 #(#match_arms_to)*
1194 }
1195 }
1196 }
1197 };
1198
1199 gen.into()
1200}
1201
1202#[cfg(test)]
1203mod tests {
1204 #[test]
1205 fn test_basic_derive() {
1206 }
1208}