use crate::types::SemiringKind;
#[derive(Debug, Clone, PartialEq)]
pub enum SemiringError {
DomainViolation { value: f64, op: &'static str },
NotSupported {
op: &'static str,
kind: SemiringKind,
},
}
impl std::fmt::Display for SemiringError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DomainViolation { value, op } => write!(
f,
"strict_probability_domain: {op} input {value} is outside [0, 1]"
),
Self::NotSupported { op, kind } => {
write!(f, "semiring {kind:?} does not support {op}")
}
}
}
}
impl std::error::Error for SemiringError {}
pub fn validate_probability_domain(
raw: f64,
op: &'static str,
strict: bool,
) -> Result<f64, SemiringError> {
if (0.0..=1.0).contains(&raw) {
return Ok(raw);
}
if strict {
return Err(SemiringError::DomainViolation { value: raw, op });
}
let clamped = raw.clamp(0.0, 1.0);
tracing::warn!("{op} input {raw} outside [0,1], clamped to {clamped}");
Ok(clamped)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ResolvedSemiringConfig {
pub kind: SemiringKind,
pub strict_probability_domain: bool,
pub probability_epsilon: f64,
pub max_bdd_variables: usize,
}
impl ResolvedSemiringConfig {
pub fn is_add_mult_prob(&self) -> bool {
matches!(self.kind, SemiringKind::AddMultProb)
}
}
pub trait LocySemiring: Send + Sync + 'static {
type Tag: Clone + Send + Sync;
fn kind(&self) -> SemiringKind;
fn is_row_at_a_time(&self) -> bool {
true
}
fn zero_disjunction(&self) -> Self::Tag;
fn one_conjunction(&self) -> Self::Tag;
fn plus(&self, a: &Self::Tag, b: &Self::Tag) -> Self::Tag;
fn times(&self, a: &Self::Tag, b: &Self::Tag) -> Self::Tag;
fn negate(&self, a: &Self::Tag) -> Result<Self::Tag, SemiringError>;
fn weight(&self, a: &Self::Tag) -> f64;
fn validate_domain(
&self,
raw: f64,
op: &'static str,
strict: bool,
) -> Result<f64, SemiringError>;
}
#[derive(Debug, Clone, Copy)]
pub struct AddMultProb {
pub probability_epsilon: f64,
}
impl AddMultProb {
pub fn new(probability_epsilon: f64) -> Self {
Self {
probability_epsilon,
}
}
}
impl Default for AddMultProb {
fn default() -> Self {
Self {
probability_epsilon: 1e-15,
}
}
}
impl LocySemiring for AddMultProb {
type Tag = f64;
fn kind(&self) -> SemiringKind {
SemiringKind::AddMultProb
}
fn zero_disjunction(&self) -> f64 {
0.0
}
fn one_conjunction(&self) -> f64 {
1.0
}
fn plus(&self, a: &f64, b: &f64) -> f64 {
1.0 - (1.0 - *a) * (1.0 - *b)
}
fn times(&self, a: &f64, b: &f64) -> f64 {
if *a < self.probability_epsilon || *b < self.probability_epsilon {
let la = a.max(self.probability_epsilon).ln();
let lb = b.max(self.probability_epsilon).ln();
(la + lb).exp()
} else {
*a * *b
}
}
fn negate(&self, a: &f64) -> Result<f64, SemiringError> {
Ok(1.0 - *a)
}
fn weight(&self, a: &f64) -> f64 {
*a
}
fn validate_domain(
&self,
raw: f64,
op: &'static str,
strict: bool,
) -> Result<f64, SemiringError> {
validate_probability_domain(raw, op, strict)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MaxMinProb;
impl LocySemiring for MaxMinProb {
type Tag = f64;
fn kind(&self) -> SemiringKind {
SemiringKind::MaxMinProb
}
fn zero_disjunction(&self) -> f64 {
0.0
}
fn one_conjunction(&self) -> f64 {
1.0
}
fn plus(&self, a: &f64, b: &f64) -> f64 {
a.max(*b)
}
fn times(&self, a: &f64, b: &f64) -> f64 {
a.min(*b)
}
fn negate(&self, a: &f64) -> Result<f64, SemiringError> {
Ok(1.0 - *a)
}
fn weight(&self, a: &f64) -> f64 {
*a
}
fn validate_domain(
&self,
raw: f64,
op: &'static str,
strict: bool,
) -> Result<f64, SemiringError> {
validate_probability_domain(raw, op, strict)
}
}
#[derive(Debug, Clone, Copy)]
pub enum SemiringDispatch {
AddMultProb(AddMultProb),
MaxMinProb(MaxMinProb),
TopKProofs { inner: AddMultProb, k: u32 },
}
impl SemiringDispatch {
pub fn new(kind: SemiringKind, probability_epsilon: f64) -> Self {
match kind {
SemiringKind::AddMultProb | SemiringKind::BddExact => {
Self::AddMultProb(AddMultProb::new(probability_epsilon))
}
SemiringKind::MaxMinProb => Self::MaxMinProb(MaxMinProb),
SemiringKind::TopKProofs { k } => {
tracing::warn!(
"TopKProofs(k={k}) runtime tag flow pending Stage 2 — \
falling back to AddMultProb row math; library-layer \
TopKProofs<K> math is available via uni_locy::top_k_proofs"
);
Self::TopKProofs {
inner: AddMultProb::new(probability_epsilon),
k,
}
}
}
}
pub fn kind(&self) -> SemiringKind {
match self {
Self::AddMultProb(sr) => sr.kind(),
Self::MaxMinProb(sr) => sr.kind(),
Self::TopKProofs { inner, .. } => inner.kind(),
}
}
pub fn top_k(&self) -> Option<u32> {
match self {
Self::TopKProofs { k, .. } => Some(*k),
_ => None,
}
}
pub fn plus(&self, a: f64, b: f64) -> f64 {
match self {
Self::AddMultProb(sr) => sr.plus(&a, &b),
Self::MaxMinProb(sr) => sr.plus(&a, &b),
Self::TopKProofs { inner, .. } => inner.plus(&a, &b),
}
}
pub fn times(&self, a: f64, b: f64) -> f64 {
match self {
Self::AddMultProb(sr) => sr.times(&a, &b),
Self::MaxMinProb(sr) => sr.times(&a, &b),
Self::TopKProofs { inner, .. } => inner.times(&a, &b),
}
}
pub fn validate_domain(
&self,
raw: f64,
op: &'static str,
strict: bool,
) -> Result<f64, SemiringError> {
match self {
Self::AddMultProb(sr) => sr.validate_domain(raw, op, strict),
Self::MaxMinProb(sr) => sr.validate_domain(raw, op, strict),
Self::TopKProofs { inner, .. } => inner.validate_domain(raw, op, strict),
}
}
pub fn plus_tag(
&self,
a: &AggregatorValue,
b: &AggregatorValue,
) -> (AggregatorValue, Option<crate::top_k_proofs::PruneNotice>) {
match (self, a, b) {
(Self::AddMultProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
(AggregatorValue::F64(sr.plus(x, y)), None)
}
(Self::MaxMinProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
(AggregatorValue::F64(sr.plus(x, y)), None)
}
(Self::TopKProofs { k, .. }, AggregatorValue::TopK(ta), AggregatorValue::TopK(tb)) => {
let (proofs, notice) = merge_top_k_dispatch(ta, tb, *k as usize);
(
AggregatorValue::TopK(crate::top_k_proofs::TopKTag { proofs }),
Some(notice),
)
}
_ => unreachable!(
"SemiringDispatch::plus_tag: type mismatch — dispatch {:?} vs ({:?}, {:?})",
self.kind(),
std::mem::discriminant(a),
std::mem::discriminant(b),
),
}
}
pub fn times_tag(
&self,
a: &AggregatorValue,
b: &AggregatorValue,
) -> (AggregatorValue, Option<crate::top_k_proofs::PruneNotice>) {
match (self, a, b) {
(Self::AddMultProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
(AggregatorValue::F64(sr.times(x, y)), None)
}
(Self::MaxMinProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
(AggregatorValue::F64(sr.times(x, y)), None)
}
(Self::TopKProofs { k, .. }, AggregatorValue::TopK(ta), AggregatorValue::TopK(tb)) => {
if ta.proofs.is_empty() || tb.proofs.is_empty() {
return (
AggregatorValue::TopK(crate::top_k_proofs::TopKTag::zero()),
None,
);
}
let mut cart: Vec<crate::top_k_proofs::Proof> =
Vec::with_capacity(ta.proofs.len() * tb.proofs.len());
for pa in &ta.proofs {
for pb in &tb.proofs {
let mut nc = pa.neural_calls.clone();
let existing: std::collections::HashSet<u32> =
pa.neural_calls.iter().map(|c| c.0).collect();
for c in &pb.neural_calls {
if !existing.contains(&c.0) {
nc.push(*c);
}
}
cart.push(crate::top_k_proofs::Proof {
weight: pa.weight * pb.weight,
base_rvs: crate::dependency_dnf::BaseRvSet::union(
&pa.base_rvs,
&pb.base_rvs,
),
neural_calls: nc,
});
}
}
let (proofs, notice) = merge_top_k_dispatch_owned(Vec::new(), cart, *k as usize);
(
AggregatorValue::TopK(crate::top_k_proofs::TopKTag { proofs }),
Some(notice),
)
}
_ => unreachable!(
"SemiringDispatch::times_tag: type mismatch — dispatch {:?} vs ({:?}, {:?})",
self.kind(),
std::mem::discriminant(a),
std::mem::discriminant(b),
),
}
}
pub fn zero_tag(&self) -> AggregatorValue {
match self {
Self::AddMultProb(_) | Self::MaxMinProb(_) => AggregatorValue::F64(0.0),
Self::TopKProofs { .. } => AggregatorValue::TopK(crate::top_k_proofs::TopKTag::zero()),
}
}
pub fn singleton_tag(
&self,
weight: f64,
base_rvs: crate::dependency_dnf::BaseRvSet,
neural_calls: Vec<crate::top_k_proofs::NeuralCallId>,
) -> AggregatorValue {
match self {
Self::AddMultProb(_) | Self::MaxMinProb(_) => AggregatorValue::F64(weight),
Self::TopKProofs { .. } => AggregatorValue::TopK(crate::top_k_proofs::TopKTag {
proofs: vec![crate::top_k_proofs::Proof {
weight,
base_rvs,
neural_calls,
}],
}),
}
}
pub fn weight_of(&self, value: &AggregatorValue) -> f64 {
match (self, value) {
(Self::AddMultProb(_) | Self::MaxMinProb(_), AggregatorValue::F64(v)) => *v,
(Self::TopKProofs { .. }, AggregatorValue::TopK(t)) => {
let mut complement = 1.0;
for p in &t.proofs {
complement *= 1.0 - p.weight;
}
(1.0 - complement).clamp(0.0, 1.0)
}
_ => unreachable!(
"SemiringDispatch::weight_of: type mismatch — dispatch {:?} vs {:?}",
self.kind(),
std::mem::discriminant(value),
),
}
}
}
#[derive(Debug, Clone)]
pub enum AggregatorValue {
F64(f64),
TopK(crate::top_k_proofs::TopKTag),
}
impl AggregatorValue {
pub fn f64(v: f64) -> Self {
AggregatorValue::F64(v)
}
}
fn merge_top_k_dispatch(
a: &crate::top_k_proofs::TopKTag,
b: &crate::top_k_proofs::TopKTag,
k: usize,
) -> (
Vec<crate::top_k_proofs::Proof>,
crate::top_k_proofs::PruneNotice,
) {
merge_top_k_dispatch_owned(a.proofs.clone(), b.proofs.clone(), k)
}
pub fn merge_top_k_dispatch_owned(
base: Vec<crate::top_k_proofs::Proof>,
additional: Vec<crate::top_k_proofs::Proof>,
k: usize,
) -> (
Vec<crate::top_k_proofs::Proof>,
crate::top_k_proofs::PruneNotice,
) {
crate::top_k_proofs::merge_top_k_with(base, additional, k)
}
impl Default for SemiringDispatch {
fn default() -> Self {
Self::AddMultProb(AddMultProb::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_mult_prob_matches_pre_refactor_noisy_or() {
let sr = AddMultProb::default();
let mut acc = sr.zero_disjunction();
for p in [0.72, 0.54, 0.56, 0.42] {
acc = sr.plus(&acc, &p);
}
assert!((acc - 0.967_130_24).abs() < 1e-9, "got {acc}");
}
#[test]
fn add_mult_prob_product_underflow_safe() {
let sr = AddMultProb::new(1e-12);
let r = sr.times(&1e-20, &1e-20);
assert!(r.is_finite());
assert!(r >= 0.0);
}
#[test]
fn max_min_prob_viterbi() {
let sr = MaxMinProb;
assert_eq!(sr.plus(&0.3, &0.7), 0.7);
assert_eq!(sr.times(&0.3, &0.7), 0.3);
}
#[test]
fn strict_domain_violation() {
let sr = AddMultProb::default();
assert!(matches!(
sr.validate_domain(1.5, "MNOR", true),
Err(SemiringError::DomainViolation { .. })
));
assert_eq!(sr.validate_domain(1.5, "MNOR", false).unwrap(), 1.0);
}
#[test]
fn max_min_prob_strict_domain_violation() {
let sr = MaxMinProb;
assert!(matches!(
sr.validate_domain(-0.1, "MPROD", true),
Err(SemiringError::DomainViolation { .. })
));
assert_eq!(sr.validate_domain(-0.1, "MPROD", false).unwrap(), 0.0);
assert_eq!(sr.validate_domain(2.0, "MNOR", false).unwrap(), 1.0);
}
#[test]
fn identities_are_correct() {
let add = AddMultProb::default();
assert_eq!(add.zero_disjunction(), 0.0);
assert_eq!(add.one_conjunction(), 1.0);
let max = MaxMinProb;
assert_eq!(max.zero_disjunction(), 0.0);
assert_eq!(max.one_conjunction(), 1.0);
}
#[test]
fn dispatch_routes_to_correct_impl() {
let add = SemiringDispatch::new(SemiringKind::AddMultProb, 1e-15);
let max = SemiringDispatch::new(SemiringKind::MaxMinProb, 1e-15);
assert_eq!(add.plus(0.3, 0.5), 1.0 - 0.7 * 0.5); assert_eq!(max.plus(0.3, 0.5), 0.5);
assert_eq!(add.times(0.3, 0.5), 0.15);
assert_eq!(max.times(0.3, 0.5), 0.3);
let bdd = SemiringDispatch::new(SemiringKind::BddExact, 1e-15);
assert_eq!(bdd.kind(), SemiringKind::AddMultProb);
assert_eq!(bdd.plus(0.3, 0.5), 1.0 - 0.7 * 0.5);
}
}