rustfst/algorithms/
tr_map.rs1use anyhow::Result;
2use std::ops::Deref;
3
4use crate::fst_properties::FstProperties;
5use crate::fst_traits::MutableFst;
6use crate::semirings::Semiring;
7use crate::Tr;
8use crate::{Label, StateId, EPS_LABEL};
9
10#[derive(Clone, Debug)]
17pub struct FinalTr<W: Semiring> {
18 pub ilabel: Label,
20 pub olabel: Label,
22 pub weight: W,
24}
25
26#[derive(PartialEq)]
28pub enum MapFinalAction {
29 MapNoSuperfinal,
32 MapAllowSuperfinal,
36 MapRequireSuperfinal,
40}
41
42pub trait TrMapper<S: Semiring> {
46 fn tr_map(&self, tr: &mut Tr<S>) -> Result<()>;
48
49 fn final_tr_map(&self, final_tr: &mut FinalTr<S>) -> Result<()>;
52
53 fn final_action(&self) -> MapFinalAction;
55
56 fn properties(&self, inprops: FstProperties) -> FstProperties;
57}
58
59impl<S: Semiring, T: TrMapper<S>, TP: Deref<Target = T>> TrMapper<S> for TP {
60 fn tr_map(&self, tr: &mut Tr<S>) -> Result<()> {
61 self.deref().tr_map(tr)
62 }
63
64 fn final_tr_map(&self, final_tr: &mut FinalTr<S>) -> Result<()> {
65 self.deref().final_tr_map(final_tr)
66 }
67
68 fn final_action(&self) -> MapFinalAction {
69 self.deref().final_action()
70 }
71
72 fn properties(&self, inprops: FstProperties) -> FstProperties {
73 self.deref().properties(inprops)
74 }
75}
76
77pub fn tr_map<W, F, M>(ifst: &mut F, mapper: &M) -> Result<()>
79where
80 W: Semiring,
81 F: MutableFst<W>,
82 M: TrMapper<W>,
83{
84 if ifst.start().is_none() {
85 return Ok(());
86 }
87
88 let inprops = ifst.properties();
89
90 let final_action = mapper.final_action();
91 let mut superfinal: Option<StateId> = None;
92
93 if final_action == MapFinalAction::MapRequireSuperfinal {
94 let superfinal_id = ifst.add_state();
95 superfinal = Some(superfinal_id);
96 ifst.set_final(superfinal_id, W::one()).unwrap();
97 }
98
99 for state in 0..(ifst.num_states() as StateId) {
100 unsafe {
101 let mut it_tr = ifst.tr_iter_unchecked_mut(state);
102 for idx_tr in 0..it_tr.len() {
103 let mut tr = it_tr.get_unchecked(idx_tr).clone();
104 mapper.tr_map(&mut tr)?;
105 it_tr.set_tr_unchecked(idx_tr, tr);
106 }
107 }
108
109 if let Some(w) = unsafe { ifst.final_weight_unchecked(state) } {
110 let mut final_tr = FinalTr {
111 ilabel: EPS_LABEL,
112 olabel: EPS_LABEL,
113 weight: w,
114 };
115 mapper.final_tr_map(&mut final_tr)?;
116 match final_action {
117 MapFinalAction::MapNoSuperfinal => {
118 if final_tr.ilabel != EPS_LABEL || final_tr.olabel != EPS_LABEL {
119 bail!("TrMap: Non-zero tr labels for superfinal tr")
120 }
121 unsafe {
122 ifst.set_final_unchecked(state, final_tr.weight);
123 }
124 }
125 MapFinalAction::MapAllowSuperfinal => {
126 if Some(state) != superfinal {
127 if final_tr.ilabel != EPS_LABEL || final_tr.olabel != EPS_LABEL {
128 if superfinal.is_none() {
129 let superfinal_id = ifst.add_state();
130 superfinal = Some(superfinal_id);
131 unsafe {
132 ifst.set_final_unchecked(superfinal_id, W::one());
134 }
135 }
136 unsafe {
137 ifst.add_tr_unchecked(
139 state,
140 Tr::new(
141 final_tr.ilabel,
142 final_tr.olabel,
143 final_tr.weight,
144 superfinal.unwrap(), ),
146 );
147 ifst.delete_final_weight_unchecked(state);
148 }
149 } else {
150 unsafe {
151 ifst.set_final_unchecked(state, final_tr.weight);
153 }
154 }
155 }
156 }
157 MapFinalAction::MapRequireSuperfinal => {
158 if Some(state) != superfinal
159 && (final_tr.ilabel != EPS_LABEL
160 || final_tr.olabel != EPS_LABEL
161 || !final_tr.weight.is_zero())
162 {
163 unsafe {
164 ifst.add_tr_unchecked(
166 state,
167 Tr::new(
168 final_tr.ilabel,
169 final_tr.olabel,
170 final_tr.weight,
171 superfinal.unwrap(),
172 ),
173 );
174 ifst.delete_final_weight_unchecked(state);
175 }
176 }
177 }
178 };
179 }
180 }
181
182 ifst.set_properties_with_mask(mapper.properties(inprops), FstProperties::all_properties());
183
184 Ok(())
185}