use std::collections::HashMap;
use vicinity::hnsw::HNSWIndex;
struct Document {
title: &'static str,
text: &'static str,
embedding: Vec<f32>,
}
fn build_corpus() -> Vec<Document> {
let docs: Vec<(&str, &str, [f32; 16])> = vec![
(
"Rust basics",
"Rust is a systems programming language focused on safety and performance",
[
1.0, 0.8, 0.3, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"C++ overview",
"C++ is a fast systems language used for operating systems and game engines",
[
0.9, 0.7, 0.2, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"Python intro",
"Python is a high-level language popular for scripting and data science",
[
0.1, 0.1, 0.0, 0.0, 0.9, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"Go concurrency",
"Go provides goroutines and channels for concurrent systems programming",
[
0.7, 0.5, 0.4, 0.2, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"JavaScript web",
"JavaScript powers interactive web applications and runs in the browser",
[
0.0, 0.0, 0.0, 0.0, 0.2, 0.3, 0.0, 0.0, 0.9, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"Rust async",
"Async Rust uses futures and tokio for high-performance network services",
[
0.9, 0.7, 0.5, 0.3, 0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"Database systems",
"Relational databases use SQL for querying structured data with ACID guarantees",
[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.9, 0.3, 0.0, 0.0, 0.0,
],
),
(
"Machine learning",
"Neural networks learn patterns from data using gradient descent optimization",
[
0.0, 0.0, 0.0, 0.0, 0.3, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.9, 0.2, 0.0,
],
),
(
"Operating systems",
"An OS kernel manages hardware resources and provides fast system call interfaces",
[
0.6, 0.5, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.8,
],
),
(
"Rust memory",
"Rust ownership and borrowing prevent memory bugs without garbage collection",
[
0.9, 0.9, 0.4, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0,
],
),
(
"WebAssembly",
"WebAssembly compiles systems languages to run fast portable code in browsers",
[
0.5, 0.3, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
),
(
"Garbage collection",
"Tracing garbage collectors reclaim memory automatically but add latency pauses",
[
0.2, 0.2, 0.0, 0.0, 0.3, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.6,
],
),
];
docs.into_iter()
.map(|(title, text, emb)| Document {
title,
text,
embedding: normalize(&emb),
})
.collect()
}
fn bm25_rank(docs: &[Document], query: &str) -> Vec<(usize, f64)> {
let k1: f64 = 1.2;
let b: f64 = 0.75;
let n = docs.len() as f64;
let query_terms: Vec<String> = tokenize(query);
let doc_tokens: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d.text)).collect();
let avg_dl: f64 = doc_tokens.iter().map(|t| t.len() as f64).sum::<f64>() / n;
let mut df: HashMap<&str, usize> = HashMap::new();
for qt in &query_terms {
let count = doc_tokens
.iter()
.filter(|tokens| tokens.iter().any(|t| t == qt))
.count();
df.insert(qt.as_str(), count);
}
let mut scores: Vec<(usize, f64)> = docs
.iter()
.enumerate()
.map(|(i, _)| {
let dl = doc_tokens[i].len() as f64;
let mut score = 0.0_f64;
for qt in &query_terms {
let tf = doc_tokens[i].iter().filter(|t| *t == qt).count() as f64;
let doc_freq = *df.get(qt.as_str()).unwrap_or(&0) as f64;
if doc_freq == 0.0 {
continue;
}
let idf = ((n - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
let tf_norm = (tf * (k1 + 1.0)) / (tf + k1 * (1.0 - b + b * dl / avg_dl));
score += idf * tf_norm;
}
(i, score)
})
.filter(|(_, s)| *s > 0.0)
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(String::from)
.collect()
}
fn reciprocal_rank_fusion(lists: &[&[(usize, f64)]], k: f64) -> Vec<(usize, f64)> {
let mut fused: HashMap<usize, f64> = HashMap::new();
for list in lists {
for (rank, &(doc_id, _score)) in list.iter().enumerate() {
*fused.entry(doc_id).or_default() += 1.0 / (k + rank as f64 + 1.0);
}
}
let mut result: Vec<(usize, f64)> = fused.into_iter().collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
fn normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 {
return v.to_vec();
}
v.iter().map(|x| x / norm).collect()
}
fn main() -> vicinity::Result<()> {
let corpus = build_corpus();
let dim = 16;
let mut index = HNSWIndex::new(dim, 8, 16)?;
for (i, doc) in corpus.iter().enumerate() {
index.add_slice(i as u32, &doc.embedding)?;
}
index.build()?;
let text_query = "fast systems programming language";
let vector_query = normalize(&[
0.9, 0.7, 0.3, 0.1, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.1,
]);
let bm25_results = bm25_rank(&corpus, text_query);
println!("=== BM25 results for \"{text_query}\" ===");
for (doc_id, score) in &bm25_results {
println!(" [{:2}] {:.4} {}", doc_id, score, corpus[*doc_id].title);
}
let k_neighbors = corpus.len();
let hnsw_raw = index.search(&vector_query, k_neighbors, 50)?;
let hnsw_results: Vec<(usize, f64)> = hnsw_raw
.iter()
.map(|&(id, dist)| (id as usize, (1.0 - dist) as f64))
.collect();
println!("\n=== HNSW vector results ===");
for (doc_id, sim) in &hnsw_results {
println!(" [{:2}] {:.4} {}", doc_id, sim, corpus[*doc_id].title);
}
let bm25_slice: Vec<(usize, f64)> = bm25_results.clone();
let hnsw_slice: Vec<(usize, f64)> = hnsw_results.clone();
let fused = reciprocal_rank_fusion(&[&bm25_slice, &hnsw_slice], 60.0);
println!("\n=== Hybrid results (RRF, k=60) ===");
for (doc_id, rrf_score) in &fused {
let bm25_rank = bm25_results
.iter()
.position(|(id, _)| id == doc_id)
.map(|r| format!("#{}", r + 1));
let hnsw_rank = hnsw_results
.iter()
.position(|(id, _)| id == doc_id)
.map(|r| format!("#{}", r + 1));
println!(
" [{:2}] {:.6} {:20} bm25={} vec={}",
doc_id,
rrf_score,
corpus[*doc_id].title,
bm25_rank.unwrap_or_else(|| "-".into()),
hnsw_rank.unwrap_or_else(|| "-".into()),
);
}
println!("\nRRF promotes documents that rank well in BOTH text and vector search.");
println!(
"Top result: \"{}\" -- strong match on both lexical and semantic axes.",
corpus[fused[0].0].title
);
Ok(())
}