Skip to main content

cowlang_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, ImplItem, ItemImpl};
6
7#[proc_macro_attribute]
8pub fn cow_module(_attr: TokenStream, item: TokenStream) -> TokenStream {
9    let ast = parse_macro_input!(item as ItemImpl);
10
11    let name = if let syn::Type::Path(p) = *ast.self_ty {
12        p.path.get_ident().unwrap().clone()
13    } else {
14        panic!("self_ty is not a path!");
15    };
16
17    if let Some((_, ref path, _)) = ast.trait_ {
18        syn::Error::new_spanned(
19            path,
20            "#[cow_module] can not be used with a trait impl block",
21        )
22        .to_compile_error();
23    } else if ast.generics != Default::default() {
24        syn::Error::new_spanned(
25            ast.generics.clone(),
26            "#[cow_module] can not be used with lifetime parameters or generics",
27        )
28        .to_compile_error();
29    }
30
31    let mut method_names = Vec::new();
32    let mut method_outputs = Vec::new();
33    let mut internal_method_names = Vec::new();
34    let mut method_return_conversions = Vec::new();
35    let mut method_args = Vec::new();
36    let mut method_attrs = Vec::new();
37    let mut method_structs = Vec::new();
38    let mut method_blocks = Vec::new();
39
40    let mut constant_names = Vec::new();
41    let mut constant_literals = Vec::new();
42    let mut constant_expressions = Vec::new();
43
44    for item in ast.items.iter() {
45        match item {
46            ImplItem::Fn(meth) => {
47                let ident = &meth.sig.ident;
48                let mut args = Vec::new();
49
50                let mut returns_object = false;
51                let mut attrs_out = Vec::new();
52
53                for attr in meth.attrs.iter() {
54                    if attr.path().segments.len() == 1
55                        && &attr.path().segments[0].ident == "returns_object"
56                    {
57                        returns_object = true;
58                    } else {
59                        attrs_out.push(attr.clone());
60                    }
61                }
62
63                method_names.push(ident.to_string());
64                internal_method_names.push(format_ident!("_internal_{}", ident));
65
66                method_attrs.push(attrs_out);
67                method_blocks.push(meth.block.clone());
68                method_outputs.push(meth.sig.output.clone());
69
70                let return_conversion = if meth.sig.output == syn::ReturnType::Default {
71                    quote! {
72                        cowlang::interpreter::Handle::wrap_value( cowlang::Value::None )
73                    }
74                } else if returns_object {
75                    quote! {
76                        cowlang::interpreter::Handle::Object( std::rc::Rc::new( result ) )
77                    }
78                } else {
79                    quote! {
80                        cowlang::interpreter::Handle::wrap_value( result.into() )
81                    }
82                };
83
84                method_return_conversions.push(return_conversion);
85
86                for arg in meth.sig.inputs.iter() {
87                    // ignore self values etc
88                    if let syn::FnArg::Typed(typed) = arg {
89                        if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &*typed.pat {
90                            args.push(ident.clone());
91                        } else {
92                            panic!("Unsupported pattern");
93                        }
94                    }
95                }
96
97                method_args.push(args);
98
99                method_structs.push(format_ident!("MethodCall_{}_{}", name, ident));
100            }
101            ImplItem::Const(constant) => {
102                if let syn::Expr::Lit(lit) = &constant.expr {
103                    constant_names.push(constant.ident.to_string());
104                    constant_literals.push(lit.lit.clone());
105                } else {
106                    panic!("Unsupported expression: {:?}", constant.expr);
107                }
108
109                constant_expressions.push(constant.clone());
110            }
111            _ => {
112                panic!("ImplItem type not supported: {:?}", item);
113            }
114        }
115    }
116
117    let method_struct_defs = method_structs.iter();
118    let method_impl_names = method_structs.iter();
119    let method_struct_defs2 = method_structs.iter();
120
121    let mut arg_strings1 = Vec::new();
122    let mut arg_strings2 = Vec::new();
123
124    let name_string = name.to_string();
125    let mut arg_conversions = Vec::new();
126    let mut arg_lens = Vec::new();
127
128    let name_iter = std::iter::repeat(name.clone());
129    let name_iter2 = std::iter::repeat(name.clone());
130
131    for args in method_args {
132        let conv = quote! {
133            let mut _internal_argv = _internal_args.drain(..);
134
135            #(
136            let #args = _internal_argv.next().unwrap();
137            )*
138        };
139
140        let arg_str1 = quote! {
141            ( #( #args),*)
142        };
143
144        let arg_str2 = quote! {
145            (&self, #( #args: cowlang::Value),*)
146        };
147
148        arg_strings1.push(arg_str1);
149        arg_strings2.push(arg_str2);
150
151        arg_lens.push(args.len());
152        arg_conversions.push(conv);
153    }
154
155    let expanded = quote! {
156        #(
157
158        #( #method_attrs )*
159        #[allow(non_camel_case_types)]
160        struct #method_struct_defs {
161            self_ref: std::rc::Rc<dyn cowlang::Module>
162        }
163
164        #( #method_attrs )*
165        impl cowlang::interpreter::Callable for #method_impl_names {
166            fn call(&self, mut _internal_args: Vec<cowlang::Value>) -> cowlang::interpreter::Handle {
167                //FIXME find a way to do this without raw pointers
168
169                let self_rc = self.self_ref.clone();
170
171                let self_ptr = std::rc::Rc::into_raw(self_rc);
172                let self_ref = unsafe{ &*(self_ptr as *const #name_iter) };
173
174                if _internal_args.len() != #arg_lens {
175                    panic!("Invalid number of arguments!");
176                }
177
178                #arg_conversions
179
180                let result = self_ref.#internal_method_names #arg_strings1;
181
182                // Avoid memory leak
183                unsafe {
184                    std::rc::Rc::from_raw(self_ptr);
185                };
186
187                #method_return_conversions
188            }
189        }
190        )*
191
192        impl #name {
193            #( #constant_expressions )*
194        }
195
196        #(
197        impl #name_iter2 {
198
199            #[inline]
200            #( #method_attrs )*
201            fn #internal_method_names #arg_strings2 #method_outputs #method_blocks
202        }
203        )*
204
205        impl cowlang::Module for #name {
206            fn get_member(&self, self_ref: &std::rc::Rc<dyn cowlang::Module>, member_name: &str) -> cowlang::interpreter::Handle {
207                #(
208                if member_name == #method_names {
209
210                    #( #method_attrs )*
211                    {
212                        return cowlang::interpreter::Handle::Callable( Box::new(
213                                #method_struct_defs2{ self_ref: self_ref.clone() }
214                        ));
215                    }
216                }
217                )*
218                #(
219                if member_name == #constant_names {
220                    return cowlang::interpreter::Handle::wrap_value( #constant_literals.into() );
221                }
222                )*
223
224                panic!("No such member {}::{}", #name_string, member_name);
225            }
226        }
227    };
228
229    TokenStream::from(expanded)
230}