use iqdb_types::{IqdbError, Result};
const DEFAULT_N_CLUSTERS: usize = 256;
const DEFAULT_N_PROBES: usize = 8;
const DEFAULT_TRAINING_SAMPLE_SIZE: usize = 65_536;
const DEFAULT_SEED: u64 = 0xDEAD_BEEF_CAFE_F00D;
const DEFAULT_PQ_REFINE_FACTOR: u32 = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct IvfConfig {
pub n_clusters: usize,
pub n_probes: usize,
pub training_sample_size: usize,
pub use_pq: bool,
pub pq_subvectors: Option<usize>,
pub pq_refine_factor: u32,
pub seed: u64,
}
impl IvfConfig {
#[must_use]
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.n_clusters = n_clusters;
self
}
#[must_use]
pub fn with_n_probes(mut self, n_probes: usize) -> Self {
self.n_probes = n_probes;
self
}
#[must_use]
pub fn with_training_sample_size(mut self, training_sample_size: usize) -> Self {
self.training_sample_size = training_sample_size;
self
}
#[must_use]
pub fn with_use_pq(mut self, use_pq: bool) -> Self {
self.use_pq = use_pq;
self
}
#[must_use]
pub fn with_pq_subvectors(mut self, pq_subvectors: Option<usize>) -> Self {
self.pq_subvectors = pq_subvectors;
self
}
#[must_use]
pub fn with_pq_refine_factor(mut self, pq_refine_factor: u32) -> Self {
self.pq_refine_factor = pq_refine_factor;
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn validate(&self) -> Result<()> {
if self.n_clusters == 0 {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.n_clusters must be greater than zero",
});
}
if self.n_probes == 0 {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.n_probes must be greater than zero",
});
}
if self.n_probes > self.n_clusters {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.n_probes must be <= n_clusters",
});
}
if self.training_sample_size == 0 {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.training_sample_size must be greater than zero",
});
}
if self.use_pq {
match self.pq_subvectors {
Some(m) if m >= 1 => {}
Some(_) => {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.pq_subvectors must be >= 1 when use_pq = true",
});
}
None => {
return Err(IqdbError::InvalidConfig {
reason: "IvfConfig.use_pq = true requires pq_subvectors = Some(_)",
});
}
}
}
Ok(())
}
}
impl Default for IvfConfig {
fn default() -> Self {
Self {
n_clusters: DEFAULT_N_CLUSTERS,
n_probes: DEFAULT_N_PROBES,
training_sample_size: DEFAULT_TRAINING_SAMPLE_SIZE,
use_pq: false,
pq_subvectors: None,
pq_refine_factor: DEFAULT_PQ_REFINE_FACTOR,
seed: DEFAULT_SEED,
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn default_values_are_the_documented_operating_point() {
let cfg = IvfConfig::default();
assert_eq!(cfg.n_clusters, 256);
assert_eq!(cfg.n_probes, 8);
assert_eq!(cfg.training_sample_size, 65_536);
assert!(!cfg.use_pq);
assert_eq!(cfg.pq_subvectors, None);
assert_eq!(cfg.pq_refine_factor, 4);
assert_eq!(cfg.seed, 0xDEAD_BEEF_CAFE_F00D);
}
#[test]
fn with_helpers_compose() {
let cfg = IvfConfig::default()
.with_n_clusters(16)
.with_n_probes(4)
.with_training_sample_size(1_024)
.with_seed(42);
assert_eq!(cfg.n_clusters, 16);
assert_eq!(cfg.n_probes, 4);
assert_eq!(cfg.training_sample_size, 1_024);
assert_eq!(cfg.seed, 42);
}
#[test]
fn validate_accepts_defaults() {
assert!(IvfConfig::default().validate().is_ok());
}
#[test]
fn validate_rejects_zero_n_clusters() {
let err = IvfConfig::default()
.with_n_clusters(0)
.validate()
.unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("n_clusters"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn validate_rejects_zero_n_probes() {
let err = IvfConfig::default()
.with_n_probes(0)
.validate()
.unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("n_probes"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn validate_rejects_n_probes_exceeding_n_clusters() {
let err = IvfConfig::default()
.with_n_clusters(4)
.with_n_probes(8)
.validate()
.unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("n_probes"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn validate_rejects_zero_training_sample_size() {
let err = IvfConfig::default()
.with_training_sample_size(0)
.validate()
.unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("training_sample_size"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn validate_rejects_use_pq_true_without_pq_subvectors() {
let err = IvfConfig::default()
.with_use_pq(true)
.validate()
.unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("pq_subvectors"));
assert!(reason.contains("Some"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn validate_rejects_use_pq_true_with_zero_pq_subvectors() {
let err = IvfConfig::default()
.with_use_pq(true)
.with_pq_subvectors(Some(0))
.validate()
.unwrap_err();
match err {
IqdbError::InvalidConfig { reason } => {
assert!(reason.contains("pq_subvectors"));
assert!(reason.contains(">= 1"));
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn validate_accepts_use_pq_true_with_valid_pq_subvectors() {
let cfg = IvfConfig::default()
.with_use_pq(true)
.with_pq_subvectors(Some(8));
assert!(cfg.validate().is_ok());
}
#[test]
fn validate_accepts_pq_refine_factor_zero() {
let cfg = IvfConfig::default()
.with_use_pq(true)
.with_pq_subvectors(Some(8))
.with_pq_refine_factor(0);
assert!(cfg.validate().is_ok());
}
#[test]
fn with_pq_refine_factor_sets_field() {
let cfg = IvfConfig::default().with_pq_refine_factor(16);
assert_eq!(cfg.pq_refine_factor, 16);
}
}