rustfst/algorithms/determinize/
divisors.rs1use std::fmt::Debug;
2
3use anyhow::Result;
4
5use crate::semirings::{
6 GallicWeight, GallicWeightLeft, GallicWeightMin, GallicWeightRestrict, StringWeightLeft,
7 StringWeightRestrict,
8};
9use crate::Semiring;
10
11pub trait CommonDivisor<W: Semiring>: PartialEq + Debug + Sync {
12 fn common_divisor(w1: &W, w2: &W) -> Result<W>;
13}
14
15#[derive(PartialEq, Debug)]
16pub struct DefaultCommonDivisor {}
17
18impl<W: Semiring> CommonDivisor<W> for DefaultCommonDivisor {
19 fn common_divisor(w1: &W, w2: &W) -> Result<W> {
20 w1.plus(w2)
21 }
22}
23
24#[derive(PartialEq, Debug)]
25pub struct LabelCommonDivisor {}
26
27macro_rules! impl_label_common_divisor {
28 ($string_semiring: ident) => {
29 impl CommonDivisor<$string_semiring> for LabelCommonDivisor {
30 fn common_divisor(
31 w1: &$string_semiring,
32 w2: &$string_semiring,
33 ) -> Result<$string_semiring> {
34 let mut iter1 = w1.iter();
35 let mut iter2 = w2.iter();
36 if w1.value.is_empty_list() || w2.value.is_empty_list() {
37 Ok($string_semiring::one())
38 } else if w1.value.is_infinity() {
39 Ok(iter2.next().unwrap().into())
40 } else if w2.value.is_infinity() {
41 Ok(iter1.next().unwrap().into())
42 } else {
43 let v1 = iter1.next().unwrap();
44 let v2 = iter2.next().unwrap();
45 if v1 == v2 {
46 Ok(v1.into())
47 } else {
48 Ok($string_semiring::one())
49 }
50 }
51 }
52 }
53 };
54}
55
56impl_label_common_divisor!(StringWeightLeft);
57impl_label_common_divisor!(StringWeightRestrict);
58
59#[derive(Debug, PartialEq)]
60pub struct GallicCommonDivisor {}
61
62macro_rules! impl_gallic_common_divisor {
63 ($gallic: ident) => {
64 impl<W: Semiring> CommonDivisor<$gallic<W>> for GallicCommonDivisor {
65 fn common_divisor(w1: &$gallic<W>, w2: &$gallic<W>) -> Result<$gallic<W>> {
66 let v1 = LabelCommonDivisor::common_divisor(w1.value1(), w2.value1())?;
67 let v2 = DefaultCommonDivisor::common_divisor(w1.value2(), w2.value2())?;
68 Ok((v1, v2).into())
69 }
70 }
71 };
72}
73
74impl_gallic_common_divisor!(GallicWeightLeft);
75impl_gallic_common_divisor!(GallicWeightRestrict);
76impl_gallic_common_divisor!(GallicWeightMin);
77
78impl<W: Semiring> CommonDivisor<GallicWeight<W>> for GallicCommonDivisor {
79 fn common_divisor(w1: &GallicWeight<W>, w2: &GallicWeight<W>) -> Result<GallicWeight<W>> {
80 let mut weight = GallicWeightRestrict::zero();
81 for w in w1.iter().chain(w2.iter()) {
82 weight = GallicCommonDivisor::common_divisor(&weight, w)?;
83 }
84 if weight.is_zero() {
85 Ok(GallicWeight::zero())
86 } else {
87 Ok(weight.into())
88 }
89 }
90}