use ordered_float::OrderedFloat;
use crate::semiring::traits::{
CommutativeTimesSemiring, DivisibleSemiring, KClosedSemiring, NonnegativeSemiring,
NumericalWeight, QuantizableSemiring, Semiring, StarSemiring, StochasticSemiring,
TotallyOrderedSemiring, ZeroSumFreeSemiring,
};
pub const DEFAULT_ETA: f64 = 1.0;
#[derive(Clone, Copy, Debug)]
pub struct PowerWeight {
value: OrderedFloat<f64>,
eta: OrderedFloat<f64>,
}
impl PowerWeight {
#[inline]
pub fn new(value: f64, eta: f64) -> Self {
debug_assert!(eta > 0.0, "η must be positive, got {}", eta);
Self {
value: OrderedFloat(value.max(0.0)),
eta: OrderedFloat(eta),
}
}
#[inline]
pub fn with_default_eta(value: f64) -> Self {
Self::new(value, DEFAULT_ETA)
}
#[inline]
pub fn value(&self) -> f64 {
self.value.into_inner()
}
#[inline]
pub fn eta(&self) -> f64 {
self.eta.into_inner()
}
#[inline]
pub fn zero_with_eta(eta: f64) -> Self {
Self::new(0.0, eta)
}
#[inline]
pub fn one_with_eta(eta: f64) -> Self {
Self::new(1.0, eta)
}
#[inline]
pub fn infinity(eta: f64) -> Self {
Self::new(f64::INFINITY, eta)
}
#[inline]
pub fn is_zero_value(&self) -> bool {
self.value.into_inner() == 0.0
}
#[inline]
pub fn is_one_value(&self) -> bool {
(self.value.into_inner() - 1.0).abs() < f64::EPSILON
}
#[inline]
pub fn is_infinite(&self) -> bool {
self.value.is_infinite()
}
#[inline]
pub fn from_probability(prob: f64, eta: f64) -> Self {
Self::new(prob.powf(eta), eta)
}
#[inline]
pub fn to_probability(&self) -> f64 {
let eta = self.eta.into_inner();
if eta == 0.0 {
let v = self.value.into_inner();
if v < 1.0 {
0.0
} else if v == 1.0 {
1.0
} else {
f64::INFINITY
}
} else {
self.value.powf(1.0 / eta)
}
}
#[inline]
fn power_plus(&self, other: &Self) -> Self {
let eta = self.eta.into_inner();
if self.is_zero_value() {
return Self::new(other.value.into_inner(), eta);
}
if other.is_zero_value() {
return Self::new(self.value.into_inner(), eta);
}
if self.is_infinite() || other.is_infinite() {
return Self::infinity(eta);
}
let x_root = self.value.powf(1.0 / eta);
let y_root = other.value.powf(1.0 / eta);
let sum = x_root + y_root;
Self::new(sum.powf(eta), eta)
}
#[inline]
fn check_eta_compatibility(&self, other: &Self) {
debug_assert!(
(self.eta.into_inner() - other.eta.into_inner()).abs() < 1e-10,
"η values must match: {} vs {}",
self.eta,
other.eta
);
}
}
impl PartialEq for PowerWeight {
fn eq(&self, other: &Self) -> bool {
(self.value.into_inner() - other.value.into_inner()).abs() < f64::EPSILON
&& (self.eta.into_inner() - other.eta.into_inner()).abs() < f64::EPSILON
}
}
impl Eq for PowerWeight {}
impl std::hash::Hash for PowerWeight {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.value.to_bits().hash(state);
self.eta.to_bits().hash(state);
}
}
impl PartialOrd for PowerWeight {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PowerWeight {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.value.cmp(&other.value)
}
}
impl Default for PowerWeight {
#[inline]
fn default() -> Self {
Self::one()
}
}
impl From<f64> for PowerWeight {
fn from(value: f64) -> Self {
Self::with_default_eta(value)
}
}
impl From<PowerWeight> for f64 {
fn from(weight: PowerWeight) -> Self {
weight.value()
}
}
impl Semiring for PowerWeight {
#[inline]
fn zero() -> Self {
Self::zero_with_eta(DEFAULT_ETA)
}
#[inline]
fn one() -> Self {
Self::one_with_eta(DEFAULT_ETA)
}
#[inline]
fn plus(&self, other: &Self) -> Self {
self.check_eta_compatibility(other);
self.power_plus(other)
}
#[inline]
fn times(&self, other: &Self) -> Self {
self.check_eta_compatibility(other);
Self::new(
self.value.into_inner() * other.value.into_inner(),
self.eta.into_inner(),
)
}
#[inline]
fn is_zero(&self) -> bool {
self.is_zero_value()
}
#[inline]
fn is_one(&self) -> bool {
self.is_one_value()
}
fn approx_eq(&self, other: &Self, epsilon: f64) -> bool {
if self.is_zero() && other.is_zero() {
return true;
}
if self.is_infinite() && other.is_infinite() {
return true;
}
if self.is_zero() || other.is_zero() || self.is_infinite() || other.is_infinite() {
return false;
}
(self.value.into_inner() - other.value.into_inner()).abs() <= epsilon
&& (self.eta.into_inner() - other.eta.into_inner()).abs() <= epsilon
}
fn natural_less(&self, other: &Self) -> Option<bool> {
Some(self.value < other.value)
}
fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(16);
bytes.extend_from_slice(&self.value.into_inner().to_le_bytes());
bytes.extend_from_slice(&self.eta.into_inner().to_le_bytes());
bytes
}
}
impl DivisibleSemiring for PowerWeight {
fn divide(&self, other: &Self) -> Option<Self> {
self.check_eta_compatibility(other);
if other.is_zero() {
None
} else {
Some(Self::new(
self.value.into_inner() / other.value.into_inner(),
self.eta.into_inner(),
))
}
}
}
impl StarSemiring for PowerWeight {
fn star(&self) -> Option<Self> {
let v = self.value.into_inner();
let eta = self.eta.into_inner();
let prob = v.powf(1.0 / eta);
if prob < 1.0 {
let prob_star = 1.0 / (1.0 - prob);
Some(Self::new(prob_star.powf(eta), eta))
} else if (prob - 1.0).abs() < f64::EPSILON {
None
} else {
None
}
}
}
impl NumericalWeight for PowerWeight {
#[inline]
fn numerical_value(&self) -> f64 {
self.value()
}
}
impl KClosedSemiring for PowerWeight {
fn closure_bound() -> Option<usize> {
None
}
}
impl ZeroSumFreeSemiring for PowerWeight {}
impl CommutativeTimesSemiring for PowerWeight {}
impl TotallyOrderedSemiring for PowerWeight {}
impl NonnegativeSemiring for PowerWeight {}
impl QuantizableSemiring for PowerWeight {
fn quantize(&self, epsilon: f64) -> i64 {
let v = self.value();
if v.is_nan() {
i64::MIN
} else if v.is_infinite() {
i64::MAX
} else {
(v / epsilon).round() as i64
}
}
}
impl StochasticSemiring for PowerWeight {
fn to_probability(&self) -> f64 {
PowerWeight::to_probability(self)
}
}
impl std::ops::Add for PowerWeight {
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
self.plus(&other)
}
}
impl std::ops::Mul for PowerWeight {
type Output = Self;
#[inline]
fn mul(self, other: Self) -> Self {
self.times(&other)
}
}
impl std::ops::AddAssign for PowerWeight {
#[inline]
fn add_assign(&mut self, other: Self) {
*self = self.plus(&other);
}
}
impl std::ops::MulAssign for PowerWeight {
#[inline]
fn mul_assign(&mut self, other: Self) {
*self = self.times(&other);
}
}
impl std::fmt::Display for PowerWeight {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "PowerWeight({}, η={})", self.value, self.eta)
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for PowerWeight {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("PowerWeight", 2)?;
state.serialize_field("value", &self.value.into_inner())?;
state.serialize_field("eta", &self.eta.into_inner())?;
state.end()
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for PowerWeight {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct PowerWeightData {
value: f64,
eta: f64,
}
let data = PowerWeightData::deserialize(deserializer)?;
Ok(PowerWeight::new(data.value, data.eta))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::traits::tests::{
verify_commutative_times_semiring, verify_k_closed_semiring, verify_quantizable_semiring,
verify_stochastic_semiring, verify_totally_ordered_semiring, verify_zero_sum_free_semiring,
};
use proptest::prelude::*;
#[test]
fn test_basic_operations() {
let eta = 2.0;
let a = PowerWeight::new(4.0, eta);
let b = PowerWeight::new(9.0, eta);
let sum = a.plus(&b);
assert!(
(sum.value() - 25.0).abs() < 1e-10,
"Expected 25.0, got {}",
sum.value()
);
let product = a.times(&b);
assert!(
(product.value() - 36.0).abs() < 1e-10,
"Expected 36.0, got {}",
product.value()
);
}
#[test]
fn test_eta_one_is_probability_semiring() {
let a = PowerWeight::new(0.3, 1.0);
let b = PowerWeight::new(0.5, 1.0);
let sum = a.plus(&b);
assert!(
(sum.value() - 0.8).abs() < 1e-10,
"Expected 0.8, got {}",
sum.value()
);
let product = a.times(&b);
assert!(
(product.value() - 0.15).abs() < 1e-10,
"Expected 0.15, got {}",
product.value()
);
}
#[test]
fn test_identities() {
let eta = 2.0;
let a = PowerWeight::new(5.0, eta);
let zero = PowerWeight::zero_with_eta(eta);
let one = PowerWeight::one_with_eta(eta);
let sum = a.plus(&zero);
assert!(
(sum.value() - 5.0).abs() < 1e-10,
"Additive identity failed"
);
let product = a.times(&one);
assert!(
(product.value() - 5.0).abs() < 1e-10,
"Multiplicative identity failed"
);
}
#[test]
fn test_zero_annihilation() {
let eta = 2.0;
let a = PowerWeight::new(5.0, eta);
let zero = PowerWeight::zero_with_eta(eta);
let product = a.times(&zero);
assert!(product.is_zero(), "Zero annihilation failed");
}
#[test]
fn test_division() {
let eta = 2.0;
let a = PowerWeight::new(10.0, eta);
let b = PowerWeight::new(2.0, eta);
let product = a.times(&b);
let quotient = product.divide(&b).expect("Division should succeed");
assert!(
(quotient.value() - 10.0).abs() < 1e-10,
"Division failed: expected 10.0, got {}",
quotient.value()
);
let zero = PowerWeight::zero_with_eta(eta);
assert!(a.divide(&zero).is_none(), "Division by zero should fail");
}
#[test]
fn test_star() {
let eta = 1.0;
let a = PowerWeight::new(0.5, eta);
let star_a = a.star().expect("Star should converge for x < 1");
assert!(
(star_a.value() - 2.0).abs() < 1e-10,
"Star failed: expected 2.0, got {}",
star_a.value()
);
let one = PowerWeight::new(1.0, eta);
assert!(one.star().is_none(), "Star should diverge for x = 1");
let big = PowerWeight::new(2.0, eta);
assert!(big.star().is_none(), "Star should diverge for x > 1");
}
#[test]
fn test_probability_conversion() {
let eta = 3.0;
let prob = 0.7;
let pw = PowerWeight::from_probability(prob, eta);
let recovered = pw.to_probability();
assert!(
(recovered - prob).abs() < 1e-10,
"Probability roundtrip failed: {} -> {} -> {}",
prob,
pw.value(),
recovered
);
}
#[test]
fn test_isomorphism_preserves_plus() {
let eta = 2.0;
let x = 0.3;
let y = 0.5;
let left = PowerWeight::from_probability(x + y, eta);
let px = PowerWeight::from_probability(x, eta);
let py = PowerWeight::from_probability(y, eta);
let right = px.plus(&py);
assert!(
(left.value() - right.value()).abs() < 1e-10,
"Isomorphism failed for plus: {} vs {}",
left.value(),
right.value()
);
}
#[test]
fn test_isomorphism_preserves_times() {
let eta = 2.0;
let x = 0.3;
let y = 0.5;
let left = PowerWeight::from_probability(x * y, eta);
let px = PowerWeight::from_probability(x, eta);
let py = PowerWeight::from_probability(y, eta);
let right = px.times(&py);
assert!(
(left.value() - right.value()).abs() < 1e-10,
"Isomorphism failed for times: {} vs {}",
left.value(),
right.value()
);
}
#[test]
fn test_large_eta_behavior() {
let eta = 100.0;
let a = PowerWeight::new(0.1, eta);
let b = PowerWeight::new(0.9, eta);
let sum = a.plus(&b);
assert!(
sum.value() > a.value() && sum.value() > b.value(),
"Large η plus should produce larger value"
);
}
proptest! {
#[test]
fn proptest_semiring_axioms(
a in 0.001f64..100.0,
b in 0.001f64..100.0,
c in 0.001f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
let wb = PowerWeight::new(b, eta);
let wc = PowerWeight::new(c, eta);
let zero = PowerWeight::zero_with_eta(eta);
let one = PowerWeight::one_with_eta(eta);
let epsilon = 1e-6;
prop_assert!(wa.plus(&zero).approx_eq(&wa, epsilon),
"Additive identity failed: a ⊕ 0̄ ≠ a");
prop_assert!(wa.times(&one).approx_eq(&wa, epsilon),
"Multiplicative identity (right) failed: a ⊗ 1̄ ≠ a");
prop_assert!(one.times(&wa).approx_eq(&wa, epsilon),
"Multiplicative identity (left) failed: 1̄ ⊗ a ≠ a");
prop_assert!(wa.plus(&wb).approx_eq(&wb.plus(&wa), epsilon),
"Additive commutativity failed: a ⊕ b ≠ b ⊕ a");
let left = wa.plus(&wb).plus(&wc);
let right = wa.plus(&wb.plus(&wc));
prop_assert!(left.approx_eq(&right, epsilon),
"Additive associativity failed: (a ⊕ b) ⊕ c ≠ a ⊕ (b ⊕ c)");
let left = wa.times(&wb).times(&wc);
let right = wa.times(&wb.times(&wc));
prop_assert!(left.approx_eq(&right, epsilon),
"Multiplicative associativity failed: (a ⊗ b) ⊗ c ≠ a ⊗ (b ⊗ c)");
let left = wa.times(&wb.plus(&wc));
let right = wa.times(&wb).plus(&wa.times(&wc));
prop_assert!(left.approx_eq(&right, epsilon),
"Left distributivity failed: a ⊗ (b ⊕ c) ≠ (a ⊗ b) ⊕ (a ⊗ c)");
let left = wa.plus(&wb).times(&wc);
let right = wa.times(&wc).plus(&wb.times(&wc));
prop_assert!(left.approx_eq(&right, epsilon),
"Right distributivity failed: (a ⊕ b) ⊗ c ≠ (a ⊗ c) ⊕ (b ⊗ c)");
prop_assert!(zero.times(&wa).approx_eq(&zero, epsilon),
"Zero annihilation (left) failed: 0̄ ⊗ a ≠ 0̄");
prop_assert!(wa.times(&zero).approx_eq(&zero, epsilon),
"Zero annihilation (right) failed: a ⊗ 0̄ ≠ 0̄");
}
#[test]
fn proptest_divisible_semiring(
a in 0.001f64..100.0,
b in 0.001f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
let wb = PowerWeight::new(b, eta);
let epsilon = 1e-6;
if !wb.is_zero() {
let product = wa.times(&wb);
if let Some(quotient) = product.divide(&wb) {
prop_assert!(quotient.approx_eq(&wa, epsilon),
"Division inverse failed: (a ⊗ b) ÷ b ≠ a");
}
}
}
#[test]
fn proptest_star_semiring(
prob in 0.01f64..0.8,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::from_probability(prob, eta);
let one = PowerWeight::one_with_eta(eta);
let star_a = match wa.star() {
Some(s) => s,
None => return Ok(()), };
let expected = one.plus(&wa.times(&star_a));
let rel_error = if expected.value() > 1e-10 {
(star_a.value() - expected.value()).abs() / expected.value()
} else {
(star_a.value() - expected.value()).abs()
};
prop_assert!(rel_error < 1e-6,
"Star axiom failed: a* ≠ 1̄ ⊕ (a ⊗ a*), rel_error = {}", rel_error);
}
#[test]
fn proptest_probability_roundtrip(
prob in 0.001f64..1.0,
eta in 0.5f64..5.0
) {
let pw = PowerWeight::from_probability(prob, eta);
let recovered = pw.to_probability();
prop_assert!((recovered - prob).abs() < 1e-10,
"Roundtrip failed: {} -> {} -> {}", prob, pw.value(), recovered);
}
#[test]
fn proptest_k_closed_semiring(
a in 0.001f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
verify_k_closed_semiring(wa, 1e-6);
}
#[test]
fn proptest_zero_sum_free_semiring(
a in 0.0f64..100.0,
b in 0.0f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
let wb = PowerWeight::new(b, eta);
verify_zero_sum_free_semiring(wa, wb, 1e-6);
}
#[test]
fn proptest_commutative_times_semiring(
a in 0.001f64..100.0,
b in 0.001f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
let wb = PowerWeight::new(b, eta);
verify_commutative_times_semiring(wa, wb, 1e-6);
}
#[test]
fn proptest_totally_ordered_semiring(
a in 0.001f64..100.0,
b in 0.001f64..100.0,
c in 0.001f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
let wb = PowerWeight::new(b, eta);
let wc = PowerWeight::new(c, eta);
verify_totally_ordered_semiring(wa, wb, wc);
}
#[test]
fn proptest_quantizable_semiring(
a in 0.001f64..100.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::new(a, eta);
verify_quantizable_semiring(wa, 1e-10);
}
#[test]
fn proptest_stochastic_semiring(
prob in 0.001f64..1.0,
eta in 0.5f64..5.0
) {
let wa = PowerWeight::from_probability(prob, eta);
verify_stochastic_semiring(wa);
}
}
#[test]
fn test_k_closed_bound() {
assert_eq!(PowerWeight::closure_bound(), None);
}
}