use proc_macro2::TokenStream;
use quote::quote;
use std::collections::HashSet;
use syn::{Error, FnArg, Pat, Type, punctuated::Punctuated};
use super::{
types::{ArgInfo, MethodAttrs},
utils::{
ParamAttrs, collect_used_type_params, convert_to_single_lifetime, parse_return_type,
snake_case_to_pascal_case, type_contains_lifetime,
},
};
use crate::utils::is_option_type;
pub(super) fn generate_method_impl(
method: &mut syn::TraitItemFn,
interface_name: &str,
trait_generics: &syn::Generics,
method_attrs: &MethodAttrs,
crate_path: &TokenStream,
param_attrs_map: &std::collections::HashMap<String, ParamAttrs>,
) -> Result<TokenStream, Error> {
let method_name = &method.sig.ident;
let method_name_str = method_name.to_string();
let converted_name = snake_case_to_pascal_case(&method_name_str);
let actual_method_name = method_attrs.rename.as_deref().unwrap_or(&converted_name);
let method_path = format!("{interface_name}.{actual_method_name}");
let has_explicit_lifetimes = method.sig.generics.lifetimes().next().is_some();
let method_generics = method.sig.generics.clone();
let method_output = method.sig.output.clone();
let arg_infos = parse_method_arguments(method, has_explicit_lifetimes, param_attrs_map)?;
let has_any_lifetime = arg_infos.iter().any(|info| info.has_lifetime);
if method_attrs.is_streaming && method_attrs.is_oneway {
return Err(Error::new_spanned(
&method.sig,
"method cannot be both streaming (`more`) and oneway (`oneway`)",
));
}
if method_attrs.is_oneway && method_attrs.return_fds {
return Err(Error::new_spanned(
&method.sig,
"method cannot be both oneway (`oneway`) and return file descriptors (`return_fds`)",
));
}
let fds_params: Vec<_> = arg_infos.iter().filter(|info| info.is_fds).collect();
if fds_params.len() > 1 {
return Err(Error::new_spanned(
&method.sig,
"method can have at most one parameter marked with #[zlink(fds)]",
));
}
#[cfg(not(feature = "std"))]
{
if !fds_params.is_empty() || method_attrs.return_fds {
return Err(Error::new_spanned(
&method.sig,
"FD-related attributes (`#[zlink(fds)]` and `#[zlink(return_fds)]`) require the `std` feature to be enabled.",
));
}
}
let (reply_type, error_type) = if method_attrs.is_oneway {
(syn::parse_quote!(()), syn::parse_quote!(#crate_path::Error))
} else {
parse_return_type(
&method_output,
method_attrs.is_streaming,
method_attrs.return_fds,
)?
};
if method_attrs.is_streaming
&& (type_contains_lifetime(&reply_type) || type_contains_lifetime(&error_type))
{
return Err(Error::new_spanned(
&method.sig.output,
"streaming methods (`more`) require owned return types (no non-static lifetimes) \
because the internal buffer may be reused between stream iterations",
));
}
#[cfg(feature = "std")]
let (params_struct_def, params_init, fds_init) = generate_method_params(
&arg_infos,
&method_generics,
trait_generics,
has_any_lifetime,
has_explicit_lifetimes,
);
#[cfg(not(feature = "std"))]
let (params_struct_def, params_init, _fds_init) = generate_method_params(
&arg_infos,
&method_generics,
trait_generics,
has_any_lifetime,
has_explicit_lifetimes,
);
let method_call_setup = quote! {
#params_struct_def
#params_init
#[derive(::serde::Serialize, ::core::fmt::Debug)]
struct MethodCall<T> {
method: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<T>,
}
let method_call = MethodCall {
method: #method_path,
parameters: params,
};
};
let out_params_extract = match &reply_type {
Type::Tuple(tuple) if tuple.elems.is_empty() => {
quote!(Ok(Ok(())))
}
_ => {
quote!(match reply.into_parameters() {
Some(params) => Ok(Ok(params)),
None => Err(#crate_path::Error::MissingParameters),
})
}
};
let (return_type, implementation) = if method_attrs.is_oneway {
generate_oneway_method(
method_call_setup,
#[cfg(feature = "std")]
fds_init,
crate_path,
)
} else if method_attrs.is_streaming {
generate_streaming_method(
method_call_setup,
&reply_type,
&error_type,
out_params_extract,
method_attrs.return_fds,
#[cfg(feature = "std")]
fds_init,
crate_path,
)
} else {
generate_regular_method(
method_call_setup,
&reply_type,
&error_type,
out_params_extract,
method_attrs.return_fds,
#[cfg(feature = "std")]
fds_init,
crate_path,
)
};
let mut method_sig = method.sig.clone();
method_sig.output = syn::parse2(quote! { -> #return_type })?;
Ok(quote! {
#method_sig
{
#implementation
}
})
}
fn parse_method_arguments<'a>(
method: &'a mut syn::TraitItemFn,
has_explicit_lifetimes: bool,
param_attrs_map: &std::collections::HashMap<String, ParamAttrs>,
) -> Result<Vec<ArgInfo<'a>>, Error> {
method
.sig
.inputs
.iter_mut()
.skip(1)
.filter_map(|arg| {
let FnArg::Typed(pat_type) = arg else {
return None;
};
let Pat::Ident(pat_ident) = &*pat_type.pat else {
return None;
};
let name = &pat_ident.ident;
let ty = &pat_type.ty;
let param_name = name.to_string();
let param_attrs = param_attrs_map.get(¶m_name);
let serialized_name = param_attrs.and_then(|attrs| attrs.rename.clone());
let is_fds = param_attrs.map(|attrs| attrs.is_fds).unwrap_or(false);
let is_optional = is_option_type(ty);
let ty_for_params = if has_explicit_lifetimes {
(**ty).clone()
} else {
convert_to_single_lifetime(ty)
};
let has_lifetime = type_contains_lifetime(&ty_for_params);
Some(Ok(ArgInfo {
name,
ty_for_params,
is_optional,
has_lifetime,
serialized_name,
is_fds,
}))
})
.collect()
}
fn generate_method_params(
arg_infos: &[ArgInfo<'_>],
method_generics: &syn::Generics,
trait_generics: &syn::Generics,
has_any_lifetime: bool,
has_explicit_lifetimes: bool,
) -> (TokenStream, TokenStream, TokenStream) {
let fds_params: Vec<_> = arg_infos.iter().filter(|info| info.is_fds).collect();
let fds_init = if let Some(fd_param) = fds_params.first() {
let fd_name = fd_param.name;
quote! { #fd_name }
} else {
quote! { ::std::vec::Vec::new() }
};
let has_regular_params = arg_infos.iter().any(|info| !info.is_fds);
if has_regular_params {
let mut used_type_params = HashSet::new();
for info in arg_infos.iter().filter(|info| !info.is_fds) {
collect_used_type_params(&info.ty_for_params, &mut used_type_params);
}
let mut combined_generics: Punctuated<syn::GenericParam, syn::Token![,]> =
Punctuated::new();
for param in &trait_generics.params {
match param {
syn::GenericParam::Type(type_param) => {
if used_type_params.contains(&type_param.ident.to_string()) {
let mut clean_param = type_param.clone();
clean_param.bounds.clear();
combined_generics.push(syn::GenericParam::Type(clean_param));
}
}
other => combined_generics.push(other.clone()),
}
}
for param in &method_generics.params {
match param {
syn::GenericParam::Type(type_param) => {
if used_type_params.contains(&type_param.ident.to_string()) {
let mut clean_param = type_param.clone();
clean_param.bounds.clear();
combined_generics.push(syn::GenericParam::Type(clean_param));
}
}
other => combined_generics.push(other.clone()),
}
}
let generics_decl = if !combined_generics.is_empty() {
quote! { <#combined_generics> }
} else if has_any_lifetime && !has_explicit_lifetimes {
quote! { <'__proxy_params> }
} else {
quote! {}
};
let struct_fields: Vec<_> = arg_infos
.iter()
.filter(|info| !info.is_fds) .map(|info| {
let name = info.name;
let ty = &info.ty_for_params;
let serde_attrs = if let Some(ref renamed) = info.serialized_name {
if info.is_optional {
quote! {
#[serde(rename = #renamed, skip_serializing_if = "Option::is_none")]
}
} else {
quote! {
#[serde(rename = #renamed)]
}
}
} else if info.is_optional {
quote! {
#[serde(skip_serializing_if = "Option::is_none")]
}
} else {
quote! {}
};
quote! {
#serde_attrs
#name: #ty
}
})
.collect();
let params_where_clause = build_params_where_clause(method_generics, &used_type_params);
let struct_def = quote! {
#[derive(::serde::Serialize, ::core::fmt::Debug)]
struct Params #generics_decl
#params_where_clause
{
#(#struct_fields,)*
}
};
let regular_arg_names: Vec<_> = arg_infos
.iter()
.filter(|info| !info.is_fds)
.map(|info| info.name)
.collect();
let init = quote! {
let params = Some(Params {
#(#regular_arg_names,)*
});
};
(struct_def, init, fds_init)
} else {
(
quote! {},
quote! {
let params: Option<()> = None;
},
fds_init,
)
}
}
fn build_params_where_clause(
method_generics: &syn::Generics,
used_type_params: &HashSet<String>,
) -> Option<syn::WhereClause> {
match &method_generics.where_clause {
None => None,
Some(method_where_clause) => {
let mut where_predicates = syn::punctuated::Punctuated::new();
for predicate in &method_where_clause.predicates {
let syn::WherePredicate::Type(type_predicate) = predicate else {
continue;
};
let syn::Type::Path(type_path) = &type_predicate.bounded_ty else {
continue;
};
if type_path.path.segments.len() != 1 {
continue;
}
let type_name = &type_path.path.segments[0].ident;
if used_type_params.contains(&type_name.to_string()) {
where_predicates.push(predicate.clone());
}
}
if !where_predicates.is_empty() {
Some(syn::WhereClause {
where_token: syn::token::Where::default(),
predicates: where_predicates,
})
} else {
None
}
}
}
}
fn generate_oneway_method(
method_call_setup: TokenStream,
#[cfg(feature = "std")] fds_init: TokenStream,
crate_path: &TokenStream,
) -> (TokenStream, TokenStream) {
let return_type = quote! {
#crate_path::Result<()>
};
#[cfg(feature = "std")]
let send_call_args = quote! { &call, #fds_init };
#[cfg(not(feature = "std"))]
let send_call_args = quote! { &call };
let implementation = quote! {
#method_call_setup
let call = #crate_path::Call::new(method_call).set_oneway(true);
self.send_call(#send_call_args).await?;
Ok(())
};
(return_type, implementation)
}
#[allow(clippy::too_many_arguments)]
fn generate_streaming_method(
method_call_setup: TokenStream,
reply_type: &Type,
error_type: &Type,
out_params_extract: TokenStream,
return_fds: bool,
#[cfg(feature = "std")] fds_init: TokenStream,
crate_path: &TokenStream,
) -> (TokenStream, TokenStream) {
let return_type = if return_fds {
quote! {
#crate_path::Result<
impl #crate_path::futures_util::stream::Stream<
Item = #crate_path::Result<(::core::result::Result<#reply_type, #error_type>, ::std::vec::Vec<::std::os::fd::OwnedFd>)>
>
>
}
} else {
quote! {
#crate_path::Result<
impl #crate_path::futures_util::stream::Stream<
Item = #crate_path::Result<::core::result::Result<#reply_type, #error_type>>
>
>
}
};
#[cfg(feature = "std")]
let send_call_args = quote! { &call, #fds_init };
#[cfg(not(feature = "std"))]
let send_call_args = quote! { &call };
let send_and_receive = quote! {
#method_call_setup
let call = #crate_path::Call::new(method_call).set_more(true);
self.send_call(#send_call_args).await?;
};
#[cfg(feature = "std")]
let map_stream = if return_fds {
quote! {
stream.map(|result| {
match result {
Ok((Ok(reply), fds)) => {
#out_params_extract.map(|r| (r, fds))
}
Ok((Err(error), fds)) => Ok((Err(error), fds)),
Err(err) => Err(err),
}
})
}
} else {
quote! {
stream.map(|result| {
match result {
Ok((Ok(reply), _fds)) => #out_params_extract,
Ok((Err(error), _fds)) => Ok(Err(error)),
Err(err) => Err(err),
}
})
}
};
#[cfg(not(feature = "std"))]
let map_stream = quote! {
stream.map(|result| {
match result {
Ok(Ok(reply)) => #out_params_extract,
Ok(Err(error)) => Ok(Err(error)),
Err(err) => Err(err),
}
})
};
let implementation = quote! {
#send_and_receive
let stream = #crate_path::connection::chain::ReplyStream::<#reply_type, #error_type>::new(
self.read_mut(),
1,
);
use #crate_path::futures_util::stream::{Stream, StreamExt};
Ok(#map_stream)
};
(return_type, implementation)
}
#[allow(clippy::too_many_arguments)]
fn generate_regular_method(
method_call_setup: TokenStream,
reply_type: &Type,
error_type: &Type,
out_params_extract: TokenStream,
return_fds: bool,
#[cfg(feature = "std")] fds_init: TokenStream,
crate_path: &TokenStream,
) -> (TokenStream, TokenStream) {
let return_type = if return_fds {
quote! {
#crate_path::Result<(::core::result::Result<#reply_type, #error_type>, ::std::vec::Vec<::std::os::fd::OwnedFd>)>
}
} else {
quote! {
#crate_path::Result<::core::result::Result<#reply_type, #error_type>>
}
};
#[cfg(feature = "std")]
let call_method_args = quote! { &call, #fds_init };
#[cfg(not(feature = "std"))]
let call_method_args = quote! { &call };
#[cfg(feature = "std")]
let (receive_result, ok_arm, err_arm) = if return_fds {
(
quote! { let (result, fds) = },
quote! { #out_params_extract.map(|r| (r, fds)) },
quote! { Ok((Err(error), fds)) },
)
} else {
(
quote! { let (result, _fds) = },
quote! { #out_params_extract },
quote! { Ok(Err(error)) },
)
};
#[cfg(not(feature = "std"))]
let (receive_result, ok_arm, err_arm) = (
quote! { let result = },
quote! { #out_params_extract },
quote! { Ok(Err(error)) },
);
let implementation = quote! {
#method_call_setup
let call = #crate_path::Call::new(method_call);
#receive_result self.call_method::<_, #reply_type, #error_type>(#call_method_args).await?;
match result {
Ok(reply) => #ok_arm,
Err(error) => #err_arm,
}
};
(return_type, implementation)
}