Skip to main content

pyro_macro/ffi/
capability.rs

1//! Transforms:
2//! ```ignore
3//! #[pyroduct::capability]
4//! impl StatefulServer {
5//!     type Config = MyConfig;
6//!     type Client = SimpleClient;
7//!     type Error = MyError; // Optional
8//!
9//!     fn new(config: &MyConfig) -> Self { Self }
10//!     fn reset(&mut self) {}
11//!     fn register(&self, _client: &SimpleClient) {}
12//!     fn call(&self, _client: &SimpleClient) -> f32 { 42.0 }
13//! }
14//! ```
15
16use std::rc::Rc;
17
18use heck::AsSnakeCase;
19use proc_macro2::TokenStream;
20use quote::{format_ident, quote};
21use syn::{Error, Ident, ImplItem, ItemImpl, Type, parse_quote};
22
23use crate::{
24    ffi::{
25        lifecycle::{InitFn, NewClientFn, ResetFn},
26        methods::ImplMethod,
27        paths::CapabilityIdent,
28    },
29    utils::extract_ident_from_type,
30};
31
32/// Parsed capability from an impl block
33#[derive(Debug)]
34pub struct CapabilityImpl {
35    // Identity storage
36    pub ident: Rc<CapabilityIdent>,
37
38    // Lifecycle
39    pub init_fn: InitFn,
40    pub reset_fn: ResetFn,
41    pub register_fn: NewClientFn,
42
43    // Methods
44    pub methods: Vec<ImplMethod>,
45
46    // Other items (consts, etc.) - excluding type aliases
47    pub other_items: Vec<ImplItem>,
48    pub attrs: Vec<syn::Attribute>,
49}
50
51impl CapabilityImpl {
52    pub fn new(
53        input: ItemImpl,
54        required_docs: bool,
55        cap_name: &str,
56        cap_semver: &str,
57    ) -> syn::Result<Self> {
58        // 1. Extract state/server type name
59        let state_tn =
60            match &*input.self_ty {
61                Type::Path(tp) => tp.path.get_ident().cloned().ok_or_else(|| {
62                    Error::new_spanned(&input.self_ty, "Expected simple type name")
63                })?,
64                _ => {
65                    return Err(Error::new_spanned(
66                        &input.self_ty,
67                        "Expected simple type name",
68                    ));
69                }
70            };
71
72        // 2. Ensure no trait impl
73        if input.trait_.is_some() {
74            return Err(Error::new_spanned(
75                &input,
76                "#[capability] cannot be used on trait implementations",
77            ));
78        }
79        let attrs = input.attrs.clone();
80
81        // 3. First pass: collect types
82        let mut client_tn: Option<Ident> = None;
83        let mut config_tn: Option<Ident> = None;
84        let mut error_tn: Option<Type> = None;
85
86        let mut init_fn: Option<InitFn> = None;
87        let mut reset_fn: Option<ResetFn> = None;
88        let mut register_fn: Option<NewClientFn> = None;
89        let mut method_fns = Vec::new();
90        let mut other_items = Vec::new();
91
92        for item in &input.items {
93            match item {
94                ImplItem::Type(ty) => {
95                    if ty.ident == "Client" {
96                        client_tn = Some(extract_ident_from_type(&ty.ty)?);
97                    } else if ty.ident == "Config" {
98                        config_tn = Some(extract_ident_from_type(&ty.ty)?);
99                    } else if ty.ident == "Error" {
100                        error_tn = Some(ty.ty.clone());
101                    }
102                    // Note: We intentionally do NOT add type aliases to other_items
103                    // because inherent associated types are unstable in Rust
104                }
105                _ => {}
106            }
107        }
108
109        let client_tn = client_tn
110            .ok_or_else(|| Error::new_spanned(&state_tn, "Missing `type Client = ...;`"))?;
111
112        // Build identifiers
113        let ident = Rc::new(CapabilityIdent {
114            pkg_name: cap_name.to_string(),
115            pkg_version: cap_semver.to_string(),
116            state_tn,
117            client_tn,
118            config_tn,
119            error_tn,
120        });
121
122        for item in &input.items {
123            match item {
124                ImplItem::Fn(f) => {
125                    let name = f.sig.ident.to_string();
126                    match name.as_str() {
127                        "new" => {
128                            let conf = ident.config_tn.clone().map(|t| parse_quote! { #t });
129                            init_fn = Some(InitFn::parse(conf, f)?);
130                        }
131                        "reset" => {
132                            reset_fn = Some(ResetFn::parse(f)?);
133                        }
134                        "register" => {
135                            register_fn = Some(NewClientFn::parse(f, &ident)?);
136                        }
137                        _ => {
138                            // Defer method parsing until we have the class
139                            method_fns.push(f.clone());
140                        }
141                    }
142                }
143                ImplItem::Type(_) => {
144                    // Skip type aliases - they were already processed above
145                    // and we don't want them in the output impl block
146                }
147                other => other_items.push(other.clone()),
148            }
149        }
150
151        let register_fn = register_fn.ok_or_else(|| {
152            Error::new_spanned(
153                &ident.state_tn,
154                "Missing `fn register(&self, client: &Client)`",
155            )
156        })?;
157        let init_fn = init_fn.ok_or_else(|| {
158            Error::new_spanned(
159                &ident.state_tn,
160                "Missing `fn new() -> Self` or `fn new(config: &Config) -> Self`",
161            )
162        })?;
163        let reset_fn = reset_fn
164            .ok_or_else(|| Error::new_spanned(&ident.state_tn, "Missing `fn reset(&mut self)`"))?;
165
166        // 5. Second pass: parse methods with class context
167        let methods: Result<Vec<_>, _> = method_fns
168            .iter()
169            .map(|f| ImplMethod::parse(f, &ident, required_docs))
170            .collect();
171        let methods = methods?;
172
173        Ok(Self {
174            ident,
175            init_fn,
176            reset_fn,
177            register_fn,
178            methods,
179            other_items,
180            attrs,
181        })
182    }
183
184    /// Generate all output code
185    pub fn expand_capability(&self) -> TokenStream {
186        let server_impl = self.generate_server_impl();
187        let lifecycle_ffi = self.generate_lifecycle_ffi();
188        let method_ffis = self.generate_method_ffis();
189        let export_table = self.generate_export_table();
190
191        quote! {
192            #server_impl
193            #lifecycle_ffi
194            #method_ffis
195            #export_table
196        }
197    }
198
199    /// Generate all output code
200    pub fn expand_module(&self) -> TokenStream {
201        let wasm_imports = self.generate_wasm_imports();
202        let client_impl = self.generate_client_impl();
203
204        quote! {
205            #client_impl
206            #wasm_imports
207        }
208    }
209
210    fn generate_server_impl(&self) -> TokenStream {
211        let server = &self.ident.state_tn;
212        let init_method = self.init_fn.generate_impl_method();
213        let reset_method = self.reset_fn.generate_impl_method();
214        let new_client_method = self.register_fn.generate_impl_method();
215        let other_items = &self.other_items;
216
217        let methods: Vec<_> = self
218            .methods
219            .iter()
220            .map(|m| m.generate_server_method())
221            .collect();
222
223        quote! {
224            impl #server {
225                #init_method
226                #reset_method
227                #new_client_method
228                #(#other_items)*
229                #(#methods)*
230            }
231        }
232    }
233
234    fn generate_client_impl(&self) -> TokenStream {
235        let client = &self.ident.client_tn;
236        let module = format_ident!("wasm");
237
238        // 1. Generate the Register Method (on the user struct)
239        let client_impl = self.register_fn.generate_client_impl(Some(&module));
240
241        // 2. Generate the trait with method signatures
242        let trait_name = format_ident!("{}Methods", client);
243
244        let trait_methods: Vec<_> = self
245            .methods
246            .iter()
247            .map(|m| {
248                let name = &m.name.0;
249                let output = &m.output.to_return_type();
250                let args: Vec<_> = m.inputs.iter().map(|(n, t)| quote!(#n: #t)).collect();
251                let docs = m.doc_attrs();
252
253                quote! {
254                    #(#docs)*
255                    fn #name(&self, #(#args),*) #output;
256                }
257            })
258            .collect();
259
260        let trait_def = quote! {
261            pub trait #trait_name {
262                #(#trait_methods)*
263            }
264        };
265
266        // 3. Generate the trait implementation for Client<T>
267        let method_impls: Vec<_> = self
268            .methods
269            .iter()
270            .map(|m| m.generate_client_method(Some(&module)))
271            .collect();
272
273        let trait_impl = quote! {
274            impl #trait_name for ::pyroduct::wasm::Client<#client> {
275                #(#method_impls)*
276            }
277        };
278
279        quote! {
280            #client_impl
281            #trait_def
282            #trait_impl
283        }
284    }
285
286    fn generate_lifecycle_ffi(&self) -> TokenStream {
287        let server = &self.ident.state_tn;
288
289        let init_ffi = self.init_fn.generate_ffi(server);
290        let reset_ffi = self.reset_fn.generate_ffi(server);
291        let register_ffi = self.register_fn.generate_capability_ffi();
292
293        quote! {
294            #init_ffi
295            #reset_ffi
296            #register_ffi
297        }
298    }
299
300    fn generate_method_ffis(&self) -> TokenStream {
301        let method_ffis: Vec<_> = self
302            .methods
303            .iter()
304            .map(|m| m.generate_server_ffi())
305            .collect();
306
307        quote! {
308            #(#method_ffis)*
309        }
310    }
311
312    fn generate_export_table(&self) -> TokenStream {
313        let cap_id = self.ident.cap_id();
314
315        let server = &self.ident.state_tn;
316        let server_snake = AsSnakeCase(server.to_string()).to_string();
317        let server_upper = server_snake.to_uppercase();
318
319        let class_name_static = format_ident!("p__{}", server_upper);
320        let class_name_string = format!("p__{}", server_snake);
321
322        let static_strs: Vec<_> = self
323            .methods
324            .iter()
325            .map(|m| {
326                let trace_name = self.ident.wasm_name(&m.name).to_string();
327                let static_name = self.ident.trace_name_static(&m.name);
328                quote! { const #static_name: &'static str = #trace_name; }
329            })
330            .collect();
331
332        let exports: Vec<_> = self
333            .methods
334            .iter()
335            .map(|ffi| ffi.generate_vtable_entry())
336            .collect();
337
338        let num_exports = exports.len();
339        let exports_array_name = format_ident!("{}__METHODS", class_name_static);
340
341        let init_export = self.init_fn.generate_export(server);
342        let reset_export = self.reset_fn.generate_export(server);
343        let register_export = self.register_fn.generate_export();
344
345        let capability_manifest_fn = quote! {
346            #[unsafe(no_mangle)]
347            pub extern "C" fn pyro_capability_manifest(
348                id: i64,
349                log_callback: ::pyroduct::ffi::LogCallback,
350            ) -> ::pyroduct::ffi::ClassExport {
351                ::pyroduct::ffi::guest::logger::init_logging(id, log_callback);
352
353                ::pyroduct::ffi::ClassExport {
354                    name: CAPABILITY_NAME_VERSION.as_ptr(),
355                    name_len: CAPABILITY_NAME_VERSION.len(),
356                    len: #exports_array_name.len(),
357                    ptr: #exports_array_name.as_ptr() as *mut _,
358                    init: #init_export,
359                    reset: #reset_export,
360                    register: #register_export,
361                }
362            }
363        };
364
365        quote! {
366            const CAPABILITY_NAME_VERSION: &'static str = #cap_id;
367            const #class_name_static: &'static str = #class_name_string;
368            #(#static_strs)*
369
370            const #exports_array_name: [::pyroduct::ffi::MethodExport; #num_exports] = [
371                #(#exports),*
372            ];
373
374            #capability_manifest_fn
375        }
376    }
377
378    fn generate_wasm_imports(&self) -> TokenStream {
379        let cap_id = self.ident.cap_id();
380        let new_client_decl = self.register_fn.generate_client_wasm();
381
382        let method_decls: Vec<_> = self
383            .methods
384            .iter()
385            .map(|m| m.generate_client_wasm())
386            .collect();
387
388        quote! {
389            mod wasm {
390                use super::*;
391                #[link(wasm_import_module = #cap_id)]
392                unsafe extern "C" {
393                    #new_client_decl
394                    #(#method_decls)*
395                }
396            }
397        }
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use syn::parse2;
405
406    #[test]
407    fn test_basic_capability_impl() {
408        let code = quote! {
409            impl StatefulServer {
410                type Client = SimpleClient;
411
412                fn new() -> Self { Self }
413                fn reset(&mut self) {}
414                fn register(&self, _client: &SimpleClient) {}
415                fn call(&self, _client: &SimpleClient) -> f32 { 42.0 }
416            }
417        };
418
419        let input: ItemImpl = parse2(code).unwrap();
420        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
421
422        assert_eq!(cap.ident.state_tn.to_string(), "StatefulServer");
423        assert_eq!(cap.ident.client_tn.to_string(), "SimpleClient");
424        assert_eq!(cap.methods.len(), 1);
425        assert_eq!(cap.methods[0].name.to_string(), "call");
426        assert!(!cap.init_fn.is_async);
427        assert!(cap.init_fn.config_type.is_none());
428        assert!(cap.ident.config_tn.is_none());
429    }
430
431    #[test]
432    fn test_with_config() {
433        let code = quote! {
434            impl StatefulServer {
435                type Config = MyConfig;
436                type Client = SimpleClient;
437
438                fn new(config: Option<MyConfig>) -> Self { Self }
439                fn reset(&mut self) {}
440                fn register(&self, client: &SimpleClient) {}
441            }
442        };
443
444        let input: ItemImpl = parse2(code).unwrap();
445        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
446
447        assert!(cap.init_fn.config_type.is_some());
448        assert!(cap.ident.config_tn.is_some());
449
450        let cfg = cap.ident.config_tn.as_ref().unwrap();
451        assert_eq!(quote!(#cfg).to_string(), "MyConfig");
452    }
453
454    #[test]
455    fn test_config_mismatch() {
456        let code = quote! {
457            impl StatefulServer {
458                type Config = MyConfig;
459                type Client = SimpleClient;
460
461                fn new(config: Option<OtherConfig>) -> Self { Self }
462                fn reset(&mut self) {}
463                fn register(&self, client: &SimpleClient) {}
464            }
465        };
466
467        let input: ItemImpl = parse2(code).unwrap();
468        let err = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap_err();
469        println!("{}", err);
470        assert!(err.to_string().contains("Type mismatch. Expected 'Option<MyConfig>' based on macro attribute, found 'Option<OtherConfig>'"));
471    }
472
473    #[test]
474    fn test_async_lifecycle() {
475        let code = quote! {
476            impl StatefulServer {
477                type Client = SimpleClient;
478
479                async fn new() -> Self { Self }
480                async fn reset(&mut self) {}
481                fn register(&self, client: &SimpleClient) {}
482            }
483        };
484
485        let input: ItemImpl = parse2(code).unwrap();
486        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
487
488        assert!(cap.init_fn.is_async);
489        assert!(cap.reset_fn.is_async);
490    }
491
492    #[test]
493    fn test_with_error_type() {
494        let code = quote! {
495            impl StatefulServer {
496                type Client = SimpleClient;
497                type Error = MyError;
498
499                fn new() -> Self { Self }
500                fn reset(&mut self) {}
501                fn register(&self, client: &SimpleClient) -> Result<(), MyError> { Ok(()) }
502                fn fallible(&self, _client: &SimpleClient) -> Result<u32, MyError> { Ok(42) }
503            }
504        };
505
506        let input: ItemImpl = parse2(code).unwrap();
507        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
508
509        assert!(cap.ident.error_tn.is_some());
510        assert!(cap.register_fn.error_type.is_some());
511        assert_eq!(cap.methods.len(), 1);
512    }
513
514    #[test]
515    fn test_generate_export_table() {
516        let code = quote! {
517            impl TestServer {
518                type Client = TestClient;
519
520                fn new() -> Self { Self }
521                fn reset(&mut self) {}
522                fn register(&self, client: &TestClient) {}
523                fn get_value(&self, client: &TestClient) -> u32 { 0 }
524            }
525        };
526
527        let input: ItemImpl = parse2(code).unwrap();
528        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
529
530        let output = cap.generate_export_table();
531
532        let expected = quote! {
533            const CAPABILITY_NAME_VERSION: &'static str = "cap_name";
534            const p__TEST_SERVER: &'static str = "p__test_server";
535            const p__TEST_SERVER__GET_VALUE: &'static str = "p__test_server__get_value__wasm";
536
537            const p__TEST_SERVER__METHODS: [::pyroduct::ffi::MethodExport; 1usize] = [
538                ::pyroduct::ffi::MethodExport {
539                    name: p__TEST_SERVER__GET_VALUE.as_ptr(),
540                    name_len: p__TEST_SERVER__GET_VALUE.len(),
541                    func: ::pyroduct::ffi::Function::Sync(p__test_server__get_value__ffi),
542                }
543            ];
544
545            #[unsafe(no_mangle)]
546            pub extern "C" fn pyro_capability_manifest(
547                id: i64,
548                log_callback: ::pyroduct::ffi::LogCallback,
549            ) -> ::pyroduct::ffi::ClassExport {
550                ::pyroduct::ffi::guest::logger::init_logging(id, log_callback);
551
552                ::pyroduct::ffi::ClassExport {
553                    name: CAPABILITY_NAME_VERSION.as_ptr(),
554                    name_len: CAPABILITY_NAME_VERSION.len(),
555                    len: p__TEST_SERVER__METHODS.len(),
556                    ptr: p__TEST_SERVER__METHODS.as_ptr() as *mut _,
557                    init: ::pyroduct::ffi::ClassInitFn::Sync(p__test_server__ffi_init),
558                    reset: ::pyroduct::ffi::ClassResetFn::Sync(p__test_server__ffi_reset),
559                    register: ::pyroduct::ffi::ClientRegisterFn::Sync(p__test_server__register__ffi),
560                }
561            }
562        };
563
564        crate::fmt::assert_code_eq_token(&output, &expected);
565    }
566
567    #[test]
568    fn test_generate_client_impl_integration() {
569        // 1. Define Input
570        let code = quote! {
571            impl MyState {
572                type Client = MyClient;
573                type Config = MyConfig;
574
575                fn new(config: Option<MyConfig>) -> Self { Self }
576                fn reset(&mut self) {}
577                fn register(&self, client: &MyClient) {}
578                fn get_info(&self, client: &MyClient) -> u32 { 0 }
579                fn get_other_info(&self, client: &MyClient, data: f32) -> u32 { 0 }
580            }
581        };
582
583        // 2. Parse
584        let input: ItemImpl = parse2(code).unwrap();
585        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
586
587        // 3. Generate Output
588        let output = cap.generate_client_impl();
589
590        // 4. Define Expected Output
591        // Checks constructor generation (using rkyv for config) and normal method generation.
592        let expected = quote! {
593            impl MyClient {
594                pub fn register(self) -> ::pyroduct::wasm::Client<Self> {
595                    ::pyroduct::wasm::Client::<Self>::__register(self, |ptr| unsafe { wasm::register(ptr) })
596                }
597            }
598            pub trait MyClientMethods {
599                fn get_info(&self) -> u32;
600                fn get_other_info(&self, data: f32) -> u32;
601            }
602            impl MyClientMethods for ::pyroduct::wasm::Client<MyClient> {
603                fn get_info(&self) -> u32 {
604                    self.__call_from_wasm::<(), u32, _>(None,
605                        |client_state_ptr: *const u8,
606                            input_ptr: *const u8| {
607                            unsafe {
608                                wasm::p__my_state__get_info__wasm(
609                                    client_state_ptr,
610                                    input_ptr,
611                                )
612                            }
613                        })
614                }
615
616                fn get_other_info(&self, data: f32) -> u32 {
617                    self.__call_from_wasm::<
618                        f32,
619                        u32,
620                        _,
621                    >(Some(&data),
622                        |client_state_ptr: *const u8,
623                            input_ptr: *const u8| {
624                            unsafe {
625                                wasm::p__my_state__get_other_info__wasm(
626                                    client_state_ptr,
627                                    input_ptr,
628                                )
629                            }
630                        },
631                    )
632                }
633            }
634        };
635
636        crate::fmt::assert_code_eq_token(&output, &expected);
637    }
638
639    #[test]
640    fn test_generate_client_impl_with_error_and_input_structs() {
641        // 1. Define Input: Complex case with Errors and Arguments
642        let code = quote! {
643            impl AdvancedStruct {
644                type Client = AdvancedClient;
645                type Error = MyError;
646
647                fn new() -> Self { Self }
648                fn reset(&mut self) {}
649
650                fn register(&self, client: &AdvancedClient) -> Result<(), MyError> {
651                    Ok(())
652                }
653
654                async fn process(&self, client: &AdvancedClient, val: u32, flag: bool) -> Result<u32, MyError> {
655                    Ok(val)
656                }
657            }
658        };
659
660        // 2. Parse
661        let input: ItemImpl = parse2(code).unwrap();
662        let cap = CapabilityImpl::new(input, false, "cap_name", "0.1.0").unwrap();
663
664        // 3. Generate Output
665        let output = cap.generate_client_impl();
666
667        // 4. Define Expected Output
668        // - new_client constructor should return Result<Self, MyError>.
669        // - process method should define an input struct and return Result<u32, MyError>.
670        let expected = quote! {
671            impl AdvancedClient {
672                pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, MyError> {
673                    ::pyroduct::wasm::Client::<Self>::__register_result::<MyError, _>(self, |ptr| unsafe { wasm::register(ptr) })
674                }
675            }
676            pub trait AdvancedClientMethods {
677                fn process(&self, val: u32, flag: bool) -> Result<u32, MyError>;
678            }
679            impl AdvancedClientMethods for ::pyroduct::wasm::Client<AdvancedClient> {
680                fn process(&self, val: u32, flag: bool) -> Result<u32, MyError> {
681                    #[::pyroduct::magma]
682                    struct p__AdvancedStruct__Process__Input {
683                        pub val: u32,
684                        pub flag: bool
685                    }
686
687                    self.__call_result_from_wasm::<
688                        p__AdvancedStruct__Process__Input,
689                        u32,
690                        MyError,
691                        _
692                    >(
693                        Some(&p__AdvancedStruct__Process__Input { val, flag }),
694                        |client_state_ptr: *const u8,
695                         input_ptr: *const u8| {
696                            unsafe {
697                                wasm::p__advanced_struct__process__wasm(
698                                    client_state_ptr,
699                                    input_ptr,
700                                )
701                            }
702                        }
703                    )
704                }
705            }
706        };
707
708        crate::fmt::assert_code_eq_token(&output, &expected);
709    }
710}