use std::io;
use std::path::Path;
use crate::Builder;
use crate::distance::Distance;
use crate::hnsw::{Config, Hnsw};
use crate::persist;
pub struct PairedResult<'a> {
pub id: usize,
pub distance: f32,
pub emb_a: &'a [f32],
pub emb_b: &'a [f32],
}
pub struct PairedIndex<A: Distance, B: Distance> {
pub index_a: Hnsw<A>,
pub index_b: Hnsw<B>,
}
impl<A: Distance, B: Distance> PairedIndex<A, B> {
pub fn new(
config_a: Config, metric_a: A,
config_b: Config, metric_b: B,
) -> Self {
Self {
index_a: Hnsw::new(config_a, metric_a),
index_b: Hnsw::new(config_b, metric_b),
}
}
pub fn from_builder(builder: Builder, metric_a: A, metric_b: B) -> Self {
let cfg = builder.into_config();
Self {
index_a: Hnsw::new(cfg.clone(), metric_a),
index_b: Hnsw::new(cfg, metric_b),
}
}
pub fn insert(&mut self, emb_a: Vec<f32>, emb_b: Vec<f32>) -> usize {
let id_a = self.index_a.insert(emb_a);
let id_b = self.index_b.insert(emb_b);
debug_assert_eq!(id_a, id_b, "PairedIndex: side-A and side-B id mismatch");
id_a
}
pub fn search_by_a<'a>(
&'a self,
query: &[f32],
k: usize,
ef: usize,
) -> Vec<PairedResult<'a>> {
self.index_a
.search(query, k, ef)
.into_iter()
.map(|sr| PairedResult {
id: sr.id,
distance: sr.distance,
emb_a: self.index_a.get_vector(sr.id),
emb_b: self.index_b.get_vector(sr.id),
})
.collect()
}
pub fn search_by_b<'a>(
&'a self,
query: &[f32],
k: usize,
ef: usize,
) -> Vec<PairedResult<'a>> {
self.index_b
.search(query, k, ef)
.into_iter()
.map(|sr| PairedResult {
id: sr.id,
distance: sr.distance,
emb_a: self.index_a.get_vector(sr.id),
emb_b: self.index_b.get_vector(sr.id),
})
.collect()
}
pub fn get_emb_a(&self, id: usize) -> &[f32] { self.index_a.get_vector(id) }
pub fn get_emb_b(&self, id: usize) -> &[f32] { self.index_b.get_vector(id) }
pub fn len(&self) -> usize { self.index_a.len() }
pub fn is_empty(&self) -> bool { self.index_a.is_empty() }
pub fn save(&self, base_path: impl AsRef<Path>) -> io::Result<()> {
let base = base_path.as_ref();
let path_a = side_path(base, 'a');
let path_b = side_path(base, 'b');
persist::save(&self.index_a, &path_a)?;
persist::save(&self.index_b, &path_b)?;
Ok(())
}
pub fn load(
base_path: impl AsRef<Path>,
metric_a: A,
metric_b: B,
) -> io::Result<Self> {
let base = base_path.as_ref();
let index_a = persist::load(side_path(base, 'a'), metric_a)?;
let index_b = persist::load(side_path(base, 'b'), metric_b)?;
if index_a.len() != index_b.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"side-A has {} vectors but side-B has {} — mismatched files?",
index_a.len(), index_b.len()
),
));
}
Ok(Self { index_a, index_b })
}
pub fn load_mmap(
base_path: impl AsRef<Path>,
metric_a: A,
metric_b: B,
) -> io::Result<Self> {
let base = base_path.as_ref();
let index_a = persist::load_mmap(side_path(base, 'a'), metric_a)?;
let index_b = persist::load_mmap(side_path(base, 'b'), metric_b)?;
Ok(Self { index_a, index_b })
}
}
impl Builder {
pub fn build_paired<A: Distance, B: Distance>(
self,
metric_a: A,
metric_b: B,
) -> PairedIndex<A, B> {
PairedIndex::from_builder(self, metric_a, metric_b)
}
}
fn side_path(base: &Path, side: char) -> std::path::PathBuf {
let mut s = base.as_os_str().to_owned();
s.push(format!("_{side}.hnsw"));
std::path::PathBuf::from(s)
}