use ordered_float::OrderedFloat;
use crate::semiring::traits::{
CommutativeTimesSemiring, DivisibleSemiring, KClosedSemiring, QuantizableSemiring, Semiring,
StarSemiring, TotallyOrderedSemiring, WeaklyLeftDivisibleSemiring, ZeroSumFreeSemiring,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ExpectationWeight {
value: OrderedFloat<f64>,
expectation: OrderedFloat<f64>,
}
impl ExpectationWeight {
#[inline]
pub fn new(value: f64, expectation: f64) -> Self {
ExpectationWeight {
value: OrderedFloat(value),
expectation: OrderedFloat(expectation),
}
}
#[inline]
pub fn from_probability(prob: f64) -> Self {
Self::new(prob, 0.0)
}
#[inline]
pub fn from_probability_and_cost(prob: f64, cost: f64) -> Self {
Self::new(prob, prob * cost)
}
#[inline]
pub fn from_log_probability_and_cost(neg_log_prob: f64, cost: f64) -> Self {
let prob = (-neg_log_prob).exp();
Self::from_probability_and_cost(prob, cost)
}
#[inline]
pub fn value(&self) -> f64 {
self.value.into_inner()
}
#[inline]
pub fn expectation(&self) -> f64 {
self.expectation.into_inner()
}
#[inline]
pub fn expected_value(&self) -> Option<f64> {
let v = self.value.into_inner();
if v == 0.0 {
None
} else {
Some(self.expectation.into_inner() / v)
}
}
#[inline]
pub fn components(&self) -> (f64, f64) {
(self.value.into_inner(), self.expectation.into_inner())
}
}
impl Default for ExpectationWeight {
#[inline]
fn default() -> Self {
Self::one()
}
}
impl From<(f64, f64)> for ExpectationWeight {
fn from((value, expectation): (f64, f64)) -> Self {
ExpectationWeight::new(value, expectation)
}
}
impl From<ExpectationWeight> for (f64, f64) {
fn from(w: ExpectationWeight) -> Self {
w.components()
}
}
impl Semiring for ExpectationWeight {
#[inline]
fn zero() -> Self {
ExpectationWeight::new(0.0, 0.0)
}
#[inline]
fn one() -> Self {
ExpectationWeight::new(1.0, 0.0)
}
#[inline]
fn plus(&self, other: &Self) -> Self {
ExpectationWeight::new(
self.value.into_inner() + other.value.into_inner(),
self.expectation.into_inner() + other.expectation.into_inner(),
)
}
#[inline]
fn times(&self, other: &Self) -> Self {
let x1 = self.value.into_inner();
let y1 = self.expectation.into_inner();
let x2 = other.value.into_inner();
let y2 = other.expectation.into_inner();
ExpectationWeight::new(x1 * x2, x1 * y2 + x2 * y1)
}
#[inline]
fn is_zero(&self) -> bool {
self.value.into_inner() == 0.0 && self.expectation.into_inner() == 0.0
}
#[inline]
fn is_one(&self) -> bool {
self.value.into_inner() == 1.0 && self.expectation.into_inner() == 0.0
}
fn approx_eq(&self, other: &Self, epsilon: f64) -> bool {
(self.value.into_inner() - other.value.into_inner()).abs() <= epsilon
&& (self.expectation.into_inner() - other.expectation.into_inner()).abs() <= epsilon
}
fn natural_less(&self, other: &Self) -> Option<bool> {
if (self.value.into_inner() - other.value.into_inner()).abs() < 1e-10 {
Some(self.expectation < other.expectation)
} else {
Some(self.value > other.value)
}
}
fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
bytes.extend(self.value.into_inner().to_le_bytes());
bytes.extend(self.expectation.into_inner().to_le_bytes());
bytes
}
}
impl DivisibleSemiring for ExpectationWeight {
fn divide(&self, other: &Self) -> Option<Self> {
let x1 = self.value.into_inner();
let y1 = self.expectation.into_inner();
let x2 = other.value.into_inner();
let y2 = other.expectation.into_inner();
if x2 == 0.0 {
return None;
}
let a = x1 / x2;
let b = (y1 * x2 - x1 * y2) / (x2 * x2);
Some(ExpectationWeight::new(a, b))
}
}
impl StarSemiring for ExpectationWeight {
fn star(&self) -> Option<Self> {
let x = self.value.into_inner();
let y = self.expectation.into_inner();
if x >= 1.0 {
return None;
}
let one_minus_x = 1.0 - x;
let star_value = 1.0 / one_minus_x;
let star_expectation = y / (one_minus_x * one_minus_x);
Some(ExpectationWeight::new(star_value, star_expectation))
}
}
impl KClosedSemiring for ExpectationWeight {
fn closure_bound() -> Option<usize> {
None
}
}
impl ZeroSumFreeSemiring for ExpectationWeight {}
impl WeaklyLeftDivisibleSemiring for ExpectationWeight {
fn left_divide(&self, divisor: &Self) -> Option<Self> {
self.divide(divisor)
}
}
impl CommutativeTimesSemiring for ExpectationWeight {}
impl TotallyOrderedSemiring for ExpectationWeight {}
impl QuantizableSemiring for ExpectationWeight {
fn quantize(&self, epsilon: f64) -> i64 {
let v = self.value();
let e = self.expectation();
let qv = if v.is_nan() {
i64::MIN
} else if v.is_infinite() {
if v > 0.0 {
i64::MAX / 2
} else {
i64::MIN / 2
}
} else {
(v / epsilon).round() as i64
};
let qe = if e.is_nan() {
0
} else if e.is_infinite() {
if e > 0.0 {
i32::MAX as i64
} else {
i32::MIN as i64
}
} else {
((e / epsilon).round() as i32) as i64
};
(qv.wrapping_shl(32)) ^ (qe & 0xFFFFFFFF)
}
}
impl std::ops::Add for ExpectationWeight {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
self.plus(&other)
}
}
impl std::ops::Mul for ExpectationWeight {
type Output = Self;
#[inline]
fn mul(self, other: Self) -> Self {
self.times(&other)
}
}
impl std::ops::AddAssign for ExpectationWeight {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = self.plus(&other);
}
}
impl std::ops::MulAssign for ExpectationWeight {
#[inline]
fn mul_assign(&mut self, other: Self) {
*self = self.times(&other);
}
}
impl PartialOrd for ExpectationWeight {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ExpectationWeight {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.value.cmp(&other.value) {
std::cmp::Ordering::Equal => self.expectation.cmp(&other.expectation),
ord => ord,
}
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for ExpectationWeight {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
(self.value.into_inner(), self.expectation.into_inner()).serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for ExpectationWeight {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let (value, expectation) = <(f64, f64)>::deserialize(deserializer)?;
Ok(ExpectationWeight::new(value, expectation))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::traits::tests::{
verify_commutative_times_semiring, verify_divisible_semiring, verify_k_closed_semiring,
verify_quantizable_semiring, verify_semiring_axioms, verify_star_semiring,
verify_totally_ordered_semiring, verify_weakly_left_divisible_semiring,
verify_zero_sum_free_semiring,
};
use proptest::prelude::*;
#[test]
fn test_basic_operations() {
let a = ExpectationWeight::new(0.3, 0.6);
let b = ExpectationWeight::new(0.5, 0.5);
let sum = a.plus(&b);
assert!((sum.value() - 0.8).abs() < 1e-10);
assert!((sum.expectation() - 1.1).abs() < 1e-10);
let prod = a.times(&b);
assert!((prod.value() - 0.15).abs() < 1e-10);
assert!((prod.expectation() - 0.45).abs() < 1e-10);
}
#[test]
fn test_identities() {
let a = ExpectationWeight::new(0.5, 0.3);
let sum = a.plus(&ExpectationWeight::zero());
assert!(a.approx_eq(&sum, 1e-10));
let prod = a.times(&ExpectationWeight::one());
assert!(a.approx_eq(&prod, 1e-10));
let prod2 = ExpectationWeight::one().times(&a);
assert!(a.approx_eq(&prod2, 1e-10));
}
#[test]
fn test_annihilation() {
let a = ExpectationWeight::new(0.5, 0.3);
let prod = a.times(&ExpectationWeight::zero());
assert!(prod.is_zero());
let prod2 = ExpectationWeight::zero().times(&a);
assert!(prod2.is_zero());
}
#[test]
fn test_expected_value() {
let path1 = ExpectationWeight::from_probability_and_cost(0.3, 2.0);
assert!((path1.value() - 0.3).abs() < 1e-10);
assert!((path1.expectation() - 0.6).abs() < 1e-10);
let path2 = ExpectationWeight::from_probability_and_cost(0.5, 1.0);
let total = path1.plus(&path2);
let expected = total.expected_value().expect("Non-zero total");
assert!(
(expected - 1.375).abs() < 1e-10,
"Expected 1.375, got {}",
expected
);
}
#[test]
fn test_division() {
let a = ExpectationWeight::new(0.3, 0.6);
let b = ExpectationWeight::new(0.5, 0.4);
let product = a.times(&b);
let quotient = product.divide(&b).expect("Division should succeed");
assert!(
a.approx_eq("ient, 1e-10),
"Division inverse failed: ({}, {}) * ({}, {}) / ({}, {}) = ({}, {}), expected ({}, {})",
a.value(),
a.expectation(),
b.value(),
b.expectation(),
b.value(),
b.expectation(),
quotient.value(),
quotient.expectation(),
a.value(),
a.expectation()
);
assert!(a.divide(&ExpectationWeight::zero()).is_none());
}
#[test]
fn test_star() {
let half = ExpectationWeight::new(0.5, 0.2);
let star = half.star().expect("Star should converge for x < 1");
assert!((star.value() - 2.0).abs() < 1e-10);
assert!((star.expectation() - 0.8).abs() < 1e-10);
let one_plus_w_star = ExpectationWeight::one().plus(&half.times(&star));
assert!(
star.approx_eq(&one_plus_w_star, 1e-6),
"Star axiom failed: ({}, {}) ≠ 1 ⊕ (w ⊗ star) = ({}, {})",
star.value(),
star.expectation(),
one_plus_w_star.value(),
one_plus_w_star.expectation()
);
assert!(ExpectationWeight::one().star().is_none());
assert!(ExpectationWeight::new(1.5, 0.1).star().is_none());
}
#[test]
fn test_multiplicative_identity_property() {
let one = ExpectationWeight::one();
let a = ExpectationWeight::new(0.3, 0.5);
let prod1 = one.times(&a);
assert!(prod1.approx_eq(&a, 1e-10));
let prod2 = a.times(&one);
assert!(prod2.approx_eq(&a, 1e-10));
}
#[test]
fn test_sequential_costs() {
let e1 = ExpectationWeight::from_probability_and_cost(0.5, 2.0);
let e2 = ExpectationWeight::from_probability_and_cost(0.4, 3.0);
let path = e1.times(&e2);
assert!((path.value() - 0.2).abs() < 1e-10);
assert!((path.expectation() - 1.0).abs() < 1e-10);
let expected = path.expected_value().expect("Non-zero path");
assert!(
(expected - 5.0).abs() < 1e-10,
"Expected cost 5, got {}",
expected
);
}
proptest! {
#[test]
fn proptest_semiring_axioms(
v1 in 0.0f64..10.0,
e1 in -10.0f64..10.0,
v2 in 0.0f64..10.0,
e2 in -10.0f64..10.0,
v3 in 0.0f64..10.0,
e3 in -10.0f64..10.0
) {
let wa = ExpectationWeight::new(v1, e1);
let wb = ExpectationWeight::new(v2, e2);
let wc = ExpectationWeight::new(v3, e3);
verify_semiring_axioms(wa, wb, wc, 1e-6);
}
#[test]
fn proptest_divisible_semiring(
v1 in 0.0f64..10.0,
e1 in -10.0f64..10.0,
v2 in 0.001f64..10.0, e2 in -10.0f64..10.0
) {
let wa = ExpectationWeight::new(v1, e1);
let wb = ExpectationWeight::new(v2, e2);
verify_divisible_semiring(wa, wb, 1e-6);
}
#[test]
fn proptest_star_semiring(
v in 0.001f64..0.999,
e in -10.0f64..10.0
) {
let w = ExpectationWeight::new(v, e);
verify_star_semiring(w, 1e-4);
}
#[test]
fn proptest_k_closed_semiring(
v in 0.0f64..10.0,
e in -10.0f64..10.0
) {
let w = ExpectationWeight::new(v, e);
verify_k_closed_semiring(w, 1e-6);
}
#[test]
fn proptest_zero_sum_free_semiring(
v1 in 0.0f64..10.0,
e1 in 0.0f64..10.0, v2 in 0.0f64..10.0,
e2 in 0.0f64..10.0
) {
let wa = ExpectationWeight::new(v1, e1);
let wb = ExpectationWeight::new(v2, e2);
verify_zero_sum_free_semiring(wa, wb, 1e-6);
}
#[test]
fn proptest_weakly_left_divisible_semiring(
v1 in 0.0f64..10.0,
e1 in -10.0f64..10.0,
v2 in 0.0f64..10.0,
e2 in -10.0f64..10.0
) {
let wa = ExpectationWeight::new(v1, e1);
let wb = ExpectationWeight::new(v2, e2);
verify_weakly_left_divisible_semiring(wa, wb, 1e-6);
}
#[test]
fn proptest_commutative_times_semiring(
v1 in 0.0f64..10.0,
e1 in -10.0f64..10.0,
v2 in 0.0f64..10.0,
e2 in -10.0f64..10.0
) {
let wa = ExpectationWeight::new(v1, e1);
let wb = ExpectationWeight::new(v2, e2);
verify_commutative_times_semiring(wa, wb, 1e-6);
}
#[test]
fn proptest_totally_ordered_semiring(
v1 in 0.0f64..10.0,
e1 in -10.0f64..10.0,
v2 in 0.0f64..10.0,
e2 in -10.0f64..10.0,
v3 in 0.0f64..10.0,
e3 in -10.0f64..10.0
) {
let wa = ExpectationWeight::new(v1, e1);
let wb = ExpectationWeight::new(v2, e2);
let wc = ExpectationWeight::new(v3, e3);
verify_totally_ordered_semiring(wa, wb, wc);
}
#[test]
fn proptest_quantizable_semiring(
v in 0.0f64..10.0,
e in -10.0f64..10.0
) {
let wa = ExpectationWeight::new(v, e);
verify_quantizable_semiring(wa, 1e-10);
}
}
#[test]
fn test_k_closed_bound() {
assert_eq!(ExpectationWeight::closure_bound(), None);
}
}