kweepeer/
lib.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4use tracing::info;
5
6pub mod api;
7pub mod apidocs;
8pub mod lexer;
9pub mod modules;
10
11#[cfg(feature = "analiticcl")]
12use modules::analiticcl::{AnaliticclConfig, AnaliticclModule};
13
14#[cfg(feature = "fst")]
15use modules::fst::{FstConfig, FstModule};
16
17use modules::lookup::{LookupConfig, LookupModule};
18
19#[cfg(feature = "finalfusion")]
20use modules::finalfusion::{FinalFusionConfig, FinalFusionModule};
21
22use modules::Module;
23
24pub use lexer::Term;
25
26/// Maps a term to expansions, each `TermExpansion` corresponds to one source/module and may itself contain multiple expansions
27pub type TermExpansions = HashMap<String, Vec<TermExpansion>>;
28
29#[derive(Default)]
30pub struct QueryExpander {
31    config: Config,
32    modules: Vec<Box<dyn Module>>,
33    initialised: bool,
34}
35
36#[derive(Deserialize, Default)]
37#[serde(default)]
38pub struct Config {
39    lookup: Vec<LookupConfig>,
40
41    #[cfg(feature = "analiticcl")]
42    analiticcl: Vec<AnaliticclConfig>,
43
44    #[cfg(feature = "fst")]
45    fst: Vec<FstConfig>,
46
47    #[cfg(feature = "finalfusion")]
48    finalfusion: Vec<FinalFusionConfig>,
49}
50
51impl QueryExpander {
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    pub fn with_config(mut self, config: Config) -> Self {
57        self.config = config;
58        self
59    }
60
61    /// Adds a new module. Only valid before call to `load()`, will panic afterwards.
62    pub fn add_module(&mut self, module: Box<dyn Module>) {
63        if self.initialised {
64            panic!("Can not add modules after load()!")
65        }
66        self.modules.push(module);
67    }
68
69    /// Adds a new module. Only valid before call to `load()`, will panic afterwards.
70    pub fn with_module(mut self, module: Box<dyn Module>) -> Self {
71        self.add_module(module);
72        self
73    }
74
75    /// Returns an iterator over all the modules
76    pub fn modules(&self) -> impl Iterator<Item = &dyn Module> {
77        self.modules.iter().map(|x| x.as_ref())
78    }
79
80    /// Initialise all modules. This should be called once after all modules are loaded. Will panic if called multiple times.
81    pub fn load(&mut self) -> Result<(), Error> {
82        if self.initialised {
83            panic!("load() can only be called once");
84        }
85        //MAYBE TODO: we could parallellize the loading for quicker startup time
86        for lookupconfig in self.config.lookup.iter() {
87            info!(
88                "Adding Lookup module {} - {}",
89                lookupconfig.id(),
90                lookupconfig.name()
91            );
92            let mut module = LookupModule::new(lookupconfig.clone());
93            module.load()?;
94            self.modules.push(Box::new(module));
95        }
96
97        #[cfg(feature = "fst")]
98        for fstconfig in self.config.fst.iter() {
99            info!(
100                "Adding Fst module {} - {}",
101                fstconfig.id(),
102                fstconfig.name()
103            );
104            let mut module = FstModule::new(fstconfig.clone());
105            module.load()?;
106            self.modules.push(Box::new(module));
107        }
108
109        #[cfg(feature = "analiticcl")]
110        for analiticclconfig in self.config.analiticcl.iter() {
111            info!(
112                "Adding Analiticcl module {} - {}",
113                analiticclconfig.id(),
114                analiticclconfig.name()
115            );
116            let mut module = AnaliticclModule::new(analiticclconfig.clone());
117            module.load()?;
118            self.modules.push(Box::new(module));
119        }
120        #[cfg(feature = "finalfusion")]
121        for finalfusionconfig in self.config.finalfusion.iter() {
122            info!(
123                "Adding Finalfusion module {} - {}",
124                finalfusionconfig.id(),
125                finalfusionconfig.name()
126            );
127            let mut module = FinalFusionModule::new(finalfusionconfig.clone());
128            module.load()?;
129            self.modules.push(Box::new(module));
130        }
131
132        info!("All modules loaded");
133        self.initialised = true;
134        Ok(())
135    }
136
137    pub fn expand_query(
138        &self,
139        terms: &Vec<Term>,
140        params: &QueryParams,
141    ) -> Result<TermExpansions, Error> {
142        let mut terms_map = TermExpansions::new();
143        self.expand_query_into(&mut terms_map, terms, params)?;
144        Ok(terms_map)
145    }
146
147    pub fn expand_query_into(
148        &self,
149        terms_map: &mut TermExpansions,
150        terms: &Vec<Term>,
151        params: &QueryParams,
152    ) -> Result<(), Error> {
153        let excludemods: Vec<_> = if let Some(mods) = params.get("", "excludemods") {
154            value_to_str_array(mods)
155        } else {
156            Vec::new()
157        };
158        let includemods: Vec<_> = if let Some(mods) = params.get("", "includemods") {
159            value_to_str_array(mods)
160        } else {
161            Vec::new()
162        };
163        for module in self.modules() {
164            if (excludemods.is_empty() || !excludemods.contains(&module.id()))
165                || (includemods.is_empty() || includemods.contains(&module.id()))
166            {
167                let mut expansion_map = module.expand_query(terms, params)?;
168                for term in terms.iter() {
169                    terms_map
170                        .entry(term.as_str().to_string())
171                        .and_modify(|expansions| {
172                            if let Some(expansions2) = expansion_map.remove(term.as_str()) {
173                                for expansion in expansions2 {
174                                    expansions.push(expansion);
175                                }
176                            }
177                        })
178                        .or_insert_with(|| {
179                            if let Some(expansions2) = expansion_map.remove(term.as_str()) {
180                                expansions2
181                            } else {
182                                vec![]
183                            }
184                        });
185                }
186            }
187        }
188        Ok(())
189    }
190
191    /// Resolve a query template by substituting the template terms by the disjunctions from query expansion
192    /// You won't really need to call this yourself.
193    pub fn resolve_query_template(
194        &self,
195        query_template: &str,
196        terms_map: &TermExpansions,
197    ) -> Result<String, Error> {
198        let mut query = String::with_capacity(query_template.len());
199        let mut termbegin = None;
200        let mut termend = None;
201        let mut prevc = None;
202        let mut expansioncache = std::collections::HashSet::<&str>::new();
203        for (i, c) in query_template.char_indices() {
204            if c == '{' && prevc == Some('{') {
205                termbegin = Some(i + 1);
206            }
207            if c == '}' && prevc == Some('}') && termbegin.is_some() {
208                if let Some(termend) = termend {
209                    query += &query_template[termend + 2..termbegin.unwrap() - 2];
210                }
211                termend = Some(i - 1);
212                let term = &query_template[termbegin.unwrap()..termend.unwrap()];
213                if let Some(termexpansions) = terms_map.get(term) {
214                    expansioncache.clear();
215                    for termexpansion in termexpansions {
216                        let mut first = true;
217                        for expansion in termexpansion.iter() {
218                            if !expansioncache.contains(expansion) {
219                                if !first {
220                                    query += "\" OR \"";
221                                } else {
222                                    if !expansioncache.is_empty() {
223                                        query += " OR ";
224                                    }
225                                    query += "(\"";
226                                }
227                                first = false;
228                                query += expansion;
229                                expansioncache.insert(expansion);
230                            }
231                        }
232                        if !first {
233                            query += "\")";
234                        }
235                    }
236                }
237                //reset
238                termbegin = None;
239            }
240            prevc = Some(c);
241        }
242        if let Some(termend) = termend {
243            query += &query_template[termend + 2..];
244        }
245        Ok(query)
246    }
247}
248
249/// convert a json array of strings to a rust Vec<&str>
250fn value_to_str_array(input: &Value) -> Vec<&str> {
251    if let Value::Array(array) = input {
252        let mut array_out = Vec::with_capacity(array.len());
253        for value in array {
254            if let Value::String(s) = value {
255                array_out.push(s.as_str());
256            }
257        }
258        array_out
259    } else {
260        Vec::new()
261    }
262}
263
264#[derive(Debug, Serialize, Default, Clone)]
265pub struct TermExpansion {
266    expansions: Vec<String>,
267    scores: Vec<f64>,
268    source_id: Option<String>,
269    source_name: Option<String>,
270    source_type: &'static str,
271    link: Option<String>,
272}
273
274impl TermExpansion {
275    pub fn with_source(mut self, module: &impl Module) -> Self {
276        self.source_id = Some(module.id().into());
277        self.source_name = Some(module.name().into());
278        self.source_type = module.kind();
279        self
280    }
281
282    pub fn with_link(mut self, link: impl Into<String>) -> Self {
283        self.link = Some(link.into());
284        self
285    }
286
287    pub fn with_expansions(mut self, expansions: Vec<String>) -> Self {
288        self.expansions = expansions;
289        self
290    }
291
292    pub fn with_scores(mut self, scores: Vec<f64>) -> Self {
293        self.scores = scores;
294        self
295    }
296
297    pub fn add_variant_with_score(&mut self, expansion: impl Into<String>, score: f64) {
298        self.expansions.push(expansion.into());
299        self.scores.push(score);
300    }
301
302    pub fn add_variant(&mut self, expansion: impl Into<String>) {
303        self.expansions.push(expansion.into());
304    }
305
306    pub fn expansions(&self) -> &Vec<String> {
307        &self.expansions
308    }
309
310    pub fn scores(&self) -> &Vec<f64> {
311        &self.scores
312    }
313
314    pub fn source_id(&self) -> Option<&str> {
315        self.source_id.as_deref()
316    }
317
318    pub fn source_name(&self) -> Option<&str> {
319        self.source_name.as_deref()
320    }
321
322    pub fn link(&self) -> Option<&str> {
323        self.link.as_deref()
324    }
325
326    pub fn len(&self) -> usize {
327        self.expansions.len()
328    }
329
330    pub fn iter(&self) -> impl Iterator<Item = &str> {
331        self.expansions.iter().map(|x| x.as_str())
332    }
333
334    pub fn as_vec(&self) -> &Vec<String> {
335        &self.expansions
336    }
337}
338
339#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
340pub struct QueryParam {
341    module_id: String,
342    key: String,
343    value: Value,
344}
345
346impl QueryParam {
347    pub fn module_id(&self) -> &str {
348        self.module_id.as_str()
349    }
350
351    pub fn key(&self) -> &str {
352        self.key.as_str()
353    }
354
355    pub fn value(&self) -> &Value {
356        &self.value
357    }
358}
359
360#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
361/// Holds arbitrary parameters passed to queries at runtime when requesting expansion
362// The implementation uses a simple vec to save ourselves HashMap overhead.
363pub struct QueryParams(Vec<QueryParam>);
364
365impl QueryParams {
366    pub fn new() -> Self {
367        Self::default()
368    }
369
370    /// Insert a new key and value (builder pattern)
371    pub fn with(
372        mut self,
373        module_id: impl Into<String>,
374        key: impl Into<String>,
375        value: Value,
376    ) -> Self {
377        self.insert(module_id, key, value);
378        self
379    }
380
381    /// Insert a new key and value
382    /// By convention, we use an empty module_id for a global scope.
383    pub fn insert(&mut self, module_id: impl Into<String>, key: impl Into<String>, value: Value) {
384        self.0.push(QueryParam {
385            module_id: module_id.into(),
386            key: key.into(),
387            value,
388        });
389    }
390
391    /// Check if a key exists. By convention, we use an empty module_id for a global scope.
392    pub fn contains(&self, module_id: &str, key: &str) -> bool {
393        for param in self.iter() {
394            if param.module_id() == module_id && param.key() == key {
395                return true;
396            }
397        }
398        false
399    }
400
401    /// Iterate over all keys and values
402    pub fn iter<'a>(&'a self) -> impl Iterator<Item = &'a QueryParam> {
403        self.0.iter()
404    }
405
406    /// Iterate over all keys and values
407    pub fn iter_for_module<'a>(
408        &'a self,
409        module_id: &'a str,
410    ) -> impl Iterator<Item = &'a QueryParam> {
411        self.0
412            .iter()
413            .filter(move |param| param.module_id() == module_id)
414    }
415
416    /// Retrieve a value by key
417    /// By convention, we use an empty module_id for a global scope.
418    pub fn get<'a>(&'a self, module_id: &str, key: &str) -> Option<&'a Value> {
419        for param in self.iter() {
420            if param.module_id() == module_id && param.key() == key {
421                return Some(param.value());
422            }
423        }
424        None
425    }
426}
427
428impl From<&HashMap<String, String>> for QueryParams {
429    fn from(map: &HashMap<String, String>) -> Self {
430        let mut result = QueryParams::new();
431        for (key, value) in map.iter() {
432            let splitkey: Vec<_> = key.splitn(2, key).collect();
433            if splitkey.len() == 1 {
434                result.insert("", key, value.to_owned().into());
435            } else {
436                result.insert(splitkey[0], splitkey[1], value.to_owned().into());
437            }
438        }
439        result
440    }
441}
442
443#[derive(Debug, Clone)]
444pub enum Error {
445    LoadError(String),
446    QueryExpandError(String),
447}
448
449impl std::fmt::Display for Error {
450    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451        match self {
452            Self::LoadError(x) => {
453                f.write_str("[Load error] ")?;
454                f.write_str(x)
455            }
456            Self::QueryExpandError(x) => {
457                f.write_str("[Query expansion error] ")?;
458                f.write_str(x)
459            }
460        }
461    }
462}
463
464impl From<std::io::Error> for Error {
465    fn from(value: std::io::Error) -> Self {
466        Self::LoadError(format!("{}", value))
467    }
468}
469
470impl Serialize for Error {
471    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
472    where
473        S: serde::Serializer,
474    {
475        match self {
476            Self::LoadError(s) | Self::QueryExpandError(s) => serializer.serialize_str(s.as_str()),
477        }
478    }
479}