Skip to main content

clingo_derive/
lib.rs

1extern crate proc_macro;
2
3use inflector::Inflector;
4use proc_macro::TokenStream;
5use quote::quote;
6use syn;
7use syn::Data::Enum;
8use syn::Data::Struct;
9use syn::Data::Union;
10use syn::Fields::*;
11use syn::Type::*;
12
13#[proc_macro_derive(ToSymbol)]
14pub fn derive_fact(input: TokenStream) -> TokenStream {
15    // Construct a representation of Rust code as a syntax tree
16    // that we can manipulate
17    let ast = syn::parse(input).expect("heeh");
18
19    // Build the trait implementation
20    impl_fact(&ast)
21}
22
23fn impl_fact(ast: &syn::DeriveInput) -> TokenStream {
24    let name = &ast.ident;
25    let gen = match &ast.data {
26        Struct(data) => match_fields_struct(&data.fields, name, &ast.generics),
27        Enum(data) => {
28            let mut variants = quote! {
29                _ => panic!("Unknown Variant"),
30            };
31            for variant in &data.variants {
32                let ident = &variant.ident;
33                let variant = match_fields_enum(&variant.fields, ident);
34                variants = quote! {
35                    #name::#variant
36                    #variants
37                }
38            }
39            let (impl_generics, ty_generics, where_clause) = &ast.generics.split_for_impl();
40            let gen = quote! {
41                impl #impl_generics ToSymbol for #name #ty_generics #where_clause {
42                    fn symbol(&self) -> Result<Symbol, ClingoError> {
43                        match self {
44                            #variants
45                        }
46                    }
47                }
48            };
49            gen.into()
50        }
51        Union(_) => panic!("Cannot derive ToSymbol for Unions!"),
52    };
53    // println!("EXPANDED: \n{}",gen);
54    gen
55}
56
57fn match_fields_struct(
58    fields: &syn::Fields,
59    name: &syn::Ident,
60    generics: &syn::Generics,
61) -> TokenStream {
62    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
63    match fields {
64        Named(named_fields) => {
65            let mut tokens = quote! {
66                let mut temp_vec =  vec![];
67            };
68            for field in &named_fields.named {
69                let i = field
70                    .ident
71                    .clone()
72                    .expect("Expected Some(Ident). None found!");
73                tokens = match_type_struct(&field.ty, &tokens, i);
74            }
75            let predicate_name = name.to_string().to_snake_case();
76            quote! {
77                impl #impl_generics ToSymbol for #name #ty_generics #where_clause {
78                    fn symbol(&self) -> Result<Symbol, ClingoError> {
79                        #tokens
80                        Symbol::create_function(#predicate_name,&temp_vec,true)
81                    }
82                }
83            }
84        }
85        Unnamed(unnamed_fields) => {
86            let mut tokens = quote! {
87                let mut temp_vec =  vec![];
88            };
89            let mut field_count = 0;
90            for field in &unnamed_fields.unnamed {
91                tokens =
92                    match_unamed_type_struct(&field.ty, &tokens, syn::Index::from(field_count));
93                field_count += 1;
94            }
95            let predicate_name = name.to_string().to_snake_case();
96            quote! {
97                impl #impl_generics ToSymbol for #name #ty_generics #where_clause {
98                    fn symbol(&self) -> Result<Symbol, ClingoError> {
99                        #tokens
100                        Symbol::create_function(#predicate_name,&temp_vec,true)
101                    }
102                }
103            }
104        }
105        Unit => {
106            let predicate_name = name.to_string().to_snake_case();
107            quote! {
108                impl #impl_generics ToSymbol for #name #ty_generics #where_clause {
109                    fn symbol(&self) -> Result<Symbol, ClingoError> {
110                        Symbol::create_id(#predicate_name,true)
111                    }
112                }
113            }
114        }
115    }
116    .into()
117}
118
119fn match_fields_enum(fields: &syn::Fields, ident: &syn::Ident) -> proc_macro2::TokenStream {
120    match &fields {
121        Named(named_fields) => {
122            let mut tokens = quote! {
123                let mut temp_vec =  vec![];
124            };
125            let mut field_idents = quote! {};
126            for field in &named_fields.named {
127                let field_ident = field
128                    .ident
129                    .clone()
130                    .expect("Expected Some(Ident). None found!");
131                if field_idents.is_empty() {
132                    field_idents = quote! {#field_ident};
133                } else {
134                    field_idents = quote! {#field_idents,#field_ident};
135                }
136                tokens = match_type_enum(&field.ty, &tokens, field_ident);
137            }
138            let predicate_name = ident.to_string().to_snake_case();
139            quote! {
140                #ident{#field_idents} => {
141                    #tokens
142                    Symbol::create_function(#predicate_name,&temp_vec,true)
143                },
144            }
145        }
146        Unnamed(unnamed_fields) => {
147            let mut tokens = quote! {
148                let mut temp_vec =  vec![];
149            };
150            let mut field_idents = quote! {};
151            let predicate_name = ident.to_string().to_snake_case();
152            let mut field_count = 1;
153            for field in &unnamed_fields.unnamed {
154                let field_ident: syn::Ident =
155                    syn::parse_str(&format!("x{}", field_count)).expect("Expected Ident");
156                if field_idents.is_empty() {
157                    field_idents = quote! {#field_ident};
158                } else {
159                    field_idents = quote! {#field_idents,#field_ident};
160                }
161                tokens = match_type_enum(&field.ty, &tokens, field_ident);
162                field_count += 1;
163            }
164            quote! {
165                #ident(#field_idents) => {
166                    #tokens
167                    Symbol::create_function(#predicate_name,&temp_vec,true)
168                },
169            }
170        }
171        Unit => {
172            let predicate_name = ident.to_string().to_snake_case();
173            quote! {
174                #ident => {
175                    Symbol::create_id(#predicate_name,true)
176                },
177            }
178        }
179    }
180}
181
182fn match_type_struct(
183    ty: &syn::Type,
184    tokens: &proc_macro2::TokenStream,
185    i: syn::Ident,
186) -> proc_macro2::TokenStream {
187    let gen = match &ty {
188        Tuple(_type_tuple) => {
189            quote! {
190                #tokens
191                temp_vec.push(self.#i.symbol()?);
192            }
193        }
194        Path(type_path) => {
195            let segments = &type_path.path.segments;
196            let typename = segments[0].ident.to_string();
197            match typename.as_ref() {
198                "u64" | "i64" | "u128" | "i128" => {
199                    panic!("Cannot derive_fact clingo library only support 32bit integers.")
200                }
201                _ => {
202                    quote! {
203                        #tokens
204                        temp_vec.push(self.#i.symbol()?);
205                    }
206                }
207            }
208        }
209        Reference(type_reference) => match_type_struct(&type_reference.elem, tokens, i),
210        _ => {
211            panic!("Unexpected type annotation");
212        }
213    };
214    gen
215}
216fn match_unamed_type_struct(
217    ty: &syn::Type,
218    tokens: &proc_macro2::TokenStream,
219    i: syn::Index,
220) -> proc_macro2::TokenStream {
221    let gen = match &ty {
222        Tuple(_type_tuple) => {
223            quote! {
224                #tokens
225                temp_vec.push(self.#i.symbol()?);
226            }
227        }
228        Path(type_path) => {
229            let segments = &type_path.path.segments;
230            let typename = segments[0].ident.to_string();
231            match typename.as_ref() {
232                "u64" | "i64" | "u128" | "i128" => {
233                    panic!("Cannot derive_fact clingo library only support 32bit integers.")
234                }
235                _ => {
236                    quote! {
237                        #tokens
238                        temp_vec.push(self.#i.symbol()?);
239                    }
240                }
241            }
242        }
243        Reference(type_reference) => match_unamed_type_struct(&type_reference.elem, tokens, i),
244        _ => {
245            panic!("Unexpected type annotation");
246        }
247    };
248    gen
249}
250
251fn match_type_enum(
252    ty: &syn::Type,
253    tokens: &proc_macro2::TokenStream,
254    i: syn::Ident,
255) -> proc_macro2::TokenStream {
256    let gen = match &ty {
257        Tuple(_type_tuple) => {
258            quote! {
259                #tokens
260                temp_vec.push(#i.symbol()?);
261            }
262        }
263        Path(type_path) => {
264            let segments = &type_path.path.segments;
265            let typename = segments[0].ident.to_string();
266            match typename.as_ref() {
267                "u64" | "i64" | "u128" | "i128" => {
268                    panic!("Cannot derive_fact clingo library only support 32bit integers.")
269                }
270                _ => {
271                    quote! {
272                        #tokens
273                        temp_vec.push(#i.symbol()?);
274                    }
275                }
276            }
277        }
278        Reference(type_reference) => match_type_enum(&type_reference.elem, tokens, i),
279        _ => {
280            panic!("Unexpected type annotation");
281        }
282    };
283    gen
284}