use std::{cell::UnsafeCell, mem, num::NonZeroUsize, ops::Deref, slice, sync::Arc};
use crate::storage::{StorageReadProvider, StorageWriteProvider};
use arc_swap::Guard;
use diskann::{ANNError, ANNResult, always_escalate, utils::IntoUsize};
use diskann_utils::future::AsyncFriendly;
use diskann_vector::distance::Metric;
use crate::{
model::graph::provider::async_::{TableDeleteProviderAsync, postprocess},
storage::{AsyncIndexMetadata, AsyncQuantLoadContext, LoadWith, SaveWith},
};
pub struct StartPoints {
start: u32,
end: u32,
}
impl StartPoints {
pub fn new(valid_points: u32, frozen_points: NonZeroUsize) -> ANNResult<Self> {
Ok(Self {
start: valid_points,
end: match valid_points.checked_add(frozen_points.get() as u32) {
Some(end) => end,
None => {
return Err(ANNError::log_index_error(
"Sum of valid points and frozen points exceeds u32::MAX",
));
}
},
})
}
pub fn range(&self) -> std::ops::Range<u32> {
self.start..self.end
}
pub fn len(&self) -> usize {
(self.end - self.start).into_usize()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn start(&self) -> u32 {
self.start
}
pub fn end(&self) -> u32 {
self.end
}
}
pub struct VectorGuard<T> {
inner: Guard<Arc<Vec<T>>>,
}
impl<T> VectorGuard<T> {
pub(crate) fn from_guard(guard: Guard<Arc<Vec<T>>>) -> Self {
Self { inner: guard }
}
}
impl<T> Deref for VectorGuard<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
pub struct AlignedMemoryVectorStore<T: bytemuck::Pod> {
store: UnsafeCell<Vec<T>>,
max_vectors: usize,
start_index: usize,
dim: usize,
padded_vector_dim: usize,
}
unsafe impl<T: bytemuck::Pod + Sync> Sync for AlignedMemoryVectorStore<T> {}
unsafe impl<T: bytemuck::Pod + Send> Send for AlignedMemoryVectorStore<T> {}
impl<T: bytemuck::Pod> AlignedMemoryVectorStore<T> {
pub fn with_capacity(max_vectors: usize, dim: usize) -> Self {
let elem_size = mem::size_of::<T>();
assert!(64 % elem_size == 0);
let vector_size = elem_size * dim;
let extra_size = vector_size % 64;
let padded_vector_dim = if extra_size == 0 {
dim
} else {
let padding_needed_size = 64 - extra_size;
assert!(padding_needed_size.is_multiple_of(elem_size));
let extra_elems = padding_needed_size / elem_size;
dim + extra_elems
};
assert!((padded_vector_dim * elem_size).is_multiple_of(64));
let last_elems: usize = 64 / elem_size - 1;
let count = max_vectors * padded_vector_dim + last_elems;
let mut store: UnsafeCell<Vec<T>> =
UnsafeCell::new(vec![<T as bytemuck::Zeroable>::zeroed(); count]);
let start_index = store.get_mut().as_ptr().align_offset(64);
Self {
store,
max_vectors,
start_index,
dim,
padded_vector_dim,
}
}
pub fn max_vectors(&self) -> usize {
self.max_vectors
}
pub fn dim(&self) -> usize {
self.dim
}
#[inline(always)]
pub unsafe fn get_slice(&self, index: usize) -> &[T] {
assert!(
index < self.max_vectors,
"index ({}) exceeded max_vectors ({})",
index,
self.max_vectors
);
let index = index * self.padded_vector_dim + self.start_index;
let buf = unsafe { (*self.store.get()).as_ptr() };
unsafe { slice::from_raw_parts(buf.add(index), self.dim) }
}
#[allow(clippy::mut_from_ref)]
pub unsafe fn get_mut_slice(&self, index: usize) -> &mut [T] {
assert!(index < self.max_vectors);
let index = index * self.padded_vector_dim + self.start_index;
unsafe {
let buf = (*self.store.get()).as_mut_ptr();
slice::from_raw_parts_mut(buf.add(index), self.dim)
}
}
}
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
pub enum PrefetchCacheLineLevel {
CacheLine4,
CacheLine8,
#[default]
CacheLine16,
All,
}
#[derive(Debug, Clone, Copy)]
pub struct Unseeded;
pub trait SetElementHelper<T> {
fn set_element(&self, index: &u32, element: &[T]) -> ANNResult<()>;
}
pub trait CreateVectorStore {
type Target: VectorStore;
fn create(
self,
max_points: usize,
metric: Metric,
prefetch_lookahead: Option<usize>,
) -> Self::Target;
}
pub trait CreateDeleteProvider {
type Target;
fn create(self, total_points: usize) -> Self::Target;
}
pub trait VectorStore: AsyncFriendly {
fn total(&self) -> usize;
fn count_for_get_vector(&self) -> usize;
}
#[derive(Debug)]
pub enum Panics {}
impl std::fmt::Display for Panics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "panics")
}
}
impl std::error::Error for Panics {}
impl From<Panics> for ANNError {
#[cold]
fn from(_: Panics) -> ANNError {
ANNError::log_async_error("unreachable")
}
}
always_escalate!(Panics);
#[derive(Debug, Clone, Copy)]
pub struct NoStore;
impl CreateVectorStore for NoStore {
type Target = NoStore;
fn create(
self,
_max_points: usize,
_metric: Metric,
_prefetch_lookahead: Option<usize>,
) -> Self::Target {
self
}
}
impl VectorStore for NoStore {
fn total(&self) -> usize {
0
}
fn count_for_get_vector(&self) -> usize {
0
}
}
impl LoadWith<AsyncQuantLoadContext> for NoStore {
type Error = ANNError;
async fn load_with<P>(_: &P, _: &AsyncQuantLoadContext) -> ANNResult<Self>
where
P: StorageReadProvider,
{
Ok(Self)
}
}
impl SaveWith<AsyncIndexMetadata> for NoStore {
type Ok = usize;
type Error = ANNError;
async fn save_with<P>(&self, _provider: &P, _auxiliary: &AsyncIndexMetadata) -> ANNResult<usize>
where
P: StorageWriteProvider,
{
Ok(0)
}
}
impl<T> SetElementHelper<T> for NoStore {
fn set_element(&self, _index: &u32, _element: &[T]) -> ANNResult<()> {
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct NoDeletes;
impl postprocess::DeletionCheck for NoDeletes {
#[inline(always)]
fn deletion_check(&self, _: u32) -> bool {
false
}
}
impl CreateDeleteProvider for NoDeletes {
type Target = Self;
fn create(self, _: usize) -> Self {
Self
}
}
#[derive(Debug, Clone, Copy)]
pub struct TableBasedDeletes;
impl CreateDeleteProvider for TableBasedDeletes {
type Target = TableDeleteProviderAsync;
fn create(self, total_points: usize) -> Self::Target {
TableDeleteProviderAsync::new(total_points)
}
}
#[derive(Debug, Clone, Copy)]
pub struct FullPrecision;
#[derive(Debug, Clone, Copy)]
pub struct Quantized;
#[derive(Debug, Clone, Copy)]
pub struct Hybrid {
pub max_fp_vecs_per_prune: Option<usize>,
}
impl Hybrid {
pub fn new(max_fp_vecs_per_prune: Option<usize>) -> Self {
Self {
max_fp_vecs_per_prune,
}
}
}
#[cfg(test)]
pub struct TestCallCount {
count: std::sync::atomic::AtomicUsize,
}
#[cfg(test)]
impl TestCallCount {
pub fn new() -> Self {
TestCallCount {
count: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn enabled() -> bool {
true
}
pub fn increment(&self) {
self.count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn get(&self) -> usize {
self.count.load(std::sync::atomic::Ordering::Relaxed)
}
}
#[cfg(not(test))]
pub struct TestCallCount {}
#[cfg(not(test))]
impl TestCallCount {
pub fn new() -> Self {
TestCallCount {}
}
pub fn enabled() -> bool {
false
}
pub fn increment(&self) {}
pub fn get(&self) -> usize {
0
}
}
impl Default for TestCallCount {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use super::*;
#[test]
fn new_creates_correct_range() {
let sp = StartPoints::new(10, NonZeroUsize::new(5).unwrap())
.expect("should construct without overflow");
let r = sp.range().collect::<Vec<_>>();
assert_eq!(r, vec![10, 11, 12, 13, 14]);
assert_eq!(sp.end(), 15);
}
#[test]
fn new_returns_error_on_overflow() {
let max = u32::MAX;
let res = StartPoints::new(max, NonZeroUsize::new(1).unwrap());
assert!(res.is_err(), "expected an error when sum exceeds u32::MAX");
if let Err(err) = res {
let msg = err.to_string();
assert!(
msg.contains("Sum of valid points and frozen points exceeds u32::MAX"),
"unexpected error message: {}",
msg
);
}
}
}