use syn::{
AngleBracketedGenericArguments, GenericArgument, PathArguments, Type, TypePath, parse_str,
};
fn is_syn_type_compatible(expected_type: &str, syn_type: &Type) -> bool {
if let Ok(expected_type) = parse_str::<Type>(expected_type) {
compare_types(&expected_type, syn_type)
} else {
false
}
}
#[cfg(test)]
fn is_type_compatible_ty<T>(syn_type: &Type) -> bool {
is_syn_type_compatible(std::any::type_name::<T>(), syn_type)
}
#[cfg(test)]
fn is_type_compatible<T>(type_str: &str) -> bool {
let ty = parse_str::<Type>(type_str).unwrap();
is_type_compatible_ty::<T>(&ty)
}
pub(crate) trait IsTypeCompatible {
fn is_type_compatible(&self, syn_type: &Type) -> bool;
}
impl IsTypeCompatible for str {
fn is_type_compatible(&self, syn_type: &Type) -> bool {
is_syn_type_compatible(self, syn_type)
}
}
fn compare_types(a: &Type, b: &Type) -> bool {
match (a, b) {
(Type::Path(a_path), Type::Path(b_path)) => compare_paths(a_path, b_path),
(Type::Reference(a_ref), Type::Reference(b_ref)) => compare_types(&a_ref.elem, &b_ref.elem),
_ => false,
}
}
fn compare_paths(a: &TypePath, b: &TypePath) -> bool {
let a_segment = a.path.segments.last();
let b_segment = b.path.segments.last();
match (a_segment, b_segment) {
(Some(a_seg), Some(b_seg)) => {
if a_seg.ident != b_seg.ident {
return false;
}
match (&a_seg.arguments, &b_seg.arguments) {
(PathArguments::AngleBracketed(a_args), PathArguments::AngleBracketed(b_args)) => {
compare_generic_arguments(a_args, b_args)
}
(PathArguments::None, PathArguments::None) => true,
_ => false,
}
}
_ => false,
}
}
fn compare_generic_arguments(
a: &AngleBracketedGenericArguments,
b: &AngleBracketedGenericArguments,
) -> bool {
if a.args.len() != b.args.len() {
return false;
}
for (a_arg, b_arg) in a.args.iter().zip(b.args.iter()) {
match (a_arg, b_arg) {
(GenericArgument::Type(a_ty), GenericArgument::Type(b_ty)) => {
if !compare_types(a_ty, b_ty) {
return false;
}
}
_ => continue, }
}
true
}
pub(crate) fn is_str_ref(ty: &Type) -> bool {
if let Type::Reference(type_ref) = ty
&& let Type::Path(type_path) = &*type_ref.elem
&& let Some(segment) = type_path.path.segments.first()
{
return segment.ident == "str";
}
false
}
#[cfg(test)]
mod test {
use super::{is_str_ref, is_type_compatible};
use chrono::{DateTime, Utc};
use syn::Type;
#[test]
fn test_is_str_ref() {
assert!(is_str_ref(&syn::parse_str::<Type>("&str").unwrap()));
assert!(is_str_ref(&syn::parse_str::<Type>("&'a str").unwrap()));
assert!(!is_str_ref(&syn::parse_str::<Type>("str").unwrap()));
assert!(!is_str_ref(&syn::parse_str::<Type>("String").unwrap()));
assert!(!is_str_ref(&syn::parse_str::<Type>("&String").unwrap()));
}
#[test]
fn test_is_type_compatible() {
assert!(is_type_compatible::<u32>("u32"));
assert!(is_type_compatible::<Vec<u32>>("Vec<u32>"));
assert!(is_type_compatible::<&Vec<&u32>>("&Vec<&u32>"));
assert!(is_type_compatible::<Vec<u32>>("alloc::vec::Vec<u32>"));
assert!(is_type_compatible::<&str>("&str"));
assert!(is_type_compatible::<DateTime<Utc>>("DateTime<Utc>"));
assert!(is_type_compatible::<DateTime<Utc>>(
"chrono::datetime::DateTime<chrono::offset::utc::Utc>"
));
assert!(is_type_compatible::<DateTime<Utc>>(
"chrono::DateTime<chrono::Utc>"
));
println!("All tests passed!");
}
}