pyro_macro/ffi/lifecycle/
new_client.rs1use std::rc::Rc;
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{Error, FnArg, Ident, ImplItemFn, Type, parse_quote};
6
7use crate::{
8 ffi::paths::{CapabilityIdent, FnName},
9 utils::extract_ident_from_type,
10};
11
12#[derive(Debug, Clone)]
13pub struct NewClientFn {
14 pub fn_name: FnName,
15 pub class: Rc<CapabilityIdent>,
16 pub client_type: Ident,
17 pub error_type: Option<Type>,
18 pub body: syn::Block,
19 pub attrs: Vec<syn::Attribute>,
20 pub is_async: bool,
21}
22
23impl NewClientFn {
24 pub fn parse(f: &ImplItemFn, class: &Rc<CapabilityIdent>) -> syn::Result<Self> {
25 let sig = &f.sig;
26
27 if sig.ident != "register" {
29 return Err(Error::new_spanned(
30 &sig.ident,
31 "Expected function named 'register'",
32 ));
33 }
34
35 let is_async = sig.asyncness.is_some();
37
38 match sig.inputs.first() {
40 Some(FnArg::Receiver(r)) => {
41 if r.mutability.is_some() {
42 return Err(Error::new_spanned(
43 r,
44 "fn register must take &self (not &mut self)",
45 ));
46 }
47 if r.reference.is_none() {
48 return Err(Error::new_spanned(
49 r,
50 "fn register must take &self (not self)",
51 ));
52 }
53 }
54 Some(arg) => {
55 return Err(Error::new_spanned(
56 arg,
57 "fn register must take &self as its first parameter",
58 ));
59 }
60 None => {
61 return Err(Error::new_spanned(sig, "fn register must take &self"));
62 }
63 }
64
65 if sig.inputs.len() != 2 {
67 return Err(Error::new_spanned(
68 &sig.inputs,
69 "fn register must take exactly two parameters: &self and client: &ClientType",
70 ));
71 }
72
73 let client_type = match sig.inputs.iter().nth(1) {
74 Some(FnArg::Typed(pt)) => {
75 let ty = &*pt.ty;
77 if let Type::Reference(r) = ty {
78 extract_ident_from_type(&r.elem)?
79 } else {
80 return Err(Error::new_spanned(
81 ty,
82 "Client parameter must be a reference: &ClientType",
83 ));
84 }
85 }
86 _ => {
87 return Err(Error::new_spanned(
88 &sig.inputs,
89 "fn new_client must have client: &ClientType as second parameter",
90 ));
91 }
92 };
93
94 let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
96 let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
97 if ok_str != "()" {
98 return Err(Error::new_spanned(
99 &sig.output,
100 "fn register must return Result<(), CapturedError> or Result<()>",
101 ));
102 }
103
104 let error_type = Some(parse_quote!(::pyroduct::CapturedError));
105
106 Ok(Self {
107 fn_name: FnName(format_ident!("register")),
108 class: class.clone(),
109 client_type,
110 error_type,
111 body: f.block.clone(),
112 attrs: f.attrs.clone(),
113 is_async,
114 })
115 }
116
117 pub fn generate_impl_method(&self) -> TokenStream {
119 let attrs = &self.attrs;
120 let body = &self.body;
121 let client = &self.client_type;
122
123 quote! {
124 #(#attrs)*
125 pub fn new_client(&self, client: &#client) -> Result<(), ::pyroduct::CapturedError> #body
126 }
127 }
128
129 pub fn generate_export(&self) -> TokenStream {
131 let init_name = self.class.ffi_name(&self.fn_name);
132
133 if self.is_async {
134 quote!(::pyroduct::ffi::ClientRegisterFn::Async(#init_name))
135 } else {
136 quote!(::pyroduct::ffi::ClientRegisterFn::Sync(#init_name))
137 }
138 }
139
140 pub fn generate_capability_ffi(&self) -> TokenStream {
141 let fn_ffi_name = self.class.ffi_name(&self.fn_name);
142 let client_type = &self.client_type;
143 let state_type = &self.class.state_tn;
144
145 quote! {
146 #[unsafe(no_mangle)]
147 pub unsafe extern "C" fn #fn_ffi_name(
148 capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
149 client_state_ptr: ::pyroduct::format::PyroRefPtr,
150 ) -> ::pyroduct::format::PyroViewPtr {
151 let client: #client_type = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
152 Ok(v) => v,
153 Err(e) => return e.encode().view().into_ptr(),
154 };
155
156 let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(client_state_ptr);
157 ::pyroduct::ffi::guest::execute_safe(|| {
158 let state = unsafe { &*(capability_state_ptr.state as *const #state_type) };
160 ::pyroduct::ffi::guest::serialize_result(state.new_client(&client))
161 }, capability_state_ptr.object_id, mux_id)
162 }
163 }
164 }
165
166 pub fn generate_client_wasm(&self) -> TokenStream {
169 quote! {
171 pub fn register(ptr: *const u8) -> *mut u8;
172 }
173 }
174
175 pub fn generate_wasm_call(&self, module: Option<&Ident>) -> TokenStream {
178 let module_prefix = if let Some(m) = module {
180 quote!(#m::)
181 } else {
182 quote!()
183 };
184
185 quote! {
186 #module_prefix register
187 }
188 }
189
190 pub fn generate_client_impl(&self, module: Option<&Ident>) -> TokenStream {
193 let client_type = &self.client_type;
194 let wasm_call = self.generate_wasm_call(module);
195
196 quote! {
197 impl #client_type {
198 pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, ::pyroduct::CapturedError> {
199 ::pyroduct::wasm::Client::<Self>::__register_result(self, |ptr| unsafe { #wasm_call(ptr) })
200 }
201 }
202 }
203 }
204}