arbor_graph/
search_index.rs1use crate::graph::NodeId;
7use std::collections::{HashMap, HashSet};
8
9const MIN_NGRAM_LEN: usize = 2;
11
12const MAX_NGRAM_LEN: usize = 4;
14
15#[derive(Debug, Default, Clone)]
21pub struct SearchIndex {
22 exact_index: HashMap<String, Vec<NodeId>>,
24 ngram_index: HashMap<String, HashSet<NodeId>>,
26}
27
28impl SearchIndex {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn insert(&mut self, name: &str, id: NodeId) {
36 let lower = name.to_lowercase();
37
38 self.exact_index.entry(lower.clone()).or_default().push(id);
40
41 for ngram in self.generate_ngrams(&lower) {
43 self.ngram_index.entry(ngram).or_default().insert(id);
44 }
45 }
46
47 pub fn remove(&mut self, name: &str, id: NodeId) {
49 let lower = name.to_lowercase();
50
51 if let Some(ids) = self.exact_index.get_mut(&lower) {
53 ids.retain(|&x| x != id);
54 if ids.is_empty() {
55 self.exact_index.remove(&lower);
56 }
57 }
58
59 for ngram in self.generate_ngrams(&lower) {
61 if let Some(ids) = self.ngram_index.get_mut(&ngram) {
62 ids.remove(&id);
63 if ids.is_empty() {
64 self.ngram_index.remove(&ngram);
65 }
66 }
67 }
68 }
69
70 pub fn search(&self, query: &str) -> Vec<NodeId> {
74 let query_lower = query.to_lowercase();
75
76 if query_lower.len() < MIN_NGRAM_LEN {
78 let mut results: Vec<NodeId> = self
79 .exact_index
80 .iter()
81 .filter(|(name, _)| name.starts_with(&query_lower))
82 .flat_map(|(_, ids)| ids.iter().copied())
83 .collect();
84 results.sort();
85 results.dedup();
86 return results;
87 }
88
89 let query_ngrams: Vec<String> = self.generate_ngrams(&query_lower);
91
92 if query_ngrams.is_empty() {
93 return Vec::new();
94 }
95
96 let mut candidates: Option<HashSet<NodeId>> = None;
98
99 for ngram in &query_ngrams {
100 if let Some(ids) = self.ngram_index.get(ngram) {
101 match &mut candidates {
102 None => candidates = Some(ids.clone()),
103 Some(c) => {
104 c.retain(|id| ids.contains(id));
105 }
106 }
107 } else {
108 return Vec::new();
110 }
111 }
112
113 let mut results: Vec<NodeId> = candidates
116 .unwrap_or_default()
117 .into_iter()
118 .filter(|id| {
119 self.exact_index
120 .iter()
121 .any(|(name, ids)| ids.contains(id) && name.contains(&query_lower))
122 })
123 .collect();
124
125 results.sort();
126 results
127 }
128
129 fn generate_ngrams(&self, s: &str) -> Vec<String> {
131 let chars: Vec<char> = s.chars().collect();
132 let mut ngrams = Vec::new();
133
134 for n in MIN_NGRAM_LEN..=MAX_NGRAM_LEN {
135 if chars.len() >= n {
136 for i in 0..=(chars.len() - n) {
137 ngrams.push(chars[i..i + n].iter().collect());
138 }
139 }
140 }
141
142 ngrams
143 }
144
145 pub fn len(&self) -> usize {
147 self.exact_index.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.exact_index.is_empty()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use petgraph::graph::NodeIndex;
160
161 fn node_id(n: u32) -> NodeId {
162 NodeIndex::new(n as usize)
163 }
164
165 #[test]
166 fn test_insert_and_search_exact() {
167 let mut index = SearchIndex::new();
168 index.insert("validate_user", node_id(0));
169 index.insert("validate_email", node_id(1));
170 index.insert("send_email", node_id(2));
171
172 let results = index.search("validate_user");
173 assert_eq!(results, vec![node_id(0)]);
174 }
175
176 #[test]
177 fn test_search_substring() {
178 let mut index = SearchIndex::new();
179 index.insert("validate_user", node_id(0));
180 index.insert("validate_email", node_id(1));
181 index.insert("send_email", node_id(2));
182
183 let results = index.search("validate");
184 assert!(results.contains(&node_id(0)));
185 assert!(results.contains(&node_id(1)));
186 assert!(!results.contains(&node_id(2)));
187 }
188
189 #[test]
190 fn test_search_case_insensitive() {
191 let mut index = SearchIndex::new();
192 index.insert("ValidateUser", node_id(0));
193
194 let results = index.search("validateuser");
195 assert_eq!(results, vec![node_id(0)]);
196
197 let results = index.search("VALIDATEUSER");
198 assert_eq!(results, vec![node_id(0)]);
199 }
200
201 #[test]
202 fn test_search_middle_substring() {
203 let mut index = SearchIndex::new();
204 index.insert("get_user_profile", node_id(0));
205
206 let results = index.search("user");
207 assert_eq!(results, vec![node_id(0)]);
208
209 let results = index.search("_user_");
210 assert_eq!(results, vec![node_id(0)]);
211 }
212
213 #[test]
214 fn test_remove_from_index() {
215 let mut index = SearchIndex::new();
216 index.insert("foo", node_id(0));
217 index.insert("foobar", node_id(1));
218
219 index.remove("foo", node_id(0));
220
221 let results = index.search("foo");
222 assert!(!results.contains(&node_id(0)));
223 assert!(results.contains(&node_id(1)));
224 }
225
226 #[test]
227 fn test_search_no_match() {
228 let mut index = SearchIndex::new();
229 index.insert("hello", node_id(0));
230
231 let results = index.search("world");
232 assert!(results.is_empty());
233 }
234
235 #[test]
236 fn test_short_query() {
237 let mut index = SearchIndex::new();
238 index.insert("ab", node_id(0));
239 index.insert("abc", node_id(1));
240 index.insert("xyz", node_id(2));
241
242 let results = index.search("a");
244 assert!(results.contains(&node_id(0)));
245 assert!(results.contains(&node_id(1)));
246 assert!(!results.contains(&node_id(2)));
247 }
248}