rustfst/semirings/
tropical_weight.rs

1use std::borrow::Borrow;
2use std::f32;
3use std::hash::{Hash, Hasher};
4use std::io::Write;
5
6use anyhow::Result;
7use nom::branch::alt;
8use nom::bytes::complete::tag_no_case;
9use nom::combinator::map;
10use nom::number::complete::float;
11use nom::IResult;
12use ordered_float::OrderedFloat;
13
14use crate::parsers::nom_utils::NomCustomError;
15use crate::parsers::parse_bin_f32;
16use crate::parsers::write_bin_f32;
17use crate::semirings::semiring::SerializableSemiring;
18use crate::semirings::utils_float::float_approx_equal;
19use crate::semirings::{
20    CompleteSemiring, DivideType, ReverseBack, Semiring, SemiringProperties, StarSemiring,
21    WeaklyDivisibleSemiring, WeightQuantize,
22};
23use crate::KDELTA;
24
25/// Tropical semiring: (min, +, inf, 0).
26#[derive(Clone, Debug, PartialOrd, Default, Copy, Eq)]
27pub struct TropicalWeight {
28    value: OrderedFloat<f32>,
29}
30
31impl Semiring for TropicalWeight {
32    type Type = f32;
33    type ReverseWeight = TropicalWeight;
34
35    fn zero() -> Self {
36        Self {
37            value: OrderedFloat(f32::INFINITY),
38        }
39    }
40
41    fn one() -> Self {
42        Self {
43            value: OrderedFloat(0.0),
44        }
45    }
46
47    fn new(value: <Self as Semiring>::Type) -> Self {
48        TropicalWeight {
49            value: OrderedFloat(value),
50        }
51    }
52
53    fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
54        if rhs.borrow().value < self.value {
55            self.value = rhs.borrow().value;
56        }
57        Ok(())
58    }
59
60    fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
61        let f1 = self.value();
62        let f2 = rhs.borrow().value();
63        if f1.eq(&f32::INFINITY) {
64        } else if f2.eq(&f32::INFINITY) {
65            self.value.0 = *f2;
66        } else {
67            self.value.0 += f2;
68        }
69        Ok(())
70    }
71
72    fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
73        float_approx_equal(self.value.0, rhs.borrow().value.0, delta)
74    }
75
76    fn value(&self) -> &Self::Type {
77        &self.value.0
78    }
79
80    fn take_value(self) -> Self::Type {
81        self.value.0
82    }
83
84    fn set_value(&mut self, value: <Self as Semiring>::Type) {
85        self.value.0 = value
86    }
87
88    fn reverse(&self) -> Result<Self::ReverseWeight> {
89        Ok(*self)
90    }
91
92    fn properties() -> SemiringProperties {
93        SemiringProperties::LEFT_SEMIRING
94            | SemiringProperties::RIGHT_SEMIRING
95            | SemiringProperties::COMMUTATIVE
96            | SemiringProperties::PATH
97            | SemiringProperties::IDEMPOTENT
98    }
99}
100
101impl ReverseBack<TropicalWeight> for TropicalWeight {
102    fn reverse_back(&self) -> Result<TropicalWeight> {
103        Ok(*self)
104    }
105}
106
107impl AsRef<TropicalWeight> for TropicalWeight {
108    fn as_ref(&self) -> &TropicalWeight {
109        self
110    }
111}
112
113display_semiring!(TropicalWeight);
114
115impl CompleteSemiring for TropicalWeight {}
116
117impl StarSemiring for TropicalWeight {
118    fn closure(&self) -> Self {
119        if self.value.is_sign_positive() && self.value.is_finite() {
120            Self::new(0.0)
121        } else {
122            Self::new(f32::NEG_INFINITY)
123        }
124    }
125}
126
127impl WeaklyDivisibleSemiring for TropicalWeight {
128    fn divide_assign(&mut self, rhs: &Self, _divide_type: DivideType) -> Result<()> {
129        self.value.0 -= rhs.value.0;
130        Ok(())
131    }
132}
133
134impl_quantize_f32!(TropicalWeight);
135
136partial_eq_and_hash_f32!(TropicalWeight);
137
138impl SerializableSemiring for TropicalWeight {
139    fn weight_type() -> String {
140        "tropical".to_string()
141    }
142
143    fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
144        let (i, weight) = parse_bin_f32(i)?;
145        Ok((i, Self::new(weight)))
146    }
147
148    fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
149        write_bin_f32(file, *self.value())
150    }
151
152    fn parse_text(i: &str) -> IResult<&str, Self> {
153        // FIXME: nom 7 does not fully parse "infinity", therefore it is done manually
154        // even after https://github.com/rust-bakery/nom/pull/1673 wass merged this issue persisted
155        // https://github.com/Garvys/rustfst/pull/253#discussion_r1494208294
156        let (i, f) = alt((map(tag_no_case("infinity"), |_| f32::INFINITY), float))(i)?;
157        Ok((i, Self::new(f)))
158    }
159}
160
161test_semiring_serializable!(
162    tests_tropical_weight_serializable,
163    TropicalWeight,
164    TropicalWeight::one() TropicalWeight::zero() TropicalWeight::new(0.3) TropicalWeight::new(0.5) TropicalWeight::new(0.0) TropicalWeight::new(-1.2)
165);
166
167impl From<f32> for TropicalWeight {
168    fn from(f: f32) -> Self {
169        Self::new(f)
170    }
171}