pyro_macro/ffi/lifecycle/
init.rs1use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Error, FnArg, GenericArgument, Ident, ImplItemFn, Pat, PathArguments, ReturnType, Type};
6
7use heck::AsSnakeCase;
8
9#[derive(Debug, Clone)]
10pub struct InitFn {
11 pub is_async: bool,
12 pub config_type: Option<Type>,
13 pub body: syn::Block,
14 pub attrs: Vec<syn::Attribute>,
15 pub arg_name: Option<Ident>,
16}
17
18impl InitFn {
19 pub fn parse(expected_config: Option<Type>, f: &ImplItemFn) -> syn::Result<Self> {
21 let sig = &f.sig;
22
23 if sig.ident != "new" {
25 return Err(Error::new_spanned(
26 &sig.ident,
27 "Expected function named 'new'",
28 ));
29 }
30
31 match &sig.output {
33 ReturnType::Type(_, ty) => {
34 let ty_str = quote!(#ty).to_string().replace(" ", "");
35 if ty_str != "Self" {
36 return Err(Error::new_spanned(&sig.output, "fn new must return Self"));
37 }
38 }
39 ReturnType::Default => {
40 return Err(Error::new_spanned(&sig, "fn new must return Self"));
41 }
42 }
43
44 if let Some(FnArg::Receiver(r)) = sig.inputs.first() {
46 return Err(Error::new_spanned(
47 r,
48 "fn new must be a static function (no self parameter)",
49 ));
50 }
51
52 let mut user_arg_name = None;
53
54 match &expected_config {
56 Some(expected_ty) => {
58 if sig.inputs.len() != 1 {
59 return Err(Error::new_spanned(
60 &sig.inputs,
61 format!(
62 "Macro attribute defined 'config = {}', so fn new must take exactly one argument: 'arg: Option<{}>'",
63 quote!(#expected_ty),
64 quote!(#expected_ty)
65 ),
66 ));
67 }
68
69 let arg = sig.inputs.first().unwrap();
70 if let FnArg::Typed(pt) = arg {
71 if let Pat::Ident(pi) = &*pt.pat {
73 user_arg_name = Some(pi.ident.clone());
74 } else {
75 return Err(Error::new_spanned(
76 &pt.pat,
77 "Expected simple identifier for argument",
78 ));
79 }
80
81 let valid_option = if let Type::Path(tp) = &*pt.ty {
83 if let Some(segment) = tp.path.segments.last() {
84 if segment.ident == "Option" {
85 if let PathArguments::AngleBracketed(args) = &segment.arguments {
86 if let Some(GenericArgument::Type(inner_ty)) = args.args.first()
87 {
88 let inner_str =
90 quote!(#inner_ty).to_string().replace(" ", "");
91 let expected_str =
92 quote!(#expected_ty).to_string().replace(" ", "");
93
94 if inner_str == expected_str {
95 Some(())
96 } else {
97 return Err(Error::new_spanned(
98 &pt.ty,
99 format!(
100 "Type mismatch. Expected 'Option<{}>' based on macro attribute, found 'Option<{}>'",
101 expected_str, inner_str
102 ),
103 ));
104 }
105 } else {
106 None
107 }
108 } else {
109 None
110 }
111 } else {
112 None
113 }
114 } else {
115 None
116 }
117 } else {
118 None
119 };
120
121 if valid_option.is_none() {
122 return Err(Error::new_spanned(
123 &pt.ty,
124 format!(
125 "Config parameter must be 'Option<{}>'",
126 quote!(#expected_ty)
127 ),
128 ));
129 }
130 }
131 }
132 None => {
134 if !sig.inputs.is_empty() {
135 return Err(Error::new_spanned(
136 &sig.inputs,
137 "No 'config' attribute specified in macro, so fn new() must take 0 arguments.",
138 ));
139 }
140 }
141 }
142
143 Ok(Self {
144 is_async: sig.asyncness.is_some(),
145 config_type: expected_config,
146 body: f.block.clone(),
147 attrs: f.attrs.clone(),
148 arg_name: user_arg_name,
149 })
150 }
151
152 pub fn generate_ffi(&self, server: &Ident) -> TokenStream {
154 let server_snake = AsSnakeCase(server.to_string()).to_string();
155 let init_name = format_ident!("p__{}__ffi_init", server_snake);
156
157 let (return_ty, closure) = match (&self.config_type, self.is_async) {
160 (Some(c), false) => (
161 quote!(::pyroduct::ffi::InitResult),
162 quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
163 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
164 Ok(config) => config,
165 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
166 };
167 ::pyroduct::ffi::InitResult::init_ok(#server::new(config), object_id)
168 }, object_id)},
169 ),
170 (None, false) => (
171 quote!(::pyroduct::ffi::InitResult),
172 quote! {::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
173 ::pyroduct::ffi::InitResult::init_ok(#server::new(), object_id)
174 }, object_id)},
175 ),
176 (Some(c), true) => (
177 quote!(::pyroduct::ffi::FutureInitResult),
178 quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
179 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<#c>(config_ptr) {
180 Ok(config) => config,
181 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
182 };
183 ::pyroduct::ffi::InitResult::init_ok(#server::new(config).await, object_id)
184 }, object_id)},
185 ),
186 (None, true) => (
187 quote!(::pyroduct::ffi::FutureInitResult),
188 quote! { ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
189 ::pyroduct::ffi::InitResult::init_ok(#server::new().await, object_id)
190 }, object_id)},
191 ),
192 };
193
194 quote! {
195 #[unsafe(no_mangle)]
196 pub extern "C" fn #init_name(
197 config_ptr: ::pyroduct::format::PyroRefPtr,
198 object_id: u64,
199 ) -> #return_ty {
200 #closure
201 }
202 }
203 }
204
205 pub fn generate_export(&self, server: &Ident) -> TokenStream {
207 let server_snake = AsSnakeCase(server.to_string()).to_string();
208 let init_name = format_ident!("p__{}__ffi_init", server_snake);
209
210 if self.is_async {
211 quote!(::pyroduct::ffi::ClassInitFn::Async(#init_name))
212 } else {
213 quote!(::pyroduct::ffi::ClassInitFn::Sync(#init_name))
214 }
215 }
216
217 pub fn generate_impl_method(&self) -> TokenStream {
219 let attrs = &self.attrs;
220 let body = &self.body;
221 let async_kw = if self.is_async {
222 quote!(async)
223 } else {
224 quote!()
225 };
226
227 let params = if let Some(config) = &self.config_type {
228 let name = self.arg_name.clone().unwrap_or(format_ident!("config"));
230 quote!(#name: Option<#config>)
231 } else {
232 quote!()
233 };
234
235 quote! {
236 #(#attrs)*
237 pub #async_kw fn new(#params) -> Self #body
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use crate::fmt::assert_code_eq_token;
245
246 use super::*;
247 use quote::{format_ident, quote};
248 use syn::parse_quote;
249
250 #[test]
251 fn test_sync_server_init_fn() {
252 let config_type: Type = parse_quote!(GreeterConfig);
254
255 let item: ImplItemFn = parse_quote! {
257 fn new(cfg: Option<GreeterConfig>) -> Self {
258 Self { count: 0 }
259 }
260 };
261
262 let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
264
265 let server_ident = format_ident!("GreeterServer");
266 let result = init_fn.generate_ffi(&server_ident);
267
268 let expected = quote! {
270 #[unsafe(no_mangle)]
271 pub extern "C" fn p__greeter_server__ffi_init(
272 config_ptr: ::pyroduct::format::PyroRefPtr,
273 object_id: u64,
274 ) -> ::pyroduct::ffi::InitResult {
275 ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_init(|object_id| {
276 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
277 Ok(config) => config,
278 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
279 };
280 ::pyroduct::ffi::InitResult::init_ok(GreeterServer::new(config), object_id)
281 }, object_id)
282 }
283 };
284
285 assert_code_eq_token(&result, &expected);
286 }
287
288 #[test]
289 fn test_async_server_init_fn() {
290 let config_type: Type = parse_quote!(GreeterConfig);
292
293 let item: ImplItemFn = parse_quote! {
295 async fn new(val: Option<GreeterConfig>) -> Self {
296 Self { count: 0 }
297 }
298 };
299
300 let init_fn = InitFn::parse(Some(config_type), &item).expect("Parse failed");
302
303 let server_ident = format_ident!("GreeterServer");
304 let result = init_fn.generate_ffi(&server_ident);
305
306 let expected = quote! {
307 #[unsafe(no_mangle)]
308 pub extern "C" fn p__greeter_server__ffi_init(
309 config_ptr: ::pyroduct::format::PyroRefPtr,
310 object_id: u64,
311 ) -> ::pyroduct::ffi::FutureInitResult {
312 ::pyroduct::ffi::guest::safe_lifecycle::execute_safe_async_init(|object_id| async move {
313 let config = match ::pyroduct::ffi::guest::safe_lifecycle::deserialize_config::<GreeterConfig>(config_ptr) {
314 Ok(config) => config,
315 Err(err) => return ::pyroduct::ffi::InitResult::init_err(err, object_id),
316 };
317 ::pyroduct::ffi::InitResult::init_ok(GreeterServer::new(config).await, object_id)
318 }, object_id)
319 }
320 };
321
322 assert_code_eq_token(&result, &expected);
323 }
324
325 #[test]
326 fn test_arbitrary_arg_name() {
327 let config_type: Type = parse_quote!(MyConfig);
328 let item: ImplItemFn = parse_quote! {
330 fn new(settings: Option<MyConfig>) -> Self { Self }
331 };
332
333 let init_fn =
334 InitFn::parse(Some(config_type), &item).expect("Should allow arbitrary names");
335
336 let impl_code = init_fn.generate_impl_method();
338 let impl_str = impl_code.to_string();
339 assert!(impl_str.contains("settings : Option < MyConfig >"));
340 }
341
342 #[test]
343 fn test_validation_errors() {
344 let config_type: Type = parse_quote!(MyConfig);
345
346 let item: ImplItemFn = parse_quote! { fn new(c: MyConfig) -> Self { Self } };
348 assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
349
350 let item: ImplItemFn = parse_quote! { fn new(c: &MyConfig) -> Self { Self } };
352 assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
353
354 let item: ImplItemFn = parse_quote! { fn new(c: Option<WrongConfig>) -> Self { Self } };
356 assert!(InitFn::parse(Some(config_type.clone()), &item).is_err());
357 }
358}