use crate::analysis::field_analysis::FieldInfo;
use crate::attributes::{parse_struct_attributes, StructAttributes};
use crate::utils::generics::needs_phantom_data;
use quote::quote;
use syn::{DeriveInput, Fields, Generics, Type};
#[derive(Debug, Clone)]
pub struct StructAnalysis {
struct_name: syn::Ident,
struct_visibility: syn::Visibility,
struct_generics: Generics,
struct_attributes: StructAttributes,
required_fields: Vec<FieldInfo>,
optional_fields: Vec<FieldInfo>,
}
impl StructAnalysis {
fn from_derive_input(input: &DeriveInput) -> syn::Result<Self> {
let struct_name = input.ident.clone();
let struct_visibility = input.vis.clone();
let struct_generics = input.generics.clone();
let struct_attributes = parse_struct_attributes(&input.attrs)?;
let fields = extract_named_fields(input)?;
let (required_fields, optional_fields) = parse_fields(fields)?;
Ok(StructAnalysis {
struct_name,
struct_visibility,
struct_generics,
struct_attributes,
required_fields,
optional_fields,
})
}
pub fn struct_name(&self) -> &syn::Ident {
&self.struct_name
}
pub fn struct_visibility(&self) -> &syn::Visibility {
&self.struct_visibility
}
pub fn struct_generics(&self) -> &Generics {
&self.struct_generics
}
pub fn struct_attributes(&self) -> &StructAttributes {
&self.struct_attributes
}
pub fn required_fields(&self) -> &[FieldInfo] {
&self.required_fields
}
pub fn optional_fields(&self) -> &[FieldInfo] {
&self.optional_fields
}
pub fn all_fields(&self) -> impl Iterator<Item = &FieldInfo> {
self.required_fields
.iter()
.chain(self.optional_fields.iter())
}
pub fn builder_method_field(&self) -> Option<&FieldInfo> {
self.required_fields
.iter()
.find(|f| f.attributes().builder_method)
}
pub fn has_only_optional_fields(&self) -> bool {
self.required_fields.is_empty()
}
pub fn needs_phantom_data(&self) -> bool {
let field_types: Vec<&Type> = self.all_fields().map(|f| f.field_type()).collect();
needs_phantom_data(&self.struct_generics, field_types.iter().copied())
}
pub fn impl_generics_tokens(&self) -> proc_macro2::TokenStream {
let (impl_generics, _ty_generics, _where_clause) = self.struct_generics.split_for_impl();
quote! { #impl_generics }
}
pub fn type_generics_tokens(&self) -> proc_macro2::TokenStream {
let (_impl_generics, ty_generics, _where_clause) = self.struct_generics.split_for_impl();
quote! { #ty_generics }
}
pub fn where_clause_tokens(&self) -> proc_macro2::TokenStream {
let (_impl_generics, _ty_generics, where_clause) = self.struct_generics.split_for_impl();
quote! { #where_clause }
}
pub fn validate_for_generation(&self) -> syn::Result<()> {
crate::analysis::validation::validate_struct_for_generation(self)
}
}
pub fn analyze_struct(input: &DeriveInput) -> syn::Result<StructAnalysis> {
StructAnalysis::from_derive_input(input)
}
fn extract_named_fields(input: &DeriveInput) -> syn::Result<&syn::FieldsNamed> {
match &input.data {
syn::Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields_named) => Ok(fields_named),
Fields::Unnamed(_) => Err(syn::Error::new_spanned(
input,
"TypeStateBuilder only supports structs with named fields",
)),
Fields::Unit => Err(syn::Error::new_spanned(
input,
"TypeStateBuilder does not support unit structs",
)),
},
syn::Data::Enum(_) => Err(syn::Error::new_spanned(
input,
"TypeStateBuilder only supports structs, not enums",
)),
syn::Data::Union(_) => Err(syn::Error::new_spanned(
input,
"TypeStateBuilder only supports structs, not unions",
)),
}
}
fn parse_fields(fields_named: &syn::FieldsNamed) -> syn::Result<(Vec<FieldInfo>, Vec<FieldInfo>)> {
let mut required_fields = Vec::new();
let mut optional_fields = Vec::new();
for field in &fields_named.named {
let field_name = field
.ident
.as_ref()
.ok_or_else(|| syn::Error::new_spanned(field, "Field must have a name"))?
.clone();
let field_info = FieldInfo::from_syn_field(field_name, field.ty.clone(), &field.attrs)?;
if field_info.is_required() {
required_fields.push(field_info);
} else {
optional_fields.push(field_info);
}
}
Ok((required_fields, optional_fields))
}
#[cfg(test)]
mod tests {
use super::*;
use quote::ToTokens;
use syn::parse_quote;
#[test]
fn test_analyze_simple_struct() {
let input: DeriveInput = parse_quote! {
struct Example {
#[builder(required)]
name: String,
age: Option<u32>,
}
};
let analysis = analyze_struct(&input).unwrap();
assert_eq!(analysis.struct_name().to_string(), "Example");
assert_eq!(analysis.required_fields().len(), 1);
assert_eq!(analysis.optional_fields().len(), 1);
assert!(analysis.struct_generics().params.is_empty());
assert!(!analysis.has_only_optional_fields());
}
#[test]
fn test_analyze_generic_struct() {
let input: DeriveInput = parse_quote! {
struct Example<T, U> {
value: T,
#[builder(required)]
name: String,
}
};
let analysis = analyze_struct(&input).unwrap();
assert!(!analysis.struct_generics().params.is_empty());
assert_eq!(analysis.struct_generics().params.len(), 2);
assert!(analysis
.all_fields()
.any(|f| f.field_type().to_token_stream().to_string() == "T"));
}
#[test]
fn test_analyze_struct_with_lifetimes() {
let input: DeriveInput = parse_quote! {
struct Example<'a, 'b> {
text: &'a str,
#[builder(required)]
name: String,
}
};
let analysis = analyze_struct(&input).unwrap();
assert!(!analysis.struct_generics().params.is_empty());
assert!(analysis.all_fields().any(|f| f
.field_type()
.to_token_stream()
.to_string()
.contains("'a")));
}
#[test]
fn test_analyze_struct_with_custom_build_method() {
let input: DeriveInput = parse_quote! {
#[builder(build_method = "create")]
struct Example {
name: String,
}
};
let analysis = analyze_struct(&input).unwrap();
assert_eq!(
analysis.struct_attributes().get_build_method_name(),
"create"
);
}
#[test]
fn test_analyze_all_optional_fields() {
let input: DeriveInput = parse_quote! {
struct Example {
name: Option<String>,
age: u32,
}
};
let analysis = analyze_struct(&input).unwrap();
assert!(analysis.has_only_optional_fields());
assert_eq!(analysis.required_fields().len(), 0);
assert_eq!(analysis.optional_fields().len(), 2);
}
#[test]
fn test_extract_named_fields_errors() {
let tuple_input: DeriveInput = parse_quote!(
struct Example(String, i32);
);
assert!(extract_named_fields(&tuple_input).is_err());
let unit_input: DeriveInput = parse_quote!(
struct Example;
);
assert!(extract_named_fields(&unit_input).is_err());
let enum_input: DeriveInput = parse_quote!(
enum Example {
A,
B,
}
);
assert!(extract_named_fields(&enum_input).is_err());
}
#[test]
fn test_token_generation() {
let input: DeriveInput = parse_quote! {
struct Example<T: Clone>
where
T: Send
{
value: T,
}
};
let analysis = analyze_struct(&input).unwrap();
let impl_generics = analysis.impl_generics_tokens();
let type_generics = analysis.type_generics_tokens();
let where_clause = analysis.where_clause_tokens();
assert!(!impl_generics.is_empty());
assert!(!type_generics.is_empty());
assert!(!where_clause.is_empty());
}
}