rustfst/algorithms/lazy/
lazy_fst.rs

1use std::collections::VecDeque;
2use std::fmt::Debug;
3use std::iter::{repeat, Map, Repeat, Zip};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::path::Path;
7use std::sync::Arc;
8
9use anyhow::Result;
10use itertools::izip;
11use unsafe_unwrap::UnsafeUnwrap;
12
13use crate::algorithms::lazy::cache::CacheStatus;
14use crate::algorithms::lazy::fst_op::{AccessibleOpState, FstOp, SerializableOpState};
15use crate::algorithms::lazy::{FstCache, SerializableCache};
16use crate::fst_properties::FstProperties;
17use crate::fst_traits::{
18    AllocableFst, CoreFst, Fst, FstIterData, FstIterator, MutableFst, StateIterator,
19};
20use crate::semirings::{Semiring, SerializableSemiring};
21use crate::{StateId, SymbolTable, Trs, TrsVec};
22
23#[derive(Debug, Clone)]
24pub struct LazyFst<W: Semiring, Op: FstOp<W>, Cache> {
25    cache: Cache,
26    pub(crate) op: Op,
27    w: PhantomData<W>,
28    isymt: Option<Arc<SymbolTable>>,
29    osymt: Option<Arc<SymbolTable>>,
30}
31
32impl<W: Semiring, Op: FstOp<W>, Cache: FstCache<W>> CoreFst<W> for LazyFst<W, Op, Cache> {
33    type TRS = TrsVec<W>;
34
35    fn start(&self) -> Option<StateId> {
36        match self.cache.get_start() {
37            CacheStatus::Computed(start) => start,
38            CacheStatus::NotComputed => {
39                // TODO: Need to return a Result
40                let start = self.op.compute_start().unwrap();
41                self.cache.insert_start(start);
42                start
43            }
44        }
45    }
46
47    fn final_weight(&self, state_id: StateId) -> Result<Option<W>> {
48        match self.cache.get_final_weight(state_id) {
49            CacheStatus::Computed(final_weight) => Ok(final_weight),
50            CacheStatus::NotComputed => {
51                let final_weight = self.op.compute_final_weight(state_id)?;
52                self.cache
53                    .insert_final_weight(state_id, final_weight.clone());
54                Ok(final_weight)
55            }
56        }
57    }
58
59    unsafe fn final_weight_unchecked(&self, state_id: StateId) -> Option<W> {
60        self.final_weight(state_id).unsafe_unwrap()
61    }
62
63    fn num_trs(&self, s: StateId) -> Result<usize> {
64        self.cache
65            .num_trs(s)
66            .ok_or_else(|| format_err!("State {:?} doesn't exist", s))
67    }
68
69    unsafe fn num_trs_unchecked(&self, s: StateId) -> usize {
70        self.cache.num_trs(s).unsafe_unwrap()
71    }
72
73    fn get_trs(&self, state_id: StateId) -> Result<Self::TRS> {
74        match self.cache.get_trs(state_id) {
75            CacheStatus::Computed(trs) => Ok(trs),
76            CacheStatus::NotComputed => {
77                let trs = self.op.compute_trs(state_id)?;
78                self.cache.insert_trs(state_id, trs.shallow_clone());
79                Ok(trs)
80            }
81        }
82    }
83
84    unsafe fn get_trs_unchecked(&self, state_id: StateId) -> Self::TRS {
85        self.get_trs(state_id).unsafe_unwrap()
86    }
87
88    fn properties(&self) -> FstProperties {
89        self.op.properties()
90    }
91
92    fn num_input_epsilons(&self, state: StateId) -> Result<usize> {
93        self.cache
94            .num_input_epsilons(state)
95            .ok_or_else(|| format_err!("State {:?} doesn't exist", state))
96    }
97
98    fn num_output_epsilons(&self, state: StateId) -> Result<usize> {
99        self.cache
100            .num_output_epsilons(state)
101            .ok_or_else(|| format_err!("State {:?} doesn't exist", state))
102    }
103}
104
105impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst<W, Op, Cache>
106where
107    W: Semiring,
108    Op: FstOp<W> + 'a,
109    Cache: FstCache<W> + 'a,
110{
111    type Iter = StatesIteratorLazyFst<'a, Self>;
112
113    fn states_iter(&'a self) -> Self::Iter {
114        self.start();
115        StatesIteratorLazyFst { fst: self, s: 0 }
116    }
117}
118
119#[derive(Clone)]
120pub struct StatesIteratorLazyFst<'a, T> {
121    pub(crate) fst: &'a T,
122    pub(crate) s: StateId,
123}
124
125impl<'a, W, Op, Cache> Iterator for StatesIteratorLazyFst<'a, LazyFst<W, Op, Cache>>
126where
127    W: Semiring,
128    Op: FstOp<W>,
129    Cache: FstCache<W>,
130{
131    type Item = StateId;
132
133    fn next(&mut self) -> Option<Self::Item> {
134        let num_known_states = self.fst.cache.num_known_states();
135        if (self.s as usize) < num_known_states {
136            let s_cur = self.s;
137            // Force expansion of the state
138            self.fst.get_trs(self.s).unwrap();
139            self.s += 1;
140            Some(s_cur)
141        } else {
142            None
143        }
144    }
145}
146
147type ZipIter<'a, W, Op, Cache, SELF> =
148    Zip<<LazyFst<W, Op, Cache> as StateIterator<'a>>::Iter, Repeat<&'a SELF>>;
149type MapFunction<'a, W, SELF, TRS> = Box<dyn FnMut((StateId, &'a SELF)) -> FstIterData<W, TRS>>;
150type MapIter<'a, W, Op, Cache, SELF, TRS> =
151    Map<ZipIter<'a, W, Op, Cache, SELF>, MapFunction<'a, W, SELF, TRS>>;
152
153impl<'a, W, Op, Cache> FstIterator<'a, W> for LazyFst<W, Op, Cache>
154where
155    W: Semiring,
156    Op: FstOp<W> + 'a,
157    Cache: FstCache<W> + 'a,
158{
159    type FstIter = MapIter<'a, W, Op, Cache, Self, Self::TRS>;
160
161    fn fst_iter(&'a self) -> Self::FstIter {
162        let it = repeat(self);
163        izip!(self.states_iter(), it).map(Box::new(|(state_id, p): (StateId, &'a Self)| {
164            FstIterData {
165                state_id,
166                trs: unsafe { p.get_trs_unchecked(state_id) },
167                final_weight: unsafe { p.final_weight_unchecked(state_id) },
168                num_trs: unsafe { p.num_trs_unchecked(state_id) },
169            }
170        }))
171    }
172}
173
174impl<W, Op, Cache> Fst<W> for LazyFst<W, Op, Cache>
175where
176    W: Semiring,
177    Op: FstOp<W> + 'static,
178    Cache: FstCache<W> + 'static,
179{
180    fn input_symbols(&self) -> Option<&Arc<SymbolTable>> {
181        self.isymt.as_ref()
182    }
183
184    fn output_symbols(&self) -> Option<&Arc<SymbolTable>> {
185        self.osymt.as_ref()
186    }
187
188    fn set_input_symbols(&mut self, symt: Arc<SymbolTable>) {
189        self.isymt = Some(symt);
190    }
191
192    fn set_output_symbols(&mut self, symt: Arc<SymbolTable>) {
193        self.osymt = Some(symt);
194    }
195
196    fn take_input_symbols(&mut self) -> Option<Arc<SymbolTable>> {
197        self.isymt.take()
198    }
199
200    fn take_output_symbols(&mut self) -> Option<Arc<SymbolTable>> {
201        self.osymt.take()
202    }
203}
204
205impl<W, Op, Cache> LazyFst<W, Op, Cache>
206where
207    W: Semiring,
208    Op: FstOp<W>,
209    Cache: FstCache<W>,
210{
211    pub fn from_op_and_cache(
212        op: Op,
213        cache: Cache,
214        isymt: Option<Arc<SymbolTable>>,
215        osymt: Option<Arc<SymbolTable>>,
216    ) -> Self {
217        Self {
218            op,
219            cache,
220            isymt,
221            osymt,
222            w: PhantomData,
223        }
224    }
225
226    /// Turns the Lazy FST into a static one.
227    pub fn compute<F2: MutableFst<W> + AllocableFst<W>>(&self) -> Result<F2> {
228        let start_state = self.start();
229        let mut fst_out = F2::new();
230        let start_state = match start_state {
231            Some(s) => s,
232            None => return Ok(fst_out),
233        };
234        fst_out.add_states(start_state as usize + 1);
235        fst_out.set_start(start_state)?;
236        let mut queue = VecDeque::new();
237        let mut visited_states = vec![];
238        visited_states.resize(start_state as usize + 1, false);
239        visited_states[start_state as usize] = true;
240        queue.push_back(start_state);
241        while let Some(s) = queue.pop_front() {
242            let trs_owner = self.get_trs(s)?;
243            for tr in trs_owner.trs() {
244                if (tr.nextstate as usize) >= visited_states.len() {
245                    visited_states.resize(tr.nextstate as usize + 1, false);
246                }
247                if !visited_states[tr.nextstate as usize] {
248                    queue.push_back(tr.nextstate);
249                    visited_states[tr.nextstate as usize] = true;
250                }
251                let n = fst_out.num_states();
252                if (tr.nextstate as usize) >= n {
253                    fst_out.add_states(tr.nextstate as usize - n + 1)
254                }
255            }
256            unsafe { fst_out.set_trs_unchecked(s, trs_owner.trs().to_vec()) };
257            if let Some(f_w) = self.final_weight(s)? {
258                fst_out.set_final(s, f_w)?;
259            }
260        }
261        fst_out.set_properties(self.properties());
262
263        if let Some(isymt) = &self.isymt {
264            fst_out.set_input_symbols(Arc::clone(isymt));
265        }
266        if let Some(osymt) = &self.osymt {
267            fst_out.set_output_symbols(Arc::clone(osymt));
268        }
269        Ok(fst_out)
270    }
271}
272
273impl<W, Op, Cache> SerializableLazyFst for LazyFst<W, Op, Cache>
274where
275    W: SerializableSemiring,
276    Op: FstOp<W> + AccessibleOpState,
277    Op::FstOpState: SerializableOpState,
278    Cache: FstCache<W> + SerializableCache,
279{
280    /// Writes LazyFst interal states to a directory of files in binary format.
281    fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()> {
282        self.cache.write(cache_dir)?;
283        self.op.get_op_state().write(op_state_dir)?;
284        Ok(())
285    }
286}
287
288pub trait SerializableLazyFst {
289    /// Writes LazyFst interal states to a directory of files in binary format.
290    fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()>;
291}
292
293impl<C: SerializableLazyFst, CP: Deref<Target = C> + Debug> SerializableLazyFst for CP {
294    fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()> {
295        self.deref().write(cache_dir, op_state_dir)
296    }
297}