use std::collections::BTreeSet;
use antigen_fingerprint::{Constraint, Fingerprint, ItemKind};
use syn::visit::Visit;
use crate::learn::self_tolerance;
#[must_use]
pub fn anti_unify(cluster: &[syn::Item]) -> Option<Fingerprint> {
if cluster.is_empty() {
return None;
}
let features: Vec<MemberFeatures> = cluster.iter().map(MemberFeatures::extract).collect();
let mut conjuncts: Vec<Constraint> = Vec::new();
let first_kind = features[0].item_kind;
let shared_kind = first_kind.filter(|k| features.iter().all(|f| f.item_kind == Some(*k)));
let shared_kind = shared_kind?; conjuncts.push(Constraint::Item(shared_kind));
if let Some(trait_name) = &features[0].impl_of_trait {
if features
.iter()
.all(|f| f.impl_of_trait.as_deref() == Some(trait_name.as_str()))
{
conjuncts.push(Constraint::ImplOfTrait(trait_name.clone()));
}
}
let shared_signals: BTreeSet<BodySignal> = features
.iter()
.map(|f| f.body_signals.clone())
.reduce(|acc, s| acc.intersection(&s).cloned().collect())
.unwrap_or_default();
let all_signals: BTreeSet<BodySignal> = features
.iter()
.flat_map(|f| f.body_signals.clone())
.collect();
let discriminating: BTreeSet<BodySignal> =
all_signals.difference(&shared_signals).cloned().collect();
for sig in &shared_signals {
conjuncts.push(sig.to_constraint());
}
let every_member_has_a_discriminating_signal = features
.iter()
.all(|f| f.body_signals.iter().any(|s| discriminating.contains(s)));
if discriminating.len() >= 2 && every_member_has_a_discriminating_signal {
let arms: Vec<Constraint> = discriminating
.into_iter()
.map(|s| s.to_constraint())
.collect();
conjuncts.push(Constraint::AnyOf(arms));
} else if discriminating.len() == 1 && every_member_has_a_discriminating_signal {
let only = discriminating.into_iter().next().expect("len == 1");
conjuncts.push(only.to_constraint());
}
Some(Fingerprint {
constraints: conjuncts,
})
}
#[must_use]
pub fn propose(cluster: &[syn::Item], clean_corpus: &[syn::Item]) -> Option<Fingerprint> {
let draft = anti_unify(cluster)?;
self_tolerance::promote_if_safe(draft, clean_corpus)
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum BodySignal {
Call(String),
Macro(String),
}
impl BodySignal {
fn to_constraint(&self) -> Constraint {
match self {
Self::Call(n) => Constraint::BodyCalls(n.clone()),
Self::Macro(n) => Constraint::BodyContainsMacro(n.clone()),
}
}
}
struct MemberFeatures {
item_kind: Option<ItemKind>,
impl_of_trait: Option<String>,
body_signals: BTreeSet<BodySignal>,
}
impl MemberFeatures {
fn extract(item: &syn::Item) -> Self {
Self {
item_kind: item_kind_of(item),
impl_of_trait: impl_trait_last_segment(item),
body_signals: collect_body_signals(item),
}
}
}
const fn item_kind_of(item: &syn::Item) -> Option<ItemKind> {
Some(match item {
syn::Item::Struct(_) => ItemKind::Struct,
syn::Item::Enum(_) => ItemKind::Enum,
syn::Item::Trait(_) => ItemKind::Trait,
syn::Item::Fn(_) => ItemKind::Fn,
syn::Item::Impl(_) => ItemKind::Impl,
syn::Item::Type(_) => ItemKind::Type,
syn::Item::Mod(_) => ItemKind::Mod,
syn::Item::Const(_) => ItemKind::Const,
syn::Item::Static(_) => ItemKind::Static,
syn::Item::Union(_) => ItemKind::Union,
_ => return None,
})
}
fn impl_trait_last_segment(item: &syn::Item) -> Option<String> {
let syn::Item::Impl(imp) = item else {
return None;
};
let (_, path, _) = imp.trait_.as_ref()?;
Some(path.segments.last()?.ident.to_string())
}
fn collect_body_signals(item: &syn::Item) -> BTreeSet<BodySignal> {
struct SignalCollector {
signals: BTreeSet<BodySignal>,
}
impl<'ast> Visit<'ast> for SignalCollector {
fn visit_expr_call(&mut self, call: &'ast syn::ExprCall) {
if let syn::Expr::Path(p) = call.func.as_ref() {
if let Some(last) = p.path.segments.last() {
self.signals
.insert(BodySignal::Call(last.ident.to_string()));
}
}
syn::visit::visit_expr_call(self, call);
}
fn visit_expr_method_call(&mut self, call: &'ast syn::ExprMethodCall) {
self.signals
.insert(BodySignal::Call(call.method.to_string()));
syn::visit::visit_expr_method_call(self, call);
}
fn visit_macro(&mut self, mac: &'ast syn::Macro) {
if let Some(last) = mac.path.segments.last() {
self.signals
.insert(BodySignal::Macro(last.ident.to_string()));
}
syn::visit::visit_macro(self, mac);
}
}
let mut collector = SignalCollector {
signals: BTreeSet::new(),
};
match item {
syn::Item::Fn(f) => collector.visit_block(&f.block),
syn::Item::Impl(imp) => {
for impl_item in &imp.items {
if let syn::ImplItem::Fn(f) = impl_item {
collector.visit_block(&f.block);
}
}
},
_ => {},
}
collector.signals
}
#[cfg(test)]
mod tests {
use super::*;
fn items(src: &str) -> Vec<syn::Item> {
syn::parse_file(src).expect("parses").items
}
fn drop_impl_for(items: &[syn::Item], ty: &str) -> syn::Item {
items
.iter()
.find(|it| {
let syn::Item::Impl(i) = it else { return false };
let Some((_, tp, _)) = &i.trait_ else {
return false;
};
let is_drop = tp.segments.last().is_some_and(|s| s.ident == "Drop");
let syn::Type::Path(p) = &*i.self_ty else {
return false;
};
is_drop && p.path.segments.last().is_some_and(|s| s.ident == ty)
})
.expect("found")
.clone()
}
const DROP_FAMILY: &str = r#"
pub struct GuardA;
impl Drop for GuardA {
fn drop(&mut self) { let _ = flush(self.h).take().unwrap(); }
}
pub struct GuardB;
impl Drop for GuardB {
fn drop(&mut self) { let _ = flush(self.h).take().expect("must"); }
}
pub struct CleanGuard;
impl Drop for CleanGuard {
fn drop(&mut self) { let _ = flush(self.h).take().ok(); }
}
"#;
#[test]
fn anti_unify_binds_the_cluster() {
let fam = items(DROP_FAMILY);
let cluster = vec![drop_impl_for(&fam, "GuardA"), drop_impl_for(&fam, "GuardB")];
let draft = anti_unify(&cluster).expect("non-empty cluster anti-unifies");
for (i, m) in cluster.iter().enumerate() {
assert!(draft.matches(m), "draft must bind cluster member {i}");
}
}
#[test]
fn anti_unify_spares_the_clean_sibling_via_disjunction() {
let fam = items(DROP_FAMILY);
let cluster = vec![drop_impl_for(&fam, "GuardA"), drop_impl_for(&fam, "GuardB")];
let clean = drop_impl_for(&fam, "CleanGuard");
let draft = anti_unify(&cluster).expect("anti-unifies");
assert!(
!draft.matches(&clean),
"anti-unify-to-disjunction must spare the clean sibling"
);
}
#[test]
fn anti_unify_keeps_the_shared_call_as_a_conjunct_and_splits_the_rest() {
let fam = items(DROP_FAMILY);
let cluster = vec![drop_impl_for(&fam, "GuardA"), drop_impl_for(&fam, "GuardB")];
let draft = anti_unify(&cluster).expect("anti-unifies");
let has_take_conjunct = draft
.constraints
.iter()
.any(|c| matches!(c, Constraint::BodyCalls(n) if n == "take"));
assert!(has_take_conjunct, "shared call `take` must be a conjunct");
let has_disjunction = draft.constraints.iter().any(|c| {
matches!(c, Constraint::AnyOf(arms) if arms.iter().all(|a|
matches!(a, Constraint::BodyCalls(n) if n == "unwrap" || n == "expect")))
});
assert!(
has_disjunction,
"distinguishing calls `unwrap`/`expect` must anti-unify to an any_of"
);
}
#[test]
fn propose_promotes_only_through_b() {
let fam = items(DROP_FAMILY);
let cluster = vec![drop_impl_for(&fam, "GuardA"), drop_impl_for(&fam, "GuardB")];
let clean_corpus = vec![drop_impl_for(&fam, "CleanGuard")];
let promoted = propose(&cluster, &clean_corpus).expect("a spare-clean draft promotes");
for m in &cluster {
assert!(promoted.matches(m), "promoted draft must bind the cluster");
}
assert!(
!promoted.matches(&clean_corpus[0]),
"promoted draft must spare clean (it came through B)"
);
}
#[test]
fn propose_returns_none_when_the_draft_binds_clean() {
let fam = items(DROP_FAMILY);
let cluster = vec![drop_impl_for(&fam, "GuardA"), drop_impl_for(&fam, "GuardB")];
let poisoned_corpus = vec![drop_impl_for(&fam, "GuardA")]; assert!(
propose(&cluster, &poisoned_corpus).is_none(),
"B must refuse to promote a draft that binds a (declared-clean) corpus item"
);
}
#[test]
fn anti_unify_mixes_call_and_macro_arms_in_one_disjunction() {
let fam = items(
r#"
struct One;
impl Drop for One { fn drop(&mut self) { teardown(); let _ = work().unwrap(); } }
struct Two;
impl Drop for Two { fn drop(&mut self) { teardown(); if !work() { panic!("boom"); } } }
struct Clean;
impl Drop for Clean { fn drop(&mut self) { teardown(); let _ = work(); } }
"#,
);
let cluster = vec![drop_impl_for(&fam, "One"), drop_impl_for(&fam, "Two")];
let clean = drop_impl_for(&fam, "Clean");
let draft = anti_unify(&cluster).expect("mixed family anti-unifies");
let disjunction = draft.constraints.iter().find_map(|c| match c {
Constraint::AnyOf(arms) => Some(arms),
_ => None,
});
let arms = disjunction.expect("a mixed family produces an any_of disjunction");
let has_call_arm = arms
.iter()
.any(|a| matches!(a, Constraint::BodyCalls(n) if n == "unwrap"));
let has_macro_arm = arms
.iter()
.any(|a| matches!(a, Constraint::BodyContainsMacro(n) if n == "panic"));
assert!(
has_call_arm && has_macro_arm,
"the disjunction must mix body_calls(unwrap) AND body_contains_macro(panic): {arms:?}"
);
assert!(
draft
.constraints
.iter()
.any(|c| matches!(c, Constraint::BodyCalls(n) if n == "teardown")),
"the shared `teardown` call must be a conjunct"
);
for (i, m) in cluster.iter().enumerate() {
assert!(draft.matches(m), "mixed draft must bind member {i}");
}
assert!(
!draft.matches(&clean),
"mixed draft must spare the clean sibling (it reaches neither panic shape)"
);
}
#[test]
fn anti_unify_declines_an_empty_cluster() {
assert!(
anti_unify(&[]).is_none(),
"empty cluster has nothing to generalize"
);
}
#[test]
fn anti_unify_declines_a_heterogeneous_cluster() {
let mixed = items("struct S; impl Drop for S { fn drop(&mut self) {} }");
assert!(
anti_unify(&mixed).is_none(),
"a cluster with no common item-kind must not produce a shapeless draft"
);
}
}