use syn::{visit::Visit, Expr, ItemFn, ReturnType as SynReturnType, Type};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConstructorReturnType {
OwnedSelf, ResultSelf, OptionSelf, RefSelf, Other, }
pub fn extract_return_type(func: &ItemFn) -> Option<ConstructorReturnType> {
match &func.sig.output {
SynReturnType::Default => None, SynReturnType::Type(_, ty) => classify_return_type(ty),
}
}
fn classify_return_type(ty: &Type) -> Option<ConstructorReturnType> {
match ty {
Type::Path(type_path) => {
let path_str = quote::quote!(#type_path).to_string();
if path_str == "Self" {
Some(ConstructorReturnType::OwnedSelf)
} else if path_str.starts_with("Result < Self") {
Some(ConstructorReturnType::ResultSelf)
} else if path_str.starts_with("Option < Self") {
Some(ConstructorReturnType::OptionSelf)
} else {
Some(ConstructorReturnType::Other)
}
}
Type::Reference(type_ref) => {
if let Type::Path(path) = &*type_ref.elem {
let path_str = quote::quote!(#path).to_string();
if path_str == "Self" {
return Some(ConstructorReturnType::RefSelf);
}
}
Some(ConstructorReturnType::Other)
}
_ => Some(ConstructorReturnType::Other),
}
}
pub struct ConstructorPatternVisitor {
pattern: BodyPattern,
}
impl ConstructorPatternVisitor {
pub fn new() -> Self {
Self {
pattern: BodyPattern::default(),
}
}
pub fn into_pattern(self) -> BodyPattern {
self.pattern
}
}
impl Default for ConstructorPatternVisitor {
fn default() -> Self {
Self::new()
}
}
impl<'ast> Visit<'ast> for ConstructorPatternVisitor {
fn visit_expr(&mut self, expr: &'ast Expr) {
match expr {
Expr::Struct(_) => {
self.pattern.struct_init_count += 1;
}
Expr::Path(path) => {
let path_str = quote::quote!(#path).to_string();
if path_str.starts_with("Self") {
self.pattern.self_refs += 1;
}
}
Expr::If(_) => self.pattern.has_if = true,
Expr::Match(_) => self.pattern.has_match = true,
Expr::Loop(_) | Expr::While(_) | Expr::ForLoop(_) => {
self.pattern.has_loop = true;
}
Expr::Return(_) => self.pattern.early_returns += 1,
Expr::Field(_) | Expr::Assign(_) => {
self.pattern.field_assignments += 1;
}
_ => {}
}
syn::visit::visit_expr(self, expr);
}
}
pub fn analyze_function_body(func: &ItemFn) -> BodyPattern {
let mut visitor = ConstructorPatternVisitor::new();
visitor.visit_block(&func.block);
visitor.into_pattern()
}
#[derive(Debug, Clone, Default)]
pub struct BodyPattern {
pub struct_init_count: usize,
pub self_refs: usize,
pub field_assignments: usize,
pub has_if: bool,
pub has_match: bool,
pub has_loop: bool,
pub early_returns: usize,
}
impl BodyPattern {
pub fn is_constructor_like(&self) -> bool {
(self.struct_init_count > 0 && !self.has_loop)
||
(self.self_refs > 0 && !self.has_loop && !self.has_match && self.field_assignments == 0)
}
#[allow(dead_code)]
pub fn is_builder_like(&self) -> bool {
self.field_assignments > 0 && self.early_returns <= 1 && !self.has_loop
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_extract_return_type_owned_self() {
let func: ItemFn = parse_quote! {
fn new() -> Self {
Self { field: 0 }
}
};
assert_eq!(
extract_return_type(&func),
Some(ConstructorReturnType::OwnedSelf)
);
}
#[test]
fn test_extract_return_type_result_self() {
let func: ItemFn = parse_quote! {
fn try_new() -> Result<Self, Error> {
Ok(Self { field: 0 })
}
};
assert_eq!(
extract_return_type(&func),
Some(ConstructorReturnType::ResultSelf)
);
}
#[test]
fn test_extract_return_type_option_self() {
let func: ItemFn = parse_quote! {
fn maybe_new() -> Option<Self> {
Some(Self { field: 0 })
}
};
assert_eq!(
extract_return_type(&func),
Some(ConstructorReturnType::OptionSelf)
);
}
#[test]
fn test_extract_return_type_ref_self() {
let func: ItemFn = parse_quote! {
fn get_self(&self) -> &Self {
self
}
};
assert_eq!(
extract_return_type(&func),
Some(ConstructorReturnType::RefSelf)
);
}
#[test]
fn test_extract_return_type_other() {
let func: ItemFn = parse_quote! {
fn get_value() -> i32 {
42
}
};
assert_eq!(
extract_return_type(&func),
Some(ConstructorReturnType::Other)
);
}
#[test]
fn test_extract_return_type_none() {
let func: ItemFn = parse_quote! {
fn do_something() {
println!("Hello");
}
};
assert_eq!(extract_return_type(&func), None);
}
#[test]
fn test_analyze_function_body_struct_init() {
let func: ItemFn = parse_quote! {
fn new() -> Self {
Self { field: 0 }
}
};
let pattern = analyze_function_body(&func);
assert_eq!(pattern.struct_init_count, 1);
assert!(pattern.is_constructor_like());
}
#[test]
fn test_analyze_function_body_with_loop() {
let func: ItemFn = parse_quote! {
fn process_items() -> Self {
let mut result = Self::new();
for item in items {
result.add(item);
}
result
}
};
let pattern = analyze_function_body(&func);
assert!(pattern.has_loop);
assert!(!pattern.is_constructor_like());
}
#[test]
fn test_analyze_function_body_self_refs() {
let func: ItemFn = parse_quote! {
fn default() -> Self {
Self::new()
}
};
let pattern = analyze_function_body(&func);
assert!(pattern.self_refs > 0);
assert!(pattern.is_constructor_like());
}
#[test]
fn test_body_pattern_is_builder_like() {
let func: ItemFn = parse_quote! {
fn set_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
};
let pattern = analyze_function_body(&func);
assert!(pattern.field_assignments > 0);
assert!(pattern.is_builder_like());
}
}