polywrap_plugin_creator/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::quote;
4
5use syn::{parse, parse_macro_input, ItemImpl };
6
7
8fn snake_case_to_camel_case(s: &str) -> String {
9    s.split('_')
10        .enumerate()
11        .map(|(i, s)| {
12            if i == 0 {
13                s.to_string()
14            } else {
15                s.chars().next().unwrap().to_uppercase().collect::<String>() + &s[1..]
16            }
17        })
18        .collect()
19}
20
21#[proc_macro_attribute]
22pub fn plugin_impl(args: TokenStream, input: TokenStream) -> TokenStream {
23    let item_impl = parse_macro_input!(input as ItemImpl);
24    let _ = parse_macro_input!(args as parse::Nothing);
25
26    let struct_ident = item_impl.clone().self_ty;
27
28    let mut method_idents: Vec<(Ident, String, bool, Option<bool>)> = vec![];
29
30    for item in item_impl.clone().items {
31        match item {
32            syn::ImplItem::Method(method) => {
33              let function_ident = &method.sig.ident;
34              let env_is_option = if &method.sig.inputs.len() > &3 {
35                let env = &method.sig.inputs[3];
36                let env_str = quote! { #env }.to_string();
37                
38                Some(env_str.contains("Option <"))
39              } else {
40                None
41              };
42              
43              let output_type = match &method.sig.output {
44                syn::ReturnType::Default => quote! { () },
45                syn::ReturnType::Type(_, ty) => quote! { #ty },
46              };
47              let output_type = quote! { #output_type }.to_string();
48              let function_ident_str =
49                  snake_case_to_camel_case(&function_ident.to_string());
50              let output_is_option = output_type.contains("Option <");
51
52              method_idents.push((
53                  function_ident.clone(),
54                  function_ident_str.clone(),
55                  output_is_option,
56                  env_is_option
57              ));
58            }
59            _ => panic!("Wrong function signature"),
60        }
61    }
62
63    let supported_methods =
64        method_idents
65            .clone()
66            .into_iter()
67            .enumerate()
68            .map(|(_, (_, ident_str, _, _))| {
69                quote! {
70                  #ident_str
71                }
72            });
73
74    let methods = method_idents
75        .into_iter()
76        .enumerate()
77        .map(|(_, (ident, ident_str, output_is_option, env_is_option))| {
78            let args = if let Some(env_is_option) = env_is_option {
79              let env = if env_is_option { 
80                quote! {
81                  if let Some(e) = env {
82                    Some(serde_json::from_value(e).unwrap())
83                  } else {
84                    None
85                  }
86                }
87              } else {
88                quote! {
89                  if let Some(e) = env {
90                    serde_json::from_value(e).unwrap()
91                  } else {
92                    panic!("Env must be defined for method '{}'", #ident_str)
93                  }
94                }
95              };
96
97              quote! {
98                &polywrap_msgpack::decode(&params).unwrap(),
99                invoker,
100                #env
101              }
102            } else {
103              quote! {
104                &polywrap_msgpack::decode(&params).unwrap(),
105                invoker
106              }
107            };
108
109            let output = if output_is_option {
110              quote! {
111                if let Some(r) = result {
112                  Ok(polywrap_msgpack::serialize(r)?)
113                } else {
114                  Ok(vec![])
115                }
116              }
117            } else {
118              quote! {
119                Ok(polywrap_msgpack::serialize(result)?)
120              }
121            };
122          
123            quote! {
124                #ident_str => {
125                  let result = self.#ident(
126                    #args
127                  )?;
128
129                  #output
130                }
131              }
132        });
133
134    let module_impl = quote! {
135        impl polywrap_plugin::module::PluginModule for #struct_ident {
136          fn _wrap_invoke(
137            &mut self,
138            method_name: &str,
139            params: &[u8],
140            env: Option<serde_json::Value>,
141            invoker: std::sync::Arc<dyn polywrap_core::invoke::Invoker>,
142        ) -> Result<Vec<u8>, polywrap_plugin::error::PluginError> {
143                let supported_methods = vec![#(#supported_methods),*];
144                match method_name {
145                    #(#methods)*
146                    _ => panic!("Method '{}' not found. Supported methods: {:#?}", method_name, supported_methods),
147                }
148            }
149        }
150    };
151
152    let from_impls = quote! {
153      impl From<#struct_ident> for polywrap_plugin::package::PluginPackage {
154        fn from(plugin: #struct_ident) -> polywrap_plugin::package::PluginPackage {
155            let plugin_module = Arc::new(std::sync::Mutex::new(Box::new(plugin) as Box<dyn polywrap_plugin::module::PluginModule>));
156            polywrap_plugin::package::PluginPackage::new(plugin_module, get_manifest())
157        }
158      }
159
160      impl From<#struct_ident> for polywrap_plugin::wrapper::PluginWrapper {
161        fn from(plugin: #struct_ident) -> polywrap_plugin::wrapper::PluginWrapper {
162            let plugin_module = Arc::new(std::sync::Mutex::new(Box::new(plugin) as Box<dyn polywrap_plugin::module::PluginModule>));
163            polywrap_plugin::wrapper::PluginWrapper::new(plugin_module)
164        }
165      }
166    };
167
168    quote! {
169        #item_impl
170
171        #module_impl
172
173        #from_impls
174    }
175    .into()
176}