rustfst/fst_impls/const_fst/
serializable_fst.rs1use std::io::Write;
2use std::sync::Arc;
3
4use anyhow::Result;
5use itertools::Itertools;
6use nom::bytes::complete::take;
7use nom::multi::count;
8use nom::IResult;
9
10use crate::fst_impls::const_fst::data_structure::ConstState;
11use crate::fst_impls::const_fst::{
12 CONST_ALIGNED_FILE_VERSION, CONST_ARCH_ALIGNMENT, CONST_FILE_VERSION, CONST_MIN_FILE_VERSION,
13};
14use crate::fst_impls::ConstFst;
15use crate::fst_properties::FstProperties;
16use crate::fst_traits::{ExpandedFst, Fst, SerializableFst};
17use crate::parsers::bin_fst::fst_header::{FstFlags, FstHeader, OpenFstString, FST_MAGIC_NUMBER};
18use crate::parsers::bin_fst::utils_parsing::{
19 parse_bin_fst_tr, parse_final_weight, parse_start_state,
20};
21use crate::parsers::nom_utils::NomCustomError;
22use crate::parsers::parse_bin_i32;
23use crate::parsers::text_fst::ParsedTextFst;
24use crate::parsers::write_bin_i32;
25use crate::semirings::SerializableSemiring;
26use crate::{Tr, EPS_LABEL};
27
28impl<W: SerializableSemiring> SerializableFst<W> for ConstFst<W> {
29 fn fst_type() -> String {
30 "const".to_string()
31 }
32
33 fn load(data: &[u8]) -> Result<Self> {
34 let (_, parsed_fst) = parse_const_fst(data)
35 .map_err(|_| format_err!("Error while parsing binary ConstFst"))?;
36
37 Ok(parsed_fst)
38 }
39
40 fn store<O: Write>(&self, mut output: O) -> Result<()> {
41 let mut flags = FstFlags::empty();
42 if self.input_symbols().is_some() {
43 flags |= FstFlags::HAS_ISYMBOLS;
44 }
45 if self.output_symbols().is_some() {
46 flags |= FstFlags::HAS_OSYMBOLS;
47 }
48
49 let hdr = FstHeader {
50 magic_number: FST_MAGIC_NUMBER,
51 fst_type: OpenFstString::new(Self::fst_type()),
52 tr_type: OpenFstString::new(Tr::<W>::tr_type()),
53 version: CONST_FILE_VERSION,
54 flags,
56 properties: self.properties.bits() | ConstFst::<W>::static_properties(),
57 start: self.start.map(|v| v as i64).unwrap_or(-1),
58 num_states: self.num_states() as i64,
59 num_trs: self.trs.len() as i64,
60 isymt: self.input_symbols().cloned(),
61 osymt: self.output_symbols().cloned(),
62 };
63 hdr.write(&mut output)?;
64
65 let zero = W::zero();
66 for const_state in &self.states {
67 let f_weight = const_state.final_weight.as_ref().unwrap_or(&zero);
68 f_weight.write_binary(&mut output)?;
69
70 write_bin_i32(&mut output, const_state.pos as i32)?;
71 write_bin_i32(&mut output, const_state.ntrs as i32)?;
72 write_bin_i32(&mut output, const_state.niepsilons as i32)?;
73 write_bin_i32(&mut output, const_state.noepsilons as i32)?;
74 }
75
76 for tr in &*self.trs {
77 write_bin_i32(&mut output, tr.ilabel as i32)?;
78 write_bin_i32(&mut output, tr.olabel as i32)?;
79 tr.weight.write_binary(&mut output)?;
80 write_bin_i32(&mut output, tr.nextstate as i32)?;
81 }
82
83 Ok(())
84 }
85
86 fn from_parsed_fst_text(mut parsed_fst_text: ParsedTextFst<W>) -> Result<Self> {
87 let start_state = parsed_fst_text.start();
88 let num_states = parsed_fst_text.num_states();
89 let num_trs = parsed_fst_text.transitions.len();
90
91 let mut const_states = Vec::with_capacity(num_states);
92 let mut const_trs = Vec::with_capacity(num_trs);
93
94 parsed_fst_text.transitions.sort_by_key(|v| v.state);
95 for (_state, tr_iterator) in parsed_fst_text
96 .transitions
97 .into_iter()
98 .group_by(|v| v.state)
99 .into_iter()
100 {
101 let pos = const_trs.len();
102 const_states.resize_with(_state as usize, || ConstState {
104 final_weight: None,
105 pos,
106 ntrs: 0,
107 niepsilons: 0,
108 noepsilons: 0,
109 });
110 let mut niepsilons = 0;
111 let mut noepsilons = 0;
112 const_trs.extend(tr_iterator.map(|v| {
113 debug_assert_eq!(_state, v.state);
114 let tr = Tr {
115 ilabel: v.ilabel,
116 olabel: v.olabel,
117 weight: v.weight.unwrap_or_else(W::one),
118 nextstate: v.nextstate,
119 };
120 if tr.ilabel == EPS_LABEL {
121 niepsilons += 1;
122 }
123 if tr.olabel == EPS_LABEL {
124 noepsilons += 1;
125 }
126 tr
127 }));
128 let num_trs_this_state = const_trs.len() - pos;
129 const_states.push(ConstState::<W> {
130 final_weight: None,
131 pos,
132 ntrs: num_trs_this_state,
133 niepsilons,
134 noepsilons,
135 })
136 }
137 const_states.resize_with(num_states, || ConstState {
138 final_weight: None,
139 pos: const_trs.len(),
140 ntrs: 0,
141 niepsilons: 0,
142 noepsilons: 0,
143 });
144 debug_assert_eq!(num_states, const_states.len());
145 for final_state in parsed_fst_text.final_states.into_iter() {
146 let weight = final_state.weight.unwrap_or_else(W::one);
147 unsafe {
148 const_states
149 .get_unchecked_mut(final_state.state as usize)
150 .final_weight = Some(weight)
151 };
152 }
153
154 let mut fst = ConstFst {
157 states: const_states,
158 trs: Arc::new(const_trs),
159 start: start_state,
160 isymt: None,
161 osymt: None,
162 properties: FstProperties::empty(),
163 };
164
165 let mut known = FstProperties::empty();
166 let properties = crate::fst_properties::compute_fst_properties(
167 &fst,
168 FstProperties::all_properties(),
169 &mut known,
170 false,
171 )?;
172 fst.properties = properties;
173
174 Ok(fst)
175 }
176}
177
178fn parse_const_state<W: SerializableSemiring>(
179 i: &[u8],
180) -> IResult<&[u8], ConstState<W>, NomCustomError<&[u8]>> {
181 let (i, final_weight) = W::parse_binary(i)?;
182 let (i, pos) = parse_bin_i32(i)?;
183 let (i, ntrs) = parse_bin_i32(i)?;
184 let (i, niepsilons) = parse_bin_i32(i)?;
185 let (i, noepsilons) = parse_bin_i32(i)?;
186
187 Ok((
188 i,
189 ConstState {
190 final_weight: parse_final_weight(final_weight),
191 pos: pos as usize,
192 ntrs: ntrs as usize,
193 niepsilons: niepsilons as usize,
194 noepsilons: noepsilons as usize,
195 },
196 ))
197}
198
199fn parse_const_fst<W: SerializableSemiring>(
200 i: &[u8],
201) -> IResult<&[u8], ConstFst<W>, NomCustomError<&[u8]>> {
202 let stream_len = i.len();
203
204 let (mut i, hdr) = FstHeader::parse(
205 i,
206 CONST_MIN_FILE_VERSION,
207 ConstFst::<W>::fst_type(),
208 Tr::<W>::tr_type(),
209 )?;
210 let aligned = hdr.version == CONST_ALIGNED_FILE_VERSION;
211 let pos = stream_len - i.len();
212
213 if aligned && hdr.num_states > 0 && pos % CONST_ARCH_ALIGNMENT > 0 {
215 i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0;
216 }
217 let (mut i, const_states) = count(parse_const_state, hdr.num_states as usize)(i)?;
218 let pos = stream_len - i.len();
219
220 if aligned && hdr.num_trs > 0 && pos % CONST_ARCH_ALIGNMENT > 0 {
222 i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0;
223 }
224 let (i, const_trs) = count(parse_bin_fst_tr, hdr.num_trs as usize)(i)?;
225
226 Ok((
227 i,
228 ConstFst {
229 start: parse_start_state(hdr.start),
230 states: const_states,
231 trs: Arc::new(const_trs),
232 isymt: hdr.isymt,
233 osymt: hdr.osymt,
234 properties: FstProperties::from_bits_truncate(hdr.properties),
235 },
236 ))
237}