use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use ndarray::{Array2, ArrayView1};
use serde::{Deserialize, Serialize};
use crate::backend::{self, BackendConfig, BackendIndex};
use crate::hnsw::search::SearchParams;
use crate::index::DistanceMetric;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenLabel {
pub doc_id: u32,
pub seq_id: u32,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
struct PendingDoc {
doc_id: u32,
embeddings: Array2<f32>,
metadata: HashMap<String, serde_json::Value>,
}
pub struct MultiVectorBuilder {
dim: usize,
pending: Vec<PendingDoc>,
backend_config: BackendConfig,
}
impl MultiVectorBuilder {
pub fn new(dim: usize) -> Self {
let mut config = BackendConfig::hnsw_default();
config.set_distance_metric(DistanceMetric::Mips);
config.set_recompute(false);
config.set_compact(false);
Self {
dim,
pending: Vec::new(),
backend_config: config,
}
}
pub fn set_m(&mut self, m: usize) -> &mut Self {
self.backend_config.set_m(m);
self
}
pub fn set_ef_construction(&mut self, ef: usize) -> &mut Self {
self.backend_config.set_ef_construction(ef);
self
}
pub fn insert(
&mut self,
doc_id: u32,
embeddings: Array2<f32>,
metadata: HashMap<String, serde_json::Value>,
) -> &mut Self {
assert_eq!(
embeddings.ncols(),
self.dim,
"embedding dim {} != expected {}",
embeddings.ncols(),
self.dim
);
self.pending.push(PendingDoc {
doc_id,
embeddings,
metadata,
});
self
}
pub fn build(&self, index_path: &Path) -> Result<()> {
anyhow::ensure!(!self.pending.is_empty(), "no documents inserted");
let total_tokens: usize = self.pending.iter().map(|d| d.embeddings.nrows()).sum();
let mut flat = Array2::<f32>::zeros((total_tokens, self.dim));
let mut labels = Vec::with_capacity(total_tokens);
let mut row = 0;
for doc in &self.pending {
for seq_id in 0..doc.embeddings.nrows() {
flat.row_mut(row).assign(&doc.embeddings.row(seq_id));
labels.push(TokenLabel {
doc_id: doc.doc_id,
seq_id: seq_id as u32,
metadata: doc.metadata.clone(),
});
row += 1;
}
}
let index_file = with_ext(index_path, "index");
backend::build_backend(&self.backend_config, &flat, &index_file, None)?;
let labels_file = with_ext(index_path, "labels.json");
let labels_json = serde_json::to_string(&labels)?;
fs::write(&labels_file, labels_json)
.with_context(|| format!("writing {}", labels_file.display()))?;
let npy_file = with_ext(index_path, "emb.npy");
write_npy(&flat, &npy_file)?;
Ok(())
}
}
pub struct MultiVectorSearcher {
index: BackendIndex,
labels: Vec<TokenLabel>,
doc_to_rows: HashMap<u32, Vec<usize>>,
#[cfg(feature = "multi-vector")]
emb_mmap: memmap2::Mmap,
#[cfg(not(feature = "multi-vector"))]
emb_data: Vec<u8>,
dim: usize,
total_tokens: usize,
}
impl MultiVectorSearcher {
pub fn open(index_path: &Path) -> Result<Self> {
let index_file = with_ext(index_path, "index");
let index = backend::read_backend_index("hnsw", &index_file)?;
let labels_file = with_ext(index_path, "labels.json");
let labels_data = fs::read_to_string(&labels_file)
.with_context(|| format!("reading {}", labels_file.display()))?;
let labels: Vec<TokenLabel> = serde_json::from_str(&labels_data)?;
let mut doc_to_rows: HashMap<u32, Vec<usize>> = HashMap::new();
for (i, label) in labels.iter().enumerate() {
doc_to_rows.entry(label.doc_id).or_default().push(i);
}
let dim = index.dimensions();
let total_tokens = labels.len();
let npy_file = with_ext(index_path, "emb.npy");
#[cfg(feature = "multi-vector")]
let emb_mmap = {
let file = fs::File::open(&npy_file)
.with_context(|| format!("opening {}", npy_file.display()))?;
unsafe { memmap2::Mmap::map(&file)? }
};
Ok(Self {
index,
labels,
doc_to_rows,
#[cfg(feature = "multi-vector")]
emb_mmap,
#[cfg(not(feature = "multi-vector"))]
emb_data: fs::read(&npy_file)?,
dim,
total_tokens,
})
}
pub fn num_docs(&self) -> usize {
self.doc_to_rows.len()
}
pub fn num_tokens(&self) -> usize {
self.total_tokens
}
pub fn search(
&self,
query_tokens: &Array2<f32>,
top_k: usize,
) -> Result<Vec<MultiVectorResult>> {
self.search_with_params(query_tokens, top_k, 50)
}
pub fn search_with_params(
&self,
query_tokens: &Array2<f32>,
top_k: usize,
per_token_k: usize,
) -> Result<Vec<MultiVectorResult>> {
let params = SearchParams::default();
let mut doc_scores: HashMap<u32, f32> = HashMap::new();
for qi in 0..query_tokens.nrows() {
let query_vec = query_tokens.row(qi);
let query_slice = query_vec.as_slice().unwrap();
let (labels_idx, distances) =
backend::search_backend(&self.index, query_slice, per_token_k, ¶ms);
let mut best_per_doc: HashMap<u32, f32> = HashMap::new();
for (idx, dist) in labels_idx.into_iter().zip(distances) {
if idx >= self.labels.len() {
continue;
}
let doc_id = self.labels[idx].doc_id;
let sim = -dist; let entry = best_per_doc.entry(doc_id).or_insert(f32::NEG_INFINITY);
if sim > *entry {
*entry = sim;
}
}
for (doc_id, score) in best_per_doc {
*doc_scores.entry(doc_id).or_insert(0.0) += score;
}
}
Ok(top_k_results(
&doc_scores,
top_k,
&self.doc_to_rows,
&self.labels,
))
}
pub fn search_exact(
&self,
query_tokens: &Array2<f32>,
top_k: usize,
first_stage_k: usize,
) -> Result<Vec<MultiVectorResult>> {
let approx = self.search_with_params(query_tokens, first_stage_k, 50)?;
let candidate_docs: Vec<u32> = approx.iter().map(|r| r.doc_id).collect();
if candidate_docs.is_empty() {
return Ok(Vec::new());
}
let emb_bytes = self.emb_bytes();
let (header_len, _rows, _cols) = parse_npy_header(emb_bytes)?;
let data_start = header_len;
let float_data = &emb_bytes[data_start..];
let mut doc_scores: HashMap<u32, f32> = HashMap::new();
for &doc_id in &candidate_docs {
if let Some(row_indices) = self.doc_to_rows.get(&doc_id) {
let score = exact_max_sim(query_tokens, float_data, row_indices, self.dim);
doc_scores.insert(doc_id, score);
}
}
Ok(top_k_results(
&doc_scores,
top_k,
&self.doc_to_rows,
&self.labels,
))
}
fn emb_bytes(&self) -> &[u8] {
#[cfg(feature = "multi-vector")]
{
&self.emb_mmap
}
#[cfg(not(feature = "multi-vector"))]
{
&self.emb_data
}
}
}
#[derive(Debug, Clone)]
pub struct MultiVectorResult {
pub doc_id: u32,
pub score: f32,
pub metadata: HashMap<String, serde_json::Value>,
}
fn exact_max_sim(
query_tokens: &Array2<f32>,
float_data: &[u8],
doc_row_indices: &[usize],
dim: usize,
) -> f32 {
let mut total = 0.0f32;
for qi in 0..query_tokens.nrows() {
let q = query_tokens.row(qi);
let mut best = f32::NEG_INFINITY;
for &row_idx in doc_row_indices {
let offset = row_idx * dim * 4;
let end = offset + dim * 4;
if end > float_data.len() {
continue;
}
let dot = dot_product_bytes(q, &float_data[offset..end]);
if dot > best {
best = dot;
}
}
if best > f32::NEG_INFINITY {
total += best;
}
}
total
}
#[inline]
fn dot_product_bytes(a: ArrayView1<f32>, b_bytes: &[u8]) -> f32 {
let mut sum = 0.0f32;
for (i, &ai) in a.iter().enumerate() {
let offset = i * 4;
let bi = f32::from_le_bytes(b_bytes[offset..offset + 4].try_into().unwrap());
sum += ai * bi;
}
sum
}
fn top_k_results(
doc_scores: &HashMap<u32, f32>,
top_k: usize,
doc_to_rows: &HashMap<u32, Vec<usize>>,
labels: &[TokenLabel],
) -> Vec<MultiVectorResult> {
let mut entries: Vec<(u32, f32)> = doc_scores.iter().map(|(&d, &s)| (d, s)).collect();
entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
entries.truncate(top_k);
entries
.into_iter()
.map(|(doc_id, score)| {
let metadata = doc_to_rows
.get(&doc_id)
.and_then(|rows| rows.first())
.map(|&idx| labels[idx].metadata.clone())
.unwrap_or_default();
MultiVectorResult {
doc_id,
score,
metadata,
}
})
.collect()
}
fn write_npy(arr: &Array2<f32>, path: &Path) -> Result<()> {
let (rows, cols) = arr.dim();
let header = format!(
"{{'descr': '<f4', 'fortran_order': False, 'shape': ({}, {}), }}",
rows, cols
);
let prefix_len = 10; let total_unpadded = prefix_len + header.len() + 1; let padding = (64 - (total_unpadded % 64)) % 64;
let header_content_len = header.len() + padding + 1;
let mut file = fs::File::create(path)?;
file.write_all(&[0x93, b'N', b'U', b'M', b'P', b'Y'])?;
file.write_all(&[1, 0])?;
file.write_all(&(header_content_len as u16).to_le_bytes())?;
file.write_all(header.as_bytes())?;
for _ in 0..padding {
file.write_all(b" ")?;
}
file.write_all(b"\n")?;
for val in arr.iter() {
file.write_all(&val.to_le_bytes())?;
}
Ok(())
}
fn parse_npy_header(data: &[u8]) -> Result<(usize, usize, usize)> {
anyhow::ensure!(data.len() >= 10, "npy file too small");
anyhow::ensure!(&data[0..6] == b"\x93NUMPY", "invalid npy magic");
let header_len = u16::from_le_bytes([data[8], data[9]]) as usize;
let header_end = 10 + header_len;
anyhow::ensure!(data.len() >= header_end, "npy header truncated");
let header_str = std::str::from_utf8(&data[10..header_end])?;
let shape_start = header_str
.find("'shape': (")
.context("no shape in npy header")?
+ "'shape': (".len();
let shape_end = header_str[shape_start..]
.find(')')
.context("unclosed shape tuple")?
+ shape_start;
let shape_str = &header_str[shape_start..shape_end];
let dims: Vec<usize> = shape_str
.split(',')
.filter_map(|s| s.trim().parse().ok())
.collect();
anyhow::ensure!(dims.len() == 2, "expected 2D shape, got {:?}", dims);
Ok((header_end, dims[0], dims[1]))
}
fn with_ext(base: &Path, ext: &str) -> PathBuf {
let mut p = base.to_path_buf();
let name = p
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string();
p.set_file_name(format!("{}.{}", name, ext));
p
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn make_test_data() -> (Array2<f32>, Array2<f32>, Array2<f32>) {
let doc0 = array![
[1.0, 0.0, 0.0, 0.0],
[0.9, 0.1, 0.0, 0.0],
[0.8, 0.2, 0.0, 0.0],
];
let doc1 = array![[0.0, 1.0, 0.0, 0.0], [0.1, 0.9, 0.0, 0.0],];
let query = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
(doc0, doc1, query)
}
#[test]
fn test_build_and_search() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_mv");
let (doc0, doc1, query) = make_test_data();
let mut builder = MultiVectorBuilder::new(4);
builder.insert(0, doc0, HashMap::new());
builder.insert(1, doc1, HashMap::new());
builder.build(&index_path).unwrap();
assert!(with_ext(&index_path, "index").exists());
assert!(with_ext(&index_path, "labels.json").exists());
assert!(with_ext(&index_path, "emb.npy").exists());
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
assert_eq!(searcher.num_docs(), 2);
assert_eq!(searcher.num_tokens(), 5);
let results = searcher.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
let exact_results = searcher.search_exact(&query, 2, 10).unwrap();
assert_eq!(exact_results.len(), 2);
}
#[test]
fn test_max_sim_scoring() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_scoring");
let doc0 = array![[1.0, 0.0, 0.0, 0.0]];
let doc1 = array![[0.0, 1.0, 0.0, 0.0]];
let query = array![[1.0, 0.0, 0.0, 0.0]];
let mut builder = MultiVectorBuilder::new(4);
builder.insert(0, doc0, HashMap::new());
builder.insert(1, doc1, HashMap::new());
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
let results = searcher.search_exact(&query, 2, 10).unwrap();
assert_eq!(results[0].doc_id, 0);
assert!(results[0].score > results[1].score);
assert!((results[0].score - 1.0).abs() < 1e-5);
assert!((results[1].score - 0.0).abs() < 1e-5);
}
#[test]
fn test_npy_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.npy");
let arr = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
write_npy(&arr, &path).unwrap();
let data = fs::read(&path).unwrap();
let (header_len, rows, cols) = parse_npy_header(&data).unwrap();
assert_eq!(rows, 2);
assert_eq!(cols, 3);
let float_data = &data[header_len..];
assert_eq!(float_data.len(), 2 * 3 * 4);
let first = f32::from_le_bytes(float_data[0..4].try_into().unwrap());
assert!((first - 1.0).abs() < 1e-6);
}
#[test]
fn test_metadata_propagation() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_meta");
let doc0 = array![[1.0, 0.0]];
let mut meta = HashMap::new();
meta.insert("filepath".to_string(), serde_json::json!("/tmp/page1.png"));
let mut builder = MultiVectorBuilder::new(2);
builder.insert(42, doc0, meta);
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
let query = array![[1.0, 0.0]];
let results = searcher.search(&query, 1).unwrap();
assert_eq!(results[0].doc_id, 42);
assert_eq!(results[0].metadata["filepath"], "/tmp/page1.png");
}
#[test]
fn test_many_docs_ranking() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_many");
let dim = 16;
let mut builder = MultiVectorBuilder::new(dim);
for doc_id in 0..10u32 {
let mut tokens = Array2::<f32>::zeros((3, dim));
for t in 0..3 {
tokens[[t, doc_id as usize]] = 1.0;
tokens[[t, (doc_id as usize + 1) % dim]] = 0.1 * (t as f32);
}
builder.insert(doc_id, tokens, HashMap::new());
}
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
assert_eq!(searcher.num_docs(), 10);
let mut query = Array2::<f32>::zeros((1, dim));
query[[0, 5]] = 1.0;
let results = searcher.search_exact(&query, 3, 30).unwrap();
assert_eq!(results[0].doc_id, 5);
}
#[test]
fn test_multi_token_query_aggregation() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_agg");
let doc0 = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
let doc1 = array![[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.9, 0.1],];
let mut builder = MultiVectorBuilder::new(4);
builder.insert(0, doc0, HashMap::new());
builder.insert(1, doc1, HashMap::new());
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
let query = array![[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0],];
let results = searcher.search_exact(&query, 2, 10).unwrap();
assert_eq!(results[0].doc_id, 0);
assert!((results[0].score - 2.0).abs() < 1e-5);
assert!(results[1].score < 0.2);
}
#[test]
fn test_single_doc_single_token() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_single");
let doc = array![[0.6, 0.8]];
let mut builder = MultiVectorBuilder::new(2);
builder.insert(0, doc, HashMap::new());
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
assert_eq!(searcher.num_docs(), 1);
assert_eq!(searcher.num_tokens(), 1);
let query = array![[0.6, 0.8]];
let results = searcher.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn test_top_k_limits_results() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_topk");
let mut builder = MultiVectorBuilder::new(4);
for i in 0..5u32 {
let doc = array![[1.0, 0.0, 0.0, 0.0]];
builder.insert(i, doc, HashMap::new());
}
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
let query = array![[1.0, 0.0, 0.0, 0.0]];
let results = searcher.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
let results_all = searcher.search(&query, 10).unwrap();
assert_eq!(results_all.len(), 5);
}
#[test]
fn test_variable_token_counts() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_vartok");
let doc0 = array![[1.0, 0.0]]; let doc1 = array![[0.0, 1.0], [0.5, 0.5], [0.3, 0.7]]; let doc2 = array![[0.7, 0.7], [0.8, 0.6]];
let mut builder = MultiVectorBuilder::new(2);
builder.insert(0, doc0, HashMap::new());
builder.insert(1, doc1, HashMap::new());
builder.insert(2, doc2, HashMap::new());
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
assert_eq!(searcher.num_docs(), 3);
assert_eq!(searcher.num_tokens(), 6);
let query = array![[0.0, 1.0]];
let results = searcher.search_exact(&query, 3, 10).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].doc_id, 1);
}
#[test]
fn test_labels_sidecar_format() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_labels");
let doc0 = array![[1.0, 0.0], [0.0, 1.0]];
let doc1 = array![[0.5, 0.5]];
let mut meta0 = HashMap::new();
meta0.insert("page".to_string(), serde_json::json!(1));
let mut builder = MultiVectorBuilder::new(2);
builder.insert(10, doc0, meta0);
builder.insert(20, doc1, HashMap::new());
builder.build(&index_path).unwrap();
let labels_path = with_ext(&index_path, "labels.json");
let data = fs::read_to_string(&labels_path).unwrap();
let labels: Vec<TokenLabel> = serde_json::from_str(&data).unwrap();
assert_eq!(labels.len(), 3);
assert_eq!(labels[0].doc_id, 10);
assert_eq!(labels[0].seq_id, 0);
assert_eq!(labels[0].metadata["page"], 1);
assert_eq!(labels[1].doc_id, 10);
assert_eq!(labels[1].seq_id, 1);
assert_eq!(labels[2].doc_id, 20);
assert_eq!(labels[2].seq_id, 0);
assert!(labels[2].metadata.is_empty());
}
#[test]
fn test_exact_vs_approximate_consistency() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_consistency");
let dim = 8;
let mut builder = MultiVectorBuilder::new(dim);
for i in 0..8u32 {
let mut emb = Array2::<f32>::zeros((1, dim));
emb[[0, i as usize]] = 1.0;
builder.insert(i, emb, HashMap::new());
}
builder.build(&index_path).unwrap();
let searcher = MultiVectorSearcher::open(&index_path).unwrap();
let mut query = Array2::<f32>::zeros((1, dim));
query[[0, 2]] = 1.0;
let exact = searcher.search_exact(&query, 1, 10).unwrap();
assert_eq!(exact[0].doc_id, 2);
assert!((exact[0].score - 1.0).abs() < 1e-5);
let approx = searcher.search(&query, 1).unwrap();
assert_eq!(approx[0].doc_id, 2);
}
#[test]
#[should_panic(expected = "no documents inserted")]
fn test_build_empty_panics() {
let dir = tempfile::tempdir().unwrap();
let index_path = dir.path().join("test_empty");
let builder = MultiVectorBuilder::new(4);
builder.build(&index_path).unwrap();
}
#[test]
#[should_panic(expected = "embedding dim 3 != expected 4")]
fn test_dimension_mismatch_panics() {
let mut builder = MultiVectorBuilder::new(4);
builder.insert(0, array![[1.0, 2.0, 3.0]], HashMap::new());
}
}