use serde::{Deserialize, Serialize};
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
)]
pub enum PruningStrategy {
Alpha,
Robust,
Hybrid,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
)]
pub enum SearchMode {
InMemory,
Streaming,
Cached,
}
#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct DiskAnnConfig {
pub dimension: usize,
pub max_degree: usize,
pub build_beam_width: usize,
pub search_beam_width: usize,
pub alpha: f32,
pub pruning_strategy: PruningStrategy,
pub search_mode: SearchMode,
pub max_vectors_in_memory: Option<usize>,
pub use_pq_compression: bool,
pub pq_subvectors: Option<usize>,
pub pq_bits: Option<u8>,
pub enable_incremental_updates: bool,
pub num_entry_points: usize,
pub io_buffer_size: usize,
}
impl DiskAnnConfig {
pub fn default_config(dimension: usize) -> Self {
Self {
dimension,
max_degree: 64,
build_beam_width: 100,
search_beam_width: 75,
alpha: 1.2,
pruning_strategy: PruningStrategy::Robust,
search_mode: SearchMode::Cached,
max_vectors_in_memory: Some(100_000),
use_pq_compression: false,
pq_subvectors: None,
pq_bits: None,
enable_incremental_updates: false,
num_entry_points: 1,
io_buffer_size: 1 << 20, }
}
pub fn memory_optimized(dimension: usize) -> Self {
Self {
dimension,
max_degree: 32,
build_beam_width: 75,
search_beam_width: 50,
alpha: 1.2,
pruning_strategy: PruningStrategy::Robust,
search_mode: SearchMode::Streaming,
max_vectors_in_memory: Some(10_000),
use_pq_compression: true,
pq_subvectors: Some(dimension / 16),
pq_bits: Some(8),
enable_incremental_updates: false,
num_entry_points: 1,
io_buffer_size: 512 * 1024, }
}
pub fn speed_optimized(dimension: usize) -> Self {
Self {
dimension,
max_degree: 96,
build_beam_width: 150,
search_beam_width: 100,
alpha: 1.2,
pruning_strategy: PruningStrategy::Alpha,
search_mode: SearchMode::InMemory,
max_vectors_in_memory: Some(1_000_000),
use_pq_compression: false,
pq_subvectors: None,
pq_bits: None,
enable_incremental_updates: true,
num_entry_points: 4,
io_buffer_size: 4 << 20, }
}
pub fn billion_scale(dimension: usize) -> Self {
Self {
dimension,
max_degree: 64,
build_beam_width: 100,
search_beam_width: 64,
alpha: 1.2,
pruning_strategy: PruningStrategy::Robust,
search_mode: SearchMode::Streaming,
max_vectors_in_memory: Some(50_000),
use_pq_compression: true,
pq_subvectors: Some(dimension / 8),
pq_bits: Some(8),
enable_incremental_updates: false,
num_entry_points: 8,
io_buffer_size: 2 << 20, }
}
pub fn validate(&self) -> Result<(), String> {
if self.dimension == 0 {
return Err("Dimension must be greater than 0".to_string());
}
if self.max_degree == 0 {
return Err("Max degree must be greater than 0".to_string());
}
if self.build_beam_width == 0 {
return Err("Build beam width must be greater than 0".to_string());
}
if self.search_beam_width == 0 {
return Err("Search beam width must be greater than 0".to_string());
}
if self.alpha <= 0.0 {
return Err("Alpha must be positive".to_string());
}
if self.use_pq_compression {
if self.pq_subvectors.is_none() {
return Err(
"PQ subvectors must be specified when compression is enabled".to_string(),
);
}
if self.pq_bits.is_none() {
return Err("PQ bits must be specified when compression is enabled".to_string());
}
let pq_subvectors = self
.pq_subvectors
.expect("pq_subvectors validated as Some above");
if self.dimension % pq_subvectors != 0 {
return Err(format!(
"Dimension {} must be divisible by PQ subvectors {}",
self.dimension, pq_subvectors
));
}
let pq_bits = self.pq_bits.expect("pq_bits validated as Some above");
if pq_bits == 0 || pq_bits > 16 {
return Err("PQ bits must be between 1 and 16".to_string());
}
}
if self.num_entry_points == 0 {
return Err("Number of entry points must be greater than 0".to_string());
}
if self.io_buffer_size == 0 {
return Err("IO buffer size must be greater than 0".to_string());
}
Ok(())
}
pub fn estimate_memory_usage(&self, num_vectors: usize) -> usize {
let graph_memory = num_vectors * (4 + self.max_degree * 4);
let vector_memory = if self.use_pq_compression {
let pq_subvectors = self.pq_subvectors.unwrap_or(self.dimension / 8);
let pq_bits = self.pq_bits.unwrap_or(8);
let bytes_per_code = (pq_bits as usize + 7) / 8;
num_vectors * pq_subvectors * bytes_per_code
} else {
num_vectors * self.dimension * 4 };
let inmem_vectors = self
.max_vectors_in_memory
.unwrap_or(num_vectors)
.min(num_vectors);
let inmem_memory = inmem_vectors * self.dimension * 4;
graph_memory + vector_memory + inmem_memory + self.io_buffer_size
}
}
impl Default for DiskAnnConfig {
fn default() -> Self {
Self::default_config(128)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = DiskAnnConfig::default_config(128);
assert_eq!(config.dimension, 128);
assert_eq!(config.max_degree, 64);
assert!(config.validate().is_ok());
}
#[test]
fn test_memory_optimized() {
let config = DiskAnnConfig::memory_optimized(256);
assert_eq!(config.dimension, 256);
assert!(config.use_pq_compression);
assert_eq!(config.search_mode, SearchMode::Streaming);
assert!(config.validate().is_ok());
}
#[test]
fn test_speed_optimized() {
let config = DiskAnnConfig::speed_optimized(512);
assert_eq!(config.dimension, 512);
assert!(!config.use_pq_compression);
assert_eq!(config.search_mode, SearchMode::InMemory);
assert!(config.validate().is_ok());
}
#[test]
fn test_billion_scale() {
let config = DiskAnnConfig::billion_scale(768);
assert_eq!(config.dimension, 768);
assert!(config.use_pq_compression);
assert_eq!(config.search_mode, SearchMode::Streaming);
assert!(config.validate().is_ok());
}
#[test]
fn test_validation() {
let mut config = DiskAnnConfig::default_config(128);
assert!(config.validate().is_ok());
config.dimension = 0;
assert!(config.validate().is_err());
config = DiskAnnConfig::default_config(128);
config.max_degree = 0;
assert!(config.validate().is_err());
config = DiskAnnConfig::default_config(128);
config.use_pq_compression = true;
assert!(config.validate().is_err()); }
#[test]
fn test_memory_estimation() {
let config = DiskAnnConfig::default_config(128);
let memory = config.estimate_memory_usage(1_000_000);
assert!(memory > 0);
let pq_config = DiskAnnConfig::memory_optimized(128);
let pq_memory = pq_config.estimate_memory_usage(1_000_000);
assert!(pq_memory < memory); }
#[test]
fn test_pq_validation() {
let mut config = DiskAnnConfig::default_config(128);
config.use_pq_compression = true;
config.pq_subvectors = Some(16);
config.pq_bits = Some(8);
assert!(config.validate().is_ok());
config.pq_subvectors = Some(15); assert!(config.validate().is_err());
config.pq_subvectors = Some(16);
config.pq_bits = Some(0);
assert!(config.validate().is_err());
config.pq_bits = Some(20); assert!(config.validate().is_err());
}
}