1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput};
5
6#[proc_macro_derive(DeriveKnet)]
7pub fn derive(input: TokenStream) -> TokenStream {
8 let ast = parse_macro_input!(input as DeriveInput);
9
10 let ienum = &ast.ident;
11 let ienum_string = ienum.to_string();
12
13 let variants = if let syn::Data::Enum(syn::DataEnum { variants, .. }) = &ast.data {
15 variants
16 } else {
17 unimplemented!("Derive Knet only support Enum")
18 };
19
20 let max_ident = variants
22 .iter()
23 .map(|v| v.ident.to_string().len())
24 .max_by(|v1, v2| v1.cmp(&v2));
25
26 let serialize_variants = variants.iter().map(|v| {
28 let name = &v.ident;
29 let sname = name.to_string();
30 let ty = inner_ty(&v.fields);
31 quote! {
32 #ienum::#name(data) => {
33 let size = ::std::mem::size_of::<#ty>();
34 let mut v = vec![0u8; #max_ident + size];
35 let src_ptr = ::std::boxed::Box::new(*data);
36 let dst_ptr = v.as_mut_ptr();
37 unsafe {
39 std::ptr::copy_nonoverlapping(#sname.as_ptr(), dst_ptr, #sname.len());
40 let dst_ptr = dst_ptr.offset(#max_ident as isize);
41 let src_ptr = ::std::boxed::Box::into_raw(src_ptr).cast::<u8>();
42 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
43 };
44 let _ = ::std::boxed::Box::from_raw(src_ptr);
45
46 v
47 }
48 }
49 });
50
51 let serialize = quote! {
53 fn serialize(&self) -> ::std::vec::Vec<u8> {
54 match self {
55 #(#serialize_variants)*
56 }
57 }
58 };
59
60 let deserialize_variants = variants.iter().map(|v| {
62 let name = &v.ident;
63 let sname = name.to_string();
64 let ty = inner_ty(&v.fields).unwrap();
65 quote!{
66 #sname => {
67 let mut data = #ty::default();
68 let src_ptr = v[#max_ident..].as_ptr();
69 let dst_ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(data)).cast::<u8>();
70 unsafe {
72 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, ::std::mem::size_of::<#ty>());
73 data = *dst_ptr.cast::<#ty>();
74 ::std::boxed::Box::from_raw(dst_ptr);
75 };
76 *self = #ienum::#name(data);
77 }
78 }
79 });
80
81 let deserialize = quote! {
83 fn deserialize(&mut self, v: &[u8]) {
84 let name = std::string::String::from_utf8(v[0..#max_ident].to_vec())
85 .unwrap();
86 let name = name.trim_end_matches(|c| c as u8 == 0u8);
87
88 match name {
89 #(#deserialize_variants),*
90 _ => { panic!("`{}` is not a part of {}", name, #ienum_string) }
91 }
92 }
93 };
94
95 let from_raw_variants = variants.iter().map(|v| {
97 let name = &v.ident;
98 let sname = name.to_string();
99 let ty = inner_ty(&v.fields).unwrap();
100 quote! {
101 #sname => {
102 let mut data = #ty::default();
103 let src_ptr = v[#max_ident..].as_ptr();
104 let dst_ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(data)).cast::<u8>();
105 unsafe {
107 std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, ::std::mem::size_of::<#ty>());
108 data = *dst_ptr.cast::<#ty>();
109 ::std::boxed::Box::from_raw(dst_ptr);
110 };
111 #ienum::#name(data)
113 }
114 }
115 });
116
117 let from_raw = quote! {
119 fn from_raw(v: &[u8]) -> Self {
120 let name = std::string::String::from_utf8(v[0..#max_ident].to_vec())
121 .unwrap();
122 let name = name.trim_end_matches(|c| c as u8 == 0u8);
123
124 match name {
125 #(#from_raw_variants),*
126 _ => { panic!("`{}` is not a part of {}", name, #ienum_string) }
127 }
128 }
129
130 };
131
132
133 let get_size_variants = variants.iter().map(|v| {
134 let name = &v.ident;
135 let sname = name.to_string();
136 let ty = inner_ty(&v.fields).unwrap();
137 quote! {
138 #sname => {
139 ::std::mem::size_of::<#ty>()
140 }
141 }
142 });
143
144 let get_size_of_data = quote! {
145 fn get_size_of_data(v: &[u8]) -> usize {
146 let name = std::string::String::from_utf8(v.to_vec())
147 .unwrap();
148 let name = name.trim_end_matches(|c| c as u8 == 0u8);
149 match name {
150 #(#get_size_variants),*
151 _ => { panic!("`{}` is not a part of {}", name, #ienum_string) }
152
153 }
154 }
155 };
156
157 let get_size_of_payload = quote! {
158 fn get_size_of_payload() -> usize {
159 #max_ident
160 }
161 };
162 let expanded = quote! {
163 impl knet::KnetTransform for #ienum {
164 #serialize
165 #deserialize
166 #from_raw
167 #get_size_of_data
168 #get_size_of_payload
169 }
170 };
171 expanded.into()
172}
173
174fn inner_ty(field: &syn::Fields) -> Option<&syn::Ident> {
175 if let syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) = field {
176 if let Some(first) = unnamed.first() {
177 if let syn::Type::Path(syn::TypePath { path, .. }) = &first.into_value().ty {
178 if path.segments.len() == 1 {
179 return Some(&path.segments[0].ident);
180 }
181 }
182 }
183 }
184 None
185}