Skip to main content

pyro_macro/ffi/lifecycle/
new_client.rs

1use std::rc::Rc;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Error, FnArg, Ident, ImplItemFn, Type, parse_quote};
6
7use crate::{
8    ffi::paths::{CapabilityIdent, FnName},
9    utils::extract_ident_from_type,
10};
11
12#[derive(Debug, Clone)]
13pub struct NewClientFn {
14    pub fn_name: FnName,
15    pub class: Rc<CapabilityIdent>,
16    pub client_type: Ident,
17    pub error_type: Option<Type>,
18    pub body: syn::Block,
19    pub attrs: Vec<syn::Attribute>,
20    pub is_async: bool,
21}
22
23impl NewClientFn {
24    pub fn parse(f: &ImplItemFn, class: &Rc<CapabilityIdent>) -> syn::Result<Self> {
25        let sig = &f.sig;
26
27        // 1. Validate name
28        if sig.ident != "register" {
29            return Err(Error::new_spanned(
30                &sig.ident,
31                "Expected function named 'register'",
32            ));
33        }
34
35        // 2. Validate not async
36        let is_async = sig.asyncness.is_some();
37
38        // 3. Validate &self as first parameter
39        match sig.inputs.first() {
40            Some(FnArg::Receiver(r)) => {
41                if r.mutability.is_some() {
42                    return Err(Error::new_spanned(
43                        r,
44                        "fn register must take &self (not &mut self)",
45                    ));
46                }
47                if r.reference.is_none() {
48                    return Err(Error::new_spanned(
49                        r,
50                        "fn register must take &self (not self)",
51                    ));
52                }
53            }
54            Some(arg) => {
55                return Err(Error::new_spanned(
56                    arg,
57                    "fn register must take &self as its first parameter",
58                ));
59            }
60            None => {
61                return Err(Error::new_spanned(sig, "fn register must take &self"));
62            }
63        }
64
65        // 4. Validate second parameter is client: &ClientType
66        if sig.inputs.len() != 2 {
67            return Err(Error::new_spanned(
68                &sig.inputs,
69                "fn register must take exactly two parameters: &self and client: &ClientType",
70            ));
71        }
72
73        let client_type = match sig.inputs.iter().nth(1) {
74            Some(FnArg::Typed(pt)) => {
75                // Extract the type (should be a reference)
76                let ty = &*pt.ty;
77                if let Type::Reference(r) = ty {
78                    extract_ident_from_type(&r.elem)?
79                } else {
80                    return Err(Error::new_spanned(
81                        ty,
82                        "Client parameter must be a reference: &ClientType",
83                    ));
84                }
85            }
86            _ => {
87                return Err(Error::new_spanned(
88                    &sig.inputs,
89                    "fn new_client must have client: &ClientType as second parameter",
90                ));
91            }
92        };
93
94        // 5. Validate return type: Result<(), CapturedError> or Result<()>
95        let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
96        let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
97        if ok_str != "()" {
98            return Err(Error::new_spanned(
99                &sig.output,
100                "fn register must return Result<(), CapturedError> or Result<()>",
101            ));
102        }
103
104        let error_type = Some(parse_quote!(::pyroduct::CapturedError));
105
106        Ok(Self {
107            fn_name: FnName(format_ident!("register")),
108            class: class.clone(),
109            client_type,
110            error_type,
111            body: f.block.clone(),
112            attrs: f.attrs.clone(),
113            is_async,
114        })
115    }
116
117    /// Generate the impl method (preserves original)
118    pub fn generate_impl_method(&self) -> TokenStream {
119        let attrs = &self.attrs;
120        let body = &self.body;
121        let client = &self.client_type;
122
123        quote! {
124            #(#attrs)*
125            pub fn new_client(&self, client: &#client) -> Result<(), ::pyroduct::CapturedError> #body
126        }
127    }
128
129    /// Generate the export entry for the init function
130    pub fn generate_export(&self) -> TokenStream {
131        let init_name = self.class.ffi_name(&self.fn_name);
132
133        if self.is_async {
134            quote!(::pyroduct::ffi::ClientRegisterFn::Async(#init_name))
135        } else {
136            quote!(::pyroduct::ffi::ClientRegisterFn::Sync(#init_name))
137        }
138    }
139
140    pub fn generate_capability_ffi(&self) -> TokenStream {
141        let fn_ffi_name = self.class.ffi_name(&self.fn_name);
142        let client_type = &self.client_type;
143        let state_type = &self.class.state_tn;
144
145        quote! {
146            #[unsafe(no_mangle)]
147            pub unsafe extern "C" fn #fn_ffi_name(
148                capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
149                client_state_ptr: ::pyroduct::format::PyroRefPtr,
150            ) -> ::pyroduct::format::PyroViewPtr {
151                let client: #client_type = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
152                    Ok(v) => v,
153                    Err(e) => return e.encode().view().into_ptr(),
154                };
155
156                let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(client_state_ptr);
157                ::pyroduct::ffi::guest::execute_safe(|| {
158                    // Reconstruct state from raw pointer
159                    let state = unsafe { &*(capability_state_ptr.state as *const #state_type) };
160                    ::pyroduct::ffi::guest::serialize_result(state.new_client(&client))
161                }, capability_state_ptr.object_id, mux_id)
162            }
163        }
164    }
165
166    /// Generates the extern declaration for the WASM import.
167    /// This corresponds to `generate_client_wasm` requested in the prompt.
168    pub fn generate_client_wasm(&self) -> TokenStream {
169        //let fn_wasm_name = self.class.wasm_name(&self.fn_name);
170        quote! {
171            pub fn register(ptr: *const u8) -> *mut u8;
172        }
173    }
174
175    /// Generates the call expression used inside the client's register method.
176    /// This corresponds to `generate_wasm_call`.
177    pub fn generate_wasm_call(&self, module: Option<&Ident>) -> TokenStream {
178        // let fn_wasm_name = self.class.wasm_name(&self.fn_name);
179        let module_prefix = if let Some(m) = module {
180            quote!(#m::)
181        } else {
182            quote!()
183        };
184
185        quote! {
186            #module_prefix register
187        }
188    }
189
190    /// Generates the full client-side implementation of the register method.
191    /// The user prompt referred to this as "generate_client_wasm needs to generate impl MyClient".
192    pub fn generate_client_impl(&self, module: Option<&Ident>) -> TokenStream {
193        let client_type = &self.client_type;
194        let wasm_call = self.generate_wasm_call(module);
195
196        quote! {
197            impl #client_type {
198                pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, ::pyroduct::CapturedError> {
199                    ::pyroduct::wasm::Client::<Self>::__register_result(self, |ptr| unsafe { #wasm_call(ptr) })
200                }
201            }
202        }
203    }
204}