use crate::semiring::traits::{
CommutativeTimesSemiring, DivisibleSemiring, IdempotentSemiring, KClosedSemiring,
NonnegativeSemiring, QuantizableSemiring, Semiring, StarSemiring, StochasticSemiring,
TotallyOrderedSemiring, WeaklyLeftDivisibleSemiring, ZeroSumFreeSemiring,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct LexicographicWeight<S1, S2>(pub S1, pub S2)
where
S1: Semiring + Ord,
S2: Semiring + Ord;
impl<S1, S2> LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
pub const fn new(first: S1, second: S2) -> Self {
LexicographicWeight(first, second)
}
#[inline]
pub fn first(&self) -> S1 {
self.0
}
#[inline]
pub fn second(&self) -> S2 {
self.1
}
#[inline]
pub fn map_first<F>(self, f: F) -> Self
where
F: FnOnce(S1) -> S1,
{
LexicographicWeight(f(self.0), self.1)
}
#[inline]
pub fn map_second<F>(self, f: F) -> Self
where
F: FnOnce(S2) -> S2,
{
LexicographicWeight(self.0, f(self.1))
}
}
impl<S1, S2> Default for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
fn default() -> Self {
Self::one()
}
}
impl<S1, S2> From<(S1, S2)> for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
fn from((first, second): (S1, S2)) -> Self {
LexicographicWeight::new(first, second)
}
}
impl<S1, S2> From<LexicographicWeight<S1, S2>> for (S1, S2)
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
fn from(weight: LexicographicWeight<S1, S2>) -> Self {
(weight.0, weight.1)
}
}
impl<S1, S2> Semiring for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
fn zero() -> Self {
LexicographicWeight(S1::zero(), S2::zero())
}
#[inline]
fn one() -> Self {
LexicographicWeight(S1::one(), S2::one())
}
#[inline]
fn plus(&self, other: &Self) -> Self {
use std::cmp::Ordering;
match self.0.cmp(&other.0) {
Ordering::Less => *self,
Ordering::Greater => *other,
Ordering::Equal => {
if self.1 <= other.1 {
*self
} else {
*other
}
}
}
}
#[inline]
fn times(&self, other: &Self) -> Self {
LexicographicWeight(self.0.times(&other.0), self.1.times(&other.1))
}
#[inline]
fn is_zero(&self) -> bool {
self.0.is_zero() && self.1.is_zero()
}
#[inline]
fn is_one(&self) -> bool {
self.0.is_one() && self.1.is_one()
}
fn approx_eq(&self, other: &Self, epsilon: f64) -> bool {
self.0.approx_eq(&other.0, epsilon) && self.1.approx_eq(&other.1, epsilon)
}
fn natural_less(&self, other: &Self) -> Option<bool> {
match self.0.cmp(&other.0) {
std::cmp::Ordering::Less => Some(true),
std::cmp::Ordering::Greater => Some(false),
std::cmp::Ordering::Equal => Some(self.1 < other.1),
}
}
fn to_bytes(&self) -> Vec<u8> {
let mut bytes = self.0.to_bytes();
bytes.extend(self.1.to_bytes());
bytes
}
}
impl<S1, S2> DivisibleSemiring for LexicographicWeight<S1, S2>
where
S1: DivisibleSemiring + Ord,
S2: DivisibleSemiring + Ord,
{
fn divide(&self, other: &Self) -> Option<Self> {
match (self.0.divide(&other.0), self.1.divide(&other.1)) {
(Some(first), Some(second)) => Some(LexicographicWeight(first, second)),
_ => None,
}
}
}
impl<S1, S2> StarSemiring for LexicographicWeight<S1, S2>
where
S1: StarSemiring + Ord,
S2: StarSemiring + Ord,
{
fn star(&self) -> Option<Self> {
match (self.0.star(), self.1.star()) {
(Some(first), Some(second)) => Some(LexicographicWeight(first, second)),
_ => None,
}
}
}
impl<S1, S2> IdempotentSemiring for LexicographicWeight<S1, S2>
where
S1: IdempotentSemiring + Ord,
S2: IdempotentSemiring + Ord,
{
}
impl<S1, S2> KClosedSemiring for LexicographicWeight<S1, S2>
where
S1: KClosedSemiring + Ord,
S2: KClosedSemiring + Ord,
{
fn closure_bound() -> Option<usize> {
match (S1::closure_bound(), S2::closure_bound()) {
(Some(k1), Some(k2)) => Some(k1.max(k2)),
_ => None,
}
}
}
impl<S1, S2> ZeroSumFreeSemiring for LexicographicWeight<S1, S2>
where
S1: ZeroSumFreeSemiring + Ord,
S2: ZeroSumFreeSemiring + Ord,
{
}
impl<S1, S2> WeaklyLeftDivisibleSemiring for LexicographicWeight<S1, S2>
where
S1: WeaklyLeftDivisibleSemiring + Ord,
S2: WeaklyLeftDivisibleSemiring + Ord,
{
fn left_divide(&self, divisor: &Self) -> Option<Self> {
match (
self.0.left_divide(&divisor.0),
self.1.left_divide(&divisor.1),
) {
(Some(first), Some(second)) => Some(LexicographicWeight(first, second)),
_ => None,
}
}
}
impl<S1, S2> CommutativeTimesSemiring for LexicographicWeight<S1, S2>
where
S1: CommutativeTimesSemiring + Ord,
S2: CommutativeTimesSemiring + Ord,
{
}
impl<S1, S2> TotallyOrderedSemiring for LexicographicWeight<S1, S2>
where
S1: TotallyOrderedSemiring,
S2: TotallyOrderedSemiring,
{
}
impl<S1, S2> NonnegativeSemiring for LexicographicWeight<S1, S2>
where
S1: NonnegativeSemiring + Ord,
S2: NonnegativeSemiring + Ord,
{
}
impl<S1, S2> QuantizableSemiring for LexicographicWeight<S1, S2>
where
S1: QuantizableSemiring + Ord,
S2: QuantizableSemiring + Ord,
{
fn quantize(&self, epsilon: f64) -> i64 {
let q1 = self.0.quantize(epsilon);
let q2 = self.1.quantize(epsilon);
(q1.wrapping_shl(32)) ^ (q2 & 0xFFFFFFFF)
}
}
impl<S1, S2> StochasticSemiring for LexicographicWeight<S1, S2>
where
S1: StochasticSemiring + Ord,
S2: StochasticSemiring + Ord,
{
fn to_probability(&self) -> f64 {
let p1 = self.0.to_probability();
let p2 = self.1.to_probability();
p1 * (1.0 + 1e-10 * p2)
}
}
impl<S1, S2> std::ops::Add for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
type Output = Self;
#[inline]
fn add(self, other: Self) -> Self {
self.plus(&other)
}
}
impl<S1, S2> std::ops::Mul for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
type Output = Self;
#[inline]
fn mul(self, other: Self) -> Self {
self.times(&other)
}
}
impl<S1, S2> std::ops::AddAssign for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
fn add_assign(&mut self, other: Self) {
*self = self.plus(&other);
}
}
impl<S1, S2> std::ops::MulAssign for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
#[inline]
fn mul_assign(&mut self, other: Self) {
*self = self.times(&other);
}
}
impl<S1, S2> PartialOrd for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<S1, S2> Ord for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.0.cmp(&other.0) {
std::cmp::Ordering::Equal => self.1.cmp(&other.1),
other_cmp => other_cmp,
}
}
}
#[cfg(feature = "serde")]
impl<S1, S2> serde::Serialize for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord + serde::Serialize,
S2: Semiring + Ord + serde::Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeTuple;
let mut tuple = serializer.serialize_tuple(2)?;
tuple.serialize_element(&self.0)?;
tuple.serialize_element(&self.1)?;
tuple.end()
}
}
#[cfg(feature = "serde")]
impl<'de, S1, S2> serde::Deserialize<'de> for LexicographicWeight<S1, S2>
where
S1: Semiring + Ord + serde::Deserialize<'de>,
S2: Semiring + Ord + serde::Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let (first, second) = <(S1, S2)>::deserialize(deserializer)?;
Ok(LexicographicWeight::new(first, second))
}
}
pub type Lexicographic3<S1, S2, S3> = LexicographicWeight<S1, LexicographicWeight<S2, S3>>;
pub type Lexicographic4<S1, S2, S3, S4> =
LexicographicWeight<S1, LexicographicWeight<S2, LexicographicWeight<S3, S4>>>;
pub fn lexicographic3<S1, S2, S3>(first: S1, second: S2, third: S3) -> Lexicographic3<S1, S2, S3>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
S3: Semiring + Ord,
{
LexicographicWeight::new(first, LexicographicWeight::new(second, third))
}
pub fn lexicographic4<S1, S2, S3, S4>(
first: S1,
second: S2,
third: S3,
fourth: S4,
) -> Lexicographic4<S1, S2, S3, S4>
where
S1: Semiring + Ord,
S2: Semiring + Ord,
S3: Semiring + Ord,
S4: Semiring + Ord,
{
LexicographicWeight::new(
first,
LexicographicWeight::new(second, LexicographicWeight::new(third, fourth)),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::traits::tests::{
verify_commutative_times_semiring, verify_divisible_semiring, verify_idempotent_semiring,
verify_k_closed_semiring, verify_quantizable_semiring, verify_semiring_axioms,
verify_star_semiring, verify_stochastic_semiring, verify_totally_ordered_semiring,
verify_weakly_left_divisible_semiring, verify_zero_sum_free_semiring,
};
use crate::semiring::TropicalWeight;
use proptest::prelude::*;
type LexTrop = LexicographicWeight<TropicalWeight, TropicalWeight>;
#[test]
fn test_lexicographic_plus() {
let a = LexTrop::new(TropicalWeight::new(2.0), TropicalWeight::new(100.0));
let b = LexTrop::new(TropicalWeight::new(1.0), TropicalWeight::new(1000.0));
let result = a.plus(&b);
assert_eq!(result.first().value(), 1.0); assert_eq!(result.second().value(), 1000.0);
let c = LexTrop::new(TropicalWeight::new(5.0), TropicalWeight::new(10.0));
let d = LexTrop::new(TropicalWeight::new(5.0), TropicalWeight::new(20.0));
let result = c.plus(&d);
assert_eq!(result.first().value(), 5.0);
assert_eq!(result.second().value(), 10.0); }
#[test]
fn test_lexicographic_times() {
let a = LexTrop::new(TropicalWeight::new(2.0), TropicalWeight::new(3.0));
let b = LexTrop::new(TropicalWeight::new(4.0), TropicalWeight::new(1.0));
let result = a.times(&b);
assert_eq!(result.first().value(), 6.0); assert_eq!(result.second().value(), 4.0); }
#[test]
fn test_identities() {
let a = LexTrop::new(TropicalWeight::new(5.0), TropicalWeight::new(3.0));
let sum = a.plus(&LexTrop::zero());
assert!(a.approx_eq(&sum, 1e-10));
let prod = a.times(&LexTrop::one());
assert!(a.approx_eq(&prod, 1e-10));
}
#[test]
fn test_annihilation() {
let a = LexTrop::new(TropicalWeight::new(5.0), TropicalWeight::new(3.0));
let prod = a.times(&LexTrop::zero());
assert!(prod.is_zero());
}
#[test]
fn test_division() {
let a = LexTrop::new(TropicalWeight::new(5.0), TropicalWeight::new(3.0));
let b = LexTrop::new(TropicalWeight::new(2.0), TropicalWeight::new(1.0));
let product = a.times(&b);
let quotient = product.divide(&b).expect("Division should succeed");
assert!(a.approx_eq("ient, 1e-10));
}
#[test]
fn test_star() {
let positive = LexTrop::new(TropicalWeight::new(1.0), TropicalWeight::new(2.0));
let star = positive.star().expect("Star should converge");
assert!(star.is_one());
let negative = LexTrop::new(TropicalWeight::new(-1.0), TropicalWeight::new(2.0));
assert!(negative.star().is_none());
}
#[test]
fn test_three_level_priority() {
type Lex3 = Lexicographic3<TropicalWeight, TropicalWeight, TropicalWeight>;
let a = lexicographic3(
TropicalWeight::new(1.0), TropicalWeight::new(5.0), TropicalWeight::new(10.0), );
let b = lexicographic3(
TropicalWeight::new(0.0), TropicalWeight::new(100.0), TropicalWeight::new(200.0), );
let best: Lex3 = a.plus(&b);
assert_eq!(best.first().value(), 0.0);
}
#[test]
fn test_error_correction_scenario() {
type CorrectionWeight = LexicographicWeight<TropicalWeight, TropicalWeight>;
let candidate_a = CorrectionWeight::new(TropicalWeight::new(1.0), TropicalWeight::new(2.0));
let candidate_b = CorrectionWeight::new(TropicalWeight::new(2.0), TropicalWeight::new(0.5));
let best = candidate_a.plus(&candidate_b);
assert_eq!(best.first().value(), 1.0);
assert_eq!(best.second().value(), 2.0);
let candidate_c = CorrectionWeight::new(TropicalWeight::new(1.0), TropicalWeight::new(3.0));
let best = candidate_a.plus(&candidate_c);
assert_eq!(best.first().value(), 1.0);
assert_eq!(best.second().value(), 2.0);
}
#[test]
fn test_ordering() {
let a = LexTrop::new(TropicalWeight::new(1.0), TropicalWeight::new(5.0));
let b = LexTrop::new(TropicalWeight::new(2.0), TropicalWeight::new(1.0));
let c = LexTrop::new(TropicalWeight::new(1.0), TropicalWeight::new(3.0));
assert!(a < b);
assert!(c < a);
assert!(c < b);
}
proptest! {
#[test]
fn proptest_semiring_axioms(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0,
b1 in 0.0f64..100.0,
b2 in 0.0f64..100.0,
c1 in 0.0f64..100.0,
c2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
let wb = LexTrop::new(TropicalWeight::new(b1), TropicalWeight::new(b2));
let wc = LexTrop::new(TropicalWeight::new(c1), TropicalWeight::new(c2));
verify_semiring_axioms(wa, wb, wc, 1e-10);
}
#[test]
fn proptest_divisible_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0,
b1 in 0.001f64..100.0,
b2 in 0.001f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
let wb = LexTrop::new(TropicalWeight::new(b1), TropicalWeight::new(b2));
verify_divisible_semiring(wa, wb, 1e-10);
}
#[test]
fn proptest_star_semiring(
a1 in 0.001f64..100.0,
a2 in 0.001f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
verify_star_semiring(wa, 1e-10);
}
#[test]
fn proptest_idempotent_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
verify_idempotent_semiring(wa, 1e-10);
}
#[test]
fn proptest_k_closed_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
verify_k_closed_semiring(wa, 1e-10);
}
#[test]
fn proptest_zero_sum_free_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0,
b1 in 0.0f64..100.0,
b2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
let wb = LexTrop::new(TropicalWeight::new(b1), TropicalWeight::new(b2));
verify_zero_sum_free_semiring(wa, wb, 1e-10);
}
#[test]
fn proptest_weakly_left_divisible_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0,
b1 in 0.0f64..100.0,
b2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
let wb = LexTrop::new(TropicalWeight::new(b1), TropicalWeight::new(b2));
let divisor = wa.plus(&wb);
verify_weakly_left_divisible_semiring(wa, divisor, 1e-10);
}
#[test]
fn proptest_commutative_times_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0,
b1 in 0.0f64..100.0,
b2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
let wb = LexTrop::new(TropicalWeight::new(b1), TropicalWeight::new(b2));
verify_commutative_times_semiring(wa, wb, 1e-10);
}
#[test]
fn proptest_totally_ordered_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0,
b1 in 0.0f64..100.0,
b2 in 0.0f64..100.0,
c1 in 0.0f64..100.0,
c2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
let wb = LexTrop::new(TropicalWeight::new(b1), TropicalWeight::new(b2));
let wc = LexTrop::new(TropicalWeight::new(c1), TropicalWeight::new(c2));
verify_totally_ordered_semiring(wa, wb, wc);
}
#[test]
fn proptest_quantizable_semiring(
a1 in 0.0f64..100.0,
a2 in 0.0f64..100.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
verify_quantizable_semiring(wa, 1e-10);
}
#[test]
fn proptest_stochastic_semiring(
a1 in 0.0f64..50.0,
a2 in 0.0f64..50.0
) {
let wa = LexTrop::new(TropicalWeight::new(a1), TropicalWeight::new(a2));
verify_stochastic_semiring(wa);
}
}
#[test]
fn test_k_closed_bound() {
assert_eq!(LexTrop::closure_bound(), Some(0));
}
}