aprender/index/
mod.rs

1//! Indexing data structures for efficient nearest neighbor search.
2//!
3//! This module provides approximate nearest neighbor search algorithms
4//! optimized for production ML workloads.
5//!
6//! # Algorithms
7//!
8//! - **HNSW** (Hierarchical Navigable Small World): O(log n) approximate search
9//!
10//! # Quick Start
11//!
12//! ```
13//! use aprender::index::hnsw::HNSWIndex;
14//! use aprender::primitives::Vector;
15//!
16//! // Create index with M=16 connections per node
17//! let mut index = HNSWIndex::new(16, 200, 0.0);
18//!
19//! // Add vectors at different angles (cosine distance measures angle)
20//! index.add("horizontal", Vector::from_slice(&[1.0, 0.0, 0.0]));
21//! index.add("diagonal", Vector::from_slice(&[1.0, 1.0, 0.0]));
22//! index.add("vertical", Vector::from_slice(&[0.0, 1.0, 0.0]));
23//!
24//! // Search for 2 nearest neighbors to nearly horizontal vector
25//! let query = Vector::from_slice(&[0.9, 0.1, 0.0]);
26//! let results = index.search(&query, 2);
27//!
28//! assert_eq!(results.len(), 2);
29//! // Results are sorted by cosine distance (closest first)
30//! assert!(results[0].1 <= results[1].1);
31//! ```
32
33pub mod hnsw;
34
35pub use hnsw::HNSWIndex;
36
37/// Cross-Encoder for reranking search results.
38///
39/// Takes (query, document) pairs and produces relevance scores.
40/// More accurate than bi-encoders but slower (can't pre-compute).
41#[derive(Debug, Clone)]
42pub struct CrossEncoder<F> {
43    score_fn: F,
44}
45
46impl<F> CrossEncoder<F>
47where
48    F: Fn(&[f32], &[f32]) -> f32,
49{
50    /// Create cross-encoder with custom scoring function.
51    pub fn new(score_fn: F) -> Self {
52        Self { score_fn }
53    }
54
55    /// Score a single (query, document) pair.
56    pub fn score(&self, query: &[f32], document: &[f32]) -> f32 {
57        (self.score_fn)(query, document)
58    }
59
60    /// Rerank candidates by cross-encoder score.
61    pub fn rerank<'a, T>(
62        &self,
63        query: &[f32],
64        candidates: &'a [(T, Vec<f32>)],
65        top_k: usize,
66    ) -> Vec<(&'a T, f32)> {
67        let mut scored: Vec<(&T, f32)> = candidates
68            .iter()
69            .map(|(id, doc)| (id, self.score(query, doc)))
70            .collect();
71
72        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
73        scored.truncate(top_k);
74        scored
75    }
76}
77
78/// Default cross-encoder using cosine similarity.
79pub fn default_cross_encoder() -> CrossEncoder<impl Fn(&[f32], &[f32]) -> f32> {
80    CrossEncoder::new(|q, d| {
81        let dot: f32 = q.iter().zip(d).map(|(&a, &b)| a * b).sum();
82        let nq: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
83        let nd: f32 = d.iter().map(|&x| x * x).sum::<f32>().sqrt();
84        dot / (nq * nd + 1e-10)
85    })
86}
87
88/// Hybrid Search combining dense and sparse retrieval.
89#[derive(Debug, Clone)]
90pub struct HybridSearch {
91    /// Weight for dense (semantic) scores
92    dense_weight: f32,
93    /// Weight for sparse (lexical) scores
94    sparse_weight: f32,
95}
96
97impl HybridSearch {
98    /// Create hybrid search with dense/sparse weights.
99    pub fn new(dense_weight: f32, sparse_weight: f32) -> Self {
100        Self {
101            dense_weight,
102            sparse_weight,
103        }
104    }
105
106    /// Fuse dense and sparse scores using linear combination.
107    pub fn fuse_scores(
108        &self,
109        dense_results: &[(String, f32)],
110        sparse_results: &[(String, f32)],
111        top_k: usize,
112    ) -> Vec<(String, f32)> {
113        use std::collections::HashMap;
114
115        let mut scores: HashMap<String, f32> = HashMap::new();
116
117        // Normalize and add dense scores
118        let dense_max = dense_results
119            .iter()
120            .map(|(_, s)| *s)
121            .fold(0.0_f32, f32::max);
122        for (id, score) in dense_results {
123            let norm = if dense_max > 0.0 {
124                score / dense_max
125            } else {
126                0.0
127            };
128            *scores.entry(id.clone()).or_insert(0.0) += self.dense_weight * norm;
129        }
130
131        // Normalize and add sparse scores
132        let sparse_max = sparse_results
133            .iter()
134            .map(|(_, s)| *s)
135            .fold(0.0_f32, f32::max);
136        for (id, score) in sparse_results {
137            let norm = if sparse_max > 0.0 {
138                score / sparse_max
139            } else {
140                0.0
141            };
142            *scores.entry(id.clone()).or_insert(0.0) += self.sparse_weight * norm;
143        }
144
145        let mut results: Vec<_> = scores.into_iter().collect();
146        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
147        results.truncate(top_k);
148        results
149    }
150
151    /// Reciprocal Rank Fusion (RRF) for combining rankings.
152    pub fn rrf_fuse(&self, rankings: &[Vec<String>], k: f32, top_n: usize) -> Vec<(String, f32)> {
153        use std::collections::HashMap;
154
155        let mut scores: HashMap<String, f32> = HashMap::new();
156
157        for ranking in rankings {
158            for (rank, id) in ranking.iter().enumerate() {
159                *scores.entry(id.clone()).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
160            }
161        }
162
163        let mut results: Vec<_> = scores.into_iter().collect();
164        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
165        results.truncate(top_n);
166        results
167    }
168
169    pub fn dense_weight(&self) -> f32 {
170        self.dense_weight
171    }
172    pub fn sparse_weight(&self) -> f32 {
173        self.sparse_weight
174    }
175}
176
177impl Default for HybridSearch {
178    fn default() -> Self {
179        Self::new(0.7, 0.3) // Default: 70% dense, 30% sparse
180    }
181}
182
183/// Bi-Encoder for efficient dense retrieval.
184///
185/// Encodes queries and documents separately, allowing pre-computation
186/// of document embeddings for fast retrieval.
187///
188/// # Architecture
189///
190/// ```text
191/// Query  ─┬─> Encoder ─> Query Embedding  ─┐
192///         │                                 ├─> Similarity Score
193/// Document ─> Encoder ─> Doc Embedding   ─┘
194/// ```
195#[derive(Debug)]
196pub struct BiEncoder<F> {
197    encode_fn: F,
198    similarity: SimilarityMetric,
199}
200
201/// Similarity metric for comparing embeddings.
202#[derive(Debug, Clone, Copy, PartialEq)]
203pub enum SimilarityMetric {
204    Cosine,
205    DotProduct,
206    Euclidean,
207}
208
209impl<F> BiEncoder<F>
210where
211    F: Fn(&[f32]) -> Vec<f32>,
212{
213    /// Create bi-encoder with custom encoding function.
214    pub fn new(encode_fn: F, similarity: SimilarityMetric) -> Self {
215        Self {
216            encode_fn,
217            similarity,
218        }
219    }
220
221    /// Encode a single input.
222    pub fn encode(&self, input: &[f32]) -> Vec<f32> {
223        (self.encode_fn)(input)
224    }
225
226    /// Encode a batch of inputs.
227    pub fn encode_batch(&self, inputs: &[Vec<f32>]) -> Vec<Vec<f32>> {
228        inputs.iter().map(|x| self.encode(x)).collect()
229    }
230
231    /// Compute similarity between two embeddings.
232    pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
233        match self.similarity {
234            SimilarityMetric::Cosine => {
235                let dot: f32 = a.iter().zip(b).map(|(&x, &y)| x * y).sum();
236                let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
237                let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
238                dot / (na * nb + 1e-10)
239            }
240            SimilarityMetric::DotProduct => a.iter().zip(b).map(|(&x, &y)| x * y).sum(),
241            SimilarityMetric::Euclidean => {
242                let dist_sq: f32 = a.iter().zip(b).map(|(&x, &y)| (x - y).powi(2)).sum();
243                -dist_sq.sqrt() // Negative for sorting (higher = more similar)
244            }
245        }
246    }
247
248    /// Retrieve top-k most similar documents.
249    pub fn retrieve<T: Clone>(
250        &self,
251        query: &[f32],
252        corpus: &[(T, Vec<f32>)],
253        top_k: usize,
254    ) -> Vec<(T, f32)> {
255        let query_emb = self.encode(query);
256        let mut scores: Vec<(T, f32)> = corpus
257            .iter()
258            .map(|(id, doc_emb)| (id.clone(), self.similarity(&query_emb, doc_emb)))
259            .collect();
260
261        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
262        scores.truncate(top_k);
263        scores
264    }
265}
266
267/// ColBERT-style late interaction retrieval.
268///
269/// Computes fine-grained token-level interactions between queries and documents,
270/// using MaxSim (maximum similarity per query token).
271///
272/// # Architecture
273///
274/// ```text
275/// Query Tokens  ─> [q1, q2, ..., qn]  ─┐
276///                                       ├─> MaxSim aggregation ─> Score
277/// Doc Tokens    ─> [d1, d2, ..., dm]  ─┘
278/// ```
279///
280/// # Reference
281///
282/// Khattab, O., & Zaharia, M. (2020). ColBERT: Efficient and Effective Passage
283/// Search via Contextualized Late Interaction over BERT. SIGIR.
284#[derive(Debug)]
285pub struct ColBERT {
286    embedding_dim: usize,
287}
288
289impl ColBERT {
290    /// Create ColBERT with specified embedding dimension.
291    pub fn new(embedding_dim: usize) -> Self {
292        Self { embedding_dim }
293    }
294
295    /// Compute MaxSim score between query and document token embeddings.
296    ///
297    /// For each query token, finds maximum similarity with any doc token,
298    /// then sums across query tokens.
299    pub fn maxsim(&self, query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
300        if query_tokens.is_empty() || doc_tokens.is_empty() {
301            return 0.0;
302        }
303
304        let mut total = 0.0_f32;
305
306        for q in query_tokens {
307            let max_sim = doc_tokens
308                .iter()
309                .map(|d| cosine_sim(q, d))
310                .fold(f32::NEG_INFINITY, f32::max);
311            total += max_sim;
312        }
313
314        total
315    }
316
317    /// Score a batch of documents against a query.
318    pub fn score_documents(
319        &self,
320        query_tokens: &[Vec<f32>],
321        documents: &[Vec<Vec<f32>>],
322    ) -> Vec<f32> {
323        documents
324            .iter()
325            .map(|doc| self.maxsim(query_tokens, doc))
326            .collect()
327    }
328
329    /// Retrieve top-k documents using MaxSim.
330    pub fn retrieve<T: Clone>(
331        &self,
332        query_tokens: &[Vec<f32>],
333        corpus: &[(T, Vec<Vec<f32>>)],
334        top_k: usize,
335    ) -> Vec<(T, f32)> {
336        let mut scores: Vec<(T, f32)> = corpus
337            .iter()
338            .map(|(id, doc_tokens)| (id.clone(), self.maxsim(query_tokens, doc_tokens)))
339            .collect();
340
341        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
342        scores.truncate(top_k);
343        scores
344    }
345
346    pub fn embedding_dim(&self) -> usize {
347        self.embedding_dim
348    }
349}
350
351fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
352    let dot: f32 = a.iter().zip(b).map(|(&x, &y)| x * y).sum();
353    let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
354    let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
355    dot / (na * nb + 1e-10)
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_cross_encoder_score() {
364        let ce = default_cross_encoder();
365        let query = vec![1.0, 0.0];
366        let doc1 = vec![1.0, 0.0];
367        let doc2 = vec![0.0, 1.0];
368
369        let score1 = ce.score(&query, &doc1);
370        let score2 = ce.score(&query, &doc2);
371        assert!(score1 > score2);
372    }
373
374    #[test]
375    fn test_cross_encoder_rerank() {
376        let ce = default_cross_encoder();
377        let query = vec![1.0, 0.0];
378        let candidates = vec![
379            ("doc1", vec![0.0, 1.0]),
380            ("doc2", vec![1.0, 0.0]),
381            ("doc3", vec![0.5, 0.5]),
382        ];
383
384        let reranked = ce.rerank(&query, &candidates, 2);
385        assert_eq!(reranked.len(), 2);
386        assert_eq!(*reranked[0].0, "doc2");
387    }
388
389    #[test]
390    fn test_hybrid_search_fuse() {
391        let hs = HybridSearch::new(0.6, 0.4);
392        let dense = vec![("a".to_string(), 0.9), ("b".to_string(), 0.5)];
393        let sparse = vec![("b".to_string(), 1.0), ("c".to_string(), 0.7)];
394
395        let fused = hs.fuse_scores(&dense, &sparse, 3);
396        assert!(fused.len() <= 3);
397    }
398
399    #[test]
400    fn test_hybrid_search_rrf() {
401        let hs = HybridSearch::default();
402        let rankings = vec![
403            vec!["a".to_string(), "b".to_string(), "c".to_string()],
404            vec!["b".to_string(), "a".to_string(), "d".to_string()],
405        ];
406
407        let fused = hs.rrf_fuse(&rankings, 60.0, 3);
408        assert_eq!(fused.len(), 3);
409    }
410
411    #[test]
412    fn test_hybrid_search_default() {
413        let hs = HybridSearch::default();
414        assert!((hs.dense_weight() - 0.7).abs() < 1e-6);
415        assert!((hs.sparse_weight() - 0.3).abs() < 1e-6);
416    }
417
418    // BiEncoder Tests
419
420    #[test]
421    fn test_bi_encoder_cosine() {
422        let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Cosine);
423        let a = vec![1.0, 0.0];
424        let b = vec![1.0, 0.0];
425        let c = vec![0.0, 1.0];
426
427        let sim_ab = encoder.similarity(&a, &b);
428        let sim_ac = encoder.similarity(&a, &c);
429
430        assert!((sim_ab - 1.0).abs() < 1e-6); // Same direction
431        assert!(sim_ac.abs() < 1e-6); // Orthogonal
432    }
433
434    #[test]
435    fn test_bi_encoder_dot_product() {
436        let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::DotProduct);
437        let a = vec![2.0, 3.0];
438        let b = vec![1.0, 2.0];
439
440        let sim = encoder.similarity(&a, &b);
441        assert!((sim - 8.0).abs() < 1e-6); // 2*1 + 3*2 = 8
442    }
443
444    #[test]
445    fn test_bi_encoder_euclidean() {
446        let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Euclidean);
447        let a = vec![0.0, 0.0];
448        let b = vec![3.0, 4.0];
449
450        let sim = encoder.similarity(&a, &b);
451        assert!((sim - (-5.0)).abs() < 1e-6); // -sqrt(9+16) = -5
452    }
453
454    #[test]
455    fn test_bi_encoder_encode() {
456        let encoder = BiEncoder::new(
457            |x: &[f32]| x.iter().map(|&v| v * 2.0).collect(),
458            SimilarityMetric::Cosine,
459        );
460
461        let input = vec![1.0, 2.0, 3.0];
462        let encoded = encoder.encode(&input);
463
464        assert_eq!(encoded, vec![2.0, 4.0, 6.0]);
465    }
466
467    #[test]
468    fn test_bi_encoder_encode_batch() {
469        let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Cosine);
470        let inputs = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
471
472        let encoded = encoder.encode_batch(&inputs);
473        assert_eq!(encoded.len(), 2);
474    }
475
476    #[test]
477    fn test_bi_encoder_retrieve() {
478        let encoder = BiEncoder::new(|x: &[f32]| x.to_vec(), SimilarityMetric::Cosine);
479        let corpus = vec![
480            ("doc1", vec![1.0, 0.0]),
481            ("doc2", vec![0.0, 1.0]),
482            ("doc3", vec![0.707, 0.707]),
483        ];
484
485        let query = vec![1.0, 0.0];
486        let results = encoder.retrieve(&query, &corpus, 2);
487
488        assert_eq!(results.len(), 2);
489        assert_eq!(results[0].0, "doc1"); // Exact match
490    }
491
492    // ColBERT Tests
493
494    #[test]
495    fn test_colbert_creation() {
496        let colbert = ColBERT::new(128);
497        assert_eq!(colbert.embedding_dim(), 128);
498    }
499
500    #[test]
501    fn test_colbert_maxsim_identical() {
502        let colbert = ColBERT::new(4);
503        let query = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
504        let doc = query.clone();
505
506        let score = colbert.maxsim(&query, &doc);
507        assert!((score - 2.0).abs() < 1e-5); // 1.0 + 1.0
508    }
509
510    #[test]
511    fn test_colbert_maxsim_different() {
512        let colbert = ColBERT::new(4);
513        let query = vec![vec![1.0, 0.0, 0.0, 0.0]];
514        let doc = vec![vec![0.0, 1.0, 0.0, 0.0]];
515
516        let score = colbert.maxsim(&query, &doc);
517        assert!(score.abs() < 1e-5); // Orthogonal
518    }
519
520    #[test]
521    fn test_colbert_maxsim_empty() {
522        let colbert = ColBERT::new(4);
523        let empty: Vec<Vec<f32>> = vec![];
524        let doc = vec![vec![1.0, 0.0, 0.0, 0.0]];
525
526        assert_eq!(colbert.maxsim(&empty, &doc), 0.0);
527        assert_eq!(colbert.maxsim(&doc, &empty), 0.0);
528    }
529
530    #[test]
531    fn test_colbert_score_documents() {
532        let colbert = ColBERT::new(2);
533        let query = vec![vec![1.0, 0.0]];
534        let docs = vec![vec![vec![1.0, 0.0]], vec![vec![0.0, 1.0]]];
535
536        let scores = colbert.score_documents(&query, &docs);
537        assert_eq!(scores.len(), 2);
538        assert!(scores[0] > scores[1]); // First doc matches better
539    }
540
541    #[test]
542    fn test_colbert_retrieve() {
543        let colbert = ColBERT::new(2);
544        let query = vec![vec![1.0, 0.0], vec![0.707, 0.707]];
545        let corpus = vec![
546            ("doc1", vec![vec![1.0, 0.0], vec![0.0, 1.0]]),
547            ("doc2", vec![vec![0.0, 1.0]]),
548        ];
549
550        let results = colbert.retrieve(&query, &corpus, 2);
551        assert_eq!(results.len(), 2);
552    }
553
554    #[test]
555    fn test_similarity_metric_equality() {
556        assert_eq!(SimilarityMetric::Cosine, SimilarityMetric::Cosine);
557        assert_ne!(SimilarityMetric::Cosine, SimilarityMetric::DotProduct);
558    }
559}