rustfst/fst_impls/const_fst/
serializable_fst.rs

1use std::io::Write;
2use std::sync::Arc;
3
4use anyhow::Result;
5use itertools::Itertools;
6use nom::bytes::complete::take;
7use nom::multi::count;
8use nom::IResult;
9
10use crate::fst_impls::const_fst::data_structure::ConstState;
11use crate::fst_impls::const_fst::{
12    CONST_ALIGNED_FILE_VERSION, CONST_ARCH_ALIGNMENT, CONST_FILE_VERSION, CONST_MIN_FILE_VERSION,
13};
14use crate::fst_impls::ConstFst;
15use crate::fst_properties::FstProperties;
16use crate::fst_traits::{ExpandedFst, Fst, SerializableFst};
17use crate::parsers::bin_fst::fst_header::{FstFlags, FstHeader, OpenFstString, FST_MAGIC_NUMBER};
18use crate::parsers::bin_fst::utils_parsing::{
19    parse_bin_fst_tr, parse_final_weight, parse_start_state,
20};
21use crate::parsers::nom_utils::NomCustomError;
22use crate::parsers::parse_bin_i32;
23use crate::parsers::text_fst::ParsedTextFst;
24use crate::parsers::write_bin_i32;
25use crate::semirings::SerializableSemiring;
26use crate::{Tr, EPS_LABEL};
27
28impl<W: SerializableSemiring> SerializableFst<W> for ConstFst<W> {
29    fn fst_type() -> String {
30        "const".to_string()
31    }
32
33    fn load(data: &[u8]) -> Result<Self> {
34        let (_, parsed_fst) = parse_const_fst(data)
35            .map_err(|_| format_err!("Error while parsing binary ConstFst"))?;
36
37        Ok(parsed_fst)
38    }
39
40    fn store<O: Write>(&self, mut output: O) -> Result<()> {
41        let mut flags = FstFlags::empty();
42        if self.input_symbols().is_some() {
43            flags |= FstFlags::HAS_ISYMBOLS;
44        }
45        if self.output_symbols().is_some() {
46            flags |= FstFlags::HAS_OSYMBOLS;
47        }
48
49        let hdr = FstHeader {
50            magic_number: FST_MAGIC_NUMBER,
51            fst_type: OpenFstString::new(Self::fst_type()),
52            tr_type: OpenFstString::new(Tr::<W>::tr_type()),
53            version: CONST_FILE_VERSION,
54            // TODO: Set flags if the content is aligned
55            flags,
56            properties: self.properties.bits() | ConstFst::<W>::static_properties(),
57            start: self.start.map(|v| v as i64).unwrap_or(-1),
58            num_states: self.num_states() as i64,
59            num_trs: self.trs.len() as i64,
60            isymt: self.input_symbols().cloned(),
61            osymt: self.output_symbols().cloned(),
62        };
63        hdr.write(&mut output)?;
64
65        let zero = W::zero();
66        for const_state in &self.states {
67            let f_weight = const_state.final_weight.as_ref().unwrap_or(&zero);
68            f_weight.write_binary(&mut output)?;
69
70            write_bin_i32(&mut output, const_state.pos as i32)?;
71            write_bin_i32(&mut output, const_state.ntrs as i32)?;
72            write_bin_i32(&mut output, const_state.niepsilons as i32)?;
73            write_bin_i32(&mut output, const_state.noepsilons as i32)?;
74        }
75
76        for tr in &*self.trs {
77            write_bin_i32(&mut output, tr.ilabel as i32)?;
78            write_bin_i32(&mut output, tr.olabel as i32)?;
79            tr.weight.write_binary(&mut output)?;
80            write_bin_i32(&mut output, tr.nextstate as i32)?;
81        }
82
83        Ok(())
84    }
85
86    fn from_parsed_fst_text(mut parsed_fst_text: ParsedTextFst<W>) -> Result<Self> {
87        let start_state = parsed_fst_text.start();
88        let num_states = parsed_fst_text.num_states();
89        let num_trs = parsed_fst_text.transitions.len();
90
91        let mut const_states = Vec::with_capacity(num_states);
92        let mut const_trs = Vec::with_capacity(num_trs);
93
94        parsed_fst_text.transitions.sort_by_key(|v| v.state);
95        for (_state, tr_iterator) in parsed_fst_text
96            .transitions
97            .into_iter()
98            .group_by(|v| v.state)
99            .into_iter()
100        {
101            let pos = const_trs.len();
102            // Some states might not have outgoing trs.
103            const_states.resize_with(_state as usize, || ConstState {
104                final_weight: None,
105                pos,
106                ntrs: 0,
107                niepsilons: 0,
108                noepsilons: 0,
109            });
110            let mut niepsilons = 0;
111            let mut noepsilons = 0;
112            const_trs.extend(tr_iterator.map(|v| {
113                debug_assert_eq!(_state, v.state);
114                let tr = Tr {
115                    ilabel: v.ilabel,
116                    olabel: v.olabel,
117                    weight: v.weight.unwrap_or_else(W::one),
118                    nextstate: v.nextstate,
119                };
120                if tr.ilabel == EPS_LABEL {
121                    niepsilons += 1;
122                }
123                if tr.olabel == EPS_LABEL {
124                    noepsilons += 1;
125                }
126                tr
127            }));
128            let num_trs_this_state = const_trs.len() - pos;
129            const_states.push(ConstState::<W> {
130                final_weight: None,
131                pos,
132                ntrs: num_trs_this_state,
133                niepsilons,
134                noepsilons,
135            })
136        }
137        const_states.resize_with(num_states, || ConstState {
138            final_weight: None,
139            pos: const_trs.len(),
140            ntrs: 0,
141            niepsilons: 0,
142            noepsilons: 0,
143        });
144        debug_assert_eq!(num_states, const_states.len());
145        for final_state in parsed_fst_text.final_states.into_iter() {
146            let weight = final_state.weight.unwrap_or_else(W::one);
147            unsafe {
148                const_states
149                    .get_unchecked_mut(final_state.state as usize)
150                    .final_weight = Some(weight)
151            };
152        }
153
154        // Trick to compute the FstProperties. Indeed we need a fst to compute the properties
155        // and we need the properties to construct a fst...
156        let mut fst = ConstFst {
157            states: const_states,
158            trs: Arc::new(const_trs),
159            start: start_state,
160            isymt: None,
161            osymt: None,
162            properties: FstProperties::empty(),
163        };
164
165        let mut known = FstProperties::empty();
166        let properties = crate::fst_properties::compute_fst_properties(
167            &fst,
168            FstProperties::all_properties(),
169            &mut known,
170            false,
171        )?;
172        fst.properties = properties;
173
174        Ok(fst)
175    }
176}
177
178fn parse_const_state<W: SerializableSemiring>(
179    i: &[u8],
180) -> IResult<&[u8], ConstState<W>, NomCustomError<&[u8]>> {
181    let (i, final_weight) = W::parse_binary(i)?;
182    let (i, pos) = parse_bin_i32(i)?;
183    let (i, ntrs) = parse_bin_i32(i)?;
184    let (i, niepsilons) = parse_bin_i32(i)?;
185    let (i, noepsilons) = parse_bin_i32(i)?;
186
187    Ok((
188        i,
189        ConstState {
190            final_weight: parse_final_weight(final_weight),
191            pos: pos as usize,
192            ntrs: ntrs as usize,
193            niepsilons: niepsilons as usize,
194            noepsilons: noepsilons as usize,
195        },
196    ))
197}
198
199fn parse_const_fst<W: SerializableSemiring>(
200    i: &[u8],
201) -> IResult<&[u8], ConstFst<W>, NomCustomError<&[u8]>> {
202    let stream_len = i.len();
203
204    let (mut i, hdr) = FstHeader::parse(
205        i,
206        CONST_MIN_FILE_VERSION,
207        ConstFst::<W>::fst_type(),
208        Tr::<W>::tr_type(),
209    )?;
210    let aligned = hdr.version == CONST_ALIGNED_FILE_VERSION;
211    let pos = stream_len - i.len();
212
213    // Align input
214    if aligned && hdr.num_states > 0 && pos % CONST_ARCH_ALIGNMENT > 0 {
215        i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0;
216    }
217    let (mut i, const_states) = count(parse_const_state, hdr.num_states as usize)(i)?;
218    let pos = stream_len - i.len();
219
220    // Align input
221    if aligned && hdr.num_trs > 0 && pos % CONST_ARCH_ALIGNMENT > 0 {
222        i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0;
223    }
224    let (i, const_trs) = count(parse_bin_fst_tr, hdr.num_trs as usize)(i)?;
225
226    Ok((
227        i,
228        ConstFst {
229            start: parse_start_state(hdr.start),
230            states: const_states,
231            trs: Arc::new(const_trs),
232            isymt: hdr.isymt,
233            osymt: hdr.osymt,
234            properties: FstProperties::from_bits_truncate(hdr.properties),
235        },
236    ))
237}