auto_diff_macros/
lib.rs

1/// procedure macros
2
3use proc_macro::TokenStream;
4use syn::{parse_macro_input, ItemStruct, parse, parse::Parser};
5use syn::punctuated::Punctuated;
6use syn::{Expr, Token};
7use quote::quote;
8
9
10#[proc_macro_attribute]
11pub fn add_op_handle(args: TokenStream, input: TokenStream) -> TokenStream {
12    let mut item_struct = parse_macro_input!(input as ItemStruct);
13    let _ = parse_macro_input!(args as parse::Nothing);
14
15    if let syn::Fields::Named(ref mut fields) = item_struct.fields {
16
17        fields.named.push(
18            syn::Field::parse_named
19                .parse2(quote! {
20                    #[cfg_attr(feature = "use-serde", serde(skip))]
21                    handle: OpHandle
22                })
23                .unwrap(),
24        );
25    }
26
27    return quote! {
28        #item_struct
29    }
30    .into();
31}
32
33#[proc_macro_attribute]
34pub fn extend_op_impl(args: TokenStream, input: TokenStream) -> TokenStream {
35    let mut item_struct = parse_macro_input!(input as ItemStruct);
36    let _ = parse_macro_input!(args as parse::Nothing);
37
38    if let syn::Fields::Named(ref mut fields) = item_struct.fields {
39
40        fields.named.push(
41            syn::Field::parse_named
42                .parse2(quote! {
43                    #[cfg_attr(feature = "use-serde", serde(skip))]
44                    handle: OpHandle
45                })
46                .unwrap(),
47        );
48    }
49
50    return quote! {
51        #item_struct
52    }
53    .into();
54}
55
56
57#[proc_macro]
58pub fn gen_serde_funcs(input: TokenStream) -> TokenStream {
59
60    let input_tokens = input.clone();
61    let parser = Punctuated::<Expr, Token![,]>::parse_separated_nonempty;
62    let input_result = parser.parse(input_tokens).expect("need list of ids");
63    let mut strs = vec![]; // This is the vec of op structure name in str.
64    for item in input_result {
65        match item {
66            Expr::Path(expr) => {
67                strs.push(expr.path.get_ident().expect("need a ident").to_string());
68            },
69            _ => {panic!("need a ident, expr::path.");}
70        }
71    }
72
73    // This is the vec of ident.
74    let names: Vec<_> = strs.iter().map(|x| quote::format_ident!("{}", x)).collect();
75    
76    let serialize_box = quote!{
77        pub fn serialize_box<S>(op: &Box<dyn OpTrait>, serializer: S) -> Result<S::Ok, S::Error>
78        where S: Serializer {
79            match op.get_name() {
80                #( #strs  => {
81                    let op = op.as_any().downcast_ref::<#names>().unwrap();
82                    op.serialize(serializer)
83                },  )*
84	        other => {
85		    return Err(ser::Error::custom(format!("unknown op {:?}", other)));
86	        }
87	    }
88        }
89    };
90
91    let deserialize_map = quote!{
92        pub fn deserialize_map<'de, V>(op_name: String, mut map: V) -> Result<Op, V::Error>
93        where V: MapAccess<'de>, {
94            match op_name.as_str() {
95                #( #strs => {
96                    let op_obj: #names = Some(map.next_value::<#names>()?).ok_or_else(|| de::Error::missing_field("op_obj"))?;
97                    return Ok(Op::new(Rc::new(RefCell::new(Box::new(op_obj)))));
98                }, )*
99                _ => {
100		    return Err(de::Error::missing_field("op_obj"));
101		}
102            }
103        }
104    };
105
106    let deserialize_seq = quote!{
107	pub fn deserialize_seq<'de, V>(op_name: String, mut seq: V) -> Result<Op, V::Error>
108        where V: SeqAccess<'de>, {
109            match op_name.as_str() {
110                #( #strs => {
111                    let op_obj: #names = seq.next_element()?.ok_or_else(|| de::Error::missing_field("op_obj"))?;
112                    return Ok(Op::new(Rc::new(RefCell::new(Box::new(op_obj)))));
113                }, )*
114                _ => {
115		    return Err(de::Error::missing_field("op_obj"));
116		}
117            }
118        }
119    };
120
121    let tokens = quote! {
122        #serialize_box
123        #deserialize_map
124        #deserialize_seq
125    };
126    
127    tokens.into()
128}
129
130
131#[cfg(test)]
132mod tests {
133    
134    #[test]
135    fn test() {
136        
137    }
138}