use proc_macro2::TokenStream;
use syn::{
Attribute, Error, Expr, GenericArgument, Ident, ImplItemFn, Lit, Meta, PathArguments,
ReturnType, Type,
};
use super::types::ParamInfo;
pub(super) struct MethodInfo {
pub name: Ident,
pub varlink_name: String,
pub interface: Option<String>,
pub params: Vec<ParamInfo>,
pub return_type: Option<Type>,
pub returns_result: bool,
pub error_type: Option<Type>,
pub body: TokenStream,
pub is_streaming: bool,
pub stream_item_type: Option<Type>,
pub stream_error_type: Option<Type>,
pub stream_return_type: Option<Type>,
pub stream_uses_impl_trait: bool,
pub return_fds: bool,
}
impl MethodInfo {
pub(super) fn extract(
method: &mut ImplItemFn,
current_interface: &mut Option<String>,
) -> Result<Self, Error> {
let name = method.sig.ident.clone();
let method_attrs = MethodAttrs::extract(&mut method.attrs)?;
if let Some(ref iface) = method_attrs.interface {
*current_interface = Some(iface.clone());
}
let varlink_name = method_attrs
.rename
.unwrap_or_else(|| snake_case_to_pascal_case(&name.to_string()));
let is_streaming = method_attrs.is_streaming;
let params: Vec<ParamInfo> = method
.sig
.inputs
.iter()
.skip(1)
.enumerate()
.filter_map(|(idx, arg)| {
let mut param_info = ParamInfo::from_fn_arg(arg)?;
if let syn::FnArg::Typed(pat_type) = arg {
let param_attrs = extract_param_attrs(&pat_type.attrs);
param_info.serialized_name = param_attrs.rename;
param_info.is_connection = param_attrs.is_connection;
param_info.is_fds = param_attrs.is_fds;
}
if is_streaming && idx == 0 {
param_info.is_more = true;
}
Some(param_info)
})
.collect();
if is_streaming {
let first_param = params.first().ok_or_else(|| {
Error::new_spanned(
&method.sig,
"streaming methods must have `more: bool` as the first parameter after `self`",
)
})?;
if !is_bool_type(&first_param.ty) {
return Err(Error::new_spanned(
&method.sig.inputs,
"streaming methods must have `more: bool` as the first parameter after `self`",
));
}
}
let return_fds = method_attrs.return_fds;
let fds_params: Vec<_> = params.iter().filter(|p| p.is_fds).collect();
if fds_params.len() > 1 {
return Err(Error::new_spanned(
&method.sig,
"at most one `#[zlink(fds)]` parameter is allowed per method",
));
}
#[cfg(not(feature = "std"))]
if !fds_params.is_empty() || return_fds {
return Err(Error::new_spanned(
&method.sig,
"FD-related attributes (`#[zlink(fds)]` and `#[zlink(return_fds)]`) \
require the `std` feature to be enabled",
));
}
let (
return_type,
returns_result,
error_type,
stream_item_type,
stream_error_type,
stream_return_type,
stream_uses_impl_trait,
) = if is_streaming && return_fds {
match &method.sig.output {
ReturnType::Default => {
return Err(Error::new_spanned(
&method.sig,
"streaming methods with return_fds must return a Stream<Item = \
(Reply<T>, Vec<OwnedFd>)> or Stream<Item = (Result<Reply<T>, E>, \
Vec<OwnedFd>)>",
));
}
ReturnType::Type(_, ty) => {
let stream_item = extract_stream_item_type(ty).ok_or_else(|| {
Error::new_spanned(
ty,
"streaming methods with return_fds must return a Stream<Item = \
(Reply<T>, Vec<OwnedFd>)> or Stream<Item = (Result<Reply<T>, E>, \
Vec<OwnedFd>)> (could not extract Stream's Item type)",
)
})?;
let first = extract_first_tuple_element(&stream_item).ok_or_else(|| {
Error::new_spanned(
ty,
"streaming methods with return_fds must return a Stream<Item = \
(Reply<T>, Vec<OwnedFd>)> or Stream<Item = (Result<Reply<T>, E>, \
Vec<OwnedFd>)> (stream item must be a tuple)",
)
})?;
let (inner_type, err_type) =
extract_reply_or_result_reply(&first).ok_or_else(|| {
Error::new_spanned(
ty,
"streaming methods with return_fds must return a Stream<Item = \
(Reply<T>, Vec<OwnedFd>)> or Stream<Item = (Result<Reply<T>, \
E>, Vec<OwnedFd>)> (first tuple element must be Reply<T> or \
Result<Reply<T>, E>)",
)
})?;
let uses_impl_trait = matches!(**ty, Type::ImplTrait(_));
(
None,
false,
None,
Some(inner_type),
err_type,
Some((**ty).clone()),
uses_impl_trait,
)
}
}
} else if is_streaming {
match &method.sig.output {
ReturnType::Default => {
return Err(Error::new_spanned(
&method.sig,
"streaming methods must return a Stream<Item = Reply<T>> or \
Stream<Item = Result<Reply<T>, E>>",
));
}
ReturnType::Type(_, ty) => {
let stream_item = extract_stream_item_type(ty).ok_or_else(|| {
Error::new_spanned(
ty,
"streaming methods must return a Stream<Item = Reply<T>> or \
Stream<Item = Result<Reply<T>, E>> (could not extract Stream's Item \
type)",
)
})?;
let (inner_type, err_type) = extract_reply_or_result_reply(&stream_item)
.ok_or_else(|| {
Error::new_spanned(
ty,
"streaming methods must return a Stream<Item = Reply<T>> or \
Stream<Item = Result<Reply<T>, E>> (stream item must be \
Reply<T> or Result<Reply<T>, E>)",
)
})?;
let uses_impl_trait = matches!(**ty, Type::ImplTrait(_));
(
None,
false,
None,
Some(inner_type),
err_type,
Some((**ty).clone()),
uses_impl_trait,
)
}
}
} else if return_fds {
match &method.sig.output {
ReturnType::Default => {
return Err(Error::new_spanned(
&method.sig,
"`return_fds` methods must have a return type",
));
}
ReturnType::Type(_, ty) => {
let first = extract_first_tuple_element(ty).ok_or_else(|| {
Error::new_spanned(
ty,
"`return_fds` methods must return \
`(T, Vec<OwnedFd>)` or `(Result<T, E>, Vec<OwnedFd>)`",
)
})?;
if let Some((inner_ty, err_ty)) = extract_result_types(&first) {
(inner_ty, true, Some(err_ty), None, None, None, false)
} else {
let data_ty = if is_unit_type(&first) {
None
} else {
Some(first)
};
(data_ty, false, None, None, None, None, false)
}
}
}
} else {
match &method.sig.output {
ReturnType::Default => (None, false, None, None, None, None, false),
ReturnType::Type(_, ty) => {
if let Some((inner_ty, err_ty)) = extract_result_types(ty) {
(inner_ty, true, Some(err_ty), None, None, None, false)
} else {
(Some((**ty).clone()), false, None, None, None, None, false)
}
}
}
};
let block = &method.block;
let body = quote::quote! { #block };
Ok(Self {
name,
varlink_name,
interface: current_interface.clone(),
params,
return_type,
returns_result,
error_type,
body,
is_streaming,
stream_item_type,
stream_error_type,
stream_return_type,
stream_uses_impl_trait,
return_fds,
})
}
pub(super) fn full_method_path(&self) -> Option<String> {
self.interface
.as_ref()
.map(|iface| format!("{}.{}", iface, self.varlink_name))
}
pub(super) fn has_connection_param(&self) -> bool {
self.params.iter().any(|p| p.is_connection)
}
pub(super) fn serialized_params(&self) -> impl Iterator<Item = &ParamInfo> {
self.params
.iter()
.filter(|p| !p.is_connection && !p.is_more && !p.is_fds)
}
}
#[derive(Default)]
struct MethodAttrs {
interface: Option<String>,
rename: Option<String>,
is_streaming: bool,
return_fds: bool,
}
impl MethodAttrs {
fn extract(attrs: &mut Vec<Attribute>) -> Result<Self, Error> {
let mut result = Self::default();
let mut indices_to_remove = Vec::new();
for (i, attr) in attrs.iter().enumerate() {
if !attr.path().is_ident("zlink") {
continue;
}
indices_to_remove.push(i);
let Meta::List(list) = &attr.meta else {
continue;
};
if list.tokens.is_empty() {
continue;
}
let nested = list.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
)?;
for meta in nested {
match &meta {
Meta::NameValue(nv) if nv.path.is_ident("interface") => {
let Expr::Lit(expr_lit) = &nv.value else {
return Err(Error::new_spanned(
&nv.value,
"interface value must be a string literal",
));
};
let Lit::Str(lit_str) = &expr_lit.lit else {
return Err(Error::new_spanned(
&nv.value,
"interface value must be a string literal",
));
};
result.interface = Some(lit_str.value());
}
Meta::NameValue(nv) if nv.path.is_ident("rename") => {
let Expr::Lit(expr_lit) = &nv.value else {
return Err(Error::new_spanned(
&nv.value,
"rename value must be a string literal",
));
};
let Lit::Str(lit_str) = &expr_lit.lit else {
return Err(Error::new_spanned(
&nv.value,
"rename value must be a string literal",
));
};
result.rename = Some(lit_str.value());
}
Meta::Path(path) if path.is_ident("more") => {
if result.is_streaming {
return Err(Error::new_spanned(&meta, "duplicate `more` attribute"));
}
result.is_streaming = true;
}
Meta::Path(path) if path.is_ident("return_fds") => {
if result.return_fds {
return Err(Error::new_spanned(
&meta,
"duplicate `return_fds` attribute",
));
}
result.return_fds = true;
}
_ => {
return Err(Error::new_spanned(&meta, "unknown zlink attribute"));
}
}
}
}
for &index in indices_to_remove.iter().rev() {
attrs.remove(index);
}
Ok(result)
}
}
#[derive(Default)]
struct ParamAttrs {
rename: Option<String>,
is_connection: bool,
is_fds: bool,
}
fn extract_param_attrs(attrs: &[Attribute]) -> ParamAttrs {
let mut result = ParamAttrs::default();
for attr in attrs {
if !attr.path().is_ident("zlink") {
continue;
}
let Meta::List(list) = &attr.meta else {
continue;
};
let Ok(nested) = list
.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
else {
continue;
};
for meta in nested {
match &meta {
Meta::NameValue(nv) if nv.path.is_ident("rename") => {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
result.rename = Some(lit_str.value());
}
}
}
Meta::Path(path) if path.is_ident("connection") => {
result.is_connection = true;
}
Meta::Path(path) if path.is_ident("fds") => {
result.is_fds = true;
}
_ => {}
}
}
}
result
}
fn snake_case_to_pascal_case(input: &str) -> String {
input
.split('_')
.map(|word| {
let mut chars = word.chars();
let Some(first) = chars.next() else {
return String::new();
};
first.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
})
.collect()
}
fn extract_result_types(ty: &Type) -> Option<(Option<Type>, Type)> {
let Type::Path(type_path) = ty else {
return None;
};
let last_segment = type_path.path.segments.last()?;
if last_segment.ident != "Result" {
return None;
}
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return None;
};
let first_arg = args.args.first()?;
let GenericArgument::Type(ok_type) = first_arg else {
return None;
};
let second_arg = args.args.iter().nth(1)?;
let GenericArgument::Type(err_type) = second_arg else {
return None;
};
let ok_type = if let Type::Tuple(tuple) = ok_type {
if tuple.elems.is_empty() {
None
} else {
Some(ok_type.clone())
}
} else {
Some(ok_type.clone())
};
Some((ok_type, err_type.clone()))
}
pub(super) fn extract_stream_item_type(ty: &Type) -> Option<Type> {
match ty {
Type::ImplTrait(impl_trait) => {
for bound in &impl_trait.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if let Some(item_type) = extract_stream_item_from_trait_bound(trait_bound) {
return Some(item_type);
}
}
}
None
}
Type::TraitObject(trait_object) => {
for bound in &trait_object.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if let Some(item_type) = extract_stream_item_from_trait_bound(trait_bound) {
return Some(item_type);
}
}
}
None
}
Type::Path(type_path) => {
let last_segment = type_path.path.segments.last()?;
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return None;
};
let first_arg = args.args.first()?;
let GenericArgument::Type(item_type) = first_arg else {
return None;
};
Some(syn::parse_quote!(Reply<#item_type>))
}
_ => None,
}
}
fn extract_stream_item_from_trait_bound(trait_bound: &syn::TraitBound) -> Option<Type> {
let last_segment = trait_bound.path.segments.last()?;
if last_segment.ident != "Stream" {
return None;
}
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return None;
};
for arg in &args.args {
if let GenericArgument::AssocType(assoc_type) = arg {
if assoc_type.ident == "Item" {
return Some(assoc_type.ty.clone());
}
}
}
None
}
pub(super) fn extract_reply_or_result_reply(ty: &Type) -> Option<(Type, Option<Type>)> {
if let Some(inner) = extract_reply_inner_type(ty) {
return Some((inner, None));
}
let (ok_ty, err_ty) = extract_result_types(ty)?;
let ok_ty = ok_ty?;
let inner = extract_reply_inner_type(&ok_ty)?;
Some((inner, Some(err_ty)))
}
fn extract_reply_inner_type(ty: &Type) -> Option<Type> {
let Type::Path(type_path) = ty else {
return None;
};
let last_segment = type_path.path.segments.last()?;
if last_segment.ident != "Reply" {
return None;
}
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return None;
};
let first_arg = args.args.first()?;
let GenericArgument::Type(inner_type) = first_arg else {
return None;
};
Some(inner_type.clone())
}
fn is_bool_type(ty: &Type) -> bool {
let Type::Path(type_path) = ty else {
return false;
};
type_path.path.is_ident("bool")
}
fn is_unit_type(ty: &Type) -> bool {
let Type::Tuple(tuple) = ty else {
return false;
};
tuple.elems.is_empty()
}
pub(super) fn extract_first_tuple_element(ty: &Type) -> Option<Type> {
let Type::Tuple(tuple) = ty else {
return None;
};
tuple.elems.first().cloned()
}