pub mod hnsw;
pub use hnsw::HNSWIndex;
#[derive(Debug, Clone)]
pub struct CrossEncoder<F> {
score_fn: F,
}
impl<F> CrossEncoder<F>
where
F: Fn(&[f32], &[f32]) -> f32,
{
pub fn new(score_fn: F) -> Self {
Self { score_fn }
}
pub fn score(&self, query: &[f32], document: &[f32]) -> f32 {
(self.score_fn)(query, document)
}
pub fn rerank<'a, T>(
&self,
query: &[f32],
candidates: &'a [(T, Vec<f32>)],
top_k: usize,
) -> Vec<(&'a T, f32)> {
let mut scored: Vec<(&T, f32)> = candidates
.iter()
.map(|(id, doc)| (id, self.score(query, doc)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
}
}
#[must_use]
pub fn default_cross_encoder() -> CrossEncoder<impl Fn(&[f32], &[f32]) -> f32> {
CrossEncoder::new(|q, d| {
let dot: f32 = q.iter().zip(d).map(|(&a, &b)| a * b).sum();
let nq: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
let nd: f32 = d.iter().map(|&x| x * x).sum::<f32>().sqrt();
dot / (nq * nd + 1e-10)
})
}
#[derive(Debug, Clone)]
pub struct HybridSearch {
dense_weight: f32,
sparse_weight: f32,
}
impl HybridSearch {
#[must_use]
pub fn new(dense_weight: f32, sparse_weight: f32) -> Self {
Self {
dense_weight,
sparse_weight,
}
}
pub fn fuse_scores(
&self,
dense_results: &[(String, f32)],
sparse_results: &[(String, f32)],
top_k: usize,
) -> Vec<(String, f32)> {
use std::collections::HashMap;
let mut scores: HashMap<String, f32> = HashMap::new();
let dense_max = dense_results
.iter()
.map(|(_, s)| *s)
.fold(0.0_f32, f32::max);
for (id, score) in dense_results {
let norm = if dense_max > 0.0 {
score / dense_max
} else {
0.0
};
*scores.entry(id.clone()).or_insert(0.0) += self.dense_weight * norm;
}
let sparse_max = sparse_results
.iter()
.map(|(_, s)| *s)
.fold(0.0_f32, f32::max);
for (id, score) in sparse_results {
let norm = if sparse_max > 0.0 {
score / sparse_max
} else {
0.0
};
*scores.entry(id.clone()).or_insert(0.0) += self.sparse_weight * norm;
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
results
}
#[must_use]
pub fn rrf_fuse(&self, rankings: &[Vec<String>], k: f32, top_n: usize) -> Vec<(String, f32)> {
use std::collections::HashMap;
let mut scores: HashMap<String, f32> = HashMap::new();
for ranking in rankings {
for (rank, id) in ranking.iter().enumerate() {
*scores.entry(id.clone()).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
}
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_n);
results
}
#[must_use]
pub fn dense_weight(&self) -> f32 {
self.dense_weight
}
#[must_use]
pub fn sparse_weight(&self) -> f32 {
self.sparse_weight
}
}
impl Default for HybridSearch {
fn default() -> Self {
Self::new(0.7, 0.3) }
}
#[derive(Debug)]
pub struct BiEncoder<F> {
encode_fn: F,
similarity: SimilarityMetric,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SimilarityMetric {
Cosine,
DotProduct,
Euclidean,
}
impl<F> BiEncoder<F>
where
F: Fn(&[f32]) -> Vec<f32>,
{
pub fn new(encode_fn: F, similarity: SimilarityMetric) -> Self {
Self {
encode_fn,
similarity,
}
}
pub fn encode(&self, input: &[f32]) -> Vec<f32> {
(self.encode_fn)(input)
}
pub fn encode_batch(&self, inputs: &[Vec<f32>]) -> Vec<Vec<f32>> {
inputs.iter().map(|x| self.encode(x)).collect()
}
pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
match self.similarity {
SimilarityMetric::Cosine => {
let dot: f32 = a.iter().zip(b).map(|(&x, &y)| x * y).sum();
let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
dot / (na * nb + 1e-10)
}
SimilarityMetric::DotProduct => a.iter().zip(b).map(|(&x, &y)| x * y).sum(),
SimilarityMetric::Euclidean => {
let dist_sq: f32 = a.iter().zip(b).map(|(&x, &y)| (x - y).powi(2)).sum();
-dist_sq.sqrt() }
}
}
pub fn retrieve<T: Clone>(
&self,
query: &[f32],
corpus: &[(T, Vec<f32>)],
top_k: usize,
) -> Vec<(T, f32)> {
let query_emb = self.encode(query);
let mut scores: Vec<(T, f32)> = corpus
.iter()
.map(|(id, doc_emb)| (id.clone(), self.similarity(&query_emb, doc_emb)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
scores
}
}
#[derive(Debug)]
pub struct ColBERT {
embedding_dim: usize,
}
impl ColBERT {
#[must_use]
pub fn new(embedding_dim: usize) -> Self {
Self { embedding_dim }
}
pub fn maxsim(&self, query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
let mut total = 0.0_f32;
for q in query_tokens {
let max_sim = doc_tokens
.iter()
.map(|d| cosine_sim(q, d))
.fold(f32::NEG_INFINITY, f32::max);
total += max_sim;
}
total
}
#[must_use]
pub fn score_documents(
&self,
query_tokens: &[Vec<f32>],
documents: &[Vec<Vec<f32>>],
) -> Vec<f32> {
documents
.iter()
.map(|doc| self.maxsim(query_tokens, doc))
.collect()
}
pub fn retrieve<T: Clone>(
&self,
query_tokens: &[Vec<f32>],
corpus: &[(T, Vec<Vec<f32>>)],
top_k: usize,
) -> Vec<(T, f32)> {
let mut scores: Vec<(T, f32)> = corpus
.iter()
.map(|(id, doc_tokens)| (id.clone(), self.maxsim(query_tokens, doc_tokens)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
scores
}
#[must_use]
pub fn embedding_dim(&self) -> usize {
self.embedding_dim
}
}
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(&x, &y)| x * y).sum();
let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
dot / (na * nb + 1e-10)
}
#[cfg(test)]
#[path = "index_tests.rs"]
mod tests;