Skip to main content

lash_core/
search.rs

1use std::collections::{HashMap, HashSet};
2
3use regex::Regex;
4
5const DEFAULT_LIMIT: usize = 10;
6const MAX_LIMIT: usize = 100;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub enum SearchMode {
10    Hybrid,
11    Literal,
12    Regex,
13}
14
15impl SearchMode {
16    pub fn parse(value: Option<&str>) -> Self {
17        match value
18            .unwrap_or("hybrid")
19            .trim()
20            .to_ascii_lowercase()
21            .as_str()
22        {
23            "literal" => Self::Literal,
24            "regex" => Self::Regex,
25            _ => Self::Hybrid,
26        }
27    }
28}
29
30#[derive(Clone)]
31pub struct SearchDoc {
32    pub fields: HashMap<&'static str, String>,
33}
34
35pub fn limit_from_args(args: &serde_json::Value) -> usize {
36    args.get("limit")
37        .and_then(|v| v.as_i64())
38        .and_then(|n| usize::try_from(n).ok())
39        .map(|n| n.clamp(1, MAX_LIMIT))
40        .unwrap_or(DEFAULT_LIMIT)
41}
42
43fn compile_regex(pattern: Option<&str>) -> Option<Regex> {
44    let raw = pattern.unwrap_or_default();
45    if raw.is_empty() {
46        return None;
47    }
48    match Regex::new(&format!("(?i){raw}")) {
49        Ok(re) => Some(re),
50        Err(_) => {
51            let escaped = regex::escape(raw);
52            Regex::new(&format!("(?i){escaped}")).ok()
53        }
54    }
55}
56
57fn tokenize(text: &str) -> Vec<String> {
58    text.split(|c: char| !(c.is_ascii_alphanumeric() || c == '_'))
59        .filter(|t| !t.is_empty())
60        .map(|t| t.to_ascii_lowercase())
61        .collect()
62}
63
64fn hybrid_token_match(query_token: &str, candidate_token: &str) -> bool {
65    query_token == candidate_token
66        || (query_token.len() >= 3
67            && candidate_token.len() >= 3
68            && (candidate_token.contains(query_token) || query_token.contains(candidate_token)))
69}
70
71fn hybrid_text_match(query_tokens: &[String], text: &str) -> bool {
72    if query_tokens.is_empty() {
73        return false;
74    }
75    let tokens = tokenize(text);
76    query_tokens.iter().any(|query| {
77        tokens
78            .iter()
79            .any(|candidate| hybrid_token_match(query, candidate))
80    })
81}
82
83fn hybrid_fallback_score(
84    query_tokens: &[String],
85    doc: &SearchDoc,
86    field_weights: &[(&'static str, f64)],
87) -> f64 {
88    let mut score = 0.0_f64;
89    for (field, weight) in field_weights {
90        if *weight <= 0.0 {
91            continue;
92        }
93        let text = doc
94            .fields
95            .get(field)
96            .map(String::as_str)
97            .unwrap_or_default();
98        if text.is_empty() {
99            continue;
100        }
101        let tokens = tokenize(text);
102        if tokens.is_empty() {
103            continue;
104        }
105        let hits = query_tokens
106            .iter()
107            .filter(|query| {
108                tokens
109                    .iter()
110                    .any(|candidate| hybrid_token_match(query, candidate))
111            })
112            .count() as f64;
113        if hits > 0.0 {
114            score += hits * *weight * 0.25;
115        }
116    }
117    score
118}
119
120fn bm25_scores(
121    query_tokens: &[String],
122    docs: &[SearchDoc],
123    field_weights: &[(&'static str, f64)],
124) -> Vec<f64> {
125    let n_docs = docs.len();
126    if n_docs == 0 {
127        return Vec::new();
128    }
129
130    let mut doc_tfs: Vec<HashMap<String, f64>> = Vec::with_capacity(n_docs);
131    let mut doc_lens: Vec<f64> = Vec::with_capacity(n_docs);
132    let mut doc_freq: HashMap<String, usize> = HashMap::new();
133
134    for doc in docs {
135        let mut tf: HashMap<String, f64> = HashMap::new();
136        let mut dlen = 0.0_f64;
137        for (field, weight) in field_weights {
138            if *weight <= 0.0 {
139                continue;
140            }
141            let text = doc
142                .fields
143                .get(field)
144                .map(String::as_str)
145                .unwrap_or_default();
146            let tokens = tokenize(text);
147            if tokens.is_empty() {
148                continue;
149            }
150            let mut counts: HashMap<String, usize> = HashMap::new();
151            for tok in &tokens {
152                *counts.entry(tok.clone()).or_insert(0) += 1;
153            }
154            for (tok, count) in counts {
155                *tf.entry(tok).or_insert(0.0) += (count as f64) * *weight;
156            }
157            dlen += (tokens.len() as f64) * *weight;
158        }
159        for tok in tf.keys() {
160            *doc_freq.entry(tok.clone()).or_insert(0) += 1;
161        }
162        doc_tfs.push(tf);
163        doc_lens.push(dlen);
164    }
165
166    let avgdl = {
167        let sum: f64 = doc_lens.iter().sum();
168        let avg = sum / (n_docs as f64);
169        if avg <= 0.0 { 1.0 } else { avg }
170    };
171
172    let mut qtf: HashMap<String, usize> = HashMap::new();
173    for tok in query_tokens {
174        *qtf.entry(tok.clone()).or_insert(0) += 1;
175    }
176
177    let k1 = 1.5_f64;
178    let b = 0.75_f64;
179    let mut scores = vec![0.0_f64; n_docs];
180    for (i, tf) in doc_tfs.iter().enumerate() {
181        let dl = doc_lens[i];
182        let norm = 1.0 - b + b * (dl / avgdl);
183        for (tok, qcount) in &qtf {
184            let freq = *tf.get(tok).unwrap_or(&0.0);
185            if freq <= 0.0 {
186                continue;
187            }
188            let df = *doc_freq.get(tok).unwrap_or(&0) as f64;
189            let idf = ((n_docs as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
190            let denom = freq + k1 * norm;
191            if denom <= 0.0 {
192                continue;
193            }
194            let term = idf * ((freq * (k1 + 1.0)) / denom);
195            scores[i] += term * (1.0 + (*qcount as f64).ln());
196        }
197    }
198
199    scores
200}
201
202fn field_hits(
203    fields: &HashMap<&'static str, String>,
204    query: &str,
205    mode: SearchMode,
206    regex_filter: Option<&Regex>,
207) -> Vec<String> {
208    let query_lower = query.to_ascii_lowercase();
209    let query_tokens: HashSet<String> = tokenize(query).into_iter().collect();
210    let mut hits = Vec::new();
211    for (field, value) in fields {
212        if value.is_empty() {
213            continue;
214        }
215        let value_lower = value.to_ascii_lowercase();
216        let mut hit = match mode {
217            SearchMode::Regex => regex_filter.is_some_and(|re| re.is_match(value)),
218            SearchMode::Literal => !query.is_empty() && value_lower.contains(&query_lower),
219            SearchMode::Hybrid => {
220                if !query_tokens.is_empty() {
221                    let tokens = tokenize(value);
222                    query_tokens.iter().any(|query| {
223                        tokens
224                            .iter()
225                            .any(|candidate| hybrid_token_match(query, candidate))
226                    })
227                } else if query.is_empty() {
228                    true
229                } else {
230                    value_lower.contains(&query_lower)
231                }
232            }
233        };
234        if hit && regex_filter.is_some() && mode != SearchMode::Regex {
235            hit = regex_filter.is_some_and(|re| re.is_match(value));
236        }
237        if hit {
238            hits.push((*field).to_string());
239        }
240    }
241    hits
242}
243
244pub fn rank_docs(
245    docs: &[SearchDoc],
246    query: &str,
247    mode: SearchMode,
248    regex: Option<&str>,
249    field_weights: &[(&'static str, f64)],
250) -> Vec<(usize, f64, Vec<String>)> {
251    let query_tokens = tokenize(query);
252    let query_lower = query.to_ascii_lowercase();
253    let mut scores = vec![0.0_f64; docs.len()];
254    if mode == SearchMode::Hybrid && !query_tokens.is_empty() {
255        scores = bm25_scores(&query_tokens, docs, field_weights);
256        for (idx, score) in scores.iter_mut().enumerate() {
257            if *score <= 0.0 {
258                *score = hybrid_fallback_score(&query_tokens, &docs[idx], field_weights);
259            }
260        }
261    }
262    let regex_filter = match mode {
263        SearchMode::Regex => compile_regex(regex.or(Some(query))),
264        _ => compile_regex(regex),
265    };
266
267    let mut indices: Vec<usize> = (0..docs.len()).collect();
268    if mode == SearchMode::Hybrid {
269        indices.sort_by(|a, b| {
270            scores[*b]
271                .partial_cmp(&scores[*a])
272                .unwrap_or(std::cmp::Ordering::Equal)
273                .then_with(|| a.cmp(b))
274        });
275    }
276
277    let mut ranked = Vec::new();
278    for idx in indices {
279        let haystack = docs[idx]
280            .fields
281            .values()
282            .filter(|value| !value.is_empty())
283            .cloned()
284            .collect::<Vec<_>>()
285            .join("\n");
286        let haystack_lower = haystack.to_ascii_lowercase();
287
288        let mut include = match mode {
289            SearchMode::Regex => regex_filter
290                .as_ref()
291                .is_some_and(|re| re.is_match(&haystack)),
292            SearchMode::Literal => !query.is_empty() && haystack_lower.contains(&query_lower),
293            SearchMode::Hybrid => {
294                if query.is_empty() {
295                    true
296                } else if !query_tokens.is_empty() {
297                    scores[idx] > 0.0
298                        || haystack_lower.contains(&query_lower)
299                        || hybrid_text_match(&query_tokens, &haystack)
300                } else {
301                    haystack_lower.contains(&query_lower)
302                }
303            }
304        };
305        if include && regex_filter.is_some() && mode != SearchMode::Regex {
306            include = regex_filter
307                .as_ref()
308                .is_some_and(|re| re.is_match(&haystack));
309        }
310        if !include {
311            continue;
312        }
313
314        let hits = field_hits(&docs[idx].fields, query, mode, regex_filter.as_ref());
315        if hits.is_empty() && !(query.is_empty() && mode != SearchMode::Regex) {
316            continue;
317        }
318        ranked.push((idx, scores[idx], hits));
319    }
320    ranked
321}