rls_analysis/
symbol_query.rs1use fst::{self, Streamer};
2
3#[derive(Debug)]
14pub struct SymbolQuery {
15 query_string: String,
16 mode: Mode,
17 limit: usize,
18 greater_than: String,
19}
20
21#[derive(Debug, Clone, Copy)]
22enum Mode {
23 Prefix,
24 Subsequence,
25}
26
27impl SymbolQuery {
28 fn new(query_string: String, mode: Mode) -> SymbolQuery {
29 SymbolQuery { query_string, mode, limit: usize::max_value(), greater_than: String::new() }
30 }
31
32 pub fn subsequence(query_string: &str) -> SymbolQuery {
33 SymbolQuery::new(query_string.to_lowercase(), Mode::Subsequence)
34 }
35
36 pub fn prefix(query_string: &str) -> SymbolQuery {
37 SymbolQuery::new(query_string.to_lowercase(), Mode::Prefix)
38 }
39
40 pub fn limit(self, limit: usize) -> SymbolQuery {
41 SymbolQuery { limit, ..self }
42 }
43
44 pub fn greater_than(self, greater_than: &str) -> SymbolQuery {
45 SymbolQuery { greater_than: greater_than.to_lowercase(), ..self }
46 }
47
48 pub(crate) fn build_stream<'a, I>(&'a self, fsts: I) -> fst::map::Union<'a>
49 where
50 I: Iterator<Item = &'a fst::Map<Vec<u8>>>,
51 {
52 let mut stream = fst::map::OpBuilder::new();
53 let automaton = QueryAutomaton { query: &self.query_string, mode: self.mode };
54 for fst in fsts {
55 stream = stream.add(fst.search(automaton).gt(&self.greater_than));
56 }
57 stream.union()
58 }
59
60 pub(crate) fn search_stream<F, T>(&self, mut stream: fst::map::Union<'_>, f: F) -> Vec<T>
61 where
62 F: Fn(&mut Vec<T>, &fst::map::IndexedValue),
63 {
64 let mut res = Vec::new();
65 while let Some((_, entries)) = stream.next() {
66 for e in entries {
67 f(&mut res, e);
68 }
69 if res.len() >= self.limit {
70 break;
71 }
72 }
73 res
74 }
75}
76
77#[derive(Clone, Copy)]
87struct QueryAutomaton<'a> {
88 query: &'a str,
89 mode: Mode,
90}
91
92const NO_MATCH: usize = !0;
93
94impl<'a> fst::Automaton for QueryAutomaton<'a> {
95 type State = usize;
96
97 fn start(&self) -> usize {
98 0
99 }
100
101 fn is_match(&self, &state: &usize) -> bool {
102 state == self.query.len()
103 }
104
105 fn accept(&self, &state: &usize, byte: u8) -> usize {
106 if state == NO_MATCH {
107 return state;
108 }
109 if state == self.query.len() {
110 return state;
111 }
112 if byte == self.query.as_bytes()[state] {
113 return state + 1;
114 }
115 match self.mode {
116 Mode::Prefix => NO_MATCH,
117 Mode::Subsequence => state,
118 }
119 }
120
121 fn can_match(&self, &state: &usize) -> bool {
122 state != NO_MATCH
123 }
124
125 fn will_always_match(&self, &state: &usize) -> bool {
126 state == self.query.len()
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::iter;
134
135 const STARS: &[&str] = &[
136 "agena", "agreetor", "algerib", "anektor", "antares", "arcturus", "canopus", "capella",
137 "duendin", "golubin", "lalandry", "spica", "vega",
138 ];
139
140 fn check(q: SymbolQuery, expected: &[&str]) {
141 let map =
142 fst::Map::from_iter(STARS.iter().enumerate().map(|(i, &s)| (s, i as u64))).unwrap();
143 let stream = q.build_stream(iter::once(&map));
144 let actual = q.search_stream(stream, |acc, iv| acc.push(STARS[iv.value as usize]));
145 assert_eq!(expected, actual.as_slice());
146 }
147
148 #[test]
149 fn test_automaton() {
150 check(SymbolQuery::prefix("an"), &["anektor", "antares"]);
151
152 check(
153 SymbolQuery::subsequence("an"),
154 &["agena", "anektor", "antares", "canopus", "lalandry"],
155 );
156
157 check(SymbolQuery::subsequence("an").limit(2), &["agena", "anektor"]);
158 check(
159 SymbolQuery::subsequence("an").limit(2).greater_than("anektor"),
160 &["antares", "canopus"],
161 );
162 check(SymbolQuery::subsequence("an").limit(2).greater_than("canopus"), &["lalandry"]);
163 }
164}