1use std::collections::{HashMap, HashSet};
2
3use regex::Regex;
4
5const DEFAULT_LIMIT: usize = 10;
6const MAX_LIMIT: usize = 100;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub enum SearchMode {
10 Hybrid,
11 Literal,
12 Regex,
13}
14
15impl SearchMode {
16 pub fn parse(value: Option<&str>) -> Self {
17 match value
18 .unwrap_or("hybrid")
19 .trim()
20 .to_ascii_lowercase()
21 .as_str()
22 {
23 "literal" => Self::Literal,
24 "regex" => Self::Regex,
25 _ => Self::Hybrid,
26 }
27 }
28}
29
30#[derive(Clone)]
31pub struct SearchDoc {
32 pub fields: HashMap<&'static str, String>,
33}
34
35pub fn limit_from_args(args: &serde_json::Value) -> usize {
36 args.get("limit")
37 .and_then(|v| v.as_i64())
38 .and_then(|n| usize::try_from(n).ok())
39 .map(|n| n.clamp(1, MAX_LIMIT))
40 .unwrap_or(DEFAULT_LIMIT)
41}
42
43fn compile_regex(pattern: Option<&str>) -> Option<Regex> {
44 let raw = pattern.unwrap_or_default();
45 if raw.is_empty() {
46 return None;
47 }
48 match Regex::new(&format!("(?i){raw}")) {
49 Ok(re) => Some(re),
50 Err(_) => {
51 let escaped = regex::escape(raw);
52 Regex::new(&format!("(?i){escaped}")).ok()
53 }
54 }
55}
56
57fn tokenize(text: &str) -> Vec<String> {
58 text.split(|c: char| !(c.is_ascii_alphanumeric() || c == '_'))
59 .filter(|t| !t.is_empty())
60 .map(|t| t.to_ascii_lowercase())
61 .collect()
62}
63
64fn hybrid_token_match(query_token: &str, candidate_token: &str) -> bool {
65 query_token == candidate_token
66 || (query_token.len() >= 3
67 && candidate_token.len() >= 3
68 && (candidate_token.contains(query_token) || query_token.contains(candidate_token)))
69}
70
71fn hybrid_text_match(query_tokens: &[String], text: &str) -> bool {
72 if query_tokens.is_empty() {
73 return false;
74 }
75 let tokens = tokenize(text);
76 query_tokens.iter().any(|query| {
77 tokens
78 .iter()
79 .any(|candidate| hybrid_token_match(query, candidate))
80 })
81}
82
83fn hybrid_fallback_score(
84 query_tokens: &[String],
85 doc: &SearchDoc,
86 field_weights: &[(&'static str, f64)],
87) -> f64 {
88 let mut score = 0.0_f64;
89 for (field, weight) in field_weights {
90 if *weight <= 0.0 {
91 continue;
92 }
93 let text = doc
94 .fields
95 .get(field)
96 .map(String::as_str)
97 .unwrap_or_default();
98 if text.is_empty() {
99 continue;
100 }
101 let tokens = tokenize(text);
102 if tokens.is_empty() {
103 continue;
104 }
105 let hits = query_tokens
106 .iter()
107 .filter(|query| {
108 tokens
109 .iter()
110 .any(|candidate| hybrid_token_match(query, candidate))
111 })
112 .count() as f64;
113 if hits > 0.0 {
114 score += hits * *weight * 0.25;
115 }
116 }
117 score
118}
119
120fn bm25_scores(
121 query_tokens: &[String],
122 docs: &[SearchDoc],
123 field_weights: &[(&'static str, f64)],
124) -> Vec<f64> {
125 let n_docs = docs.len();
126 if n_docs == 0 {
127 return Vec::new();
128 }
129
130 let mut doc_tfs: Vec<HashMap<String, f64>> = Vec::with_capacity(n_docs);
131 let mut doc_lens: Vec<f64> = Vec::with_capacity(n_docs);
132 let mut doc_freq: HashMap<String, usize> = HashMap::new();
133
134 for doc in docs {
135 let mut tf: HashMap<String, f64> = HashMap::new();
136 let mut dlen = 0.0_f64;
137 for (field, weight) in field_weights {
138 if *weight <= 0.0 {
139 continue;
140 }
141 let text = doc
142 .fields
143 .get(field)
144 .map(String::as_str)
145 .unwrap_or_default();
146 let tokens = tokenize(text);
147 if tokens.is_empty() {
148 continue;
149 }
150 let mut counts: HashMap<String, usize> = HashMap::new();
151 for tok in &tokens {
152 *counts.entry(tok.clone()).or_insert(0) += 1;
153 }
154 for (tok, count) in counts {
155 *tf.entry(tok).or_insert(0.0) += (count as f64) * *weight;
156 }
157 dlen += (tokens.len() as f64) * *weight;
158 }
159 for tok in tf.keys() {
160 *doc_freq.entry(tok.clone()).or_insert(0) += 1;
161 }
162 doc_tfs.push(tf);
163 doc_lens.push(dlen);
164 }
165
166 let avgdl = {
167 let sum: f64 = doc_lens.iter().sum();
168 let avg = sum / (n_docs as f64);
169 if avg <= 0.0 { 1.0 } else { avg }
170 };
171
172 let mut qtf: HashMap<String, usize> = HashMap::new();
173 for tok in query_tokens {
174 *qtf.entry(tok.clone()).or_insert(0) += 1;
175 }
176
177 let k1 = 1.5_f64;
178 let b = 0.75_f64;
179 let mut scores = vec![0.0_f64; n_docs];
180 for (i, tf) in doc_tfs.iter().enumerate() {
181 let dl = doc_lens[i];
182 let norm = 1.0 - b + b * (dl / avgdl);
183 for (tok, qcount) in &qtf {
184 let freq = *tf.get(tok).unwrap_or(&0.0);
185 if freq <= 0.0 {
186 continue;
187 }
188 let df = *doc_freq.get(tok).unwrap_or(&0) as f64;
189 let idf = ((n_docs as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
190 let denom = freq + k1 * norm;
191 if denom <= 0.0 {
192 continue;
193 }
194 let term = idf * ((freq * (k1 + 1.0)) / denom);
195 scores[i] += term * (1.0 + (*qcount as f64).ln());
196 }
197 }
198
199 scores
200}
201
202fn field_hits(
203 fields: &HashMap<&'static str, String>,
204 query: &str,
205 mode: SearchMode,
206 regex_filter: Option<&Regex>,
207) -> Vec<String> {
208 let query_lower = query.to_ascii_lowercase();
209 let query_tokens: HashSet<String> = tokenize(query).into_iter().collect();
210 let mut hits = Vec::new();
211 for (field, value) in fields {
212 if value.is_empty() {
213 continue;
214 }
215 let value_lower = value.to_ascii_lowercase();
216 let mut hit = match mode {
217 SearchMode::Regex => regex_filter.is_some_and(|re| re.is_match(value)),
218 SearchMode::Literal => !query.is_empty() && value_lower.contains(&query_lower),
219 SearchMode::Hybrid => {
220 if !query_tokens.is_empty() {
221 let tokens = tokenize(value);
222 query_tokens.iter().any(|query| {
223 tokens
224 .iter()
225 .any(|candidate| hybrid_token_match(query, candidate))
226 })
227 } else if query.is_empty() {
228 true
229 } else {
230 value_lower.contains(&query_lower)
231 }
232 }
233 };
234 if hit && regex_filter.is_some() && mode != SearchMode::Regex {
235 hit = regex_filter.is_some_and(|re| re.is_match(value));
236 }
237 if hit {
238 hits.push((*field).to_string());
239 }
240 }
241 hits
242}
243
244pub fn rank_docs(
245 docs: &[SearchDoc],
246 query: &str,
247 mode: SearchMode,
248 regex: Option<&str>,
249 field_weights: &[(&'static str, f64)],
250) -> Vec<(usize, f64, Vec<String>)> {
251 let query_tokens = tokenize(query);
252 let query_lower = query.to_ascii_lowercase();
253 let mut scores = vec![0.0_f64; docs.len()];
254 if mode == SearchMode::Hybrid && !query_tokens.is_empty() {
255 scores = bm25_scores(&query_tokens, docs, field_weights);
256 for (idx, score) in scores.iter_mut().enumerate() {
257 if *score <= 0.0 {
258 *score = hybrid_fallback_score(&query_tokens, &docs[idx], field_weights);
259 }
260 }
261 }
262 let regex_filter = match mode {
263 SearchMode::Regex => compile_regex(regex.or(Some(query))),
264 _ => compile_regex(regex),
265 };
266
267 let mut indices: Vec<usize> = (0..docs.len()).collect();
268 if mode == SearchMode::Hybrid {
269 indices.sort_by(|a, b| {
270 scores[*b]
271 .partial_cmp(&scores[*a])
272 .unwrap_or(std::cmp::Ordering::Equal)
273 .then_with(|| a.cmp(b))
274 });
275 }
276
277 let mut ranked = Vec::new();
278 for idx in indices {
279 let haystack = docs[idx]
280 .fields
281 .values()
282 .filter(|value| !value.is_empty())
283 .cloned()
284 .collect::<Vec<_>>()
285 .join("\n");
286 let haystack_lower = haystack.to_ascii_lowercase();
287
288 let mut include = match mode {
289 SearchMode::Regex => regex_filter
290 .as_ref()
291 .is_some_and(|re| re.is_match(&haystack)),
292 SearchMode::Literal => !query.is_empty() && haystack_lower.contains(&query_lower),
293 SearchMode::Hybrid => {
294 if query.is_empty() {
295 true
296 } else if !query_tokens.is_empty() {
297 scores[idx] > 0.0
298 || haystack_lower.contains(&query_lower)
299 || hybrid_text_match(&query_tokens, &haystack)
300 } else {
301 haystack_lower.contains(&query_lower)
302 }
303 }
304 };
305 if include && regex_filter.is_some() && mode != SearchMode::Regex {
306 include = regex_filter
307 .as_ref()
308 .is_some_and(|re| re.is_match(&haystack));
309 }
310 if !include {
311 continue;
312 }
313
314 let hits = field_hits(&docs[idx].fields, query, mode, regex_filter.as_ref());
315 if hits.is_empty() && !(query.is_empty() && mode != SearchMode::Regex) {
316 continue;
317 }
318 ranked.push((idx, scores[idx], hits));
319 }
320 ranked
321}