#![warn(missing_debug_implementations, missing_docs)]
use diskann::ANNError;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct GeneratePivotArguments {
num_train: usize,
dim: usize,
num_centers: usize,
num_pq_chunks: usize,
max_k_means_reps: usize,
translate_to_center: bool,
}
#[derive(Error, Debug, PartialEq)]
#[non_exhaustive]
#[allow(missing_docs)]
pub enum GeneratePivotArgumentsError {
#[error("number of chunks {num_pq_chunks} more than dimension {dim}")]
NumChunksMoreThanDim { num_pq_chunks: usize, dim: usize },
#[error("invalid number of chunks 0 reatively to dimension")]
NumChunksIsZero,
#[error("vector dimension {0} is greater than i32::MAX_VALUE")]
DimGreaterThanI32MaxValue(usize),
#[error("number of vectors {0} is greater than i32::MAX_VALUE")]
NumTrainGreaterThanI32MaxValue(usize),
}
impl From<GeneratePivotArgumentsError> for ANNError {
#[track_caller]
fn from(value: GeneratePivotArgumentsError) -> Self {
ANNError::log_pq_error(value)
}
}
impl GeneratePivotArguments {
pub fn new(
num_train: usize,
dim: usize,
num_centers: usize,
num_pq_chunks: usize,
max_k_means_reps: usize,
translate_to_center: bool,
) -> Result<Self, GeneratePivotArgumentsError> {
if num_pq_chunks > dim {
return Err(GeneratePivotArgumentsError::NumChunksMoreThanDim { num_pq_chunks, dim });
}
if num_pq_chunks == 0 {
return Err(GeneratePivotArgumentsError::NumChunksIsZero);
}
if dim > (i32::MAX as usize) {
return Err(GeneratePivotArgumentsError::DimGreaterThanI32MaxValue(dim));
}
if num_train > (i32::MAX as usize) {
return Err(GeneratePivotArgumentsError::NumTrainGreaterThanI32MaxValue(
num_train,
));
}
Ok(Self {
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
translate_to_center,
})
}
pub fn num_train(&self) -> usize {
self.num_train
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn num_centers(&self) -> usize {
self.num_centers
}
pub fn num_pq_chunks(&self) -> usize {
self.num_pq_chunks
}
pub fn max_k_means_reps(&self) -> usize {
self.max_k_means_reps
}
pub fn translate_to_center(&self) -> bool {
self.translate_to_center
}
}
#[cfg(test)]
mod arguments_test {
use diskann::{ANNErrorKind, ANNResult};
use super::*;
#[test]
fn num_chunks_exceeds_dim() {
let num_train = 10;
let dim = 5;
let num_centers = 2;
let num_pq_chunks = dim + 1; let max_k_means_reps = 10;
let result = GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
true,
);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
GeneratePivotArgumentsError::NumChunksMoreThanDim { num_pq_chunks, dim }
);
}
#[test]
fn num_chunks_is_zero() {
let num_train = 10;
let dim = 5;
let num_centers = 2;
let num_pq_chunks = 0; let max_k_means_reps = 10;
let result = GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
true,
);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
GeneratePivotArgumentsError::NumChunksIsZero
);
}
#[test]
fn num_dim_exceeds_i32_max() {
let num_train = 10;
let dim = i32::MAX as usize + 1; let num_centers = 2;
let num_pq_chunks = 2;
let max_k_means_reps = 10;
let result = GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
true,
);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
GeneratePivotArgumentsError::DimGreaterThanI32MaxValue(dim)
);
}
#[test]
fn num_train_exceeds_i32_max() {
let num_train = i32::MAX as usize + 1; let dim = 5;
let num_centers = 2;
let num_pq_chunks = 2;
let max_k_means_reps = 10;
let result = GeneratePivotArguments::new(
num_train,
dim,
num_centers,
num_pq_chunks,
max_k_means_reps,
true,
);
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
GeneratePivotArgumentsError::NumTrainGreaterThanI32MaxValue(num_train)
);
}
#[test]
fn compatibility_with_ann_error_test() {
let result = compatibility_helper();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ANNErrorKind::PQError,);
}
fn compatibility_helper() -> ANNResult<()> {
Err(GeneratePivotArgumentsError::NumChunksIsZero)?
}
}