Skip to main content

hashtree_index/
search.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use hashtree_core::Cid;
5
6use crate::{BTree, BTreeError, BTreeOptions};
7
8#[derive(Debug, Clone, Default)]
9pub struct SearchIndexOptions {
10    pub order: Option<usize>,
11    pub stop_words: Option<HashSet<String>>,
12    pub min_keyword_length: Option<usize>,
13}
14
15#[derive(Debug, Clone, Default)]
16pub struct SearchOptions {
17    pub limit: Option<usize>,
18    pub full_match: bool,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct SearchResult {
23    pub id: String,
24    pub value: String,
25    pub score: usize,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct SearchLinkResult {
30    pub id: String,
31    pub cid: Cid,
32    pub score: usize,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum SearchError {
37    #[error("btree error: {0}")]
38    BTree(#[from] BTreeError),
39    #[error("{0}")]
40    Validation(String),
41}
42
43const DEFAULT_MIN_KEYWORD_LENGTH: usize = 2;
44
45pub struct SearchIndex<S: hashtree_core::Store> {
46    btree: BTree<S>,
47    stop_words: HashSet<String>,
48    min_keyword_length: usize,
49}
50
51impl<S: hashtree_core::Store> SearchIndex<S> {
52    pub fn new(store: Arc<S>, options: SearchIndexOptions) -> Self {
53        Self {
54            btree: BTree::new(
55                store,
56                BTreeOptions {
57                    order: options.order,
58                },
59            ),
60            stop_words: options.stop_words.unwrap_or_else(default_stop_words),
61            min_keyword_length: options
62                .min_keyword_length
63                .unwrap_or(DEFAULT_MIN_KEYWORD_LENGTH),
64        }
65    }
66
67    pub fn parse_keywords(&self, text: &str) -> Vec<String> {
68        if text.is_empty() {
69            return Vec::new();
70        }
71
72        let mut keywords = Vec::new();
73        let mut seen = HashSet::new();
74
75        for raw_word in text
76            .split(|character: char| !character.is_alphanumeric())
77            .filter(|token| !token.is_empty())
78        {
79            for word in expand_keyword_variants(raw_word) {
80                if word.chars().count() < self.min_keyword_length
81                    || self.stop_words.contains(&word)
82                    || is_pure_number(&word)
83                    || !seen.insert(word.clone())
84                {
85                    continue;
86                }
87                keywords.push(word);
88            }
89        }
90
91        keywords
92    }
93
94    pub async fn index(
95        &self,
96        root: Option<&Cid>,
97        prefix: &str,
98        terms: &[String],
99        id: &str,
100        value: &str,
101    ) -> Result<Cid, SearchError> {
102        let mut new_root = root.cloned();
103        for term in terms {
104            new_root = Some(
105                self.btree
106                    .insert(new_root.as_ref(), &format!("{prefix}{term}:{id}"), value)
107                    .await?,
108            );
109        }
110
111        new_root.ok_or_else(|| {
112            SearchError::Validation("search index requires at least one term".to_string())
113        })
114    }
115
116    pub async fn remove(
117        &self,
118        root: &Cid,
119        prefix: &str,
120        terms: &[String],
121        id: &str,
122    ) -> Result<Option<Cid>, SearchError> {
123        let mut new_root = Some(root.clone());
124        for term in terms {
125            let Some(active_root) = new_root.as_ref() else {
126                break;
127            };
128            new_root = self
129                .btree
130                .delete(active_root, &format!("{prefix}{term}:{id}"))
131                .await?;
132        }
133        Ok(new_root)
134    }
135
136    pub async fn search(
137        &self,
138        root: Option<&Cid>,
139        prefix: &str,
140        query: &str,
141        options: SearchOptions,
142    ) -> Result<Vec<SearchResult>, SearchError> {
143        let Some(root) = root else {
144            return Ok(Vec::new());
145        };
146
147        let limit = options.limit.unwrap_or(20);
148        if limit == 0 {
149            return Ok(Vec::new());
150        }
151
152        let keywords = self.parse_keywords(query);
153        if keywords.is_empty() {
154            return Ok(Vec::new());
155        }
156
157        #[derive(Debug)]
158        struct Aggregate {
159            value: String,
160            score: usize,
161            exact_matches: usize,
162            prefix_distance: usize,
163        }
164
165        let mut results = HashMap::<String, Aggregate>::new();
166        for keyword in keywords {
167            let search_prefix = if options.full_match {
168                format!("{prefix}{keyword}:")
169            } else {
170                format!("{prefix}{keyword}")
171            };
172            let mut count = 0usize;
173            for (key, value) in self.btree.prefix(root, &search_prefix).await? {
174                if count >= limit.saturating_mul(2) {
175                    break;
176                }
177                count += 1;
178
179                let Some((term, id)) = decode_search_key(prefix, &key) else {
180                    continue;
181                };
182                let aggregate = results.entry(id).or_insert_with(|| Aggregate {
183                    value,
184                    score: 0,
185                    exact_matches: 0,
186                    prefix_distance: 0,
187                });
188                aggregate.score += 1;
189                if term == keyword {
190                    aggregate.exact_matches += 1;
191                }
192                aggregate.prefix_distance += term.len().saturating_sub(keyword.len());
193            }
194        }
195
196        let mut sorted = results.into_iter().collect::<Vec<_>>();
197        sorted.sort_by(|left, right| {
198            let left_data = &left.1;
199            let right_data = &right.1;
200            right_data
201                .score
202                .cmp(&left_data.score)
203                .then(right_data.exact_matches.cmp(&left_data.exact_matches))
204                .then(left_data.prefix_distance.cmp(&right_data.prefix_distance))
205                .then(left.0.cmp(&right.0))
206        });
207        sorted.truncate(limit);
208        Ok(sorted
209            .into_iter()
210            .map(|(id, aggregate)| SearchResult {
211                id,
212                value: aggregate.value,
213                score: aggregate.score,
214            })
215            .collect())
216    }
217
218    pub async fn merge(
219        &self,
220        base: Option<&Cid>,
221        other: Option<&Cid>,
222        prefer_other: bool,
223    ) -> Result<Option<Cid>, SearchError> {
224        Ok(self.btree.merge(base, other, prefer_other).await?)
225    }
226
227    pub async fn build_links<I>(&self, items: I) -> Result<Option<Cid>, SearchError>
228    where
229        I: IntoIterator<Item = (String, Cid)>,
230    {
231        Ok(self.btree.build_links(items).await?)
232    }
233
234    pub async fn index_link(
235        &self,
236        root: Option<&Cid>,
237        prefix: &str,
238        terms: &[String],
239        id: &str,
240        target_cid: &Cid,
241    ) -> Result<Cid, SearchError> {
242        let mut new_root = root.cloned();
243        for term in terms {
244            new_root = Some(
245                self.btree
246                    .insert_link(
247                        new_root.as_ref(),
248                        &format!("{prefix}{term}:{id}"),
249                        target_cid,
250                    )
251                    .await?,
252            );
253        }
254
255        new_root.ok_or_else(|| {
256            SearchError::Validation("search index requires at least one term".to_string())
257        })
258    }
259
260    pub async fn remove_link(
261        &self,
262        root: &Cid,
263        prefix: &str,
264        terms: &[String],
265        id: &str,
266    ) -> Result<Option<Cid>, SearchError> {
267        let mut new_root = Some(root.clone());
268        for term in terms {
269            let Some(active_root) = new_root.as_ref() else {
270                break;
271            };
272            new_root = self
273                .btree
274                .delete(active_root, &format!("{prefix}{term}:{id}"))
275                .await?;
276        }
277        Ok(new_root)
278    }
279
280    pub async fn search_links(
281        &self,
282        root: Option<&Cid>,
283        prefix: &str,
284        query: &str,
285        options: SearchOptions,
286    ) -> Result<Vec<SearchLinkResult>, SearchError> {
287        let Some(root) = root else {
288            return Ok(Vec::new());
289        };
290
291        let limit = options.limit.unwrap_or(20);
292        if limit == 0 {
293            return Ok(Vec::new());
294        }
295
296        let keywords = self.parse_keywords(query);
297        if keywords.is_empty() {
298            return Ok(Vec::new());
299        }
300
301        #[derive(Debug)]
302        struct Aggregate {
303            cid: Cid,
304            score: usize,
305            exact_matches: usize,
306            prefix_distance: usize,
307        }
308
309        let mut results = HashMap::<String, Aggregate>::new();
310        for keyword in keywords {
311            let search_prefix = if options.full_match {
312                format!("{prefix}{keyword}:")
313            } else {
314                format!("{prefix}{keyword}")
315            };
316            let mut count = 0usize;
317            for (key, cid) in self.btree.prefix_links(root, &search_prefix).await? {
318                if count >= limit.saturating_mul(2) {
319                    break;
320                }
321                count += 1;
322
323                let Some((term, id)) = decode_search_key(prefix, &key) else {
324                    continue;
325                };
326                let aggregate = results.entry(id).or_insert_with(|| Aggregate {
327                    cid,
328                    score: 0,
329                    exact_matches: 0,
330                    prefix_distance: 0,
331                });
332                aggregate.score += 1;
333                if term == keyword {
334                    aggregate.exact_matches += 1;
335                }
336                aggregate.prefix_distance += term.len().saturating_sub(keyword.len());
337            }
338        }
339
340        let mut sorted = results.into_iter().collect::<Vec<_>>();
341        sorted.sort_by(|left, right| {
342            let left_data = &left.1;
343            let right_data = &right.1;
344            right_data
345                .score
346                .cmp(&left_data.score)
347                .then(right_data.exact_matches.cmp(&left_data.exact_matches))
348                .then(left_data.prefix_distance.cmp(&right_data.prefix_distance))
349                .then(left.0.cmp(&right.0))
350        });
351        sorted.truncate(limit);
352        Ok(sorted
353            .into_iter()
354            .map(|(id, aggregate)| SearchLinkResult {
355                id,
356                cid: aggregate.cid,
357                score: aggregate.score,
358            })
359            .collect())
360    }
361
362    pub async fn merge_links(
363        &self,
364        base: Option<&Cid>,
365        other: Option<&Cid>,
366        prefer_other: bool,
367    ) -> Result<Option<Cid>, SearchError> {
368        Ok(self.btree.merge_links(base, other, prefer_other).await?)
369    }
370}
371
372fn decode_search_key(prefix: &str, key: &str) -> Option<(String, String)> {
373    if !key.starts_with(prefix) {
374        return None;
375    }
376    let after_prefix = &key[prefix.len()..];
377    let colon_index = after_prefix.find(':')?;
378    Some((
379        after_prefix[..colon_index].to_string(),
380        after_prefix[colon_index + 1..].to_string(),
381    ))
382}
383
384fn expand_keyword_variants(raw_word: &str) -> Vec<String> {
385    let mut variants = Vec::new();
386    let normalized = raw_word.to_lowercase();
387    if !normalized.is_empty() {
388        variants.push(normalized);
389    }
390
391    for segment in split_keyword_segments(raw_word) {
392        let normalized_segment = segment.to_lowercase();
393        if normalized_segment.is_empty()
394            || variants
395                .iter()
396                .any(|existing| existing == &normalized_segment)
397        {
398            continue;
399        }
400        variants.push(normalized_segment);
401    }
402
403    variants
404}
405
406fn split_keyword_segments(raw_word: &str) -> Vec<String> {
407    let chars = raw_word.chars().collect::<Vec<_>>();
408    if chars.is_empty() {
409        return Vec::new();
410    }
411
412    let mut parts = Vec::new();
413    let mut start = 0usize;
414    for index in 1..chars.len() {
415        let previous = chars[index - 1];
416        let current = chars[index];
417        let next = chars.get(index + 1).copied();
418        if is_keyword_boundary(previous, current, next) {
419            parts.push(chars[start..index].iter().collect::<String>());
420            start = index;
421        }
422    }
423    parts.push(chars[start..].iter().collect::<String>());
424    parts
425}
426
427fn is_keyword_boundary(previous: char, current: char, next: Option<char>) -> bool {
428    (previous.is_lowercase() && current.is_uppercase())
429        || (previous.is_alphabetic() && current.is_numeric())
430        || (previous.is_numeric() && current.is_alphabetic())
431        || (previous.is_uppercase()
432            && current.is_uppercase()
433            && next.is_some_and(|next| next.is_lowercase()))
434}
435
436fn is_pure_number(word: &str) -> bool {
437    if !word.bytes().all(|byte| byte.is_ascii_digit()) {
438        return false;
439    }
440    !(word.len() == 4 && (word.starts_with("19") || word.starts_with("20")))
441}
442
443fn default_stop_words() -> HashSet<String> {
444    [
445        "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
446        "from", "is", "it", "as", "be", "was", "are", "this", "that", "these", "those", "i", "you",
447        "he", "she", "we", "they", "my", "your", "his", "her", "its", "our", "their", "what",
448        "which", "who", "whom", "how", "when", "where", "why", "will", "would", "could", "should",
449        "can", "may", "might", "must", "have", "has", "had", "do", "does", "did", "been", "being",
450        "get", "got", "just", "now", "then", "so", "if", "not", "no", "yes", "all", "any", "some",
451        "more", "most", "other", "into", "over", "after", "before", "about", "up", "down", "out",
452        "off", "through", "during", "under", "again", "further", "once",
453    ]
454    .into_iter()
455    .map(str::to_string)
456    .collect()
457}