1use std::collections::HashMap;
11
12use crate::fulltext::FullTextResult;
13
14type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
16
17#[derive(Debug, Clone)]
19pub struct HybridConfig {
20 pub vector_weight: f32,
22 pub require_both: bool,
24}
25
26impl Default for HybridConfig {
27 fn default() -> Self {
28 Self {
29 vector_weight: 0.5,
30 require_both: false,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37struct RawScore {
38 score: f32,
40 metadata: Option<serde_json::Value>,
42 vector: Option<Vec<f32>>,
43}
44
45#[derive(Debug, Clone)]
47pub struct HybridResult {
48 pub id: String,
50 pub combined_score: f32,
52 pub vector_score: f32,
54 pub text_score: f32,
56 pub metadata: Option<serde_json::Value>,
58 pub vector: Option<Vec<f32>>,
60}
61
62pub struct HybridSearcher {
64 config: HybridConfig,
65}
66
67impl HybridSearcher {
68 pub fn new(config: HybridConfig) -> Self {
69 Self { config }
70 }
71
72 pub fn with_vector_weight(mut self, weight: f32) -> Self {
73 self.config.vector_weight = weight.clamp(0.0, 1.0);
74 self
75 }
76
77 pub fn search(
87 &self,
88 vector_results: Vec<VectorResultRow>,
89 text_results: Vec<FullTextResult>,
90 top_k: usize,
91 ) -> Vec<HybridResult> {
92 let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
94 let mut text_scores: HashMap<String, f32> = HashMap::new();
95
96 let mut vector_min = f32::MAX;
98 let mut vector_max = f32::MIN;
99 let mut text_min = f32::MAX;
100 let mut text_max = f32::MIN;
101
102 for (id, score, metadata, vector) in vector_results {
104 vector_min = vector_min.min(score);
105 vector_max = vector_max.max(score);
106 vector_scores.insert(
107 id,
108 RawScore {
109 score,
110 metadata,
111 vector,
112 },
113 );
114 }
115
116 for result in text_results {
118 text_min = text_min.min(result.score);
119 text_max = text_max.max(result.score);
120 text_scores.insert(result.doc_id, result.score);
121 }
122
123 let mut all_ids: Vec<String> = vector_scores
125 .keys()
126 .chain(text_scores.keys())
127 .cloned()
128 .collect();
129 all_ids.sort();
130 all_ids.dedup();
131
132 let mut results: Vec<HybridResult> = Vec::new();
134
135 for id in all_ids {
136 let vector_raw = vector_scores.get(&id);
137 let text_raw = text_scores.get(&id);
138
139 if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
141 continue;
142 }
143
144 let vector_normalized = if let Some(raw) = vector_raw {
146 normalize_score(raw.score, vector_min, vector_max)
147 } else {
148 0.0
149 };
150
151 let text_normalized = if let Some(&score) = text_raw {
152 normalize_score(score, text_min, text_max)
153 } else {
154 0.0
155 };
156
157 let combined = self.config.vector_weight * vector_normalized
159 + (1.0 - self.config.vector_weight) * text_normalized;
160
161 let (metadata, vector) = if let Some(raw) = vector_raw {
163 (raw.metadata.clone(), raw.vector.clone())
164 } else {
165 (None, None)
166 };
167
168 results.push(HybridResult {
169 id,
170 combined_score: combined,
171 vector_score: vector_normalized,
172 text_score: text_normalized,
173 metadata,
174 vector,
175 });
176 }
177
178 results.sort_by(|a, b| {
180 b.combined_score
181 .partial_cmp(&a.combined_score)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184
185 results.truncate(top_k);
187 results
188 }
189}
190
191impl Default for HybridSearcher {
192 fn default() -> Self {
193 Self::new(HybridConfig::default())
194 }
195}
196
197pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
210 match kind {
211 crate::routing::QueryKind::Keyword => 0.25,
212 crate::routing::QueryKind::Hybrid => 0.50,
213 crate::routing::QueryKind::Semantic => 0.75,
214 }
215}
216
217fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
219 if (max - min).abs() < f32::EPSILON {
220 1.0
222 } else {
223 (score - min) / (max - min)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_hybrid_search_basic() {
233 let searcher = HybridSearcher::default();
234
235 let vector_results = vec![
236 ("doc1".to_string(), 0.9, None, None),
237 ("doc2".to_string(), 0.7, None, None),
238 ("doc3".to_string(), 0.5, None, None),
239 ];
240
241 let text_results = vec![
242 FullTextResult {
243 doc_id: "doc1".to_string(),
244 score: 3.0,
245 },
246 FullTextResult {
247 doc_id: "doc2".to_string(),
248 score: 4.0,
249 },
250 FullTextResult {
251 doc_id: "doc4".to_string(),
252 score: 2.0,
253 },
254 ];
255
256 let results = searcher.search(vector_results, text_results, 10);
257
258 assert_eq!(results.len(), 4);
260
261 let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
264 assert!(doc1.vector_score > 0.0);
265 assert!(doc1.text_score >= 0.0);
266 assert!(doc1.combined_score > 0.0);
267
268 let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
269 assert!(doc2.vector_score > 0.0);
270 assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
272
273 assert_eq!(doc2.text_score, 1.0);
275 }
276
277 #[test]
278 fn test_hybrid_search_vector_only() {
279 let searcher = HybridSearcher::new(HybridConfig {
280 vector_weight: 1.0,
281 require_both: false,
282 });
283
284 let vector_results = vec![
285 ("doc1".to_string(), 0.9, None, None),
286 ("doc2".to_string(), 0.5, None, None),
287 ];
288
289 let text_results = vec![FullTextResult {
290 doc_id: "doc1".to_string(),
291 score: 1.0,
292 }];
293
294 let results = searcher.search(vector_results, text_results, 10);
295
296 assert_eq!(results[0].id, "doc1");
298 assert_eq!(results[0].combined_score, results[0].vector_score);
299 }
300
301 #[test]
302 fn test_hybrid_search_text_only() {
303 let searcher = HybridSearcher::new(HybridConfig {
304 vector_weight: 0.0,
305 require_both: false,
306 });
307
308 let vector_results = vec![
309 ("doc1".to_string(), 0.9, None, None),
310 ("doc2".to_string(), 0.5, None, None),
311 ];
312
313 let text_results = vec![
314 FullTextResult {
315 doc_id: "doc1".to_string(),
316 score: 1.0,
317 },
318 FullTextResult {
319 doc_id: "doc2".to_string(),
320 score: 3.0,
321 },
322 ];
323
324 let results = searcher.search(vector_results, text_results, 10);
325
326 assert_eq!(results[0].id, "doc2");
328 assert_eq!(results[0].combined_score, results[0].text_score);
329 }
330
331 #[test]
332 fn test_hybrid_search_require_both() {
333 let searcher = HybridSearcher::new(HybridConfig {
334 vector_weight: 0.5,
335 require_both: true,
336 });
337
338 let vector_results = vec![
339 ("doc1".to_string(), 0.9, None, None),
340 ("doc2".to_string(), 0.7, None, None),
341 ];
342
343 let text_results = vec![FullTextResult {
344 doc_id: "doc1".to_string(),
345 score: 2.0,
346 }];
347
348 let results = searcher.search(vector_results, text_results, 10);
349
350 assert_eq!(results.len(), 1);
352 assert_eq!(results[0].id, "doc1");
353 }
354
355 #[test]
356 fn test_hybrid_search_top_k() {
357 let searcher = HybridSearcher::default();
358
359 let vector_results = vec![
360 ("doc1".to_string(), 0.9, None, None),
361 ("doc2".to_string(), 0.8, None, None),
362 ("doc3".to_string(), 0.7, None, None),
363 ("doc4".to_string(), 0.6, None, None),
364 ("doc5".to_string(), 0.5, None, None),
365 ];
366
367 let text_results = vec![];
368
369 let results = searcher.search(vector_results, text_results, 3);
370
371 assert_eq!(results.len(), 3);
372 }
373
374 #[test]
375 fn test_hybrid_search_with_metadata() {
376 let searcher = HybridSearcher::default();
377
378 let metadata = serde_json::json!({"title": "Test Document"});
379 let vector = vec![1.0, 0.0, 0.0];
380
381 let vector_results = vec![(
382 "doc1".to_string(),
383 0.9,
384 Some(metadata.clone()),
385 Some(vector.clone()),
386 )];
387
388 let text_results = vec![FullTextResult {
389 doc_id: "doc1".to_string(),
390 score: 2.0,
391 }];
392
393 let results = searcher.search(vector_results, text_results, 10);
394
395 assert_eq!(results.len(), 1);
396 assert_eq!(results[0].metadata, Some(metadata));
397 assert_eq!(results[0].vector, Some(vector));
398 }
399
400 #[test]
401 fn test_normalize_score() {
402 assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
404 assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
405 assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
406
407 assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
409 }
410
411 #[test]
412 fn test_hybrid_searcher_builder() {
413 let searcher = HybridSearcher::default().with_vector_weight(0.7);
414
415 assert_eq!(searcher.config.vector_weight, 0.7);
416 }
417
418 #[test]
419 fn test_vector_weight_clamping() {
420 let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
421 assert_eq!(searcher1.config.vector_weight, 1.0);
422
423 let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
424 assert_eq!(searcher2.config.vector_weight, 0.0);
425 }
426}