use std::rc::Rc;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{Error, FnArg, Ident, ImplItemFn, Type, parse_quote};
use crate::{
ffi::paths::{CapabilityIdent, FnName},
utils::extract_ident_from_type,
};
#[derive(Debug, Clone)]
pub struct NewClientFn {
pub fn_name: FnName,
pub class: Rc<CapabilityIdent>,
pub client_type: Ident,
pub error_type: Option<Type>,
pub body: syn::Block,
pub attrs: Vec<syn::Attribute>,
pub is_async: bool,
}
impl NewClientFn {
pub fn parse(f: &ImplItemFn, class: &Rc<CapabilityIdent>) -> syn::Result<Self> {
let sig = &f.sig;
if sig.ident != "register" {
return Err(Error::new_spanned(
&sig.ident,
"Expected function named 'register'",
));
}
let is_async = sig.asyncness.is_some();
match sig.inputs.first() {
Some(FnArg::Receiver(r)) => {
if r.mutability.is_some() {
return Err(Error::new_spanned(
r,
"fn register must take &self (not &mut self)",
));
}
if r.reference.is_none() {
return Err(Error::new_spanned(
r,
"fn register must take &self (not self)",
));
}
}
Some(arg) => {
return Err(Error::new_spanned(
arg,
"fn register must take &self as its first parameter",
));
}
None => {
return Err(Error::new_spanned(sig, "fn register must take &self"));
}
}
if sig.inputs.len() != 2 {
return Err(Error::new_spanned(
&sig.inputs,
"fn register must take exactly two parameters: &self and client: &ClientType",
));
}
let client_type = match sig.inputs.iter().nth(1) {
Some(FnArg::Typed(pt)) => {
let ty = &*pt.ty;
if let Type::Reference(r) = ty {
extract_ident_from_type(&r.elem)?
} else {
return Err(Error::new_spanned(
ty,
"Client parameter must be a reference: &ClientType",
));
}
}
_ => {
return Err(Error::new_spanned(
&sig.inputs,
"fn new_client must have client: &ClientType as second parameter",
));
}
};
let (ok_ty, _err_ty) = crate::ffi::paths::verify_result_return_type(&sig.output)?;
let ok_str = quote!(#ok_ty).to_string().replace(" ", "");
if ok_str != "()" {
return Err(Error::new_spanned(
&sig.output,
"fn register must return Result<(), CapturedError> or Result<()>",
));
}
let error_type = Some(parse_quote!(::pyroduct::CapturedError));
Ok(Self {
fn_name: FnName(format_ident!("register")),
class: class.clone(),
client_type,
error_type,
body: f.block.clone(),
attrs: f.attrs.clone(),
is_async,
})
}
pub fn generate_impl_method(&self) -> TokenStream {
let attrs = &self.attrs;
let body = &self.body;
let client = &self.client_type;
quote! {
#(#attrs)*
pub fn new_client(&self, client: &#client) -> Result<(), ::pyroduct::CapturedError> #body
}
}
pub fn generate_export(&self) -> TokenStream {
let init_name = self.class.ffi_name(&self.fn_name);
if self.is_async {
quote!(::pyroduct::ffi::ClientRegisterFn::Async(#init_name))
} else {
quote!(::pyroduct::ffi::ClientRegisterFn::Sync(#init_name))
}
}
pub fn generate_capability_ffi(&self) -> TokenStream {
let fn_ffi_name = self.class.ffi_name(&self.fn_name);
let client_type = &self.client_type;
let state_type = &self.class.state_tn;
quote! {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn #fn_ffi_name(
capability_state_ptr: ::pyroduct::ffi::PyroRefObjectPtr,
client_state_ptr: ::pyroduct::format::PyroRefPtr,
) -> ::pyroduct::format::PyroViewPtr {
let client: #client_type = match ::pyroduct::ffi::guest::deserialize_input(client_state_ptr) {
Ok(v) => v,
Err(e) => return e.encode().view().into_ptr(),
};
let (_client_id, mux_id, _class_id, _fn_id) = ::pyroduct::format::get_ref_ids(client_state_ptr);
::pyroduct::ffi::guest::execute_safe(|| {
let state = unsafe { &*(capability_state_ptr.state as *const #state_type) };
::pyroduct::ffi::guest::serialize_result(state.new_client(&client))
}, capability_state_ptr.object_id, mux_id)
}
}
}
pub fn generate_client_wasm(&self) -> TokenStream {
quote! {
pub fn register(ptr: *const u8) -> *mut u8;
}
}
pub fn generate_wasm_call(&self, module: Option<&Ident>) -> TokenStream {
let module_prefix = if let Some(m) = module {
quote!(#m::)
} else {
quote!()
};
quote! {
#module_prefix register
}
}
pub fn generate_client_impl(&self, module: Option<&Ident>) -> TokenStream {
let client_type = &self.client_type;
let wasm_call = self.generate_wasm_call(module);
quote! {
impl #client_type {
pub fn register(self) -> Result<::pyroduct::wasm::Client<Self>, ::pyroduct::CapturedError> {
::pyroduct::wasm::Client::<Self>::__register_result(self, |ptr| unsafe { #wasm_call(ptr) })
}
}
}
}
}