Skip to main content

derive_knet/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput};
5
6#[proc_macro_derive(DeriveKnet)]
7pub fn derive(input: TokenStream) -> TokenStream {
8    let ast = parse_macro_input!(input as DeriveInput);
9    
10    let ienum = &ast.ident;
11    let ienum_string = ienum.to_string();
12
13    //Iterator enum
14    let variants = if let syn::Data::Enum(syn::DataEnum { variants, .. }) = &ast.data {
15        variants
16    } else {
17        unimplemented!("Derive Knet only support Enum")
18    };
19
20    //Max len of identifier
21    let max_ident = variants
22        .iter()
23        .map(|v| v.ident.to_string().len())
24        .max_by(|v1, v2| v1.cmp(&v2));
25
26    //match for serialize()
27    let serialize_variants = variants.iter().map(|v| {
28        let name = &v.ident;
29        let sname = name.to_string();
30        let ty = inner_ty(&v.fields);
31        quote! {
32            #ienum::#name(data) => {
33                let size = ::std::mem::size_of::<#ty>();
34                let mut v  = vec![0u8; #max_ident + size];
35                let src_ptr = ::std::boxed::Box::new(*data);
36                let dst_ptr = v.as_mut_ptr();
37                /// SAFETY : TODO
38                unsafe {
39                    std::ptr::copy_nonoverlapping(#sname.as_ptr(), dst_ptr, #sname.len());
40                    let dst_ptr = dst_ptr.offset(#max_ident as isize);
41                    let src_ptr = ::std::boxed::Box::into_raw(src_ptr).cast::<u8>();
42                    std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, size);
43                };
44                let _ = ::std::boxed::Box::from_raw(src_ptr);
45
46                v
47            }
48        }
49    });
50
51    //function serialize()
52    let serialize = quote! {
53        fn serialize(&self) -> ::std::vec::Vec<u8> {
54            match self {
55                #(#serialize_variants)*
56            }
57        }
58    };
59
60    //match for deserialize()
61    let deserialize_variants = variants.iter().map(|v| {
62        let name = &v.ident;
63        let sname = name.to_string();
64        let ty = inner_ty(&v.fields).unwrap();
65        quote!{
66            #sname => {
67                let mut data = #ty::default();
68                let src_ptr = v[#max_ident..].as_ptr();
69                let dst_ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(data)).cast::<u8>();
70                ///SAFETY : TODO
71                unsafe {
72                    std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, ::std::mem::size_of::<#ty>());
73                    data = *dst_ptr.cast::<#ty>();
74                    ::std::boxed::Box::from_raw(dst_ptr);
75                };
76                *self = #ienum::#name(data);
77            }
78        }
79    });
80
81    //deserialize function
82    let deserialize = quote! {
83        fn deserialize(&mut self, v: &[u8]) {
84            let name = std::string::String::from_utf8(v[0..#max_ident].to_vec())
85            .unwrap();
86            let name = name.trim_end_matches(|c| c as u8 == 0u8);
87
88            match name {
89                #(#deserialize_variants),*
90                _ => { panic!("`{}` is not a part of {}", name, #ienum_string) }
91            }
92        }
93    };
94
95    //match for from_raw
96    let from_raw_variants = variants.iter().map(|v| {
97        let name = &v.ident;
98        let sname = name.to_string();
99        let ty = inner_ty(&v.fields).unwrap();
100        quote! {
101            #sname => {
102                let mut data = #ty::default();
103                let src_ptr = v[#max_ident..].as_ptr();
104                let dst_ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(data)).cast::<u8>();
105                ///SAFETY : TODO
106                unsafe {
107                    std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, ::std::mem::size_of::<#ty>());
108                    data = *dst_ptr.cast::<#ty>();
109                    ::std::boxed::Box::from_raw(dst_ptr);
110                };
111                ///SAFETY : TODO
112                #ienum::#name(data)
113            }
114        }
115    });
116
117    //function from_raw
118    let from_raw = quote! {
119        fn from_raw(v: &[u8]) -> Self {
120            let name = std::string::String::from_utf8(v[0..#max_ident].to_vec())
121                .unwrap();
122            let name = name.trim_end_matches(|c| c as u8 == 0u8);
123
124            match name {
125                #(#from_raw_variants),*
126                _ => { panic!("`{}` is not a part of {}", name, #ienum_string) }
127            }
128        }
129
130    };
131
132    
133    let get_size_variants = variants.iter().map(|v| {
134        let name = &v.ident;
135        let sname = name.to_string();
136        let ty = inner_ty(&v.fields).unwrap();
137        quote! {
138            #sname => {
139                ::std::mem::size_of::<#ty>()
140            }
141        }
142    });
143
144    let get_size_of_data = quote! {
145        fn get_size_of_data(v: &[u8]) -> usize {
146            let name = std::string::String::from_utf8(v.to_vec())
147                .unwrap();
148            let name = name.trim_end_matches(|c| c as u8 == 0u8);
149            match name {
150                #(#get_size_variants),*
151                _ => { panic!("`{}` is not a part of {}", name, #ienum_string) }
152
153            }
154        }
155    };
156
157    let get_size_of_payload = quote! {
158        fn get_size_of_payload() -> usize {
159            #max_ident
160        }
161    };
162    let expanded = quote! {
163        impl knet::KnetTransform for #ienum {
164            #serialize
165            #deserialize
166            #from_raw
167            #get_size_of_data
168            #get_size_of_payload
169        }
170    };
171    expanded.into()
172}
173
174fn inner_ty(field: &syn::Fields) -> Option<&syn::Ident> {
175    if let syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) = field {
176        if let Some(first) = unnamed.first() {
177            if let syn::Type::Path(syn::TypePath { path, .. }) = &first.into_value().ty {
178                if path.segments.len() == 1 {
179                    return Some(&path.segments[0].ident);
180                }
181            }
182        }
183    }
184    None
185}