Skip to main content

hashtree_index/
search.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use hashtree_core::Cid;
5
6use crate::{BTree, BTreeError, BTreeOptions};
7
8#[derive(Debug, Clone, Default)]
9pub struct SearchIndexOptions {
10    pub order: Option<usize>,
11    pub stop_words: Option<HashSet<String>>,
12    pub min_keyword_length: Option<usize>,
13}
14
15#[derive(Debug, Clone, Default)]
16pub struct SearchOptions {
17    pub limit: Option<usize>,
18    pub full_match: bool,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct SearchResult {
23    pub id: String,
24    pub value: String,
25    pub score: usize,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct SearchLinkResult {
30    pub id: String,
31    pub cid: Cid,
32    pub score: usize,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum SearchError {
37    #[error("btree error: {0}")]
38    BTree(#[from] BTreeError),
39    #[error("{0}")]
40    Validation(String),
41}
42
43const DEFAULT_MIN_KEYWORD_LENGTH: usize = 2;
44
45pub struct SearchIndex<S: hashtree_core::Store> {
46    btree: BTree<S>,
47    stop_words: HashSet<String>,
48    min_keyword_length: usize,
49}
50
51impl<S: hashtree_core::Store> SearchIndex<S> {
52    pub fn new(store: Arc<S>, options: SearchIndexOptions) -> Self {
53        Self {
54            btree: BTree::new(
55                store,
56                BTreeOptions {
57                    order: options.order,
58                },
59            ),
60            stop_words: options.stop_words.unwrap_or_else(default_stop_words),
61            min_keyword_length: options
62                .min_keyword_length
63                .unwrap_or(DEFAULT_MIN_KEYWORD_LENGTH),
64        }
65    }
66
67    pub fn parse_keywords(&self, text: &str) -> Vec<String> {
68        if text.is_empty() {
69            return Vec::new();
70        }
71
72        let mut keywords = Vec::new();
73        let mut seen = HashSet::new();
74
75        for raw_word in text
76            .split(|character: char| !character.is_alphanumeric())
77            .filter(|token| !token.is_empty())
78        {
79            for word in expand_keyword_variants(raw_word) {
80                if word.chars().count() < self.min_keyword_length
81                    || self.stop_words.contains(&word)
82                    || is_pure_number(&word)
83                    || !seen.insert(word.clone())
84                {
85                    continue;
86                }
87                keywords.push(word);
88            }
89        }
90
91        keywords
92    }
93
94    pub async fn index(
95        &self,
96        root: Option<&Cid>,
97        prefix: &str,
98        terms: &[String],
99        id: &str,
100        value: &str,
101    ) -> Result<Cid, SearchError> {
102        let mut new_root = root.cloned();
103        for term in terms {
104            new_root = Some(
105                self.btree
106                    .insert(new_root.as_ref(), &format!("{prefix}{term}:{id}"), value)
107                    .await?,
108            );
109        }
110
111        new_root.ok_or_else(|| {
112            SearchError::Validation("search index requires at least one term".to_string())
113        })
114    }
115
116    pub async fn remove(
117        &self,
118        root: &Cid,
119        prefix: &str,
120        terms: &[String],
121        id: &str,
122    ) -> Result<Option<Cid>, SearchError> {
123        let mut new_root = Some(root.clone());
124        for term in terms {
125            let Some(active_root) = new_root.as_ref() else {
126                break;
127            };
128            new_root = self
129                .btree
130                .delete(active_root, &format!("{prefix}{term}:{id}"))
131                .await?;
132        }
133        Ok(new_root)
134    }
135
136    pub async fn search(
137        &self,
138        root: Option<&Cid>,
139        prefix: &str,
140        query: &str,
141        options: SearchOptions,
142    ) -> Result<Vec<SearchResult>, SearchError> {
143        let Some(root) = root else {
144            return Ok(Vec::new());
145        };
146
147        let limit = options.limit.unwrap_or(20);
148        if limit == 0 {
149            return Ok(Vec::new());
150        }
151
152        let keywords = self.parse_keywords(query);
153        if keywords.is_empty() {
154            return Ok(Vec::new());
155        }
156
157        #[derive(Debug)]
158        struct Aggregate {
159            value: String,
160            score: usize,
161            exact_matches: usize,
162            prefix_distance: usize,
163        }
164
165        let mut results = HashMap::<String, Aggregate>::new();
166        for keyword in keywords {
167            let search_prefix = if options.full_match {
168                format!("{prefix}{keyword}:")
169            } else {
170                format!("{prefix}{keyword}")
171            };
172            for (count, (key, value)) in self
173                .btree
174                .prefix(root, &search_prefix)
175                .await?
176                .into_iter()
177                .enumerate()
178            {
179                if count >= limit.saturating_mul(2) {
180                    break;
181                }
182
183                let Some((term, id)) = decode_search_key(prefix, &key) else {
184                    continue;
185                };
186                let aggregate = results.entry(id).or_insert_with(|| Aggregate {
187                    value,
188                    score: 0,
189                    exact_matches: 0,
190                    prefix_distance: 0,
191                });
192                aggregate.score += 1;
193                if term == keyword {
194                    aggregate.exact_matches += 1;
195                }
196                aggregate.prefix_distance += term.len().saturating_sub(keyword.len());
197            }
198        }
199
200        let mut sorted = results.into_iter().collect::<Vec<_>>();
201        sorted.sort_by(|left, right| {
202            let left_data = &left.1;
203            let right_data = &right.1;
204            right_data
205                .score
206                .cmp(&left_data.score)
207                .then(right_data.exact_matches.cmp(&left_data.exact_matches))
208                .then(left_data.prefix_distance.cmp(&right_data.prefix_distance))
209                .then(left.0.cmp(&right.0))
210        });
211        sorted.truncate(limit);
212        Ok(sorted
213            .into_iter()
214            .map(|(id, aggregate)| SearchResult {
215                id,
216                value: aggregate.value,
217                score: aggregate.score,
218            })
219            .collect())
220    }
221
222    pub async fn merge(
223        &self,
224        base: Option<&Cid>,
225        other: Option<&Cid>,
226        prefer_other: bool,
227    ) -> Result<Option<Cid>, SearchError> {
228        Ok(self.btree.merge(base, other, prefer_other).await?)
229    }
230
231    pub async fn build_links<I>(&self, items: I) -> Result<Option<Cid>, SearchError>
232    where
233        I: IntoIterator<Item = (String, Cid)>,
234    {
235        Ok(self.btree.build_links(items).await?)
236    }
237
238    pub async fn index_link(
239        &self,
240        root: Option<&Cid>,
241        prefix: &str,
242        terms: &[String],
243        id: &str,
244        target_cid: &Cid,
245    ) -> Result<Cid, SearchError> {
246        let mut new_root = root.cloned();
247        for term in terms {
248            new_root = Some(
249                self.btree
250                    .insert_link(
251                        new_root.as_ref(),
252                        &format!("{prefix}{term}:{id}"),
253                        target_cid,
254                    )
255                    .await?,
256            );
257        }
258
259        new_root.ok_or_else(|| {
260            SearchError::Validation("search index requires at least one term".to_string())
261        })
262    }
263
264    pub async fn remove_link(
265        &self,
266        root: &Cid,
267        prefix: &str,
268        terms: &[String],
269        id: &str,
270    ) -> Result<Option<Cid>, SearchError> {
271        let mut new_root = Some(root.clone());
272        for term in terms {
273            let Some(active_root) = new_root.as_ref() else {
274                break;
275            };
276            new_root = self
277                .btree
278                .delete(active_root, &format!("{prefix}{term}:{id}"))
279                .await?;
280        }
281        Ok(new_root)
282    }
283
284    pub async fn search_links(
285        &self,
286        root: Option<&Cid>,
287        prefix: &str,
288        query: &str,
289        options: SearchOptions,
290    ) -> Result<Vec<SearchLinkResult>, SearchError> {
291        let Some(root) = root else {
292            return Ok(Vec::new());
293        };
294
295        let limit = options.limit.unwrap_or(20);
296        if limit == 0 {
297            return Ok(Vec::new());
298        }
299
300        let keywords = self.parse_keywords(query);
301        if keywords.is_empty() {
302            return Ok(Vec::new());
303        }
304
305        #[derive(Debug)]
306        struct Aggregate {
307            cid: Cid,
308            score: usize,
309            exact_matches: usize,
310            prefix_distance: usize,
311        }
312
313        let mut results = HashMap::<String, Aggregate>::new();
314        for keyword in keywords {
315            let search_prefix = if options.full_match {
316                format!("{prefix}{keyword}:")
317            } else {
318                format!("{prefix}{keyword}")
319            };
320            for (count, (key, cid)) in self
321                .btree
322                .prefix_links(root, &search_prefix)
323                .await?
324                .into_iter()
325                .enumerate()
326            {
327                if count >= limit.saturating_mul(2) {
328                    break;
329                }
330
331                let Some((term, id)) = decode_search_key(prefix, &key) else {
332                    continue;
333                };
334                let aggregate = results.entry(id).or_insert_with(|| Aggregate {
335                    cid,
336                    score: 0,
337                    exact_matches: 0,
338                    prefix_distance: 0,
339                });
340                aggregate.score += 1;
341                if term == keyword {
342                    aggregate.exact_matches += 1;
343                }
344                aggregate.prefix_distance += term.len().saturating_sub(keyword.len());
345            }
346        }
347
348        let mut sorted = results.into_iter().collect::<Vec<_>>();
349        sorted.sort_by(|left, right| {
350            let left_data = &left.1;
351            let right_data = &right.1;
352            right_data
353                .score
354                .cmp(&left_data.score)
355                .then(right_data.exact_matches.cmp(&left_data.exact_matches))
356                .then(left_data.prefix_distance.cmp(&right_data.prefix_distance))
357                .then(left.0.cmp(&right.0))
358        });
359        sorted.truncate(limit);
360        Ok(sorted
361            .into_iter()
362            .map(|(id, aggregate)| SearchLinkResult {
363                id,
364                cid: aggregate.cid,
365                score: aggregate.score,
366            })
367            .collect())
368    }
369
370    pub async fn merge_links(
371        &self,
372        base: Option<&Cid>,
373        other: Option<&Cid>,
374        prefer_other: bool,
375    ) -> Result<Option<Cid>, SearchError> {
376        Ok(self.btree.merge_links(base, other, prefer_other).await?)
377    }
378}
379
380fn decode_search_key(prefix: &str, key: &str) -> Option<(String, String)> {
381    if !key.starts_with(prefix) {
382        return None;
383    }
384    let after_prefix = &key[prefix.len()..];
385    let colon_index = after_prefix.find(':')?;
386    Some((
387        after_prefix[..colon_index].to_string(),
388        after_prefix[colon_index + 1..].to_string(),
389    ))
390}
391
392fn expand_keyword_variants(raw_word: &str) -> Vec<String> {
393    let mut variants = Vec::new();
394    let normalized = raw_word.to_lowercase();
395    if !normalized.is_empty() {
396        variants.push(normalized);
397    }
398
399    for segment in split_keyword_segments(raw_word) {
400        let normalized_segment = segment.to_lowercase();
401        if normalized_segment.is_empty()
402            || variants
403                .iter()
404                .any(|existing| existing == &normalized_segment)
405        {
406            continue;
407        }
408        variants.push(normalized_segment);
409    }
410
411    variants
412}
413
414fn split_keyword_segments(raw_word: &str) -> Vec<String> {
415    let chars = raw_word.chars().collect::<Vec<_>>();
416    if chars.is_empty() {
417        return Vec::new();
418    }
419
420    let mut parts = Vec::new();
421    let mut start = 0usize;
422    for index in 1..chars.len() {
423        let previous = chars[index - 1];
424        let current = chars[index];
425        let next = chars.get(index + 1).copied();
426        if is_keyword_boundary(previous, current, next) {
427            parts.push(chars[start..index].iter().collect::<String>());
428            start = index;
429        }
430    }
431    parts.push(chars[start..].iter().collect::<String>());
432    parts
433}
434
435fn is_keyword_boundary(previous: char, current: char, next: Option<char>) -> bool {
436    (previous.is_lowercase() && current.is_uppercase())
437        || (previous.is_alphabetic() && current.is_numeric())
438        || (previous.is_numeric() && current.is_alphabetic())
439        || (previous.is_uppercase()
440            && current.is_uppercase()
441            && next.is_some_and(|next| next.is_lowercase()))
442}
443
444fn is_pure_number(word: &str) -> bool {
445    if !word.bytes().all(|byte| byte.is_ascii_digit()) {
446        return false;
447    }
448    !(word.len() == 4 && (word.starts_with("19") || word.starts_with("20")))
449}
450
451fn default_stop_words() -> HashSet<String> {
452    [
453        "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
454        "from", "is", "it", "as", "be", "was", "are", "this", "that", "these", "those", "i", "you",
455        "he", "she", "we", "they", "my", "your", "his", "her", "its", "our", "their", "what",
456        "which", "who", "whom", "how", "when", "where", "why", "will", "would", "could", "should",
457        "can", "may", "might", "must", "have", "has", "had", "do", "does", "did", "been", "being",
458        "get", "got", "just", "now", "then", "so", "if", "not", "no", "yes", "all", "any", "some",
459        "more", "most", "other", "into", "over", "after", "before", "about", "up", "down", "out",
460        "off", "through", "during", "under", "again", "further", "once",
461    ]
462    .into_iter()
463    .map(str::to_string)
464    .collect()
465}