use std::collections::HashMap;
use std::mem::size_of;
use std::sync::Arc;
use error_forge::ForgeError;
use iqdb_index::{Index, IndexCore, IndexStats};
use iqdb_quantize::{PqCode, ProductQuantizer, Quantizer};
use iqdb_types::{DistanceMetric, Hit, IqdbError, Metadata, Result, SearchParams, VectorId};
use crate::assign::assign_to_cluster;
use crate::config::IvfConfig;
use crate::pq_variant;
use crate::rng::SplitMix64;
use crate::search;
use crate::stats::IvfClusterStats;
use crate::train::{subsample_refs, train_kmeans};
#[derive(Debug)]
pub(crate) struct InvertedListEntry {
id: VectorId,
vector: Arc<[f32]>,
pq_code: Option<PqCode>,
metadata: Option<Metadata>,
seq: u64,
}
impl InvertedListEntry {
pub(crate) fn id(&self) -> &VectorId {
&self.id
}
pub(crate) fn vector_slice(&self) -> &[f32] {
&self.vector
}
pub(crate) fn metadata(&self) -> Option<&Metadata> {
self.metadata.as_ref()
}
pub(crate) fn seq(&self) -> u64 {
self.seq
}
pub(crate) fn pq_code_or_err(&self) -> Result<&PqCode> {
self.pq_code.as_ref().ok_or(IqdbError::InvalidConfig {
reason: "IvfIndex entry is missing its PqCode in IVF-PQ mode (index invariant violated)",
})
}
}
#[derive(Debug)]
pub struct IvfIndex {
dim: usize,
metric: DistanceMetric,
cfg: IvfConfig,
trained: bool,
centroids: Vec<Vec<f32>>,
inverted_lists: Vec<Vec<InvertedListEntry>>,
id_to_cluster: HashMap<VectorId, usize>,
next_seq: u64,
live_count: usize,
pq: Option<ProductQuantizer>,
}
impl IvfIndex {
pub fn new_unconfigured(dim: usize, metric: DistanceMetric, cfg: IvfConfig) -> Result<Self> {
if dim == 0 {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex dim must be greater than zero",
});
}
cfg.validate()?;
if cfg.use_pq {
match metric {
DistanceMetric::Euclidean
| DistanceMetric::DotProduct
| DistanceMetric::Manhattan => {}
_ => return Err(IqdbError::InvalidMetric),
}
let m = cfg.pq_subvectors.unwrap_or(0);
if m == 0 || !dim.is_multiple_of(m) {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.pq_subvectors must divide IvfIndex dim",
});
}
}
let mut inverted_lists: Vec<Vec<InvertedListEntry>> = Vec::with_capacity(cfg.n_clusters);
for _ in 0..cfg.n_clusters {
inverted_lists.push(Vec::new());
}
Ok(Self {
dim,
metric,
cfg,
trained: false,
centroids: Vec::new(),
inverted_lists,
id_to_cluster: HashMap::new(),
next_seq: 0,
live_count: 0,
pq: None,
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[must_use]
pub fn len(&self) -> usize {
self.live_count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.live_count == 0
}
#[must_use]
pub fn is_trained(&self) -> bool {
self.trained
}
#[must_use]
pub fn n_probes(&self) -> usize {
self.cfg.n_probes
}
pub fn set_n_probes(&mut self, n: usize) -> Result<()> {
if n == 0 {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex::set_n_probes requires n >= 1",
});
}
if n > self.cfg.n_clusters {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex::set_n_probes requires n <= n_clusters",
});
}
self.cfg.n_probes = n;
Ok(())
}
#[must_use]
pub fn pq_refine_factor(&self) -> u32 {
self.cfg.pq_refine_factor
}
pub fn set_pq_refine_factor(&mut self, factor: u32) {
self.cfg.pq_refine_factor = factor;
}
#[must_use]
pub fn cluster_stats(&self) -> IvfClusterStats {
let sizes: Vec<usize> = if self.trained {
self.inverted_lists.iter().map(|l| l.len()).collect()
} else {
Vec::new()
};
IvfClusterStats::from_sizes(self.cfg.n_clusters, sizes)
}
pub fn train(&mut self, sample: &[&[f32]]) -> Result<()> {
if self.trained {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex is already trained; use retrain() to rebuild centroids",
});
}
let centroids = train_kmeans(
self.dim,
self.cfg.n_clusters,
self.cfg.seed,
sample,
self.cfg.training_sample_size,
)?;
debug_assert_eq!(centroids.len(), self.cfg.n_clusters);
self.centroids = centroids;
if self.cfg.use_pq {
let pq = pq_variant::train_pq(&self.cfg, sample)?;
self.pq = Some(pq);
}
self.trained = true;
Ok(())
}
pub fn retrain(&mut self) -> Result<()> {
if !self.trained {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex must be trained before retrain()",
});
}
if self.live_count == 0 {
return Ok(());
}
let mut snapshot: Vec<InvertedListEntry> = Vec::with_capacity(self.live_count);
for list in self.inverted_lists.iter_mut() {
let taken = std::mem::take(list);
for entry in taken {
snapshot.push(entry);
}
}
debug_assert!(self.inverted_lists.iter().all(|l| l.is_empty()));
self.id_to_cluster.clear();
self.live_count = 0;
snapshot.sort_by_key(|e| e.seq);
let all_refs: Vec<&[f32]> = snapshot.iter().map(|e| e.vector_slice()).collect();
let target_len = self.cfg.training_sample_size.min(all_refs.len());
let mut rng = SplitMix64::new(self.cfg.seed);
let capped: Vec<&[f32]> = subsample_refs(&all_refs, target_len, &mut rng);
let centroids = train_kmeans(
self.dim,
self.cfg.n_clusters,
self.cfg.seed,
&capped,
capped.len(),
)?;
debug_assert_eq!(centroids.len(), self.cfg.n_clusters);
self.centroids = centroids;
if self.cfg.use_pq {
let pq = pq_variant::train_pq(&self.cfg, &capped)?;
self.pq = Some(pq);
}
for mut entry in snapshot {
let cluster = assign_to_cluster(&self.centroids, &entry.vector);
if let Some(pq) = self.pq.as_ref() {
entry.pq_code = Some(pq.quantize(&entry.vector)?);
}
let id = entry.id.clone();
let _prev = self.id_to_cluster.insert(id, cluster);
self.inverted_lists[cluster].push(entry);
self.live_count += 1;
}
Ok(())
}
pub fn suggest_n_probes(&self, coverage: f32) -> Result<usize> {
if !coverage.is_finite() || !(0.0..=1.0).contains(&coverage) {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex::suggest_n_probes requires coverage in [0.0, 1.0]",
});
}
if !self.trained || self.live_count == 0 || self.cfg.n_clusters == 1 {
return Ok(1);
}
if coverage == 0.0 {
return Ok(1);
}
if coverage == 1.0 {
return Ok(self.cfg.n_clusters);
}
let mut sizes: Vec<usize> = self.inverted_lists.iter().map(|l| l.len()).collect();
sizes.sort_by(|a, b| b.cmp(a));
let target_f = (self.live_count as f64) * (coverage as f64);
let target = target_f.ceil() as usize;
let mut cumsum: usize = 0;
for (i, &s) in sizes.iter().enumerate() {
cumsum = cumsum.saturating_add(s);
if cumsum >= target {
let n = i + 1;
return Ok(n.clamp(1, self.cfg.n_clusters));
}
}
Ok(self.cfg.n_clusters)
}
pub(crate) fn centroids_slice(&self) -> &[Vec<f32>] {
&self.centroids
}
pub(crate) fn inverted_list(&self, cluster: usize) -> &[InvertedListEntry] {
&self.inverted_lists[cluster]
}
pub(crate) fn cfg(&self) -> &IvfConfig {
&self.cfg
}
pub(crate) fn pq(&self) -> Option<&ProductQuantizer> {
self.pq.as_ref()
}
fn check_dim(&self, vector_len: usize) -> Result<()> {
if vector_len != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: vector_len,
});
}
Ok(())
}
fn require_trained(&self) -> Result<()> {
if !self.trained {
return Err(IqdbError::InvalidConfig {
reason: "IvfIndex must be trained before use",
});
}
Ok(())
}
fn approximate_memory_bytes(&self) -> usize {
let arc_header_bytes = 2 * size_of::<usize>();
let centroid_bytes: usize = self
.centroids
.iter()
.map(|c| c.capacity() * size_of::<f32>())
.sum::<usize>()
+ self.centroids.capacity() * size_of::<Vec<f32>>();
let mut list_bytes: usize = 0;
for list in &self.inverted_lists {
list_bytes += list.capacity() * size_of::<InvertedListEntry>();
for entry in list {
list_bytes += entry.vector.len() * size_of::<f32>() + arc_header_bytes;
}
}
let id_to_cluster_bytes =
self.id_to_cluster.capacity() * (size_of::<VectorId>() + size_of::<usize>());
centroid_bytes + list_bytes + id_to_cluster_bytes
}
fn insert_inner(
&mut self,
id: VectorId,
vector: Arc<[f32]>,
metadata: Option<Metadata>,
) -> Result<()> {
self.require_trained()?;
self.check_dim(vector.len())?;
if self.id_to_cluster.contains_key(&id) {
return Err(IqdbError::Duplicate);
}
let seq = self.next_seq;
self.next_seq = self
.next_seq
.checked_add(1)
.ok_or(IqdbError::InvalidConfig {
reason: "IvfIndex insertion sequence counter overflowed u64",
})?;
let cluster = assign_to_cluster(&self.centroids, &vector);
let pq_code = match self.pq.as_ref() {
Some(pq) => Some(pq.quantize(&vector)?),
None => None,
};
let _prev = self.id_to_cluster.insert(id.clone(), cluster);
self.inverted_lists[cluster].push(InvertedListEntry {
id,
vector,
pq_code,
metadata,
seq,
});
self.live_count += 1;
Ok(())
}
fn delete_inner(&mut self, id: &VectorId) -> Result<()> {
let cluster = self.id_to_cluster.remove(id).ok_or(IqdbError::NotFound)?;
let list = &mut self.inverted_lists[cluster];
let pos = list
.iter()
.position(|e| &e.id == id)
.ok_or(IqdbError::NotFound)?;
let _entry = list.swap_remove(pos);
self.live_count -= 1;
Ok(())
}
fn search_inner(&self, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
self.require_trained()?;
search::ivf_search(self, query, params)
}
}
impl IndexCore for IvfIndex {
#[tracing::instrument(
level = "debug",
skip_all,
fields(vector_id = %id, n = self.live_count, dim = self.dim),
)]
fn insert(
&mut self,
id: VectorId,
vector: Arc<[f32]>,
metadata: Option<Metadata>,
) -> Result<()> {
match self.insert_inner(id, vector, metadata) {
Ok(()) => Ok(()),
Err(err) => {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
"ivf insert failed",
);
Err(err)
}
}
}
#[tracing::instrument(
level = "debug",
skip_all,
fields(vector_id = %id, n = self.live_count),
)]
fn delete(&mut self, id: &VectorId) -> Result<()> {
match self.delete_inner(id) {
Ok(()) => Ok(()),
Err(err) => {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
"ivf delete failed",
);
Err(err)
}
}
}
#[tracing::instrument(
level = "debug",
skip_all,
fields(
k = params.k,
dim = self.dim,
n = self.live_count,
filter = params.filter.is_some(),
metric = ?params.metric,
),
)]
fn search(&self, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
match self.search_inner(query, params) {
Ok(hits) => Ok(hits),
Err(err) => {
tracing::error!(
error.kind = err.kind(),
error.reason = err.caption(),
"ivf search failed",
);
Err(err)
}
}
}
fn len(&self) -> usize {
IvfIndex::len(self)
}
fn is_empty(&self) -> bool {
IvfIndex::is_empty(self)
}
fn dim(&self) -> usize {
IvfIndex::dim(self)
}
fn metric(&self) -> DistanceMetric {
IvfIndex::metric(self)
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.live_count,
memory_bytes: self.approximate_memory_bytes(),
disk_bytes: None,
index_type: "ivf",
extra: None,
}
}
}
impl Index for IvfIndex {
type Config = IvfConfig;
fn new(dim: usize, metric: DistanceMetric, config: Self::Config) -> Result<Self> {
Self::new_unconfigured(dim, metric, config)
}
}