use alloc::vec::Vec;
use quote::{ToTokens, quote};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::{Comma, Mut};
use syn::{
Error, FnArg, GenericArgument, Ident, ItemTrait, PathArguments, PathSegment, ReturnType,
Signature, TraitItem, TraitItemFn, TraitItemType, Type, TypePath, parse_quote, parse2,
};
use crate::derive_getter::getter_field::GetterField;
use crate::derive_getter::{FieldMode, ReceiverMode};
use crate::replace_self::replace_self_type;
pub fn parse_getter_fields(
context_type: &Ident,
consumer_trait: &ItemTrait,
) -> syn::Result<(Vec<GetterField>, Option<TraitItemType>)> {
let mut fields = Vec::new();
let mut field_assoc_type: Option<TraitItemType> = None;
for item in consumer_trait.items.iter() {
if let TraitItem::Type(item_type) = item {
if field_assoc_type.is_some() {
return Err(Error::new(
item_type.span(),
"at most one associated type is allowed in getter trait",
));
}
if !item_type.generics.params.is_empty() {
return Err(Error::new(
item_type.generics.params.span(),
"associated type in getter trait must not contain generic params",
));
}
field_assoc_type = Some(item_type.clone());
}
}
for item in consumer_trait.items.iter() {
match item {
TraitItem::Fn(method) => {
let getter_spec = parse_getter_method(
context_type,
method,
&field_assoc_type.as_ref().map(|item| item.ident.clone()),
)?;
fields.push(getter_spec);
}
TraitItem::Type(_) => {
}
_ => {
return Err(Error::new(
item.span(),
"getter trait can only contain getter methods",
));
}
}
}
match (&field_assoc_type, fields.first(), fields.len()) {
(None, _, _) => {}
(Some(field_assoc_type), Some(field), 1) => {
let field_assoc_type_ident = &field_assoc_type.ident;
let field_type = &field.field_type;
if field_type != &parse_quote! { Self :: #field_assoc_type_ident }
&& field_type != &parse_quote! { #context_type :: #field_assoc_type_ident }
{
return Err(Error::new(
field.field_type.span(),
"getter method return type must match the associated type",
));
}
}
_ => {
return Err(Error::new(
consumer_trait.span(),
"if associated type is defined, exactly one getter method must be defined",
));
}
}
Ok((fields, field_assoc_type))
}
pub fn parse_getter_method(
context_type: &Ident,
method: &TraitItemFn,
field_assoc_type: &Option<Ident>,
) -> syn::Result<GetterField> {
let signature = &method.sig;
validate_getter_method_signature(signature)?;
let field_name = signature.ident.clone();
let (arg, phantom) = parse_method_args(&signature.inputs)?;
let (receiver_mode, receiver_mut) = parse_receiver(context_type, arg)?;
let return_type = parse_return_type(context_type, &signature.output, field_assoc_type)?;
let (field_type, field_mode) = parse_field_type(&return_type, &receiver_mut)?;
Ok(GetterField {
receiver_mode,
field_name,
field_type,
return_type,
receiver_mut,
phantom_arg_type: phantom,
field_mode,
})
}
pub fn validate_getter_method_signature(signature: &Signature) -> syn::Result<()> {
if signature.constness.is_some() {
return Err(Error::new(
signature.constness.span(),
"getter method must not be const fn",
));
}
if signature.asyncness.is_some() {
return Err(Error::new(
signature.asyncness.span(),
"getter method must not be async fn",
));
}
if signature.unsafety.is_some() {
return Err(Error::new(
signature.unsafety.span(),
"getter method must not be unsafe fn",
));
}
if !signature.generics.params.is_empty() {
return Err(Error::new(
signature.generics.params.span(),
"getter method must not contain generic param",
));
}
if signature.generics.where_clause.is_some() {
return Err(Error::new(
signature.generics.where_clause.span(),
"getter method must not contain where clause",
));
}
Ok(())
}
fn parse_method_args(args: &Punctuated<FnArg, Comma>) -> syn::Result<(&FnArg, Option<Type>)> {
let args_count = args.len();
if args_count == 1 {
let [arg] = parse_fixed_size_args::<1>(args)?;
Ok((arg, None))
} else if args_count == 2 {
let [arg, phantom_arg] = parse_fixed_size_args::<2>(args)?;
let phantom_arg_type = parse_phantom_arg_type(phantom_arg)?;
Ok((arg, Some(phantom_arg_type)))
} else {
Err(Error::new(
args.span(),
"getter method must contain exactly one `&self` argument",
))
}
}
fn parse_fixed_size_args<const I: usize>(
args: &Punctuated<FnArg, Comma>,
) -> syn::Result<[&FnArg; I]> {
args.iter()
.collect::<Vec<&FnArg>>()
.try_into()
.map_err(|_| Error::new(args.span(), "expect getter method to contain {I} arguments"))
}
fn parse_phantom_arg_type(phantom_arg: &FnArg) -> syn::Result<Type> {
match phantom_arg {
FnArg::Typed(phantom_type) => match phantom_type.ty.as_ref() {
Type::Path(type_path) => {
let segment = parse_single_segment_type_path(type_path)?;
try_parse_phantom_arg_type_path(segment).ok_or_else(|| {
Error::new(
phantom_type.span(),
"only PhantomData is allowed as second argument",
)
})
}
_ => Err(Error::new(
phantom_type.span(),
"only PhantomData is allowed as second argument",
)),
},
_ => Err(Error::new(
phantom_arg.span(),
"optional second argument in getter must be PhantomData",
)),
}
}
fn parse_receiver(context_ident: &Ident, arg: &FnArg) -> syn::Result<(ReceiverMode, Option<Mut>)> {
match arg {
FnArg::Receiver(receiver) => {
if receiver.reference.is_none() {
Err(Error::new(
receiver.span(),
"first argument to getter method must be a reference to self, i.e. `&self`",
))
} else {
Ok((ReceiverMode::SelfReceiver, receiver.mutability))
}
}
FnArg::Typed(arg) => match arg.ty.as_ref() {
Type::Reference(ty) => {
let receiver = parse2(replace_self_type(
ty.elem.to_token_stream(),
context_ident.to_token_stream(),
&Vec::new(),
))?;
Ok((ReceiverMode::Type(receiver), ty.mutability))
}
_ => Err(Error::new(
arg.span(),
"first argument to getter method must be a reference",
)),
},
}
}
fn parse_return_type(
context_type: &Ident,
return_type: &ReturnType,
field_assoc_type: &Option<Ident>,
) -> syn::Result<Type> {
match return_type {
ReturnType::Type(_, ty) => parse2(replace_self_type(
ty.to_token_stream(),
context_type.to_token_stream(),
&field_assoc_type.iter().cloned().collect::<Vec<_>>(),
)),
_ => Err(Error::new(
return_type.span(),
"return type must be specified",
)),
}
}
pub fn parse_field_type(
return_type: &Type,
receiver_mut: &Option<Mut>,
) -> syn::Result<(Type, FieldMode)> {
match &return_type {
Type::Reference(type_ref) => {
if type_ref.mutability.is_some() && receiver_mut.is_none() {
return Err(Error::new(
type_ref.span(),
format!(
"&mut self is required for mutable field reference `{}`",
type_ref.to_token_stream()
),
));
}
if type_ref.elem.as_ref() == &parse_quote! { str } {
let field_type: Type = parse_quote! { String };
Ok((field_type, FieldMode::Str))
} else if let (Type::Slice(slice), None) = (type_ref.elem.as_ref(), receiver_mut) {
let field_type = slice.elem.as_ref().clone();
Ok((field_type, FieldMode::Slice))
} else {
let field_type = type_ref.elem.as_ref().clone();
Ok((field_type, FieldMode::Reference))
}
}
Type::Path(type_path) => {
if let Some(field_type) = try_parse_option_ref(type_path) {
Ok((
parse2(quote! { Option< #field_type > })?,
FieldMode::OptionRef,
))
} else if let (Some(field_type), None) = (try_parse_mref(type_path), receiver_mut) {
Ok((field_type.clone(), FieldMode::MRef))
} else {
Ok((return_type.clone(), FieldMode::Copy))
}
}
_ => Err(Error::new(
return_type.span(),
"return type must be a reference",
)),
}
}
fn parse_single_segment_type_path(type_path: &TypePath) -> syn::Result<&PathSegment> {
let [segment]: [&PathSegment; 1] = type_path
.path
.segments
.iter()
.collect::<Vec<_>>()
.try_into()
.map_err(|_| {
Error::new(
type_path.span(),
"type path must contain exactly one path segment",
)
})?;
Ok(segment)
}
fn try_parse_phantom_arg_type_path(segment: &PathSegment) -> Option<Type> {
if segment.ident == "PhantomData"
&& let PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(GenericArgument::Type(ty)) = args.args.first()
{
return Some(ty.clone());
}
None
}
fn try_parse_option_ref(type_path: &TypePath) -> Option<&Type> {
let segment = parse_single_segment_type_path(type_path).ok()?;
if segment.ident == "Option"
&& let PathArguments::AngleBracketed(args) = &segment.arguments
{
let [arg] = Vec::from_iter(args.args.iter()).try_into().ok()?;
if let GenericArgument::Type(Type::Reference(type_ref)) = arg {
return Some(type_ref.elem.as_ref());
}
}
None
}
fn try_parse_mref(type_path: &TypePath) -> Option<&Type> {
let segment = parse_single_segment_type_path(type_path).ok()?;
if segment.ident == "MRef"
&& let PathArguments::AngleBracketed(args) = &segment.arguments
{
let [arg1, arg2] = Vec::from_iter(args.args.iter()).try_into().ok()?;
if let (GenericArgument::Lifetime(_), GenericArgument::Type(ty)) = (arg1, arg2) {
return Some(ty);
}
}
None
}