Skip to main content

pyro_macro/ffi/lifecycle/
init.rs

1//! Init function parsing and validation
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Error, FnArg, GenericArgument, Ident, ImplItemFn, Pat, PathArguments, Type};
6
7use heck::AsSnakeCase;
8
9#[derive(Debug, Clone)]
10pub struct InitFn {
11    pub is_async: bool,
12    pub config_type: Option<Type>,
13    pub body: syn::Block,
14    pub attrs: Vec<syn::Attribute>,
15    pub arg_name: Option<Ident>,
16}
17
18impl InitFn {
19    /// Parse the init function and validate it against the expected configuration.
20    pub fn parse(expected_config: Option<Type>, f: &ImplItemFn) -> syn::Result<Self> {
21        let sig = &f.sig;
22
23        // 1. Validate name
24        if sig.ident != "new" {
25            return Err(Error::new_spanned(
26                &sig.ident,
27                "Expected function named 'new'",
28            ));
29        }
30
31        // 2. Validate return type is Result<Self, CapturedError> or Result<Self>
32        let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
33        let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
34        if ok_str != "Self" {
35            return Err(Error::new_spanned(
36                &sig.output,
37                "fn new must return Result<Self, CapturedError> or Result<Self>",
38            ));
39        }
40
41        // 3. Validate no &self receiver
42        if let Some(FnArg::Receiver(r)) = sig.inputs.first() {
43            return Err(Error::new_spanned(
44                r,
45                "fn new must be a static function (no self parameter)",
46            ));
47        }
48
49        let mut user_arg_name = None;
50
51        // 4. Validate Argument consistency against Expected Config
52        match &expected_config {
53            // Case A: Attribute said `config = MyType`
54            Some(expected_ty) => {
55                if sig.inputs.len() != 1 {
56                    return Err(Error::new_spanned(
57                        &sig.inputs,
58                        format!(
59                            "Macro attribute defined 'config = {}', so fn new must take exactly one argument: 'arg: Option<{}>'",
60                            quote!(#expected_ty),
61                            quote!(#expected_ty)
62                        ),
63                    ));
64                }
65
66                let arg = sig.inputs.first().unwrap();
67                if let FnArg::Typed(pt) = arg {
68                    // Capture argument name (don't validate it)
69                    if let Pat::Ident(pi) = &*pt.pat {
70                        user_arg_name = Some(pi.ident.clone());
71                    } else {
72                        return Err(Error::new_spanned(
73                            &pt.pat,
74                            "Expected simple identifier for argument",
75                        ));
76                    }
77
78                    // Check type is Option<T>
79                    let valid_option = if let Type::Path(tp) = &*pt.ty {
80                        if let Some(segment) = tp.path.segments.last() {
81                            if segment.ident == "Option" {
82                                if let PathArguments::AngleBracketed(args) = &segment.arguments {
83                                    if let Some(GenericArgument::Type(inner_ty)) = args.args.first()
84                                    {
85                                        // Compare inner type with expected type
86                                        let inner_str =
87                                            quote!(#inner_ty).to_string().replace(" ", "");
88                                        let expected_str =
89                                            quote!(#expected_ty).to_string().replace(" ", "");
90
91                                        if inner_str == expected_str {
92                                            Some(())
93                                        } else {
94                                            return Err(Error::new_spanned(
95                                                &pt.ty,
96                                                format!(
97                                                    "Type mismatch. Expected 'Option<{}>' based on macro attribute, found 'Option<{}>'",
98                                                    expected_str, inner_str
99                                                ),
100                                            ));
101                                        }
102                                    } else {
103                                        None
104                                    }
105                                } else {
106                                    None
107                                }
108                            } else {
109                                None
110                            }
111                        } else {
112                            None
113                        }
114                    } else {
115                        None
116                    };
117
118                    if valid_option.is_none() {
119                        return Err(Error::new_spanned(
120                            &pt.ty,
121                            format!(
122                                "Config parameter must be 'Option<{}>'",
123                                quote!(#expected_ty)
124                            ),
125                        ));
126                    }
127                }
128            }
129            // Case B: No config attribute
130            None => {
131                if !sig.inputs.is_empty() {
132                    return Err(Error::new_spanned(
133                        &sig.inputs,
134                        "No 'config' attribute specified in macro, so fn new() must take 0 arguments.",
135                    ));
136                }
137            }
138        }
139
140        Ok(Self {
141            is_async: sig.asyncness.is_some(),
142            config_type: expected_config,
143            body: f.block.clone(),
144            attrs: f.attrs.clone(),
145            arg_name: user_arg_name,
146        })
147    }
148
149    /// Generate the FFI init function
150    pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
151        let server_snake = AsSnakeCase(server.to_string()).to_string();
152        let init_name = format_ident!("p__{}__ffi_init", server_snake);
153
154        // Determine config type and closure body
155        // The safe_lifecycle functions expect Option<T> to be passed through
156        let (return_ty, closure) = match (&self.config_type, self.is_async) {
157            (Some(c), false) => (
158                quote!(::pyroduct::ffi::InitResult),
159                quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
160                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
161                        Ok(config) => config,
162                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
163                    };
164                    match #server::new(config) {
165                        Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
166                        Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
167                    }
168                }, object_id)},
169            ),
170            (None, false) => (
171                quote!(::pyroduct::ffi::InitResult),
172                quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
173                    match #server::new() {
174                        Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
175                        Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
176                    }
177                }, object_id)},
178            ),
179            (Some(c), true) => (
180                quote!(::pyroduct::ffi::FutureInitResult),
181                quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
182                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
183                        Ok(config) => config,
184                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
185                    };
186                    match #server::new(config).await {
187                        Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
188                        Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
189                    }
190                }, object_id)},
191            ),
192            (None, true) => (
193                quote!(::pyroduct::ffi::FutureInitResult),
194                quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
195                    match #server::new().await {
196                        Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
197                        Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
198                    }
199                }, object_id)},
200            ),
201        };
202
203        quote! {
204            #[unsafe(no_mangle)]
205            pub extern "C" fn #init_name(
206                config_ptr: ::pyroduct::format::PyroRefPtr,
207                object_id: u64,
208            ) -> #return_ty {
209                #closure
210            }
211        }
212    }
213
214    /// Generate the export entry for the init function
215    pub fn generate_export(&self, server: &Ident) -> TokenStream {
216        let server_snake = AsSnakeCase(server.to_string()).to_string();
217        let init_name = format_ident!("p__{}__ffi_init", server_snake);
218
219        if self.is_async {
220            quote!(::pyroduct::ffi::ClassInitFn::Async(#init_name))
221        } else {
222            quote!(::pyroduct::ffi::ClassInitFn::Sync(#init_name))
223        }
224    }
225
226    /// Generate the impl method (preserves original)
227    pub fn generate_impl_method(&self) -> TokenStream {
228        let attrs = &self.attrs;
229        let body = &self.body;
230        let async_kw = if self.is_async {
231            quote!(async)
232        } else {
233            quote!()
234        };
235
236        let params = if let Some(config) = &self.config_type {
237            // Use the user's variable name, fallback to 'config' if something went weird
238            let name = self.arg_name.clone().unwrap_or(format_ident!("config"));
239            quote!(#name: Option<#config>)
240        } else {
241            quote!()
242        };
243
244        quote! {
245            #(#attrs)*
246            pub #async_kw fn new(#params) -> Result<Self, ::pyroduct::CapturedError> #body
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use crate::fmt::assert_code_eq_token;
254
255    use super::*;
256    use quote::{format_ident, quote};
257    use syn::parse_quote;
258
259    #[test]
260    fn test_sync_server_init_fn() {
261        // 1. Simulate the config attribute passed from the macro
262        let config_type: Type = parse_quote!(GreeterConfig);
263
264        // 2. Simulate the user's implementation (Using Option, and variable name 'cfg')
265        let item: ImplItemFn = parse_quote! {
266            fn new(cfg: Option<GreeterConfig>) -> Result<Self, CapturedError> {
267                Ok(Self { count: 0 })
268            }
269        };
270
271        // 3. Parse with validation
272        let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
273
274        let server_ident = format_ident!("GreeterServer");
275        let result = init_fn.generate_ffi(&server_ident);
276
277        // Note: Closure now calls new(config) directly (config is already Option<T>)
278        let expected = quote! {
279            #[unsafe(no_mangle)]
280            pub extern "C" fn p__greeter_server__ffi_init(
281                config_ptr: ::pyroduct::format::PyroRefPtr,
282                object_id: u64,
283            ) -> ::pyroduct::ffi::InitResult {
284                ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
285                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
286                        Ok(config) => config,
287                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
288                    };
289                    match GreeterServer::new(config) {
290                        Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
291                        Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
292                    }
293                }, object_id)
294            }
295        };
296
297        assert_code_eq_token(&result, &expected);
298    }
299
300    #[test]
301    fn test_async_server_init_fn() {
302        // 1. Config attribute
303        let config_type: Type = parse_quote!(GreeterConfig);
304
305        // 2. User implementation
306        let item: ImplItemFn = parse_quote! {
307            async fn new(val: Option<GreeterConfig>) -> Result<Self, CapturedError> {
308                Ok(Self { count: 0 })
309            }
310        };
311
312        // 3. Parse
313        let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
314
315        let server_ident = format_ident!("GreeterServer");
316        let result = init_fn.generate_ffi(&server_ident);
317
318        let expected = quote! {
319            #[unsafe(no_mangle)]
320            pub extern "C" fn p__greeter_server__ffi_init(
321                config_ptr: ::pyroduct::format::PyroRefPtr,
322                object_id: u64,
323            ) -> ::pyroduct::ffi::FutureInitResult {
324                ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
325                    let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
326                        Ok(config) => config,
327                        Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
328                    };
329                    match GreeterServer::new(config).await {
330                        Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
331                        Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
332                    }
333                }, object_id)
334            }
335        };
336
337        assert_code_eq_token(&result, &expected);
338    }
339
340    #[test]
341    fn test_arbitrary_arg_name() {
342        let config_type: Type = parse_quote!(MyConfig);
343        // User uses 'settings' instead of 'config'
344        let item: ImplItemFn = parse_quote! {
345            fn new(settings: Option<MyConfig>) -> Result<Self, CapturedError> { Ok(Self) }
346        };
347
348        let init_fn =
349            InitFn::parse(Some(config_type), &item).expect("Should allow arbitrary names");
350
351        // Check if generate_impl_method preserves the name 'settings'
352        let impl_code = init_fn.generate_impl_method();
353        let impl_str = impl_code.to_string();
354        assert!(impl_str.contains("settings : Option < MyConfig >"));
355    }
356
357    #[test]
358    fn test_validation_errors() {
359        let config_type: Type = parse_quote!(MyConfig);
360
361        // Case: Not Option<T>
362        let item: ImplItemFn =
363            parse_quote! { fn new(c: MyConfig) -> Result<Self, CapturedError> { Ok(Self) } };
364        assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
365
366        // Case: Not Option<T> (Reference)
367        let item: ImplItemFn =
368            parse_quote! { fn new(c: &MyConfig) -> Result<Self, CapturedError> { Ok(Self) } };
369        assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
370
371        // Case: Option<WrongType>
372        let item: ImplItemFn = parse_quote! { fn new(c: Option<WrongConfig>) -> Result<Self, CapturedError> { Ok(Self) } };
373        assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
374    }
375}