use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct BaseRv(pub u32);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BaseRvSet {
Inline { bits: u64, base: u32 },
Vec(Vec<u32>),
}
impl Default for BaseRvSet {
fn default() -> Self {
Self::empty()
}
}
impl BaseRvSet {
pub fn empty() -> Self {
Self::Inline { bits: 0, base: 0 }
}
pub fn single(rv: BaseRv) -> Self {
Self::Inline {
bits: 1,
base: rv.0,
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Inline { bits, .. } => *bits == 0,
Self::Vec(v) => v.is_empty(),
}
}
pub fn len(&self) -> usize {
match self {
Self::Inline { bits, .. } => bits.count_ones() as usize,
Self::Vec(v) => v.len(),
}
}
pub fn contains(&self, rv: BaseRv) -> bool {
match self {
Self::Inline { bits, base } => {
let offset = rv.0.wrapping_sub(*base);
offset < 64 && (*bits >> offset) & 1 == 1
}
Self::Vec(v) => v.binary_search(&rv.0).is_ok(),
}
}
pub fn insert(&mut self, rv: BaseRv) {
if self.contains(rv) {
return;
}
match self {
Self::Inline { bits, base } => {
if *bits == 0 {
*base = rv.0;
*bits = 1;
return;
}
if rv.0 >= *base && rv.0 - *base < 64 {
*bits |= 1u64 << (rv.0 - *base);
return;
}
let mut all: Vec<u32> = self.iter().map(|r| r.0).collect();
all.push(rv.0);
all.sort_unstable();
all.dedup();
*self = Self::Vec(all);
}
Self::Vec(v) => {
let pos = v.partition_point(|x| *x < rv.0);
v.insert(pos, rv.0);
}
}
}
pub fn iter(&self) -> Box<dyn Iterator<Item = BaseRv> + '_> {
match self {
Self::Inline { bits, base } => {
let bits = *bits;
let base = *base;
Box::new((0u32..64).filter_map(move |i| {
if (bits >> i) & 1 == 1 {
Some(BaseRv(base + i))
} else {
None
}
}))
}
Self::Vec(v) => Box::new(v.iter().map(|x| BaseRv(*x))),
}
}
pub fn union(a: &Self, b: &Self) -> Self {
let mut out = a.clone();
for rv in b.iter() {
out.insert(rv);
}
out
}
pub fn intersect_any(a: &Self, b: &Self) -> bool {
match (a, b) {
(Self::Inline { bits: ba, base: ea }, Self::Inline { bits: bb, base: eb }) => {
if *ea == *eb {
return (*ba & *bb) != 0;
}
a.iter().any(|rv| b.contains(rv))
}
_ => a.iter().any(|rv| b.contains(rv)),
}
}
pub fn is_subset_of(a: &Self, b: &Self) -> bool {
a.iter().all(|rv| b.contains(rv))
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct DependencyDnf {
pub clauses: Vec<BaseRvSet>,
}
impl DependencyDnf {
pub fn new() -> Self {
Self::default()
}
pub fn single_clause(clause: BaseRvSet) -> Self {
Self {
clauses: vec![clause],
}
}
pub fn or_with(&mut self, other: &Self) {
self.clauses.extend(other.clauses.iter().cloned());
}
pub fn and_with(&mut self, other: &Self) {
let mut new_clauses = Vec::with_capacity(self.clauses.len() * other.clauses.len());
for a in &self.clauses {
for b in &other.clauses {
new_clauses.push(BaseRvSet::union(a, b));
}
}
self.clauses = new_clauses;
}
pub fn weight(&self, base_weights: &HashMap<BaseRv, f64>) -> f64 {
if self.clauses.is_empty() {
return 0.0;
}
let n = self.clauses.len();
if n > 24 {
let mut complement = 1.0;
for clause in &self.clauses {
complement *= 1.0 - clause_weight(clause, base_weights);
}
return 1.0 - complement;
}
let mut acc = 0.0f64;
for mask in 1u32..(1u32 << n) {
let mut union = BaseRvSet::empty();
let mut bits = mask;
while bits != 0 {
let i = bits.trailing_zeros() as usize;
union = BaseRvSet::union(&union, &self.clauses[i]);
bits &= bits - 1;
}
let p = clause_weight(&union, base_weights);
let popcount = mask.count_ones() as i32;
let sign = if popcount % 2 == 1 { 1.0 } else { -1.0 };
acc += sign * p;
}
acc.clamp(0.0, 1.0)
}
}
fn clause_weight(clause: &BaseRvSet, base_weights: &HashMap<BaseRv, f64>) -> f64 {
if clause.is_empty() {
return 1.0;
}
let mut p = 1.0;
for rv in clause.iter() {
let w = base_weights.get(&rv).copied().unwrap_or(1.0);
p *= w;
}
p
}
#[cfg(test)]
mod tests {
use super::*;
fn rvset(rvs: &[u32]) -> BaseRvSet {
let mut s = BaseRvSet::empty();
for r in rvs {
s.insert(BaseRv(*r));
}
s
}
fn weights(pairs: &[(u32, f64)]) -> HashMap<BaseRv, f64> {
pairs.iter().map(|(r, w)| (BaseRv(*r), *w)).collect()
}
#[test]
fn rvset_empty_and_single() {
let e = BaseRvSet::empty();
assert!(e.is_empty());
assert_eq!(e.len(), 0);
let s = BaseRvSet::single(BaseRv(7));
assert!(!s.is_empty());
assert_eq!(s.len(), 1);
assert!(s.contains(BaseRv(7)));
assert!(!s.contains(BaseRv(8)));
}
#[test]
fn rvset_insert_within_window() {
let mut s = BaseRvSet::empty();
s.insert(BaseRv(3));
s.insert(BaseRv(5));
s.insert(BaseRv(3)); assert_eq!(s.len(), 2);
assert!(s.contains(BaseRv(3)));
assert!(s.contains(BaseRv(5)));
let collected: Vec<u32> = s.iter().map(|r| r.0).collect();
assert_eq!(collected, vec![3, 5]);
}
#[test]
fn rvset_insert_outside_window_spills_to_vec() {
let mut s = BaseRvSet::empty();
s.insert(BaseRv(0));
s.insert(BaseRv(100));
assert!(matches!(s, BaseRvSet::Vec(_)));
assert_eq!(s.len(), 2);
assert!(s.contains(BaseRv(0)));
assert!(s.contains(BaseRv(100)));
}
#[test]
fn rvset_union_and_intersect() {
let a = rvset(&[1, 3, 5]);
let b = rvset(&[3, 5, 7]);
let u = BaseRvSet::union(&a, &b);
assert_eq!(u.len(), 4);
for r in [1, 3, 5, 7] {
assert!(u.contains(BaseRv(r)));
}
assert!(BaseRvSet::intersect_any(&a, &b));
let c = rvset(&[2, 4]);
assert!(!BaseRvSet::intersect_any(&a, &c));
}
#[test]
fn rvset_subset() {
let a = rvset(&[1, 3]);
let b = rvset(&[1, 3, 5]);
assert!(BaseRvSet::is_subset_of(&a, &b));
assert!(!BaseRvSet::is_subset_of(&b, &a));
}
#[test]
fn dnf_empty_is_zero() {
let d = DependencyDnf::new();
assert_eq!(d.weight(&HashMap::new()), 0.0);
}
#[test]
fn dnf_single_empty_clause_is_one() {
let d = DependencyDnf::single_clause(BaseRvSet::empty());
assert_eq!(d.weight(&HashMap::new()), 1.0);
}
#[test]
fn dnf_single_rv_clause_returns_rv_weight() {
let d = DependencyDnf::single_clause(rvset(&[1]));
let w = weights(&[(1, 0.42)]);
assert!((d.weight(&w) - 0.42).abs() < 1e-12);
}
#[test]
fn dnf_independent_clauses_match_noisy_or() {
let d = DependencyDnf {
clauses: vec![rvset(&[1]), rvset(&[2])],
};
let w = weights(&[(1, 0.3), (2, 0.5)]);
let expected = 1.0 - (1.0 - 0.3) * (1.0 - 0.5);
assert!(
(d.weight(&w) - expected).abs() < 1e-12,
"got {}",
d.weight(&w)
);
}
#[test]
fn dnf_shared_rv_corrects_for_overlap() {
let d = DependencyDnf {
clauses: vec![rvset(&[1, 2]), rvset(&[1, 3])],
};
let w = weights(&[(1, 0.5), (2, 0.4), (3, 0.6)]);
let p_a = 0.5;
let p_b = 0.4;
let p_c = 0.6;
let expected = p_a * p_b + p_a * p_c - p_a * p_b * p_c;
assert!((d.weight(&w) - expected).abs() < 1e-12);
assert!((d.weight(&w) - 0.38).abs() < 1e-12);
}
#[test]
fn dnf_and_with_distributes() {
let mut a = DependencyDnf {
clauses: vec![rvset(&[1]), rvset(&[2])],
};
let b = DependencyDnf {
clauses: vec![rvset(&[3]), rvset(&[4])],
};
a.and_with(&b);
assert_eq!(a.clauses.len(), 4);
}
#[test]
fn dnf_or_with_concatenates() {
let mut a = DependencyDnf::single_clause(rvset(&[1]));
let b = DependencyDnf::single_clause(rvset(&[2]));
a.or_with(&b);
assert_eq!(a.clauses.len(), 2);
}
}