1pub mod hnsw;
34
35pub use hnsw::HNSWIndex;
36
37#[derive(Debug, Clone)]
42pub struct CrossEncoder<F> {
43 score_fn: F,
44}
45
46impl<F> CrossEncoder<F>
47where
48 F: Fn(&[f32], &[f32]) -> f32,
49{
50 pub fn new(score_fn: F) -> Self {
52 Self { score_fn }
53 }
54
55 pub fn score(&self, query: &[f32], document: &[f32]) -> f32 {
57 (self.score_fn)(query, document)
58 }
59
60 pub fn rerank<'a, T>(
62 &self,
63 query: &[f32],
64 candidates: &'a [(T, Vec<f32>)],
65 top_k: usize,
66 ) -> Vec<(&'a T, f32)> {
67 let mut scored: Vec<(&T, f32)> = candidates
68 .iter()
69 .map(|(id, doc)| (id, self.score(query, doc)))
70 .collect();
71
72 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
73 scored.truncate(top_k);
74 scored
75 }
76}
77
78pub fn default_cross_encoder() -> CrossEncoder<impl Fn(&[f32], &[f32]) -> f32> {
80 CrossEncoder::new(|q, d| {
81 let dot: f32 = q.iter().zip(d).map(|(&a, &b)| a * b).sum();
82 let nq: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
83 let nd: f32 = d.iter().map(|&x| x * x).sum::<f32>().sqrt();
84 dot / (nq * nd + 1e-10)
85 })
86}
87
88#[derive(Debug, Clone)]
90pub struct HybridSearch {
91 dense_weight: f32,
93 sparse_weight: f32,
95}
96
97impl HybridSearch {
98 pub fn new(dense_weight: f32, sparse_weight: f32) -> Self {
100 Self {
101 dense_weight,
102 sparse_weight,
103 }
104 }
105
106 pub fn fuse_scores(
108 &self,
109 dense_results: &[(String, f32)],
110 sparse_results: &[(String, f32)],
111 top_k: usize,
112 ) -> Vec<(String, f32)> {
113 use std::collections::HashMap;
114
115 let mut scores: HashMap<String, f32> = HashMap::new();
116
117 let dense_max = dense_results
119 .iter()
120 .map(|(_, s)| *s)
121 .fold(0.0_f32, f32::max);
122 for (id, score) in dense_results {
123 let norm = if dense_max > 0.0 {
124 score / dense_max
125 } else {
126 0.0
127 };
128 *scores.entry(id.clone()).or_insert(0.0) += self.dense_weight * norm;
129 }
130
131 let sparse_max = sparse_results
133 .iter()
134 .map(|(_, s)| *s)
135 .fold(0.0_f32, f32::max);
136 for (id, score) in sparse_results {
137 let norm = if sparse_max > 0.0 {
138 score / sparse_max
139 } else {
140 0.0
141 };
142 *scores.entry(id.clone()).or_insert(0.0) += self.sparse_weight * norm;
143 }
144
145 let mut results: Vec<_> = scores.into_iter().collect();
146 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
147 results.truncate(top_k);
148 results
149 }
150
151 pub fn rrf_fuse(&self, rankings: &[Vec<String>], k: f32, top_n: usize) -> Vec<(String, f32)> {
153 use std::collections::HashMap;
154
155 let mut scores: HashMap<String, f32> = HashMap::new();
156
157 for ranking in rankings {
158 for (rank, id) in ranking.iter().enumerate() {
159 *scores.entry(id.clone()).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
160 }
161 }
162
163 let mut results: Vec<_> = scores.into_iter().collect();
164 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
165 results.truncate(top_n);
166 results
167 }
168
169 pub fn dense_weight(&self) -> f32 {
170 self.dense_weight
171 }
172 pub fn sparse_weight(&self) -> f32 {
173 self.sparse_weight
174 }
175}
176
177impl Default for HybridSearch {
178 fn default() -> Self {
179 Self::new(0.7, 0.3) }
181}
182
183#[derive(Debug)]
196pub struct BiEncoder<F> {
197 encode_fn: F,
198 similarity: SimilarityMetric,
199}
200
201#[derive(Debug, Clone, Copy, PartialEq)]
203pub enum SimilarityMetric {
204 Cosine,
205 DotProduct,
206 Euclidean,
207}
208
209impl<F> BiEncoder<F>
210where
211 F: Fn(&[f32]) -> Vec<f32>,
212{
213 pub fn new(encode_fn: F, similarity: SimilarityMetric) -> Self {
215 Self {
216 encode_fn,
217 similarity,
218 }
219 }
220
221 pub fn encode(&self, input: &[f32]) -> Vec<f32> {
223 (self.encode_fn)(input)
224 }
225
226 pub fn encode_batch(&self, inputs: &[Vec<f32>]) -> Vec<Vec<f32>> {
228 inputs.iter().map(|x| self.encode(x)).collect()
229 }
230
231 pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
233 match self.similarity {
234 SimilarityMetric::Cosine => {
235 let dot: f32 = a.iter().zip(b).map(|(&x, &y)| x * y).sum();
236 let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
237 let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
238 dot / (na * nb + 1e-10)
239 }
240 SimilarityMetric::DotProduct => a.iter().zip(b).map(|(&x, &y)| x * y).sum(),
241 SimilarityMetric::Euclidean => {
242 let dist_sq: f32 = a.iter().zip(b).map(|(&x, &y)| (x - y).powi(2)).sum();
243 -dist_sq.sqrt() }
245 }
246 }
247
248 pub fn retrieve<T: Clone>(
250 &self,
251 query: &[f32],
252 corpus: &[(T, Vec<f32>)],
253 top_k: usize,
254 ) -> Vec<(T, f32)> {
255 let query_emb = self.encode(query);
256 let mut scores: Vec<(T, f32)> = corpus
257 .iter()
258 .map(|(id, doc_emb)| (id.clone(), self.similarity(&query_emb, doc_emb)))
259 .collect();
260
261 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
262 scores.truncate(top_k);
263 scores
264 }
265}
266
267#[derive(Debug)]
285pub struct ColBERT {
286 embedding_dim: usize,
287}
288
289impl ColBERT {
290 pub fn new(embedding_dim: usize) -> Self {
292 Self { embedding_dim }
293 }
294
295 pub fn maxsim(&self, query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
300 if query_tokens.is_empty() || doc_tokens.is_empty() {
301 return 0.0;
302 }
303
304 let mut total = 0.0_f32;
305
306 for q in query_tokens {
307 let max_sim = doc_tokens
308 .iter()
309 .map(|d| cosine_sim(q, d))
310 .fold(f32::NEG_INFINITY, f32::max);
311 total += max_sim;
312 }
313
314 total
315 }
316
317 pub fn score_documents(
319 &self,
320 query_tokens: &[Vec<f32>],
321 documents: &[Vec<Vec<f32>>],
322 ) -> Vec<f32> {
323 documents
324 .iter()
325 .map(|doc| self.maxsim(query_tokens, doc))
326 .collect()
327 }
328
329 pub fn retrieve<T: Clone>(
331 &self,
332 query_tokens: &[Vec<f32>],
333 corpus: &[(T, Vec<Vec<f32>>)],
334 top_k: usize,
335 ) -> Vec<(T, f32)> {
336 let mut scores: Vec<(T, f32)> = corpus
337 .iter()
338 .map(|(id, doc_tokens)| (id.clone(), self.maxsim(query_tokens, doc_tokens)))
339 .collect();
340
341 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
342 scores.truncate(top_k);
343 scores
344 }
345
346 pub fn embedding_dim(&self) -> usize {
347 self.embedding_dim
348 }
349}
350
351fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
352 let dot: f32 = a.iter().zip(b).map(|(&x, &y)| x * y).sum();
353 let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
354 let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
355 dot / (na * nb + 1e-10)
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_cross_encoder_score() {
364 let ce = default_cross_encoder();
365 let query = vec![1.0, 0.0];
366 let doc1 = vec![1.0, 0.0];
367 let doc2 = vec![0.0, 1.0];
368
369 let score1 = ce.score(&query, &doc1);
370 let score2 = ce.score(&query, &doc2);
371 assert!(score1 > score2);
372 }
373
374 #[test]
375 fn test_cross_encoder_rerank() {
376 let ce = default_cross_encoder();
377 let query = vec![1.0, 0.0];
378 let candidates = vec![
379 ("doc1", vec![0.0, 1.0]),
380 ("doc2", vec![1.0, 0.0]),
381 ("doc3", vec![0.5, 0.5]),
382 ];
383
384 let reranked = ce.rerank(&query, &candidates, 2);
385 assert_eq!(reranked.len(), 2);
386 assert_eq!(*reranked[0].0, "doc2");
387 }
388
389 #[test]
390 fn test_hybrid_search_fuse() {
391 let hs = HybridSearch::new(0.6, 0.4);
392 let dense = vec![("a".to_string(), 0.9), ("b".to_string(), 0.5)];
393 let sparse = vec![("b".to_string(), 1.0), ("c".to_string(), 0.7)];
394
395 let fused = hs.fuse_scores(&dense, &sparse, 3);
396 assert!(fused.len() <= 3);
397 }
398
399 #[test]
400 fn test_hybrid_search_rrf() {
401 let hs = HybridSearch::default();
402 let rankings = vec![
403 vec!["a".to_string(), "b".to_string(), "c".to_string()],
404 vec!["b".to_string(), "a".to_string(), "d".to_string()],
405 ];
406
407 let fused = hs.rrf_fuse(&rankings, 60.0, 3);
408 assert_eq!(fused.len(), 3);
409 }
410
411 #[test]
412 fn test_hybrid_search_default() {
413 let hs = HybridSearch::default();
414 assert!((hs.dense_weight() - 0.7).abs() < 1e-6);
415 assert!((hs.sparse_weight() - 0.3).abs() < 1e-6);
416 }
417
418 #[test]
421 fn test_bi_encoder_cosine() {
422 let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Cosine);
423 let a = vec![1.0, 0.0];
424 let b = vec![1.0, 0.0];
425 let c = vec![0.0, 1.0];
426
427 let sim_ab = encoder.similarity(&a, &b);
428 let sim_ac = encoder.similarity(&a, &c);
429
430 assert!((sim_ab - 1.0).abs() < 1e-6); assert!(sim_ac.abs() < 1e-6); }
433
434 #[test]
435 fn test_bi_encoder_dot_product() {
436 let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::DotProduct);
437 let a = vec![2.0, 3.0];
438 let b = vec![1.0, 2.0];
439
440 let sim = encoder.similarity(&a, &b);
441 assert!((sim - 8.0).abs() < 1e-6); }
443
444 #[test]
445 fn test_bi_encoder_euclidean() {
446 let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Euclidean);
447 let a = vec![0.0, 0.0];
448 let b = vec![3.0, 4.0];
449
450 let sim = encoder.similarity(&a, &b);
451 assert!((sim - (-5.0)).abs() < 1e-6); }
453
454 #[test]
455 fn test_bi_encoder_encode() {
456 let encoder = BiEncoder::new(
457 |x: &[f32]| x.iter().map(|&v| v * 2.0).collect(),
458 SimilarityMetric::Cosine,
459 );
460
461 let input = vec![1.0, 2.0, 3.0];
462 let encoded = encoder.encode(&input);
463
464 assert_eq!(encoded, vec![2.0, 4.0, 6.0]);
465 }
466
467 #[test]
468 fn test_bi_encoder_encode_batch() {
469 let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Cosine);
470 let inputs = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
471
472 let encoded = encoder.encode_batch(&inputs);
473 assert_eq!(encoded.len(), 2);
474 }
475
476 #[test]
477 fn test_bi_encoder_retrieve() {
478 let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Cosine);
479 let corpus = vec![
480 ("doc1", vec![1.0, 0.0]),
481 ("doc2", vec![0.0, 1.0]),
482 ("doc3", vec![0.707, 0.707]),
483 ];
484
485 let query = vec![1.0, 0.0];
486 let results = encoder.retrieve(&query, &corpus, 2);
487
488 assert_eq!(results.len(), 2);
489 assert_eq!(results[0].0, "doc1"); }
491
492 #[test]
495 fn test_colbert_creation() {
496 let colbert = ColBERT::new(128);
497 assert_eq!(colbert.embedding_dim(), 128);
498 }
499
500 #[test]
501 fn test_colbert_maxsim_identical() {
502 let colbert = ColBERT::new(4);
503 let query = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
504 let doc = query.clone();
505
506 let score = colbert.maxsim(&query, &doc);
507 assert!((score - 2.0).abs() < 1e-5); }
509
510 #[test]
511 fn test_colbert_maxsim_different() {
512 let colbert = ColBERT::new(4);
513 let query = vec![vec![1.0, 0.0, 0.0, 0.0]];
514 let doc = vec![vec![0.0, 1.0, 0.0, 0.0]];
515
516 let score = colbert.maxsim(&query, &doc);
517 assert!(score.abs() < 1e-5); }
519
520 #[test]
521 fn test_colbert_maxsim_empty() {
522 let colbert = ColBERT::new(4);
523 let empty: Vec<Vec<f32>> = vec![];
524 let doc = vec![vec![1.0, 0.0, 0.0, 0.0]];
525
526 assert_eq!(colbert.maxsim(&empty, &doc), 0.0);
527 assert_eq!(colbert.maxsim(&doc, &empty), 0.0);
528 }
529
530 #[test]
531 fn test_colbert_score_documents() {
532 let colbert = ColBERT::new(2);
533 let query = vec![vec![1.0, 0.0]];
534 let docs = vec![vec![vec![1.0, 0.0]], vec![vec![0.0, 1.0]]];
535
536 let scores = colbert.score_documents(&query, &docs);
537 assert_eq!(scores.len(), 2);
538 assert!(scores[0] > scores[1]); }
540
541 #[test]
542 fn test_colbert_retrieve() {
543 let colbert = ColBERT::new(2);
544 let query = vec![vec![1.0, 0.0], vec![0.707, 0.707]];
545 let corpus = vec![
546 ("doc1", vec![vec![1.0, 0.0], vec![0.0, 1.0]]),
547 ("doc2", vec![vec![0.0, 1.0]]),
548 ];
549
550 let results = colbert.retrieve(&query, &corpus, 2);
551 assert_eq!(results.len(), 2);
552 }
553
554 #[test]
555 fn test_similarity_metric_equality() {
556 assert_eq!(SimilarityMetric::Cosine, SimilarityMetric::Cosine);
557 assert_ne!(SimilarityMetric::Cosine, SimilarityMetric::DotProduct);
558 }
559}