Skip to main content

diskann_vector/
lib.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5//! # vector
6//!
7//! This crate contains SIMD accelerated functions for operating on vector data. Note that the name 'vector'
8//! does not exclusively mean embedding vectors, but any array of data appropriate for SIMD. Therefor, aside
9//! from fast implementations of distance for real vectors, this crate also includes things like SIMD
10//! accelerated contains for slices.
11#![cfg_attr(
12    not(test),
13    warn(
14        clippy::panic,
15        clippy::unwrap_used,
16        clippy::expect_used,
17        clippy::undocumented_unsafe_blocks
18    )
19)]
20
21mod half;
22pub use half::Half;
23
24mod traits;
25pub use traits::{
26    DistanceFunction, DistanceFunctionMut, Norm, PreprocessedDistanceFunction, PureDistanceFunction,
27};
28
29mod value;
30pub use value::{MathematicalValue, SimilarityScore};
31
32pub mod contains;
33pub mod conversion;
34pub mod distance;
35pub mod norm;
36
37cfg_if::cfg_if! {
38    if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
39        const CACHE_LINE_SIZE: usize = 64;
40
41        #[inline(always)]
42        unsafe fn prefetch_exactly<const N: usize>(ptr: *const i8) {
43            use std::arch::x86_64::*;
44            for i in 0..N {
45                _mm_prefetch(ptr.add(i * CACHE_LINE_SIZE), _MM_HINT_T0);
46            }
47        }
48
49        #[inline(always)]
50        unsafe fn prefetch_at_most<const N: usize>(ptr: *const i8, bytes: usize) {
51            use std::arch::x86_64::*;
52            for i in 0..N {
53                if CACHE_LINE_SIZE * i >= bytes {
54                    break;
55                }
56                _mm_prefetch(ptr.add(i * CACHE_LINE_SIZE), _MM_HINT_T0);
57            }
58        }
59
60        /// Prefetch the given vector in chunks of 64 bytes, which is a cache line size.
61        /// Only the first `MAX_BLOCKS` chunks will be prefetched.
62        #[inline]
63        pub fn prefetch_hint_max<const MAX_CACHE_LINES: usize, T>(vec: &[T]) {
64            let vecsize = std::mem::size_of_val(vec);
65            if vecsize >= MAX_CACHE_LINES * 64 {
66                // SAFETY: Pointer is in-bounds and use of the intrinsic is cfg gated.
67                unsafe { prefetch_exactly::<MAX_CACHE_LINES>(vec.as_ptr().cast()) }
68            } else {
69                // SAFETY: Pointer is in-bounds and use of the intrinsic is cfg gated.
70                unsafe { prefetch_at_most::<MAX_CACHE_LINES>(vec.as_ptr().cast(), vecsize) }
71            }
72        }
73
74        /// Prefetch the given vector in chunks of 64 bytes, which is a cache line size.
75        /// The entire vector will be prefetched.
76        #[inline]
77        pub fn prefetch_hint_all<T>(vec: &[T]) {
78            use std::arch::x86_64::*;
79
80            let vecsize = std::mem::size_of_val(vec);
81            let num_prefetch_blocks = vecsize.div_ceil(64);
82            let vec_ptr = vec.as_ptr() as *const i8;
83            for d in 0..num_prefetch_blocks {
84                // SAFETY: Pointer is in-bounds and use of the intrinsic is gated by the
85                // `cfg`-guard on this function.
86                unsafe {
87                    std::arch::x86_64::_mm_prefetch(vec_ptr.add(d * CACHE_LINE_SIZE), _MM_HINT_T0);
88                }
89            }        }
90    } else {
91        pub fn prefetch_hint_max<const MAX_CACHE_LINES: usize, T>(_vec: &[T]) {}
92        pub fn prefetch_hint_all<T>(_vec: &[T]) {}
93    }
94}
95
96#[cfg(test)]
97mod test_util;