use crate::storage::StorageReadProvider;
use diskann::{ANNError, ANNResult, error::IntoANNResult, utils::VectorRepr};
use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
use rand::Rng;
use rand_distr::{Distribution, StandardUniform};
use tracing::info;
use crate::utils::{
VectorDataIterator, load_metadata_from_file,
sampling::{SampleVectorReader, SamplingDensity},
};
pub const MAX_MEDOID_SAMPLE_SIZE: usize = 50_000;
fn calculate_centroid<T, Iter>(
iter: Iter,
dimension: usize,
num_points: usize,
) -> ANNResult<Vec<f32>>
where
T: VectorRepr + Copy,
Iter: Iterator<Item = (Box<[T]>, ())>,
{
let mut result = vec![0.0_f32; dimension];
for (v, _) in iter {
let vector = T::as_f32(&v).map_err(|x| x.into())?;
if vector.len() != dimension {
return Err(ANNError::log_index_error(
"Vector f32 dimension doesn't match input dim.",
));
}
for j in 0..vector.len() {
result[j] += vector[j];
}
}
if num_points == 0 {
Ok(result)
} else {
Ok(result
.into_iter()
.map(|item| item / num_points as f32)
.collect())
}
}
fn calculate_centroid_with_sampling<T, Reader>(
path: &str,
reader: &Reader,
sampling_rate: f64,
rng: &mut impl Rng,
) -> ANNResult<Vec<f32>>
where
T: VectorRepr,
Reader: StorageReadProvider,
{
let mut sample_reader = SampleVectorReader::<T, _>::new(
path,
SamplingDensity::from_sample_rate(sampling_rate),
reader,
)?;
let (npts, dim) = sample_reader.get_dataset_headers();
let dim = dim as usize;
let distribution = StandardUniform;
let indices = (0..npts).filter(|_| {
let p: f64 = distribution.sample(rng);
p < sampling_rate
});
let mut centroid: Vec<f32> = vec![0.0f32; dim];
let mut vectors_processed = 0;
let mut centroid_initialized = false;
sample_reader.read_vectors(indices, |vector| {
if !centroid_initialized {
let full_dim = T::full_dimension(vector).into_ann_result()?;
centroid = vec![0.0f32; full_dim];
centroid_initialized = true;
}
let f32_vector = T::as_f32(vector).into_ann_result()?;
for j in 0..f32_vector.len() {
centroid[j] += f32_vector[j];
}
vectors_processed += 1;
Ok(())
})?;
if !centroid_initialized {
Err(ANNError::log_index_error(
"Trying to compute centroid on zero vectors",
))
} else {
for value in centroid.iter_mut() {
*value /= vectors_processed as f32;
}
Ok(centroid)
}
}
pub fn find_nearest_vector_with_id<T, Iter>(
iter: Iter,
centroid: &[f32],
) -> ANNResult<Option<(Box<[T]>, usize)>>
where
T: VectorRepr,
Iter: Iterator<Item = (Box<[T]>, ())>,
{
let mut min_dist: f32 = f32::MAX;
let mut nearest = None;
let mut min_id = 0;
for (id, (v, _)) in iter.enumerate() {
let vf32 = T::as_f32(&v).into_ann_result()?;
let dist = SquaredL2::evaluate(centroid, vf32.as_ref());
if dist < min_dist {
min_dist = dist;
nearest = Some(v.clone());
min_id = id;
}
}
Ok(nearest.map(|v| (v, min_id)))
}
pub fn find_medoid_from_file<T, Reader>(path: &str, reader: &Reader) -> ANNResult<(Vec<T>, usize)>
where
T: VectorRepr,
Reader: StorageReadProvider,
{
let iter: VectorDataIterator<Reader, T> =
VectorDataIterator::<Reader, T>::new(path, None, reader)?;
let num_points = iter.get_num_points();
let mut iter = iter.peekable();
if let Some((x, _)) = iter.peek() {
let full_dimension = T::full_dimension(x).into_ann_result()?;
let centroid = calculate_centroid(iter, full_dimension, num_points)?;
let iter = VectorDataIterator::<Reader, T>::new(path, None, reader)?;
let (medoid, medoid_id) = find_nearest_vector_with_id(iter, ¢roid)?
.ok_or_else(|| ANNError::log_index_error("medoid not found"))?;
Ok((medoid.to_vec(), medoid_id))
} else {
Err(ANNError::log_index_error(
"Medoid not calculable on zero length iterator",
))
}
}
pub fn find_medoid_with_sampling<T, Reader>(
path: &str,
reader: &Reader,
max_sample_size: usize,
rng: &mut impl Rng,
) -> ANNResult<(Vec<T>, usize)>
where
T: VectorRepr,
Reader: StorageReadProvider,
{
let metadata = load_metadata_from_file(reader, path)?;
let sampling_rate = if max_sample_size == 0 || max_sample_size >= metadata.npoints() {
1.0 } else {
max_sample_size as f64 / metadata.npoints() as f64
};
info!(
"Finding medoid from {} points with max max_sample_size: {}, sampling_rate: {:.2}",
metadata.npoints(),
max_sample_size,
sampling_rate
);
let centroid = calculate_centroid_with_sampling::<T, _>(path, reader, sampling_rate, rng)?;
let iter = VectorDataIterator::<Reader, T>::new(path, None, reader)?;
let (medoid, medoid_id) = find_nearest_vector_with_id(iter, ¢roid)?
.ok_or_else(|| ANNError::log_index_error("medoid not found"))?;
Ok((medoid.to_vec(), medoid_id))
}
#[cfg(test)]
mod tests {
use std::{io::Write, num::NonZeroUsize};
use crate::storage::VirtualStorageProvider;
use diskann::utils::VectorRepr;
use diskann_quantization::{
CompressInto,
algorithms::{Transform, transforms::NullTransform},
minmax::{DataMutRef, MinMaxQuantizer},
num::Positive,
};
use diskann_utils::{ReborrowMut, io::Metadata};
use rand::{SeedableRng, rngs::StdRng};
use vfs::{FileSystem, MemoryFS};
use super::*;
use crate::common::MinMaxElement;
fn create_test_vector_file<T: VectorRepr>(
filesystem: &MemoryFS,
path: &str,
vectors: &[Vec<T>],
) -> Result<(), Box<dyn std::error::Error>> {
let mut file = filesystem.create_file(path)?;
let num_points = vectors.len();
let dimension = if vectors.is_empty() {
0
} else {
vectors[0].len()
};
Metadata::new(num_points, dimension)
.unwrap()
.write(&mut file)?;
for vector in vectors {
let bytes = bytemuck::cast_slice(vector);
file.write_all(bytes)?;
}
Ok(())
}
fn create_f32_test_vectors() -> Vec<Vec<f32>> {
vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
vec![2.0, 3.0, 4.0],
]
}
fn create_minmax_test_vectors() -> Result<Vec<Vec<MinMaxElement<8>>>, Box<dyn std::error::Error>>
{
let f32_vectors = create_f32_test_vectors();
let mut minmax_vectors = Vec::new();
for vector in f32_vectors {
let transform =
Transform::Null(NullTransform::new(NonZeroUsize::new(vector.len()).unwrap()));
let quantizer = MinMaxQuantizer::new(transform, Positive::new(1.0).unwrap());
let mut bytes =
vec![
0_u8;
diskann_quantization::minmax::DataRef::<8>::canonical_bytes(vector.len())
];
let mut compressed =
DataMutRef::<8>::from_canonical_front_mut(&mut bytes, vector.len()).unwrap();
quantizer
.compress_into(vector.as_slice(), compressed.reborrow_mut())
.unwrap();
let minmax_vector: Vec<MinMaxElement<8>> = bytemuck::cast_slice(&bytes).to_vec();
minmax_vectors.push(minmax_vector);
}
Ok(minmax_vectors)
}
#[test]
fn test_calculate_centroid_basic() {
let vectors = vec![
(vec![1.0f32, 2.0, 3.0].into_boxed_slice(), ()),
(vec![4.0f32, 5.0, 6.0].into_boxed_slice(), ()),
(vec![7.0f32, 8.0, 9.0].into_boxed_slice(), ()),
];
let centroid = calculate_centroid(vectors.into_iter(), 3, 3).unwrap();
assert_eq!(centroid, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_calculate_centroid_empty_iterator() {
let vectors: Vec<(Box<[f32]>, ())> = vec![];
let centroid = calculate_centroid(vectors.into_iter(), 3, 0).unwrap();
assert_eq!(centroid, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_calculate_centroid_with_minmax_success() {
let minmax_vectors = create_minmax_test_vectors().unwrap();
let vectors: Vec<(Box<[MinMaxElement<8>]>, ())> = minmax_vectors
.into_iter()
.map(|v| (v.into_boxed_slice(), ()))
.collect();
let dimension = MinMaxElement::full_dimension(&vectors[0].0).unwrap();
let centroid = calculate_centroid(vectors.into_iter(), dimension, 4).unwrap();
assert_eq!(centroid.len(), 3);
assert!((centroid[0] - 3.5).abs() < 1e-2);
assert!((centroid[1] - 4.5).abs() < 1e-2);
assert!((centroid[2] - 5.5).abs() < 1e-2);
}
#[test]
fn test_find_nearest_vector_with_id_basic() {
let vectors = vec![
(vec![1.0f32, 2.0, 3.0].into_boxed_slice(), ()),
(vec![4.0f32, 5.0, 6.0].into_boxed_slice(), ()),
(vec![7.0f32, 8.0, 9.0].into_boxed_slice(), ()),
];
let centroid = vec![4.5, 5.5, 6.5];
let result = find_nearest_vector_with_id(vectors.into_iter(), ¢roid).unwrap();
assert!(result.is_some());
let (nearest_vector, nearest_id) = result.unwrap();
assert_eq!(nearest_id, 1);
assert_eq!(nearest_vector.as_ref(), &[4.0, 5.0, 6.0]);
}
#[test]
fn test_find_nearest_vector_with_id_empty_iterator() {
let vectors: Vec<(Box<[f32]>, ())> = vec![];
let centroid = vec![1.0, 2.0, 3.0];
let result = find_nearest_vector_with_id(vectors.into_iter(), ¢roid).unwrap();
assert!(result.is_none());
}
#[test]
fn test_find_nearest_vector_with_minmax_success() {
let minmax_vectors = create_minmax_test_vectors().unwrap();
let vectors: Vec<(Box<[MinMaxElement<8>]>, ())> = minmax_vectors
.into_iter()
.map(|v| (v.into_boxed_slice(), ()))
.collect();
let centroid = vec![3.5, 4.5, 5.5]; let result = find_nearest_vector_with_id(vectors.into_iter(), ¢roid).unwrap();
assert!(result.is_some());
let (_, nearest_id) = result.unwrap();
assert!(nearest_id == 1);
}
#[test]
fn test_find_medoid_from_file_basic() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors = create_f32_test_vectors();
create_test_vector_file(storage_provider.filesystem(), "/test_vectors.bin", &vectors)
.unwrap();
let result = find_medoid_from_file::<f32, _>("/test_vectors.bin", &storage_provider);
assert!(result.is_ok());
let (medoid, medoid_id) = result.unwrap();
assert_eq!(medoid.len(), 3);
assert!(medoid_id == 1); }
#[test]
fn test_find_medoid_from_file_empty_file() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors: Vec<Vec<f32>> = vec![]; create_test_vector_file(
storage_provider.filesystem(),
"/empty_vectors.bin",
&vectors,
)
.unwrap();
let result = find_medoid_from_file::<f32, _>("/empty_vectors.bin", &storage_provider);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("zero length iterator")
);
}
#[test]
fn test_find_medoid_with_sampling_basic() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors = create_f32_test_vectors();
create_test_vector_file(storage_provider.filesystem(), "/test_vectors.bin", &vectors)
.unwrap();
let mut rng = StdRng::seed_from_u64(12345);
let result = find_medoid_with_sampling::<f32, _>(
"/test_vectors.bin",
&storage_provider,
2, &mut rng,
);
assert!(result.is_ok());
let (medoid, medoid_id) = result.unwrap();
assert_eq!(medoid.len(), 3);
assert!(medoid_id < 4);
}
#[test]
fn test_find_medoid_with_sampling_no_sampling() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors = create_f32_test_vectors();
create_test_vector_file(storage_provider.filesystem(), "/test_vectors.bin", &vectors)
.unwrap();
let mut rng = StdRng::seed_from_u64(12345);
let result = find_medoid_with_sampling::<f32, _>(
"/test_vectors.bin",
&storage_provider,
0, &mut rng,
);
assert!(result.is_ok());
let (medoid, medoid_id) = result.unwrap();
assert_eq!(medoid.len(), 3);
assert!(medoid_id == 1);
}
#[test]
fn test_calculate_centroid_with_sampling_empty_vectors() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors: Vec<Vec<f32>> = vec![];
create_test_vector_file(
storage_provider.filesystem(),
"/empty_vectors.bin",
&vectors,
)
.unwrap();
let mut rng = StdRng::seed_from_u64(12345);
let result = calculate_centroid_with_sampling::<f32, _>(
"/empty_vectors.bin",
&storage_provider,
1.0, &mut rng,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("zero vectors"));
}
#[test]
fn test_calculate_centroid_with_sampling_basic() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors = create_f32_test_vectors();
create_test_vector_file(storage_provider.filesystem(), "/test_vectors.bin", &vectors)
.unwrap();
let mut rng = StdRng::seed_from_u64(12345);
let result = calculate_centroid_with_sampling::<f32, _>(
"/test_vectors.bin",
&storage_provider,
1.0, &mut rng,
);
assert!(result.is_ok());
let centroid = result.unwrap();
assert_eq!(centroid.len(), 3);
assert!(centroid.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_find_medoid_with_minmax_vectors() {
let storage_provider = VirtualStorageProvider::new_memory();
let minmax_vectors = create_minmax_test_vectors().unwrap();
create_test_vector_file(
storage_provider.filesystem(),
"/minmax_vectors.bin",
&minmax_vectors,
)
.unwrap();
let result =
find_medoid_from_file::<MinMaxElement<8>, _>("/minmax_vectors.bin", &storage_provider);
assert!(result.is_ok());
let (medoid, medoid_id) = result.unwrap();
assert!(!medoid.is_empty());
assert_eq!(medoid_id, 1);
}
#[test]
fn test_calculate_centroid_with_sampling_zero_sampling_rate() {
let storage_provider = VirtualStorageProvider::new_memory();
let vectors = create_f32_test_vectors();
create_test_vector_file(storage_provider.filesystem(), "/test_vectors.bin", &vectors)
.unwrap();
let mut rng = StdRng::seed_from_u64(12345);
let result = calculate_centroid_with_sampling::<f32, _>(
"/test_vectors.bin",
&storage_provider,
0.001, &mut rng,
);
if let Err(err) = result {
assert!(err.to_string().contains("zero vectors"));
}
}
#[test]
fn test_error_handling_with_corrupted_minmax_data() {
let corrupted_vectors = vec![
vec![MinMaxElement::<8>::default(); 2], ];
let storage_provider = VirtualStorageProvider::new_memory();
create_test_vector_file(
storage_provider.filesystem(),
"/corrupted.bin",
&corrupted_vectors,
)
.unwrap();
let result =
find_medoid_from_file::<MinMaxElement<8>, _>("/corrupted.bin", &storage_provider);
assert!(result.is_err());
}
}