kweepeer/modules/
fst.rs

1use serde::Deserialize;
2use std::borrow::Cow;
3use std::fs::File;
4use std::io::{prelude::*, BufReader};
5use std::path::PathBuf;
6use tracing::{debug, info};
7
8use fst::automaton::Levenshtein;
9use fst::{IntoStreamer, Set, SetBuilder};
10
11use crate::lexer::Term;
12use crate::modules::Module;
13use crate::{Error, QueryParams, TermExpansion, TermExpansions};
14
15/// A simple hash-map-based lookup module
16/// mapping keywords to variants.
17pub struct FstModule {
18    config: FstConfig,
19    set: Set<Vec<u8>>,
20}
21
22#[derive(Debug, Deserialize, Clone)]
23pub struct FstConfig {
24    /// Short identifier
25    id: String,
26
27    /// Human readable label
28    name: String,
29
30    /// The path to the lexicon (a simple wordlist, one word per line, the entries *MUST* be in lexographical order!
31    file: PathBuf,
32
33    /// Default Levenshtein distance for lookups,
34    distance: u8,
35
36    /// Is the lexicon already sorted lexographically? If it is, setting this to true improves loading time/memory consumption
37    #[serde(default)]
38    sorted: bool,
39
40    /// Set this if the first line is a header
41    #[serde(default)]
42    skipfirstline: bool,
43
44    #[serde(default)]
45    casesensitive: bool,
46}
47
48impl FstConfig {
49    pub fn new(
50        id: impl Into<String>,
51        name: impl Into<String>,
52        file: impl Into<PathBuf>,
53        distance: u8,
54        sorted: bool,
55    ) -> Self {
56        Self {
57            id: id.into(),
58            name: name.into(),
59            file: file.into(),
60            distance,
61            sorted,
62            skipfirstline: false,
63            casesensitive: false,
64        }
65    }
66
67    pub fn with_distance(mut self, distance: u8) -> Self {
68        self.distance = distance;
69        self
70    }
71
72    pub fn with_skipfirstline(mut self) -> Self {
73        self.skipfirstline = true;
74        self
75    }
76
77    pub fn id(&self) -> &str {
78        self.id.as_str()
79    }
80
81    pub fn name(&self) -> &str {
82        self.name.as_str()
83    }
84}
85
86impl FstModule {
87    pub fn new(config: FstConfig) -> Self {
88        Self {
89            config,
90            set: Set::default(),
91        }
92    }
93}
94
95impl Module for FstModule {
96    fn id(&self) -> &str {
97        self.config.id.as_str()
98    }
99
100    fn name(&self) -> &str {
101        self.config.name.as_str()
102    }
103
104    fn kind(&self) -> &'static str {
105        "fst"
106    }
107
108    fn load(&mut self) -> Result<(), Error> {
109        info!("Loading lexicon {}", self.config.file.as_path().display());
110        let file = File::open(self.config.file.as_path()).map_err(|e| {
111            Error::LoadError(format!(
112                "Fst Module could not open {}: {}",
113                self.config.file.as_path().display(),
114                e
115            ))
116        })?;
117        let mut buffer = String::new();
118        let mut reader = BufReader::new(file);
119        let mut firstline = true;
120        let mut builder = SetBuilder::memory();
121        let mut entries: Vec<String> = Vec::new();
122        while let Ok(bytes) = reader.read_line(&mut buffer) {
123            if bytes == 0 {
124                //EOF
125                break;
126            }
127            if firstline {
128                firstline = false;
129                if self.config.skipfirstline {
130                    buffer.clear();
131                    continue;
132                }
133            }
134            if buffer.chars().next() != Some('#') {
135                if let Some(line) = buffer.trim().splitn(2, '\t').next() {
136                    if !line.is_empty() {
137                        if self.config.sorted {
138                            if self.config.casesensitive {
139                                builder.insert(line.as_bytes())?;
140                            } else {
141                                builder.insert(line.to_lowercase().as_bytes())?;
142                            }
143                        } else {
144                            if self.config.casesensitive {
145                                entries.push(line.to_owned());
146                            } else {
147                                entries.push(line.to_lowercase().to_owned());
148                            }
149                        }
150                    }
151                }
152            }
153            buffer.clear();
154        }
155        if !entries.is_empty() {
156            entries.sort();
157            for entry in entries {
158                if self.config.casesensitive {
159                    builder.insert(entry.as_bytes())?;
160                } else {
161                    builder.insert(entry.to_lowercase().as_bytes())?;
162                }
163            }
164        }
165        info!("Building FST");
166        self.set = Set::new(builder.into_inner()?)?;
167        Ok(())
168    }
169
170    fn expand_query(
171        &self,
172        terms: &Vec<Term>,
173        params: &QueryParams,
174    ) -> Result<TermExpansions, Error> {
175        let distance = if let Some(param) = params.get(self.id(), "distance") {
176            param.as_u64().ok_or_else(|| {
177                Error::QueryExpandError("invalid value for distance parameter".into())
178            })? as u32
179        } else {
180            self.config.distance as u32
181        };
182        let mut expansions = TermExpansions::new();
183        for term in terms {
184            let term = if self.config.casesensitive {
185                Cow::Borrowed(term.as_str())
186            } else {
187                Cow::Owned(term.as_str().to_lowercase())
188            };
189            match Levenshtein::new(term.as_ref(), distance) {
190                Ok(levaut) => {
191                    debug!("Looking up {}", term);
192                    let stream = self.set.search(levaut).into_stream();
193                    if let Ok(variants) = stream.into_strs() {
194                        if !variants.is_empty() {
195                            debug!("found {} expansions", variants.len());
196                            expansions.insert(
197                                term.into_owned(),
198                                vec![TermExpansion::default()
199                                    .with_source(self)
200                                    .with_expansions(variants.to_vec())],
201                            );
202                        } else {
203                            debug!("not found");
204                        }
205                    } else {
206                        debug!("UTF-8 decoding error, no results returned");
207                    }
208                }
209                Err(e) => debug!("Can't build FST for term '{}': {}", term, e),
210            }
211        }
212        Ok(expansions)
213    }
214}
215
216impl From<fst::Error> for Error {
217    fn from(value: fst::Error) -> Self {
218        Self::LoadError(format!("{}", value))
219    }
220}
221
222impl From<fst::automaton::LevenshteinError> for Error {
223    fn from(value: fst::automaton::LevenshteinError) -> Self {
224        Self::QueryExpandError(format!("{}", value))
225    }
226}
227
228mod tests {
229    use super::*;
230
231    fn init_test() -> Result<FstModule, Error> {
232        let mut testdir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
233        testdir.push("test");
234        let mut lexicon_file = testdir.clone();
235        lexicon_file.push("test.nofreq.lexicon");
236        let config = FstConfig {
237            id: "fst".into(),
238            name: "fst".into(),
239            file: lexicon_file,
240            distance: 2,
241            sorted: false,
242            skipfirstline: false,
243            casesensitive: true,
244        };
245        Ok(FstModule::new(config))
246    }
247
248    #[test]
249    pub fn test001_lookup_load() -> Result<(), Error> {
250        let mut module = init_test()?;
251        module.load()?;
252        Ok(())
253    }
254
255    #[test]
256    pub fn test002_lookup_query() -> Result<(), Error> {
257        let mut module = init_test()?;
258        module.load()?;
259        let terms = vec![Term::Singular("belangrijk")];
260        let expansions = module.expand_query(&terms, &QueryParams::new())?;
261        assert_eq!(expansions.len(), 1, "Checking number of terms returned");
262        let termexpansion = expansions
263            .get("belangrijk")
264            .expect("term must exists")
265            .get(0)
266            .expect("term must have results");
267        assert_eq!(termexpansion.source_id(), Some("fst"), "Checking source id");
268        assert_eq!(
269            termexpansion.source_name(),
270            Some("fst"),
271            "Checking source name"
272        );
273        assert_eq!(
274            termexpansion.iter().collect::<Vec<_>>(),
275            [
276                "belangrijk",
277                "belangrijke",
278                "belangrijker",
279                "belangrijks",
280                "belangrijkst",
281                "onbelangrijk"
282            ],
283            "Checking returned expansions"
284        );
285        Ok(())
286    }
287
288    #[test]
289    pub fn test002_lookup_query_nomatch() -> Result<(), Error> {
290        let mut module = init_test()?;
291        module.load()?;
292        let terms = vec![Term::Singular("blah")];
293        let expansions = module.expand_query(&terms, &QueryParams::new())?;
294        assert_eq!(expansions.len(), 0, "Checking number of terms returned");
295        Ok(())
296    }
297}