1use anyhow::Result;
2
3use bitflags::bitflags;
4
5use crate::algorithms::factor_weight::factor_iterators::{GallicFactorLeft, GallicFactorRight};
6use crate::algorithms::factor_weight::{factor_weight, FactorWeightOptions, FactorWeightType};
7use crate::algorithms::fst_convert::fst_convert_from_ref;
8use crate::algorithms::tr_mappers::RmWeightMapper;
9use crate::algorithms::weight_converters::{FromGallicConverter, ToGallicConverter};
10use crate::algorithms::{
11 reweight, shortest_distance_with_config, tr_map, weight_convert, ReweightType,
12 ShortestDistanceConfig,
13};
14use crate::fst_impls::VectorFst;
15use crate::fst_traits::{AllocableFst, ExpandedFst, MutableFst};
16use crate::semirings::{DivideType, Semiring};
17use crate::semirings::{
18 GallicWeightLeft, GallicWeightRight, StringWeightLeft, StringWeightRight,
19 WeaklyDivisibleSemiring, WeightQuantize,
20};
21use crate::{StateId, KDELTA};
22
23bitflags! {
24 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
26 pub struct PushType: u32 {
27 const PUSH_WEIGHTS = 0b01;
28 const PUSH_LABELS = 0b10;
29 const REMOVE_TOTAL_WEIGHT = 0b100;
30 const REMOVE_COMMON_AFFIX = 0b1000;
31 }
32}
33
34#[derive(Clone, Debug, Copy, PartialOrd, PartialEq)]
36pub struct PushWeightsConfig {
37 delta: f32,
38 remove_total_weight: bool,
39}
40
41impl Default for PushWeightsConfig {
42 fn default() -> Self {
43 Self {
44 delta: KDELTA,
45 remove_total_weight: false,
46 }
47 }
48}
49
50impl PushWeightsConfig {
51 pub fn new(delta: f32, remove_total_weight: bool) -> Self {
52 Self {
53 delta,
54 remove_total_weight,
55 }
56 }
57
58 pub fn with_delta(self, delta: f32) -> Self {
59 Self { delta, ..self }
60 }
61
62 pub fn with_remove_total_weight(self, remove_total_weight: bool) -> Self {
63 Self {
64 remove_total_weight,
65 ..self
66 }
67 }
68}
69
70pub fn push_weights<W, F>(fst: &mut F, reweight_type: ReweightType) -> Result<()>
77where
78 F: MutableFst<W>,
79 W: WeaklyDivisibleSemiring,
80{
81 push_weights_with_config(fst, reweight_type, PushWeightsConfig::default())
82}
83
84pub fn push_weights_with_config<W, F>(
91 fst: &mut F,
92 reweight_type: ReweightType,
93 config: PushWeightsConfig,
94) -> Result<()>
95where
96 F: MutableFst<W>,
97 W: WeaklyDivisibleSemiring,
98{
99 let remove_total_weight = config.remove_total_weight;
100 let delta = config.delta;
101 let dist = shortest_distance_with_config(
102 fst,
103 reweight_type == ReweightType::ReweightToInitial,
104 ShortestDistanceConfig::new(delta),
105 )?;
106
107 if remove_total_weight {
108 let total_weight =
109 compute_total_weight(fst, &dist, reweight_type == ReweightType::ReweightToInitial)?;
110 reweight(fst, &dist, reweight_type)?;
111 remove_weight(
112 fst,
113 total_weight,
114 reweight_type == ReweightType::ReweightToFinal,
115 )?;
116 } else {
117 reweight(fst, &dist, reweight_type)?;
118 }
119 Ok(())
120}
121
122fn compute_total_weight<W, F>(fst: &F, dist: &[W], reverse: bool) -> Result<W>
123where
124 W: Semiring,
125 F: ExpandedFst<W>,
126{
127 if reverse {
128 Ok(fst
129 .start()
130 .and_then(|start| dist.get(start as usize))
131 .cloned()
132 .unwrap_or_else(W::zero))
133 } else {
134 let mut sum = W::zero();
135 for (s, dist_s) in dist.iter().enumerate() {
136 sum.plus_assign(dist_s.times(
137 unsafe { fst.final_weight_unchecked(s as StateId) }.unwrap_or_else(W::zero),
138 )?)?;
139 }
140 Ok(sum)
141 }
142}
143
144fn remove_weight<W, F>(fst: &mut F, weight: W, at_final: bool) -> Result<()>
145where
146 F: MutableFst<W>,
147 W: WeaklyDivisibleSemiring,
148{
149 if weight.is_one() || weight.is_zero() {
150 return Ok(());
151 }
152 if at_final {
153 unsafe {
154 for s in fst.states_range() {
155 if let Some(mut final_weight) = fst.final_weight_unchecked(s) {
156 final_weight.divide_assign(&weight, DivideType::DivideRight)?;
157 fst.set_final_unchecked(s, final_weight);
158 }
159 }
160 }
161 } else if let Some(start) = fst.start() {
162 unsafe {
163 let mut it_tr = fst.tr_iter_unchecked_mut(start);
164 for idx_tr in 0..it_tr.len() {
165 let tr = it_tr.get_unchecked(idx_tr);
166 let weight = tr.weight.divide(&weight, DivideType::DivideLeft)?;
167 it_tr.set_weight_unchecked(idx_tr, weight);
168 }
169 if let Some(mut final_weight) = fst.final_weight_unchecked(start) {
170 final_weight.divide_assign(&weight, DivideType::DivideLeft)?;
171 fst.set_final_unchecked(start, final_weight);
172 }
173 }
174 }
175 Ok(())
176}
177
178macro_rules! m_labels_pushing {
179 ($ifst: ident, $reweight_type: ident, $push_type: ident, $delta: ident, $gallic_weight: ty, $string_weight: ident, $gallic_factor: ty) => {{
180 let mut mapper = ToGallicConverter {};
182 let mut gfst: VectorFst<$gallic_weight> = weight_convert($ifst, &mut mapper)?;
183 let gdistance = if $push_type.intersects(PushType::PUSH_WEIGHTS) {
184 shortest_distance_with_config(
185 &gfst,
186 $reweight_type == ReweightType::ReweightToInitial,
187 ShortestDistanceConfig::new($delta),
188 )?
189 } else {
190 let rm_weight_mapper = RmWeightMapper {};
191 let mut uwfst: VectorFst<_> = fst_convert_from_ref($ifst);
192 tr_map(&mut uwfst, &rm_weight_mapper)?;
193 let guwfst: VectorFst<$gallic_weight> = weight_convert(&uwfst, &mut mapper)?;
194 shortest_distance_with_config(
195 &guwfst,
196 $reweight_type == ReweightType::ReweightToInitial,
197 ShortestDistanceConfig::new($delta),
198 )?
199 };
200 if $push_type.intersects(PushType::REMOVE_COMMON_AFFIX | PushType::REMOVE_TOTAL_WEIGHT) {
201 let mut total_weight = compute_total_weight(
202 &gfst,
203 &gdistance,
204 $reweight_type == ReweightType::ReweightToInitial,
205 )?;
206 if !$push_type.intersects(PushType::REMOVE_COMMON_AFFIX) {
207 total_weight.set_value1($string_weight::one());
208 }
209 if !$push_type.intersects(PushType::REMOVE_TOTAL_WEIGHT) {
210 total_weight.set_value2(W::one());
211 }
212 reweight(&mut gfst, gdistance.as_slice(), $reweight_type)?;
213 remove_weight(
214 &mut gfst,
215 total_weight,
216 $reweight_type == ReweightType::ReweightToFinal,
217 )?;
218 } else {
219 reweight(&mut gfst, gdistance.as_slice(), $reweight_type)?;
220 }
221 let fwfst: VectorFst<$gallic_weight> =
222 factor_weight::<_, VectorFst<$gallic_weight>, _, _, $gallic_factor>(
223 &gfst,
224 FactorWeightOptions::new(
225 FactorWeightType::FACTOR_FINAL_WEIGHTS | FactorWeightType::FACTOR_ARC_WEIGHTS,
226 ),
227 )?;
228 let mut mapper_from_gallic = FromGallicConverter {
229 superfinal_label: 0,
230 };
231 weight_convert(&fwfst, &mut mapper_from_gallic)
232 }};
233}
234
235#[derive(Clone, Copy, Debug, PartialOrd, PartialEq)]
237pub struct PushConfig {
238 delta: f32,
239}
240
241impl Default for PushConfig {
242 fn default() -> Self {
243 Self { delta: KDELTA }
244 }
245}
246
247impl PushConfig {
248 pub fn new(delta: f32) -> Self {
249 Self { delta }
250 }
251
252 pub fn with_delta(self, delta: f32) -> Self {
253 Self { delta }
254 }
255}
256
257pub fn push<W, F1, F2>(ifst: &F1, reweight_type: ReweightType, push_type: PushType) -> Result<F2>
260where
261 F1: ExpandedFst<W>,
262 F2: ExpandedFst<W> + MutableFst<W> + AllocableFst<W>,
263 W: WeaklyDivisibleSemiring + WeightQuantize,
264 <W as Semiring>::ReverseWeight: 'static,
265{
266 push_with_config(ifst, reweight_type, push_type, PushConfig::default())
267}
268
269pub fn push_with_config<W, F1, F2>(
272 ifst: &F1,
273 reweight_type: ReweightType,
274 push_type: PushType,
275 config: PushConfig,
276) -> Result<F2>
277where
278 F1: ExpandedFst<W>,
279 F2: ExpandedFst<W> + MutableFst<W> + AllocableFst<W>,
280 W: WeaklyDivisibleSemiring + WeightQuantize,
281 <W as Semiring>::ReverseWeight: 'static,
282{
283 let delta = config.delta;
284 if push_type.intersects(PushType::PUSH_WEIGHTS) && !push_type.intersects(PushType::PUSH_LABELS)
285 {
286 let mut ofst = fst_convert_from_ref(ifst);
288 let push_weights_config =
289 PushWeightsConfig::new(delta, push_type.intersects(PushType::REMOVE_TOTAL_WEIGHT));
290 push_weights_with_config(&mut ofst, reweight_type, push_weights_config)?;
291 Ok(ofst)
292 } else if push_type.intersects(PushType::PUSH_LABELS) {
293 match reweight_type {
294 ReweightType::ReweightToInitial => m_labels_pushing!(
295 ifst,
296 reweight_type,
297 push_type,
298 delta,
299 GallicWeightLeft<W>,
300 StringWeightLeft,
301 GallicFactorLeft<W>
302 ),
303 ReweightType::ReweightToFinal => m_labels_pushing!(
304 ifst,
305 reweight_type,
306 push_type,
307 delta,
308 GallicWeightRight<W>,
309 StringWeightRight,
310 GallicFactorRight<W>
311 ),
312 }
313 } else {
314 Ok(fst_convert_from_ref(ifst))
316 }
317}