use anyhow::Result;
use crate::core::schema::{DistanceMetric, VectorIndexType};
use crate::muvera::DEFAULT_FDE_SEED;
#[derive(Debug, Default, Clone)]
pub struct VectorIndexOpts<'a> {
pub type_name: Option<&'a str>,
pub partitions: Option<u32>,
pub m: Option<u32>,
pub ef_construction: Option<u32>,
pub sub_vectors: Option<u32>,
pub num_bits: Option<u8>,
pub k_sim: Option<u32>,
pub reps: Option<u32>,
pub d_proj: Option<u32>,
pub seed: Option<u64>,
pub inner: Option<&'a str>,
}
fn ann_type(o: &VectorIndexOpts, t: Option<&str>) -> VectorIndexType {
match t {
Some("flat") => VectorIndexType::Flat,
Some("ivf_flat") => VectorIndexType::IvfFlat {
num_partitions: o.partitions.unwrap_or(256),
},
Some("ivf_sq") => VectorIndexType::IvfSq {
num_partitions: o.partitions.unwrap_or(256),
},
Some("ivf_rq") => VectorIndexType::IvfRq {
num_partitions: o.partitions.unwrap_or(256),
num_bits: o.num_bits,
},
Some("hnsw_flat") => VectorIndexType::HnswFlat {
m: o.m.unwrap_or(16),
ef_construction: o.ef_construction.unwrap_or(200),
num_partitions: o.partitions,
},
Some("hnsw") | Some("hnsw_sq") => VectorIndexType::HnswSq {
m: o.m.unwrap_or(16),
ef_construction: o.ef_construction.unwrap_or(200),
num_partitions: o.partitions,
},
Some("hnsw_pq") => VectorIndexType::HnswPq {
m: o.m.unwrap_or(16),
ef_construction: o.ef_construction.unwrap_or(200),
num_sub_vectors: o.sub_vectors.unwrap_or(16),
num_partitions: o.partitions,
},
_ => VectorIndexType::IvfPq {
num_partitions: o.partitions.unwrap_or(256),
num_sub_vectors: o.sub_vectors.unwrap_or(16),
bits_per_subvector: o.num_bits.unwrap_or(8),
},
}
}
pub fn build_vector_index_type(o: &VectorIndexOpts) -> VectorIndexType {
match o.type_name {
Some("muvera") => VectorIndexType::Muvera {
k_sim: o.k_sim.unwrap_or(4),
reps: o.reps.unwrap_or(20),
d_proj: o.d_proj.unwrap_or(16),
seed: o.seed.unwrap_or(DEFAULT_FDE_SEED),
inner: Box::new(ann_type(o, o.inner)),
},
other => ann_type(o, other),
}
}
pub fn parse_vector_metric(s: Option<&str>) -> Result<DistanceMetric> {
match s.map(|m| m.to_ascii_lowercase()).as_deref() {
Some("l2") | Some("euclidean") => Ok(DistanceMetric::L2),
Some("dot") => Ok(DistanceMetric::Dot),
Some("cosine") | None => Ok(DistanceMetric::Cosine),
Some(other) => Err(anyhow::anyhow!(
"Unknown vector index metric '{other}' (expected cosine, l2, or dot)"
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn opts(type_name: Option<&str>) -> VectorIndexOpts<'_> {
VectorIndexOpts {
type_name,
..Default::default()
}
}
#[test]
fn default_is_ivf_pq_for_both_paths() {
assert!(matches!(
build_vector_index_type(&opts(None)),
VectorIndexType::IvfPq { .. }
));
assert!(matches!(
build_vector_index_type(&opts(Some("nonsense"))),
VectorIndexType::IvfPq { .. }
));
}
#[test]
fn named_types_map() {
assert!(matches!(
build_vector_index_type(&opts(Some("flat"))),
VectorIndexType::Flat
));
assert!(matches!(
build_vector_index_type(&opts(Some("hnsw"))),
VectorIndexType::HnswSq { .. }
));
}
#[test]
fn muvera_defaults_and_inner() {
let o = VectorIndexOpts {
type_name: Some("muvera"),
inner: Some("flat"),
..Default::default()
};
match build_vector_index_type(&o) {
VectorIndexType::Muvera {
k_sim,
reps,
d_proj,
seed,
inner,
} => {
assert_eq!((k_sim, reps, d_proj), (4, 20, 16));
assert_eq!(seed, DEFAULT_FDE_SEED);
assert!(matches!(*inner, VectorIndexType::Flat));
}
other => panic!("expected Muvera, got {other:?}"),
}
assert!(matches!(
build_vector_index_type(&opts(Some("muvera"))),
VectorIndexType::Muvera { inner, .. } if matches!(*inner, VectorIndexType::IvfPq { .. })
));
}
#[test]
fn metric_parsing() {
assert_eq!(parse_vector_metric(None).unwrap(), DistanceMetric::Cosine);
assert_eq!(parse_vector_metric(Some("L2")).unwrap(), DistanceMetric::L2);
assert_eq!(
parse_vector_metric(Some("dot")).unwrap(),
DistanceMetric::Dot
);
assert!(parse_vector_metric(Some("hamming")).is_err());
}
}