use quote::quote;
use std::collections::BTreeSet;
use syn::{GenericParam, Generics, Type};
pub fn has_generics_or_lifetimes(ty: &Type) -> bool {
match ty {
Type::Path(type_path) => {
if type_path.path.segments.len() == 1 {
let segment = &type_path.path.segments[0];
if segment.arguments.is_empty() {
let ident_str = segment.ident.to_string();
if is_likely_generic_name(&ident_str) {
return true;
}
}
}
for segment in &type_path.path.segments {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
match arg {
syn::GenericArgument::Type(inner_ty) => {
if has_generics_or_lifetimes(inner_ty) {
return true;
}
}
syn::GenericArgument::Lifetime(_) => {
return true;
}
syn::GenericArgument::Const(_) => {
return true;
}
_ => {}
}
}
}
}
}
Type::Reference(type_ref) => {
if type_ref.lifetime.is_some() {
return true;
}
return has_generics_or_lifetimes(&type_ref.elem);
}
Type::Tuple(type_tuple) => {
for elem in &type_tuple.elems {
if has_generics_or_lifetimes(elem) {
return true;
}
}
}
Type::Array(type_array) => {
return has_generics_or_lifetimes(&type_array.elem);
}
Type::Slice(type_slice) => {
return has_generics_or_lifetimes(&type_slice.elem);
}
Type::Ptr(type_ptr) => {
return has_generics_or_lifetimes(&type_ptr.elem);
}
Type::BareFn(bare_fn) => {
if bare_fn.lifetimes.is_some() {
return true;
}
for input in &bare_fn.inputs {
if has_generics_or_lifetimes(&input.ty) {
return true;
}
}
if let syn::ReturnType::Type(_, output_ty) = &bare_fn.output {
if has_generics_or_lifetimes(output_ty) {
return true;
}
}
}
Type::TraitObject(trait_object) => {
for bound in &trait_object.bounds {
if let syn::TypeParamBound::Lifetime(_) = bound {
return true;
}
}
}
_ => {
}
}
false
}
fn is_likely_generic_name(name: &str) -> bool {
if name.len() == 1 && name.chars().next().is_some_and(|c| c.is_uppercase()) {
return true;
}
if name.len() <= 8 && name.starts_with(char::is_uppercase) {
match name {
"String" | "Vec" | "HashMap" | "HashSet" | "BTreeMap" | "BTreeSet" | "Option"
| "Result" | "Box" | "Arc" | "Rc" | "Cell" | "RefCell" | "Mutex" | "RwLock"
| "AtomicBool" | "AtomicI32" | "AtomicU32" | "PathBuf" | "OsString" | "CString"
| "Duration" | "Instant" => false,
_ => true,
}
} else {
false
}
}
pub fn collect_declared_generic_names(generics: &Generics) -> BTreeSet<String> {
generics
.params
.iter()
.filter_map(|param| match param {
GenericParam::Type(type_param) => {
Some(type_param.ident.to_string())
}
GenericParam::Const(const_param) => {
Some(const_param.ident.to_string())
}
GenericParam::Lifetime(_) => {
None
}
})
.collect()
}
pub fn transform_type_for_phantom_data(
ty: &Type,
declared_generics: &BTreeSet<String>,
) -> proc_macro2::TokenStream {
match ty {
Type::Reference(type_ref) => {
let lifetime = &type_ref.lifetime;
match type_ref.elem.as_ref() {
Type::Path(inner_path) if inner_path.path.segments.len() == 1 => {
let inner_ident = &inner_path.path.segments[0].ident;
if declared_generics.contains(&inner_ident.to_string()) {
quote! { #ty }
} else {
if let Some(lifetime) = lifetime {
quote! { &#lifetime () }
} else {
quote! { &() }
}
}
}
_ => {
if let Some(lifetime) = lifetime {
quote! { &#lifetime () }
} else {
quote! { &() }
}
}
}
}
Type::Path(type_path) => {
if type_path.path.segments.len() == 1 && type_path.path.segments[0].arguments.is_empty()
{
let ident = &type_path.path.segments[0].ident;
if declared_generics.contains(&ident.to_string()) {
quote! { #ident }
} else {
quote! { () }
}
} else {
quote! { () }
}
}
Type::Tuple(_) => {
quote! { () }
}
Type::Array(_) => {
quote! { () }
}
Type::Slice(_) => {
quote! { () }
}
_ => {
quote! { () }
}
}
}
pub fn generate_phantom_data_type<'a>(
field_types: impl Iterator<Item = &'a Type>,
struct_generics: &Generics,
) -> proc_macro2::TokenStream {
let mut phantom_types = Vec::new();
let declared_generics = collect_declared_generic_names(struct_generics);
for field_type in field_types {
let transformed = transform_type_for_phantom_data(field_type, &declared_generics);
phantom_types.push(transformed);
}
for param in &struct_generics.params {
match param {
GenericParam::Type(type_param) => {
let ident = &type_param.ident;
phantom_types.push(quote! { #ident });
}
GenericParam::Lifetime(_) => {
}
GenericParam::Const(_) => {
}
}
}
if phantom_types.is_empty() {
quote! {}
} else {
quote! { ::core::marker::PhantomData<( #(#phantom_types),* )> }
}
}
pub fn needs_phantom_data<'a>(
struct_generics: &Generics,
field_types: impl Iterator<Item = &'a Type>,
) -> bool {
if !struct_generics.params.is_empty() {
return true;
}
for field_type in field_types {
if has_generics_or_lifetimes(field_type) {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_has_generics_or_lifetimes() {
let generic_ty: Type = parse_quote!(Vec<T>);
assert!(has_generics_or_lifetimes(&generic_ty));
let concrete_ty: Type = parse_quote!(String);
assert!(!has_generics_or_lifetimes(&concrete_ty));
let reference_ty: Type = parse_quote!(&'a str);
assert!(has_generics_or_lifetimes(&reference_ty));
let simple_generic: Type = parse_quote!(T);
assert!(has_generics_or_lifetimes(&simple_generic));
let complex_generic: Type = parse_quote!(HashMap<K, V>);
assert!(has_generics_or_lifetimes(&complex_generic));
let no_generics: Type = parse_quote!(i32);
assert!(!has_generics_or_lifetimes(&no_generics));
}
#[test]
fn test_is_likely_generic_name() {
assert!(is_likely_generic_name("T"));
assert!(is_likely_generic_name("U"));
assert!(is_likely_generic_name("K"));
assert!(!is_likely_generic_name("String"));
assert!(!is_likely_generic_name("Vec"));
assert!(!is_likely_generic_name("HashMap"));
assert!(is_likely_generic_name("Item"));
assert!(is_likely_generic_name("Key"));
assert!(is_likely_generic_name("Value"));
assert!(!is_likely_generic_name("field"));
assert!(!is_likely_generic_name("value"));
}
#[test]
fn test_collect_declared_generic_names() {
let generics: Generics = parse_quote!(<T, U>);
let names = collect_declared_generic_names(&generics);
assert!(names.contains("T"));
assert!(names.contains("U"));
assert_eq!(names.len(), 2);
let generics: Generics = parse_quote!(<'a, T: Clone, const N: usize>);
let names = collect_declared_generic_names(&generics);
assert!(names.contains("T"));
assert!(names.contains("N"));
assert!(!names.contains("a")); assert_eq!(names.len(), 2);
let generics: Generics = parse_quote!();
let names = collect_declared_generic_names(&generics);
assert!(names.is_empty());
}
#[test]
fn test_transform_type_for_phantom_data() {
let generics: Generics = parse_quote!(<T>);
let declared = collect_declared_generic_names(&generics);
let generic_ty: Type = parse_quote!(T);
let result = transform_type_for_phantom_data(&generic_ty, &declared);
assert_eq!(result.to_string(), "T");
let concrete_ty: Type = parse_quote!(String);
let result = transform_type_for_phantom_data(&concrete_ty, &declared);
assert_eq!(result.to_string(), "()");
let ref_generic_ty: Type = parse_quote!(&'a T);
let result = transform_type_for_phantom_data(&ref_generic_ty, &declared);
assert_eq!(result.to_string(), "& 'a T");
let ref_concrete_ty: Type = parse_quote!(&'a str);
let result = transform_type_for_phantom_data(&ref_concrete_ty, &declared);
assert_eq!(result.to_string(), "& 'a ()");
}
#[test]
fn test_generate_phantom_data_type() {
let generics: Generics = parse_quote!(<T, U>);
let field_types = [parse_quote!(T), parse_quote!(String), parse_quote!(Vec<U>)];
let result = generate_phantom_data_type(field_types.iter(), &generics);
let result_str = result.to_string();
assert!(result_str.contains("PhantomData"));
assert!(result_str.contains("T"));
assert!(result_str.contains("U"));
}
#[test]
fn test_needs_phantom_data() {
let generics: Generics = parse_quote!(<T>);
let field_types = [parse_quote!(T)];
assert!(needs_phantom_data(&generics, field_types.iter()));
let no_generics: Generics = parse_quote!();
let concrete_fields = [parse_quote!(String)];
assert!(!needs_phantom_data(&no_generics, concrete_fields.iter()));
let no_generics: Generics = parse_quote!();
let no_fields: Vec<Type> = vec![];
assert!(!needs_phantom_data(&no_generics, no_fields.iter()));
}
}