use std::collections::HashSet;
use crate::dependency_dnf::{BaseRvSet, DependencyDnf};
use crate::semiring::{LocySemiring, SemiringError};
use crate::types::SemiringKind;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct NeuralCallId(pub u32);
#[derive(Debug, Clone, PartialEq)]
pub struct Proof {
pub weight: f64,
pub base_rvs: BaseRvSet,
pub neural_calls: Vec<NeuralCallId>,
}
impl Proof {
pub fn tautology() -> Self {
Self {
weight: 1.0,
base_rvs: BaseRvSet::empty(),
neural_calls: Vec::new(),
}
}
fn dependency_key(&self) -> (Vec<u32>, Vec<u32>) {
let mut rvs: Vec<u32> = self.base_rvs.iter().map(|r| r.0).collect();
rvs.sort_unstable();
let mut calls: Vec<u32> = self.neural_calls.iter().map(|c| c.0).collect();
calls.sort_unstable();
(rvs, calls)
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct TopKTag {
pub proofs: Vec<Proof>,
}
impl TopKTag {
pub fn zero() -> Self {
Self { proofs: Vec::new() }
}
pub fn one() -> Self {
Self {
proofs: vec![Proof::tautology()],
}
}
pub fn from_proofs(proofs: Vec<Proof>) -> Self {
Self { proofs }
}
pub fn is_empty(&self) -> bool {
self.proofs.is_empty()
}
pub fn to_dnf(&self) -> DependencyDnf {
DependencyDnf {
clauses: self.proofs.iter().map(|p| p.base_rvs.clone()).collect(),
}
}
}
pub struct TopKProofs<const K: usize>;
impl<const K: usize> TopKProofs<K> {
pub const fn capacity() -> usize {
K
}
pub fn merge_top_k(base: Vec<Proof>, additional: Vec<Proof>) -> (Vec<Proof>, PruneNotice) {
merge_top_k_with(base, additional, K)
}
}
pub fn merge_top_k_with(
mut base: Vec<Proof>,
additional: Vec<Proof>,
k: usize,
) -> (Vec<Proof>, PruneNotice) {
base.extend(additional);
let mut keep: Vec<Proof> = Vec::with_capacity(base.len());
let mut seen: std::collections::HashMap<(Vec<u32>, Vec<u32>), usize> =
std::collections::HashMap::new();
for p in base.drain(..) {
let key = p.dependency_key();
match seen.get(&key) {
Some(&idx) => {
if p.weight > keep[idx].weight {
keep[idx] = p;
}
}
None => {
seen.insert(key, keep.len());
keep.push(p);
}
}
}
keep.sort_by(|a, b| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(std::cmp::Ordering::Equal)
});
if keep.len() <= k {
return (keep, PruneNotice::None);
}
let (retained, dropped) = keep.split_at(k);
let mut crossed = false;
for d in dropped {
if retained
.iter()
.any(|r| BaseRvSet::intersect_any(&r.base_rvs, &d.base_rvs))
{
crossed = true;
break;
}
}
let notice = if crossed {
PruneNotice::CrossedDependency
} else {
PruneNotice::Pruned
};
(retained.to_vec(), notice)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PruneNotice {
None,
Pruned,
CrossedDependency,
}
impl<const K: usize> LocySemiring for TopKProofs<K> {
type Tag = TopKTag;
fn kind(&self) -> SemiringKind {
SemiringKind::TopKProofs { k: K as u32 }
}
fn zero_disjunction(&self) -> TopKTag {
TopKTag::zero()
}
fn one_conjunction(&self) -> TopKTag {
TopKTag::one()
}
fn plus(&self, a: &TopKTag, b: &TopKTag) -> TopKTag {
let (proofs, _) = Self::merge_top_k(a.proofs.clone(), b.proofs.clone());
TopKTag { proofs }
}
fn times(&self, a: &TopKTag, b: &TopKTag) -> TopKTag {
if a.proofs.is_empty() || b.proofs.is_empty() {
return TopKTag::zero();
}
let mut cartesian: Vec<Proof> = Vec::with_capacity(a.proofs.len() * b.proofs.len());
for pa in &a.proofs {
for pb in &b.proofs {
let mut nc = pa.neural_calls.clone();
nc.extend(pb.neural_calls.iter().copied());
let mut seen: HashSet<u32> = HashSet::new();
nc.retain(|c| seen.insert(c.0));
cartesian.push(Proof {
weight: pa.weight * pb.weight,
base_rvs: BaseRvSet::union(&pa.base_rvs, &pb.base_rvs),
neural_calls: nc,
});
}
}
let (proofs, _) = Self::merge_top_k(Vec::new(), cartesian);
TopKTag { proofs }
}
fn negate(&self, a: &TopKTag) -> Result<TopKTag, SemiringError> {
let w = self.weight(a);
let complement = (1.0 - w).clamp(0.0, 1.0);
Ok(TopKTag {
proofs: vec![Proof {
weight: complement,
base_rvs: BaseRvSet::empty(),
neural_calls: Vec::new(),
}],
})
}
fn weight(&self, a: &TopKTag) -> f64 {
let mut complement = 1.0;
for p in &a.proofs {
complement *= 1.0 - p.weight;
}
(1.0 - complement).clamp(0.0, 1.0)
}
fn validate_domain(
&self,
raw: f64,
op: &'static str,
strict: bool,
) -> Result<f64, SemiringError> {
crate::semiring::validate_probability_domain(raw, op, strict)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dependency_dnf::BaseRv;
use std::collections::HashMap;
fn proof(weight: f64, rvs: &[u32]) -> Proof {
let mut s = BaseRvSet::empty();
for r in rvs {
s.insert(BaseRv(*r));
}
Proof {
weight,
base_rvs: s,
neural_calls: Vec::new(),
}
}
#[test]
fn empty_tag_is_additive_identity() {
let sr = TopKProofs::<4>;
let z = sr.zero_disjunction();
let t = TopKTag::from_proofs(vec![proof(0.5, &[1])]);
assert_eq!(sr.plus(&z, &t), t.clone());
assert_eq!(sr.plus(&t, &z), t);
}
#[test]
fn one_tag_is_multiplicative_identity() {
let sr = TopKProofs::<4>;
let one = sr.one_conjunction();
let t = TopKTag::from_proofs(vec![proof(0.5, &[1])]);
assert_eq!(sr.times(&one, &t), t.clone());
assert_eq!(sr.times(&t, &one), t);
}
#[test]
fn weight_of_zero_is_zero() {
let sr = TopKProofs::<4>;
assert_eq!(sr.weight(&sr.zero_disjunction()), 0.0);
}
#[test]
fn weight_of_one_is_one() {
let sr = TopKProofs::<4>;
assert_eq!(sr.weight(&sr.one_conjunction()), 1.0);
}
#[test]
fn weight_single_proof() {
let sr = TopKProofs::<4>;
let t = TopKTag::from_proofs(vec![proof(0.3, &[1])]);
assert!((sr.weight(&t) - 0.3).abs() < 1e-12);
}
#[test]
fn weight_independent_proofs_match_noisy_or() {
let sr = TopKProofs::<4>;
let t = TopKTag::from_proofs(vec![proof(0.3, &[1]), proof(0.5, &[2])]);
let expected = 1.0 - (1.0 - 0.3) * (1.0 - 0.5);
assert!((sr.weight(&t) - expected).abs() < 1e-12);
}
#[test]
fn dnf_view_corrects_for_shared_rv() {
let t = TopKTag::from_proofs(vec![proof(0.5 * 0.4, &[1, 2]), proof(0.5 * 0.6, &[1, 3])]);
let dnf = t.to_dnf();
let weights: HashMap<BaseRv, f64> = [(BaseRv(1), 0.5), (BaseRv(2), 0.4), (BaseRv(3), 0.6)]
.into_iter()
.collect();
assert!((dnf.weight(&weights) - 0.38).abs() < 1e-12);
}
#[test]
fn plus_dedups_identical_dependency_proofs() {
let sr = TopKProofs::<4>;
let a = TopKTag::from_proofs(vec![proof(0.4, &[1])]);
let b = TopKTag::from_proofs(vec![proof(0.7, &[1])]);
let result = sr.plus(&a, &b);
assert_eq!(result.proofs.len(), 1);
assert_eq!(result.proofs[0].weight, 0.7);
}
#[test]
fn plus_retains_top_k_by_weight() {
let (kept, notice) = TopKProofs::<2>::merge_top_k(
vec![],
vec![proof(0.1, &[1]), proof(0.9, &[2]), proof(0.5, &[3])],
);
assert_eq!(kept.len(), 2);
assert_eq!(kept[0].weight, 0.9);
assert_eq!(kept[1].weight, 0.5);
assert_eq!(notice, PruneNotice::Pruned);
}
#[test]
fn plus_emits_crossed_dependency_when_pruning_drops_shared_rv() {
let (kept, notice) = TopKProofs::<2>::merge_top_k(
vec![],
vec![
proof(0.9, &[1, 2]),
proof(0.5, &[3, 4]),
proof(0.3, &[1, 5]),
],
);
assert_eq!(kept.len(), 2);
assert_eq!(notice, PruneNotice::CrossedDependency);
}
#[test]
fn times_cartesian_products_proofs() {
let sr = TopKProofs::<4>;
let a = TopKTag::from_proofs(vec![proof(0.5, &[1]), proof(0.6, &[2])]);
let b = TopKTag::from_proofs(vec![proof(0.4, &[3])]);
let result = sr.times(&a, &b);
assert_eq!(result.proofs.len(), 2);
assert!((result.proofs[0].weight - 0.24).abs() < 1e-12);
assert!((result.proofs[1].weight - 0.20).abs() < 1e-12);
assert!(result.proofs[0].base_rvs.contains(BaseRv(2)));
assert!(result.proofs[0].base_rvs.contains(BaseRv(3)));
}
#[test]
fn times_with_zero_is_zero() {
let sr = TopKProofs::<4>;
let z = sr.zero_disjunction();
let t = TopKTag::from_proofs(vec![proof(0.5, &[1])]);
assert_eq!(sr.times(&z, &t), TopKTag::zero());
assert_eq!(sr.times(&t, &z), TopKTag::zero());
}
#[test]
fn negate_collapses_to_degenerate_tag() {
let sr = TopKProofs::<4>;
let t = TopKTag::from_proofs(vec![proof(0.5, &[1])]);
let neg = sr.negate(&t).unwrap();
assert_eq!(neg.proofs.len(), 1);
assert!((neg.proofs[0].weight - 0.5).abs() < 1e-12);
assert_eq!(neg.proofs[0].base_rvs.iter().count(), 0);
assert!(neg.proofs[0].neural_calls.is_empty());
}
#[test]
fn negate_of_two_independent_proofs() {
let sr = TopKProofs::<4>;
let t = TopKTag::from_proofs(vec![proof(0.7, &[1]), proof(0.5, &[2])]);
let neg = sr.negate(&t).unwrap();
assert_eq!(neg.proofs.len(), 1);
assert!((neg.proofs[0].weight - 0.15).abs() < 1e-12);
}
#[test]
fn kind_reports_k() {
assert_eq!(TopKProofs::<4>.kind(), SemiringKind::TopKProofs { k: 4 });
assert_eq!(TopKProofs::<16>.kind(), SemiringKind::TopKProofs { k: 16 });
}
#[test]
fn dedup_after_cartesian_in_times() {
let sr = TopKProofs::<4>;
let a = TopKTag::from_proofs(vec![proof(0.3, &[1]), proof(0.6, &[1])]);
let b = TopKTag::from_proofs(vec![proof(0.5, &[2])]);
let result = sr.times(&a, &b);
assert_eq!(result.proofs.len(), 1);
assert!((result.proofs[0].weight - 0.30).abs() < 1e-12);
}
}