pyro_macro/ffi/lifecycle/
init.rs1use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Error, FnArg, GenericArgument, Ident, ImplItemFn, Pat, PathArguments, Type};
6
7use heck::AsSnakeCase;
8
9#[derive(Debug, Clone)]
10pub struct InitFn {
11 pub is_async: bool,
12 pub config_type: Option<Type>,
13 pub body: syn::Block,
14 pub attrs: Vec<syn::Attribute>,
15 pub arg_name: Option<Ident>,
16}
17
18impl InitFn {
19 pub fn parse(expected_config: Option<Type>, f: &ImplItemFn) -> syn::Result<Self> {
21 let sig = &f.sig;
22
23 if sig.ident != "new" {
25 return Err(Error::new_spanned(
26 &sig.ident,
27 "Expected function named 'new'",
28 ));
29 }
30
31 let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
33 let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
34 if ok_str != "Self" {
35 return Err(Error::new_spanned(
36 &sig.output,
37 "fn new must return Result<Self, CapturedError> or Result<Self>",
38 ));
39 }
40
41 if let Some(FnArg::Receiver(r)) = sig.inputs.first() {
43 return Err(Error::new_spanned(
44 r,
45 "fn new must be a static function (no self parameter)",
46 ));
47 }
48
49 let mut user_arg_name = None;
50
51 match &expected_config {
53 Some(expected_ty) => {
55 if sig.inputs.len() != 1 {
56 return Err(Error::new_spanned(
57 &sig.inputs,
58 format!(
59 "Macro attribute defined 'config = {}', so fn new must take exactly one argument: 'arg: Option<{}>'",
60 quote!(#expected_ty),
61 quote!(#expected_ty)
62 ),
63 ));
64 }
65
66 let arg = sig.inputs.first().unwrap();
67 if let FnArg::Typed(pt) = arg {
68 if let Pat::Ident(pi) = &*pt.pat {
70 user_arg_name = Some(pi.ident.clone());
71 } else {
72 return Err(Error::new_spanned(
73 &pt.pat,
74 "Expected simple identifier for argument",
75 ));
76 }
77
78 let valid_option = if let Type::Path(tp) = &*pt.ty {
80 if let Some(segment) = tp.path.segments.last() {
81 if segment.ident == "Option" {
82 if let PathArguments::AngleBracketed(args) = &segment.arguments {
83 if let Some(GenericArgument::Type(inner_ty)) = args.args.first()
84 {
85 let inner_str =
87 quote!(#inner_ty).to_string().replace(" ", "");
88 let expected_str =
89 quote!(#expected_ty).to_string().replace(" ", "");
90
91 if inner_str == expected_str {
92 Some(())
93 } else {
94 return Err(Error::new_spanned(
95 &pt.ty,
96 format!(
97 "Type mismatch. Expected 'Option<{}>' based on macro attribute, found 'Option<{}>'",
98 expected_str, inner_str
99 ),
100 ));
101 }
102 } else {
103 None
104 }
105 } else {
106 None
107 }
108 } else {
109 None
110 }
111 } else {
112 None
113 }
114 } else {
115 None
116 };
117
118 if valid_option.is_none() {
119 return Err(Error::new_spanned(
120 &pt.ty,
121 format!(
122 "Config parameter must be 'Option<{}>'",
123 quote!(#expected_ty)
124 ),
125 ));
126 }
127 }
128 }
129 None => {
131 if !sig.inputs.is_empty() {
132 return Err(Error::new_spanned(
133 &sig.inputs,
134 "No 'config' attribute specified in macro, so fn new() must take 0 arguments.",
135 ));
136 }
137 }
138 }
139
140 Ok(Self {
141 is_async: sig.asyncness.is_some(),
142 config_type: expected_config,
143 body: f.block.clone(),
144 attrs: f.attrs.clone(),
145 arg_name: user_arg_name,
146 })
147 }
148
149 pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
151 let server_snake = AsSnakeCase(server.to_string()).to_string();
152 let init_name = format_ident!("p__{}__ffi_init", server_snake);
153
154 let (return_ty, closure) = match (&self.config_type, self.is_async) {
157 (Some(c), false) => (
158 quote!(::pyroduct::ffi::InitResult),
159 quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
160 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
161 Ok(config) => config,
162 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
163 };
164 match #server::new(config) {
165 Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
166 Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
167 }
168 }, object_id)},
169 ),
170 (None, false) => (
171 quote!(::pyroduct::ffi::InitResult),
172 quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
173 match #server::new() {
174 Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
175 Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
176 }
177 }, object_id)},
178 ),
179 (Some(c), true) => (
180 quote!(::pyroduct::ffi::FutureInitResult),
181 quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
182 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
183 Ok(config) => config,
184 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
185 };
186 match #server::new(config).await {
187 Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
188 Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
189 }
190 }, object_id)},
191 ),
192 (None, true) => (
193 quote!(::pyroduct::ffi::FutureInitResult),
194 quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
195 match #server::new().await {
196 Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
197 Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
198 }
199 }, object_id)},
200 ),
201 };
202
203 quote! {
204 #[unsafe(no_mangle)]
205 pub extern "C" fn #init_name(
206 config_ptr: ::pyroduct::format::PyroRefPtr,
207 object_id: u64,
208 ) -> #return_ty {
209 #closure
210 }
211 }
212 }
213
214 pub fn generate_export(&self, server: &Ident) -> TokenStream {
216 let server_snake = AsSnakeCase(server.to_string()).to_string();
217 let init_name = format_ident!("p__{}__ffi_init", server_snake);
218
219 if self.is_async {
220 quote!(::pyroduct::ffi::ClassInitFn::Async(#init_name))
221 } else {
222 quote!(::pyroduct::ffi::ClassInitFn::Sync(#init_name))
223 }
224 }
225
226 pub fn generate_impl_method(&self) -> TokenStream {
228 let attrs = &self.attrs;
229 let body = &self.body;
230 let async_kw = if self.is_async {
231 quote!(async)
232 } else {
233 quote!()
234 };
235
236 let params = if let Some(config) = &self.config_type {
237 let name = self.arg_name.clone().unwrap_or(format_ident!("config"));
239 quote!(#name: Option<#config>)
240 } else {
241 quote!()
242 };
243
244 quote! {
245 #(#attrs)*
246 pub #async_kw fn new(#params) -> Result<Self, ::pyroduct::CapturedError> #body
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use crate::fmt::assert_code_eq_token;
254
255 use super::*;
256 use quote::{format_ident, quote};
257 use syn::parse_quote;
258
259 #[test]
260 fn test_sync_server_init_fn() {
261 let config_type: Type = parse_quote!(GreeterConfig);
263
264 let item: ImplItemFn = parse_quote! {
266 fn new(cfg: Option<GreeterConfig>) -> Result<Self, CapturedError> {
267 Ok(Self { count: 0 })
268 }
269 };
270
271 let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
273
274 let server_ident = format_ident!("GreeterServer");
275 let result = init_fn.generate_ffi(&server_ident);
276
277 let expected = quote! {
279 #[unsafe(no_mangle)]
280 pub extern "C" fn p__greeter_server__ffi_init(
281 config_ptr: ::pyroduct::format::PyroRefPtr,
282 object_id: u64,
283 ) -> ::pyroduct::ffi::InitResult {
284 ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
285 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
286 Ok(config) => config,
287 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
288 };
289 match GreeterServer::new(config) {
290 Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
291 Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
292 }
293 }, object_id)
294 }
295 };
296
297 assert_code_eq_token(&result, &expected);
298 }
299
300 #[test]
301 fn test_async_server_init_fn() {
302 let config_type: Type = parse_quote!(GreeterConfig);
304
305 let item: ImplItemFn = parse_quote! {
307 async fn new(val: Option<GreeterConfig>) -> Result<Self, CapturedError> {
308 Ok(Self { count: 0 })
309 }
310 };
311
312 let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
314
315 let server_ident = format_ident!("GreeterServer");
316 let result = init_fn.generate_ffi(&server_ident);
317
318 let expected = quote! {
319 #[unsafe(no_mangle)]
320 pub extern "C" fn p__greeter_server__ffi_init(
321 config_ptr: ::pyroduct::format::PyroRefPtr,
322 object_id: u64,
323 ) -> ::pyroduct::ffi::FutureInitResult {
324 ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
325 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
326 Ok(config) => config,
327 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
328 };
329 match GreeterServer::new(config).await {
330 Ok(state) => ::pyroduct::ffi::InitResult::init_ok(state, object_id),
331 Err(err) => ::pyroduct::ffi::InitResult::init_err(::pyroduct::PyroError::CodePanic(err), object_id),
332 }
333 }, object_id)
334 }
335 };
336
337 assert_code_eq_token(&result, &expected);
338 }
339
340 #[test]
341 fn test_arbitrary_arg_name() {
342 let config_type: Type = parse_quote!(MyConfig);
343 let item: ImplItemFn = parse_quote! {
345 fn new(settings: Option<MyConfig>) -> Result<Self, CapturedError> { Ok(Self) }
346 };
347
348 let init_fn =
349 InitFn::parse(Some(config_type), &item).expect("Should allow arbitrary names");
350
351 let impl_code = init_fn.generate_impl_method();
353 let impl_str = impl_code.to_string();
354 assert!(impl_str.contains("settings : Option < MyConfig >"));
355 }
356
357 #[test]
358 fn test_validation_errors() {
359 let config_type: Type = parse_quote!(MyConfig);
360
361 let item: ImplItemFn =
363 parse_quote! { fn new(c: MyConfig) -> Result<Self, CapturedError> { Ok(Self) } };
364 assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
365
366 let item: ImplItemFn =
368 parse_quote! { fn new(c: &MyConfig) -> Result<Self, CapturedError> { Ok(Self) } };
369 assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
370
371 let item: ImplItemFn = parse_quote! { fn new(c: Option<WrongConfig>) -> Result<Self, CapturedError> { Ok(Self) } };
373 assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
374 }
375}