1use std::collections::HashMap;
11
12use crate::fulltext::FullTextResult;
13
14type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
16
17#[derive(Debug, Clone)]
19pub struct HybridConfig {
20 pub vector_weight: f32,
22 pub require_both: bool,
24}
25
26impl Default for HybridConfig {
27 fn default() -> Self {
28 Self {
29 vector_weight: 0.5,
30 require_both: false,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37struct RawScore {
38 score: f32,
40 metadata: Option<serde_json::Value>,
42 vector: Option<Vec<f32>>,
43}
44
45#[derive(Debug, Clone)]
47pub struct HybridResult {
48 pub id: String,
50 pub combined_score: f32,
52 pub vector_score: f32,
54 pub text_score: f32,
56 pub metadata: Option<serde_json::Value>,
58 pub vector: Option<Vec<f32>>,
60}
61
62pub struct HybridSearcher {
64 config: HybridConfig,
65}
66
67impl HybridSearcher {
68 pub fn new(config: HybridConfig) -> Self {
69 Self { config }
70 }
71
72 pub fn with_vector_weight(mut self, weight: f32) -> Self {
73 self.config.vector_weight = weight.clamp(0.0, 1.0);
74 self
75 }
76
77 pub fn search(
87 &self,
88 vector_results: Vec<VectorResultRow>,
89 text_results: Vec<FullTextResult>,
90 top_k: usize,
91 ) -> Vec<HybridResult> {
92 let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
94 let mut text_scores: HashMap<String, f32> = HashMap::new();
95
96 let mut vector_min = f32::MAX;
98 let mut vector_max = f32::MIN;
99 let mut text_min = f32::MAX;
100 let mut text_max = f32::MIN;
101
102 for (id, score, metadata, vector) in vector_results {
104 vector_min = vector_min.min(score);
105 vector_max = vector_max.max(score);
106 vector_scores.insert(
107 id,
108 RawScore {
109 score,
110 metadata,
111 vector,
112 },
113 );
114 }
115
116 for result in text_results {
118 text_min = text_min.min(result.score);
119 text_max = text_max.max(result.score);
120 text_scores.insert(result.doc_id, result.score);
121 }
122
123 let mut all_ids: Vec<String> = vector_scores
125 .keys()
126 .chain(text_scores.keys())
127 .cloned()
128 .collect();
129 all_ids.sort();
130 all_ids.dedup();
131
132 let mut results: Vec<HybridResult> = Vec::new();
134
135 for id in all_ids {
136 let vector_raw = vector_scores.get(&id);
137 let text_raw = text_scores.get(&id);
138
139 if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
141 continue;
142 }
143
144 let vector_normalized = if let Some(raw) = vector_raw {
146 normalize_score(raw.score, vector_min, vector_max)
147 } else {
148 0.0
149 };
150
151 let text_normalized = if let Some(&score) = text_raw {
152 normalize_score(score, text_min, text_max)
153 } else {
154 0.0
155 };
156
157 let combined = self.config.vector_weight * vector_normalized
159 + (1.0 - self.config.vector_weight) * text_normalized;
160
161 let (metadata, vector) = if let Some(raw) = vector_raw {
163 (raw.metadata.clone(), raw.vector.clone())
164 } else {
165 (None, None)
166 };
167
168 results.push(HybridResult {
169 id,
170 combined_score: combined,
171 vector_score: vector_normalized,
172 text_score: text_normalized,
173 metadata,
174 vector,
175 });
176 }
177
178 results.sort_by(|a, b| {
180 b.combined_score
181 .partial_cmp(&a.combined_score)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184
185 results.truncate(top_k);
187 results
188 }
189}
190
191impl Default for HybridSearcher {
192 fn default() -> Self {
193 Self::new(HybridConfig::default())
194 }
195}
196
197fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
199 if (max - min).abs() < f32::EPSILON {
200 1.0
202 } else {
203 (score - min) / (max - min)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_hybrid_search_basic() {
213 let searcher = HybridSearcher::default();
214
215 let vector_results = vec![
216 ("doc1".to_string(), 0.9, None, None),
217 ("doc2".to_string(), 0.7, None, None),
218 ("doc3".to_string(), 0.5, None, None),
219 ];
220
221 let text_results = vec![
222 FullTextResult {
223 doc_id: "doc1".to_string(),
224 score: 3.0,
225 },
226 FullTextResult {
227 doc_id: "doc2".to_string(),
228 score: 4.0,
229 },
230 FullTextResult {
231 doc_id: "doc4".to_string(),
232 score: 2.0,
233 },
234 ];
235
236 let results = searcher.search(vector_results, text_results, 10);
237
238 assert_eq!(results.len(), 4);
240
241 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
244 assert!(doc1.vector_score > 0.0);
245 assert!(doc1.text_score >= 0.0);
246 assert!(doc1.combined_score > 0.0);
247
248 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
249 assert!(doc2.vector_score > 0.0);
250 assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
252
253 assert_eq!(doc2.text_score, 1.0);
255 }
256
257 #[test]
258 fn test_hybrid_search_vector_only() {
259 let searcher = HybridSearcher::new(HybridConfig {
260 vector_weight: 1.0,
261 require_both: false,
262 });
263
264 let vector_results = vec![
265 ("doc1".to_string(), 0.9, None, None),
266 ("doc2".to_string(), 0.5, None, None),
267 ];
268
269 let text_results = vec![FullTextResult {
270 doc_id: "doc1".to_string(),
271 score: 1.0,
272 }];
273
274 let results = searcher.search(vector_results, text_results, 10);
275
276 assert_eq!(results[0].id, "doc1");
278 assert_eq!(results[0].combined_score, results[0].vector_score);
279 }
280
281 #[test]
282 fn test_hybrid_search_text_only() {
283 let searcher = HybridSearcher::new(HybridConfig {
284 vector_weight: 0.0,
285 require_both: false,
286 });
287
288 let vector_results = vec![
289 ("doc1".to_string(), 0.9, None, None),
290 ("doc2".to_string(), 0.5, None, None),
291 ];
292
293 let text_results = vec![
294 FullTextResult {
295 doc_id: "doc1".to_string(),
296 score: 1.0,
297 },
298 FullTextResult {
299 doc_id: "doc2".to_string(),
300 score: 3.0,
301 },
302 ];
303
304 let results = searcher.search(vector_results, text_results, 10);
305
306 assert_eq!(results[0].id, "doc2");
308 assert_eq!(results[0].combined_score, results[0].text_score);
309 }
310
311 #[test]
312 fn test_hybrid_search_require_both() {
313 let searcher = HybridSearcher::new(HybridConfig {
314 vector_weight: 0.5,
315 require_both: true,
316 });
317
318 let vector_results = vec![
319 ("doc1".to_string(), 0.9, None, None),
320 ("doc2".to_string(), 0.7, None, None),
321 ];
322
323 let text_results = vec![FullTextResult {
324 doc_id: "doc1".to_string(),
325 score: 2.0,
326 }];
327
328 let results = searcher.search(vector_results, text_results, 10);
329
330 assert_eq!(results.len(), 1);
332 assert_eq!(results[0].id, "doc1");
333 }
334
335 #[test]
336 fn test_hybrid_search_top_k() {
337 let searcher = HybridSearcher::default();
338
339 let vector_results = vec![
340 ("doc1".to_string(), 0.9, None, None),
341 ("doc2".to_string(), 0.8, None, None),
342 ("doc3".to_string(), 0.7, None, None),
343 ("doc4".to_string(), 0.6, None, None),
344 ("doc5".to_string(), 0.5, None, None),
345 ];
346
347 let text_results = vec![];
348
349 let results = searcher.search(vector_results, text_results, 3);
350
351 assert_eq!(results.len(), 3);
352 }
353
354 #[test]
355 fn test_hybrid_search_with_metadata() {
356 let searcher = HybridSearcher::default();
357
358 let metadata = serde_json::json!({"title": "Test Document"});
359 let vector = vec![1.0, 0.0, 0.0];
360
361 let vector_results = vec![(
362 "doc1".to_string(),
363 0.9,
364 Some(metadata.clone()),
365 Some(vector.clone()),
366 )];
367
368 let text_results = vec![FullTextResult {
369 doc_id: "doc1".to_string(),
370 score: 2.0,
371 }];
372
373 let results = searcher.search(vector_results, text_results, 10);
374
375 assert_eq!(results.len(), 1);
376 assert_eq!(results[0].metadata, Some(metadata));
377 assert_eq!(results[0].vector, Some(vector));
378 }
379
380 #[test]
381 fn test_normalize_score() {
382 assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
384 assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
385 assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
386
387 assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
389 }
390
391 #[test]
392 fn test_hybrid_searcher_builder() {
393 let searcher = HybridSearcher::default().with_vector_weight(0.7);
394
395 assert_eq!(searcher.config.vector_weight, 0.7);
396 }
397
398 #[test]
399 fn test_vector_weight_clamping() {
400 let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
401 assert_eq!(searcher1.config.vector_weight, 1.0);
402
403 let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
404 assert_eq!(searcher2.config.vector_weight, 0.0);
405 }
406}