Skip to main content

pyro_macro/ffi/lifecycle/
init.rs

1//! Init function parsing and validation
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Error, FnArg, GenericArgument, Ident, ImplItemFn, Pat, PathArguments, ReturnType, Type};
6
7use heck::AsSnakeCase;
8
9#[derive(Debug, Clone)]
10pub struct InitFn {
11    pub is_async: bool,
12    pub config_type: Option<Type>,
13    pub body: syn::Block,
14    pub attrs: Vec<syn::Attribute>,
15    pub arg_name: Option<Ident>,
16}
17
18impl InitFn {
19    /// Parse the init function and validate it against the expected configuration.
20    pub fn parse(expected_config: Option<Type>, f: &ImplItemFn) -> syn::Result<Self> {
21        let sig = &f.sig;
22
23        // 1. Validate name
24        if sig.ident != "new" {
25            return Err(Error::new_spanned(
26                &sig.ident,
27                "Expected function named 'new'",
28            ));
29        }
30
31        // 2. Validate return type is Self
32        match &sig.output {
33            ReturnType::Type(_, ty) => {
34                let ty_str = quote!(#ty).to_string().replace(" ", "");
35                if ty_str != "Self" {
36                    return Err(Error::new_spanned(&sig.output, "fn new must return Self"));
37                }
38            }
39            ReturnType::Default => {
40                return Err(Error::new_spanned(&sig, "fn new must return Self"));
41            }
42        }
43
44        // 3. Validate no &self receiver
45        if let Some(FnArg::Receiver(r)) = sig.inputs.first() {
46            return Err(Error::new_spanned(
47                r,
48                "fn new must be a static function (no self parameter)",
49            ));
50        }
51
52        let mut user_arg_name = None;
53
54        // 4. Validate Argument consistency against Expected Config
55        match &expected_config {
56            // Case A: Attribute said `config = MyType`
57            Some(expected_ty) => {
58                if sig.inputs.len() != 1 {
59                    return Err(Error::new_spanned(
60                        &sig.inputs,
61                        format!(
62                            "Macro attribute defined 'config = {}', so fn new must take exactly one argument: 'arg: Option<{}>'",
63                            quote!(#expected_ty),
64                            quote!(#expected_ty)
65                        ),
66                    ));
67                }
68
69                let arg = sig.inputs.first().unwrap();
70                if let FnArg::Typed(pt) = arg {
71                    // Capture argument name (don't validate it)
72                    if let Pat::Ident(pi) = &*pt.pat {
73                        user_arg_name = Some(pi.ident.clone());
74                    } else {
75                        return Err(Error::new_spanned(
76                            &pt.pat,
77                            "Expected simple identifier for argument",
78                        ));
79                    }
80
81                    // Check type is Option<T>
82                    let valid_option = if let Type::Path(tp) = &*pt.ty {
83                        if let Some(segment) = tp.path.segments.last() {
84                            if segment.ident == "Option" {
85                                if let PathArguments::AngleBracketed(args) = &segment.arguments {
86                                    if let Some(GenericArgument::Type(inner_ty)) = args.args.first()
87                                    {
88                                        // Compare inner type with expected type
89                                        let inner_str =
90                                            quote!(#inner_ty).to_string().replace(" ", "");
91                                        let expected_str =
92                                            quote!(#expected_ty).to_string().replace(" ", "");
93
94                                        if inner_str == expected_str {
95                                            Some(())
96                                        } else {
97                                            return Err(Error::new_spanned(
98                                                &pt.ty,
99                                                format!(
100                                                    "Type mismatch. Expected 'Option<{}>' based on macro attribute, found 'Option<{}>'",
101                                                    expected_str, inner_str
102                                                ),
103                                            ));
104                                        }
105                                    } else {
106                                        None
107                                    }
108                                } else {
109                                    None
110                                }
111                            } else {
112                                None
113                            }
114                        } else {
115                            None
116                        }
117                    } else {
118                        None
119                    };
120
121                    if valid_option.is_none() {
122                        return Err(Error::new_spanned(
123                            &pt.ty,
124                            format!(
125                                "Config parameter must be 'Option<{}>'",
126                                quote!(#expected_ty)
127                            ),
128                        ));
129                    }
130                }
131            }
132            // Case B: No config attribute
133            None => {
134                if !sig.inputs.is_empty() {
135                    return Err(Error::new_spanned(
136                        &sig.inputs,
137                        "No 'config' attribute specified in macro, so fn new() must take 0 arguments.",
138                    ));
139                }
140            }
141        }
142
143        Ok(Self {
144            is_async: sig.asyncness.is_some(),
145            config_type: expected_config,
146            body: f.block.clone(),
147            attrs: f.attrs.clone(),
148            arg_name: user_arg_name,
149        })
150    }
151
152    /// Generate the FFI init function
153    pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
154        let server_snake = AsSnakeCase(server.to_string()).to_string();
155        let init_name = format_ident!("p__{}__ffi_init", server_snake);
156
157        // Determine config type and closure body
158        // The safe_lifecycle functions expect Option<T> to be passed through
159        let (return_ty, closure) = match (&self.config_type, self.is_async) {
160            (Some(c), false) => (
161                quote!(::pyroduct::ffi::InitResult),
162                quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
163                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
164                        Ok(config) => config,
165                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
166                    };
167                    ::pyroduct::ffi::InitResult::init_ok(#server::new(config), object_id)
168                }, object_id)},
169            ),
170            (None, false) => (
171                quote!(::pyroduct::ffi::InitResult),
172                quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
173                    ::pyroduct::ffi::InitResult::init_ok(#server::new(), object_id)
174                }, object_id)},
175            ),
176            (Some(c), true) => (
177                quote!(::pyroduct::ffi::FutureInitResult),
178                quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
179                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
180                        Ok(config) => config,
181                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
182                    };
183                    ::pyroduct::ffi::InitResult::init_ok(#server::new(config).await, object_id)
184                }, object_id)},
185            ),
186            (None, true) => (
187                quote!(::pyroduct::ffi::FutureInitResult),
188                quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
189                    ::pyroduct::ffi::InitResult::init_ok(#server::new().await, object_id)
190                }, object_id)},
191            ),
192        };
193
194        quote! {
195            #[unsafe(no_mangle)]
196            pub extern "C" fn #init_name(
197                config_ptr: ::pyroduct::format::PyroRefPtr,
198                object_id: u64,
199            ) -> #return_ty {
200                #closure
201            }
202        }
203    }
204
205    /// Generate the export entry for the init function
206    pub fn generate_export(&self, server: &Ident) -> TokenStream {
207        let server_snake = AsSnakeCase(server.to_string()).to_string();
208        let init_name = format_ident!("p__{}__ffi_init", server_snake);
209
210        if self.is_async {
211            quote!(::pyroduct::ffi::ClassInitFn::Async(#init_name))
212        } else {
213            quote!(::pyroduct::ffi::ClassInitFn::Sync(#init_name))
214        }
215    }
216
217    /// Generate the impl method (preserves original)
218    pub fn generate_impl_method(&self) -> TokenStream {
219        let attrs = &self.attrs;
220        let body = &self.body;
221        let async_kw = if self.is_async {
222            quote!(async)
223        } else {
224            quote!()
225        };
226
227        let params = if let Some(config) = &self.config_type {
228            // Use the user's variable name, fallback to 'config' if something went weird
229            let name = self.arg_name.clone().unwrap_or(format_ident!("config"));
230            quote!(#name: Option<#config>)
231        } else {
232            quote!()
233        };
234
235        quote! {
236            #(#attrs)*
237            pub #async_kw fn new(#params) -> Self #body
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use crate::fmt::assert_code_eq_token;
245
246    use super::*;
247    use quote::{format_ident, quote};
248    use syn::parse_quote;
249
250    #[test]
251    fn test_sync_server_init_fn() {
252        // 1. Simulate the config attribute passed from the macro
253        let config_type: Type = parse_quote!(GreeterConfig);
254
255        // 2. Simulate the user's implementation (Using Option, and variable name 'cfg')
256        let item: ImplItemFn = parse_quote! {
257            fn new(cfg: Option<GreeterConfig>) -> Self {
258                Self { count: 0 }
259            }
260        };
261
262        // 3. Parse with validation
263        let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
264
265        let server_ident = format_ident!("GreeterServer");
266        let result = init_fn.generate_ffi(&server_ident);
267
268        // Note: Closure now calls new(config) directly (config is already Option<T>)
269        let expected = quote! {
270            #[unsafe(no_mangle)]
271            pub extern "C" fn p__greeter_server__ffi_init(
272                config_ptr: ::pyroduct::format::PyroRefPtr,
273                object_id: u64,
274            ) -> ::pyroduct::ffi::InitResult {
275                ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
276                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
277                        Ok(config) => config,
278                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
279                    };
280                    ::pyroduct::ffi::InitResult::init_ok(GreeterServer::new(config), object_id)
281                }, object_id)
282            }
283        };
284
285        assert_code_eq_token(&result, &expected);
286    }
287
288    #[test]
289    fn test_async_server_init_fn() {
290        // 1. Config attribute
291        let config_type: Type = parse_quote!(GreeterConfig);
292
293        // 2. User implementation
294        let item: ImplItemFn = parse_quote! {
295            async fn new(val: Option<GreeterConfig>) -> Self {
296                Self { count: 0 }
297            }
298        };
299
300        // 3. Parse
301        let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
302
303        let server_ident = format_ident!("GreeterServer");
304        let result = init_fn.generate_ffi(&server_ident);
305
306        let expected = quote! {
307            #[unsafe(no_mangle)]
308            pub extern "C" fn p__greeter_server__ffi_init(
309                config_ptr: ::pyroduct::format::PyroRefPtr,
310                object_id: u64,
311            ) -> ::pyroduct::ffi::FutureInitResult {
312                ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
313                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
314                        Ok(config) => config,
315                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
316                    };
317                    ::pyroduct::ffi::InitResult::init_ok(GreeterServer::new(config).await, object_id)
318                }, object_id)
319            }
320        };
321
322        assert_code_eq_token(&result, &expected);
323    }
324
325    #[test]
326    fn test_arbitrary_arg_name() {
327        let config_type: Type = parse_quote!(MyConfig);
328        // User uses 'settings' instead of 'config'
329        let item: ImplItemFn = parse_quote! {
330            fn new(settings: Option<MyConfig>) -> Self { Self }
331        };
332
333        let init_fn =
334            InitFn::parse(Some(config_type), &item).expect("Should allow arbitrary names");
335
336        // Check if generate_impl_method preserves the name 'settings'
337        let impl_code = init_fn.generate_impl_method();
338        let impl_str = impl_code.to_string();
339        assert!(impl_str.contains("settings : Option < MyConfig >"));
340    }
341
342    #[test]
343    fn test_validation_errors() {
344        let config_type: Type = parse_quote!(MyConfig);
345
346        // Case: Not Option<T>
347        let item: ImplItemFn = parse_quote! { fn new(c: MyConfig) -> Self { Self } };
348        assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
349
350        // Case: Not Option<T> (Reference)
351        let item: ImplItemFn = parse_quote! { fn new(c: &MyConfig) -> Self { Self } };
352        assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
353
354        // Case: Option<WrongType>
355        let item: ImplItemFn = parse_quote! { fn new(c: Option<WrongConfig>) -> Self { Self } };
356        assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
357    }
358}