use std::collections::HashMap;
use rankops::{mmr_embeddings, MmrConfig};
use vicinity::hnsw::HNSWIndex;
const DIM: usize = 16;
const N_TOPICS: usize = 4;
const PER_TOPIC: usize = 5;
fn main() -> vicinity::Result<()> {
let topic_dirs: [[f32; DIM]; N_TOPICS] = [
[
1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
[
0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0,
],
[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
],
];
let mut vectors: HashMap<u32, Vec<f32>> = HashMap::new();
let mut index = HNSWIndex::new(DIM, 8, 16)?;
for (t, dir) in topic_dirs.iter().enumerate() {
for dup in 0..PER_TOPIC {
let id = (t * PER_TOPIC + dup) as u32;
let v: Vec<f32> = dir
.iter()
.enumerate()
.map(|(d, &val)| val + 0.02 * ((dup * 7 + d * 5) as f32 * 1.13).sin())
.collect();
let v = normalize(&v);
index.add_slice(id, &v)?;
vectors.insert(id, v);
}
}
index.build()?;
let query = normalize(&[1.0; DIM]);
let raw = index.search(&query, 20, 50)?;
println!("--- Stage 1: HNSW retrieval (top 12 of 20) ---");
for (id, dist) in raw.iter().take(12) {
println!(" id={:2} topic={} distance={:.4}", id, id / 5, dist);
}
let max_dist = raw.iter().map(|(_, d)| *d).fold(0.0_f32, f32::max);
let candidates: Vec<(u32, f32, Vec<f32>)> = raw
.iter()
.map(|&(id, dist)| {
let relevance = (max_dist - dist) + 0.01;
(id, relevance, vectors[&id].clone())
})
.collect();
let config = MmrConfig::new(0.5).with_top_k(8);
let reranked = mmr_embeddings(&candidates, config);
println!("\n--- Stage 2: MMR reranked (top 8, lambda=0.5) ---");
for (id, score) in &reranked {
println!(" id={:2} topic={} mmr_score={:.4}", id, id / 5, score);
}
let n = N_TOPICS;
let raw_topics: Vec<u32> = raw.iter().take(n).map(|(id, _)| id / 5).collect();
let mmr_topics: Vec<u32> = reranked.iter().take(n).map(|(id, _)| id / 5).collect();
let raw_uniq = unique_count(&raw_topics);
let mmr_uniq = unique_count(&mmr_topics);
println!("\n--- Topic coverage (first {n} picks) ---");
println!(" HNSW : {raw_uniq} unique topics {raw_topics:?}");
println!(" MMR : {mmr_uniq} unique topics {mmr_topics:?}");
if mmr_uniq > raw_uniq {
println!(" -> MMR spread results across more topics.");
} else {
println!(" -> Equal coverage (topics were already interleaved by distance).");
}
Ok(())
}
fn normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 {
return v.to_vec();
}
v.iter().map(|x| x / norm).collect()
}
fn unique_count(items: &[u32]) -> usize {
let mut seen = items.to_vec();
seen.sort_unstable();
seen.dedup();
seen.len()
}