rustfst/semirings/
semiring.rs1use std::borrow::Borrow;
2use std::fmt::Debug;
3use std::fmt::Display;
4use std::hash::Hash;
5
6use bitflags::bitflags;
7
8use crate::parsers::nom_utils::NomCustomError;
9use anyhow::Result;
10use nom::IResult;
11use std::io::Write;
12
13bitflags! {
14 pub struct SemiringProperties: u32 {
16 const LEFT_SEMIRING = 0b00001;
18 const RIGHT_SEMIRING = 0b00010;
20 const COMMUTATIVE = 0b00100;
22 const IDEMPOTENT = 0b01000;
24 const PATH = 0b10000;
26 const SEMIRING = Self::LEFT_SEMIRING.bits() | Self::RIGHT_SEMIRING.bits();
27 }
28}
29
30pub trait Semiring: Clone + PartialEq + PartialOrd + Debug + Hash + Eq + Sync + 'static {
38 type Type: Clone + Debug;
39 type ReverseWeight: Semiring + ReverseBack<Self>;
40
41 fn zero() -> Self;
42 fn one() -> Self;
43
44 fn new(value: Self::Type) -> Self;
45
46 fn plus<P: Borrow<Self>>(&self, rhs: P) -> Result<Self> {
47 let mut w = self.clone();
48 w.plus_assign(rhs)?;
49 Ok(w)
50 }
51 fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()>;
52
53 fn times<P: Borrow<Self>>(&self, rhs: P) -> Result<Self> {
54 let mut w = self.clone();
55 w.times_assign(rhs)?;
56 Ok(w)
57 }
58 fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()>;
59
60 fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool;
61
62 fn value(&self) -> &Self::Type;
64 fn take_value(self) -> Self::Type;
66 fn set_value(&mut self, value: Self::Type);
67 fn is_one(&self) -> bool {
68 *self == Self::one()
69 }
70 fn is_zero(&self) -> bool {
71 *self == Self::zero()
72 }
73 fn reverse(&self) -> Result<Self::ReverseWeight>;
74 fn properties() -> SemiringProperties;
75}
76
77pub trait ReverseBack<W> {
78 fn reverse_back(&self) -> Result<W>;
79}
80
81#[derive(Copy, Clone, PartialOrd, PartialEq)]
83pub enum DivideType {
84 DivideLeft,
86 DivideRight,
88 DivideAny,
90}
91
92pub trait WeaklyDivisibleSemiring: Semiring {
99 fn divide_assign(&mut self, rhs: &Self, divide_type: DivideType) -> Result<()>;
100 fn divide(&self, rhs: &Self, divide_type: DivideType) -> Result<Self> {
101 let mut w = self.clone();
102 w.divide_assign(rhs, divide_type)?;
103 Ok(w)
104 }
105}
106
107pub trait CompleteSemiring: Semiring {}
114
115pub trait StarSemiring: Semiring {
120 fn closure(&self) -> Self;
121}
122
123pub trait WeightQuantize: Semiring {
124 fn quantize_assign(&mut self, delta: f32) -> Result<()>;
125 fn quantize(&self, delta: f32) -> Result<Self> {
126 let mut w = self.clone();
127 w.quantize_assign(delta)?;
128 Ok(w)
129 }
130}
131
132macro_rules! impl_quantize_f32 {
133 ($semiring: ident) => {
134 impl WeightQuantize for $semiring {
135 fn quantize_assign(&mut self, delta: f32) -> Result<()> {
136 let v = *self.value();
137 if v.is_infinite() {
138 return Ok(());
139 }
140 self.set_value(((v / delta) + 0.5).floor() * delta);
141 Ok(())
142 }
143 }
144 };
145}
146
147macro_rules! display_semiring {
148 ($semiring:tt) => {
149 use std::fmt;
150 impl fmt::Display for $semiring {
151 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
152 write!(f, "{}", self.value())?;
153 Ok(())
154 }
155 }
156 };
157}
158
159macro_rules! partial_eq_and_hash_f32 {
160 ($semiring:tt) => {
161 impl PartialEq for $semiring {
162 fn eq(&self, other: &Self) -> bool {
163 let w1 = *self.value();
165 let w2 = *other.value();
166 w1 <= (w2 + KDELTA) && w2 <= (w1 + KDELTA)
167 }
168 }
169
170 impl Hash for $semiring {
171 fn hash<H: Hasher>(&self, state: &mut H) {
172 self.value.hash(state)
173 }
174 }
175 };
176}
177
178pub trait SerializableSemiring: Semiring + Display {
179 fn weight_type() -> String;
180 fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>>;
181 fn write_binary<F: Write>(&self, file: &mut F) -> Result<()>;
182
183 fn parse_text(i: &str) -> IResult<&str, Self>;
184 fn write_text<F: Write>(&self, file: &mut F) -> Result<()> {
185 write!(file, "{}", self)?;
187 Ok(())
188 }
189}