Skip to main content

hashtree_index/
search.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use hashtree_core::Cid;
5
6use crate::{BTree, BTreeError, BTreeOptions};
7
8#[derive(Debug, Clone, Default)]
9pub struct SearchIndexOptions {
10    pub order: Option<usize>,
11    pub stop_words: Option<HashSet<String>>,
12    pub min_keyword_length: Option<usize>,
13}
14
15#[derive(Debug, Clone, Default)]
16pub struct SearchOptions {
17    pub limit: Option<usize>,
18    pub full_match: bool,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct SearchResult {
23    pub id: String,
24    pub value: String,
25    pub score: usize,
26}
27
28#[derive(Debug, Clone, PartialEq)]
29pub struct SearchLinkResult {
30    pub id: String,
31    pub cid: Cid,
32    pub score: usize,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum SearchError {
37    #[error("btree error: {0}")]
38    BTree(#[from] BTreeError),
39    #[error("{0}")]
40    Validation(String),
41}
42
43const DEFAULT_MIN_KEYWORD_LENGTH: usize = 2;
44
45pub struct SearchIndex<S: hashtree_core::Store> {
46    btree: BTree<S>,
47    stop_words: HashSet<String>,
48    min_keyword_length: usize,
49}
50
51impl<S: hashtree_core::Store> SearchIndex<S> {
52    pub fn new(store: Arc<S>, options: SearchIndexOptions) -> Self {
53        Self {
54            btree: BTree::new(
55                store,
56                BTreeOptions {
57                    order: options.order,
58                },
59            ),
60            stop_words: options.stop_words.unwrap_or_else(default_stop_words),
61            min_keyword_length: options
62                .min_keyword_length
63                .unwrap_or(DEFAULT_MIN_KEYWORD_LENGTH),
64        }
65    }
66
67    pub fn parse_keywords(&self, text: &str) -> Vec<String> {
68        if text.is_empty() {
69            return Vec::new();
70        }
71
72        let mut keywords = Vec::new();
73        let mut seen = HashSet::new();
74
75        for raw_word in text
76            .split(|character: char| !character.is_alphanumeric())
77            .filter(|token| !token.is_empty())
78        {
79            for word in expand_keyword_variants(raw_word) {
80                if word.chars().count() < self.min_keyword_length
81                    || self.stop_words.contains(&word)
82                    || is_pure_number(&word)
83                    || !seen.insert(word.clone())
84                {
85                    continue;
86                }
87                keywords.push(word);
88            }
89        }
90
91        keywords
92    }
93
94    pub async fn index(
95        &self,
96        root: Option<&Cid>,
97        prefix: &str,
98        terms: &[String],
99        id: &str,
100        value: &str,
101    ) -> Result<Cid, SearchError> {
102        let mut new_root = root.cloned();
103        for term in terms {
104            new_root = Some(
105                self.btree
106                    .insert(new_root.as_ref(), &format!("{prefix}{term}:{id}"), value)
107                    .await?,
108            );
109        }
110
111        new_root.ok_or_else(|| {
112            SearchError::Validation("search index requires at least one term".to_string())
113        })
114    }
115
116    pub async fn remove(
117        &self,
118        root: &Cid,
119        prefix: &str,
120        terms: &[String],
121        id: &str,
122    ) -> Result<Option<Cid>, SearchError> {
123        let mut new_root = Some(root.clone());
124        for term in terms {
125            let Some(active_root) = new_root.as_ref() else {
126                break;
127            };
128            new_root = self
129                .btree
130                .delete(active_root, &format!("{prefix}{term}:{id}"))
131                .await?;
132        }
133        Ok(new_root)
134    }
135
136    pub async fn search(
137        &self,
138        root: Option<&Cid>,
139        prefix: &str,
140        query: &str,
141        options: SearchOptions,
142    ) -> Result<Vec<SearchResult>, SearchError> {
143        let Some(root) = root else {
144            return Ok(Vec::new());
145        };
146
147        let limit = options.limit.unwrap_or(20);
148        if limit == 0 {
149            return Ok(Vec::new());
150        }
151
152        let keywords = self.parse_keywords(query);
153        if keywords.is_empty() {
154            return Ok(Vec::new());
155        }
156
157        #[derive(Debug)]
158        struct Aggregate {
159            value: String,
160            score: usize,
161            exact_matches: usize,
162            prefix_distance: usize,
163        }
164
165        let mut results = HashMap::<String, Aggregate>::new();
166        for keyword in keywords {
167            let search_prefix = if options.full_match {
168                format!("{prefix}{keyword}:")
169            } else {
170                format!("{prefix}{keyword}")
171            };
172            let mut count = 0usize;
173            for (key, value) in self.btree.prefix(root, &search_prefix).await? {
174                if count >= limit.saturating_mul(2) {
175                    break;
176                }
177                count += 1;
178
179                let Some((term, id)) = decode_search_key(prefix, &key) else {
180                    continue;
181                };
182                let aggregate = results.entry(id).or_insert_with(|| Aggregate {
183                    value,
184                    score: 0,
185                    exact_matches: 0,
186                    prefix_distance: 0,
187                });
188                aggregate.score += 1;
189                if term == keyword {
190                    aggregate.exact_matches += 1;
191                }
192                aggregate.prefix_distance += term.len().saturating_sub(keyword.len());
193            }
194        }
195
196        let mut sorted = results.into_iter().collect::<Vec<_>>();
197        sorted.sort_by(|left, right| {
198            let left_data = &left.1;
199            let right_data = &right.1;
200            right_data
201                .score
202                .cmp(&left_data.score)
203                .then(right_data.exact_matches.cmp(&left_data.exact_matches))
204                .then(left_data.prefix_distance.cmp(&right_data.prefix_distance))
205                .then(left.0.cmp(&right.0))
206        });
207        sorted.truncate(limit);
208        Ok(sorted
209            .into_iter()
210            .map(|(id, aggregate)| SearchResult {
211                id,
212                value: aggregate.value,
213                score: aggregate.score,
214            })
215            .collect())
216    }
217
218    pub async fn merge(
219        &self,
220        base: Option<&Cid>,
221        other: Option<&Cid>,
222        prefer_other: bool,
223    ) -> Result<Option<Cid>, SearchError> {
224        Ok(self.btree.merge(base, other, prefer_other).await?)
225    }
226
227    pub async fn index_link(
228        &self,
229        root: Option<&Cid>,
230        prefix: &str,
231        terms: &[String],
232        id: &str,
233        target_cid: &Cid,
234    ) -> Result<Cid, SearchError> {
235        let mut new_root = root.cloned();
236        for term in terms {
237            new_root = Some(
238                self.btree
239                    .insert_link(
240                        new_root.as_ref(),
241                        &format!("{prefix}{term}:{id}"),
242                        target_cid,
243                    )
244                    .await?,
245            );
246        }
247
248        new_root.ok_or_else(|| {
249            SearchError::Validation("search index requires at least one term".to_string())
250        })
251    }
252
253    pub async fn remove_link(
254        &self,
255        root: &Cid,
256        prefix: &str,
257        terms: &[String],
258        id: &str,
259    ) -> Result<Option<Cid>, SearchError> {
260        let mut new_root = Some(root.clone());
261        for term in terms {
262            let Some(active_root) = new_root.as_ref() else {
263                break;
264            };
265            new_root = self
266                .btree
267                .delete(active_root, &format!("{prefix}{term}:{id}"))
268                .await?;
269        }
270        Ok(new_root)
271    }
272
273    pub async fn search_links(
274        &self,
275        root: Option<&Cid>,
276        prefix: &str,
277        query: &str,
278        options: SearchOptions,
279    ) -> Result<Vec<SearchLinkResult>, SearchError> {
280        let Some(root) = root else {
281            return Ok(Vec::new());
282        };
283
284        let limit = options.limit.unwrap_or(20);
285        if limit == 0 {
286            return Ok(Vec::new());
287        }
288
289        let keywords = self.parse_keywords(query);
290        if keywords.is_empty() {
291            return Ok(Vec::new());
292        }
293
294        #[derive(Debug)]
295        struct Aggregate {
296            cid: Cid,
297            score: usize,
298            exact_matches: usize,
299            prefix_distance: usize,
300        }
301
302        let mut results = HashMap::<String, Aggregate>::new();
303        for keyword in keywords {
304            let search_prefix = if options.full_match {
305                format!("{prefix}{keyword}:")
306            } else {
307                format!("{prefix}{keyword}")
308            };
309            let mut count = 0usize;
310            for (key, cid) in self.btree.prefix_links(root, &search_prefix).await? {
311                if count >= limit.saturating_mul(2) {
312                    break;
313                }
314                count += 1;
315
316                let Some((term, id)) = decode_search_key(prefix, &key) else {
317                    continue;
318                };
319                let aggregate = results.entry(id).or_insert_with(|| Aggregate {
320                    cid,
321                    score: 0,
322                    exact_matches: 0,
323                    prefix_distance: 0,
324                });
325                aggregate.score += 1;
326                if term == keyword {
327                    aggregate.exact_matches += 1;
328                }
329                aggregate.prefix_distance += term.len().saturating_sub(keyword.len());
330            }
331        }
332
333        let mut sorted = results.into_iter().collect::<Vec<_>>();
334        sorted.sort_by(|left, right| {
335            let left_data = &left.1;
336            let right_data = &right.1;
337            right_data
338                .score
339                .cmp(&left_data.score)
340                .then(right_data.exact_matches.cmp(&left_data.exact_matches))
341                .then(left_data.prefix_distance.cmp(&right_data.prefix_distance))
342                .then(left.0.cmp(&right.0))
343        });
344        sorted.truncate(limit);
345        Ok(sorted
346            .into_iter()
347            .map(|(id, aggregate)| SearchLinkResult {
348                id,
349                cid: aggregate.cid,
350                score: aggregate.score,
351            })
352            .collect())
353    }
354
355    pub async fn merge_links(
356        &self,
357        base: Option<&Cid>,
358        other: Option<&Cid>,
359        prefer_other: bool,
360    ) -> Result<Option<Cid>, SearchError> {
361        Ok(self.btree.merge_links(base, other, prefer_other).await?)
362    }
363}
364
365fn decode_search_key(prefix: &str, key: &str) -> Option<(String, String)> {
366    if !key.starts_with(prefix) {
367        return None;
368    }
369    let after_prefix = &key[prefix.len()..];
370    let colon_index = after_prefix.find(':')?;
371    Some((
372        after_prefix[..colon_index].to_string(),
373        after_prefix[colon_index + 1..].to_string(),
374    ))
375}
376
377fn expand_keyword_variants(raw_word: &str) -> Vec<String> {
378    let mut variants = Vec::new();
379    let normalized = raw_word.to_lowercase();
380    if !normalized.is_empty() {
381        variants.push(normalized);
382    }
383
384    for segment in split_keyword_segments(raw_word) {
385        let normalized_segment = segment.to_lowercase();
386        if normalized_segment.is_empty()
387            || variants
388                .iter()
389                .any(|existing| existing == &normalized_segment)
390        {
391            continue;
392        }
393        variants.push(normalized_segment);
394    }
395
396    variants
397}
398
399fn split_keyword_segments(raw_word: &str) -> Vec<String> {
400    let chars = raw_word.chars().collect::<Vec<_>>();
401    if chars.is_empty() {
402        return Vec::new();
403    }
404
405    let mut parts = Vec::new();
406    let mut start = 0usize;
407    for index in 1..chars.len() {
408        let previous = chars[index - 1];
409        let current = chars[index];
410        let next = chars.get(index + 1).copied();
411        if is_keyword_boundary(previous, current, next) {
412            parts.push(chars[start..index].iter().collect::<String>());
413            start = index;
414        }
415    }
416    parts.push(chars[start..].iter().collect::<String>());
417    parts
418}
419
420fn is_keyword_boundary(previous: char, current: char, next: Option<char>) -> bool {
421    (previous.is_lowercase() && current.is_uppercase())
422        || (previous.is_alphabetic() && current.is_numeric())
423        || (previous.is_numeric() && current.is_alphabetic())
424        || (previous.is_uppercase()
425            && current.is_uppercase()
426            && next.is_some_and(|next| next.is_lowercase()))
427}
428
429fn is_pure_number(word: &str) -> bool {
430    if !word.bytes().all(|byte| byte.is_ascii_digit()) {
431        return false;
432    }
433    !(word.len() == 4 && (word.starts_with("19") || word.starts_with("20")))
434}
435
436fn default_stop_words() -> HashSet<String> {
437    [
438        "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
439        "from", "is", "it", "as", "be", "was", "are", "this", "that", "these", "those", "i", "you",
440        "he", "she", "we", "they", "my", "your", "his", "her", "its", "our", "their", "what",
441        "which", "who", "whom", "how", "when", "where", "why", "will", "would", "could", "should",
442        "can", "may", "might", "must", "have", "has", "had", "do", "does", "did", "been", "being",
443        "get", "got", "just", "now", "then", "so", "if", "not", "no", "yes", "all", "any", "some",
444        "more", "most", "other", "into", "over", "after", "before", "about", "up", "down", "out",
445        "off", "through", "during", "under", "again", "further", "once",
446    ]
447    .into_iter()
448    .map(str::to_string)
449    .collect()
450}