Skip to main content

lattice_embed/simd/
mod.rs

1//! SIMD-accelerated vector operations for embedding similarity.
2//!
3//! Provides optimized implementations with automatic fallback:
4//! - **x86_64 (float32)**: AVX-512F > AVX2 + FMA > scalar
5//! - **x86_64 (int8)**: AVX-512 VNNI > AVX2 > scalar
6//! - **aarch64**: ARM NEON with multiple accumulators and loop unrolling
7//! - **Other**: Scalar fallback
8//!
9//! ## Optimizations
10//!
11//! - **Multiple accumulators**: 4 parallel accumulators to break dependency chains
12//! - **Loop unrolling**: Process 16/32/64 elements per iteration depending on ISA
13//! - **AVX-512F**: Wide float32 kernels for dot, cosine, normalize, and distance
14//! - **AVX-512 VNNI**: Integer VNNI instructions when available (quantized int8 path)
15
16mod binary;
17mod cosine;
18mod distance;
19mod dot_product;
20mod int4;
21mod normalize;
22mod quantized;
23mod tier;
24
25#[cfg(test)]
26mod tests;
27
28// Re-export public API
29pub use binary::BinaryVector;
30pub use cosine::{batch_cosine_similarity, cosine_similarity};
31pub use distance::{euclidean_distance, squared_euclidean_distance};
32pub use dot_product::{
33    DotBatch4Kernel, DotKernel, batch_dot_product, dot_product, dot_product_batch4,
34    resolved_dot_product_batch4_kernel, resolved_dot_product_kernel,
35};
36pub use int4::{Int4Params, Int4Vector};
37pub use normalize::normalize;
38pub use quantized::{
39    I8DotKernel, QuantizationParams, QuantizedVector, cosine_similarity_i8, dot_product_i8,
40    dot_product_i8_raw, resolved_i8_dot_kernel,
41};
42pub use tier::{
43    NormalizationHint, PreparedQuery, PreparedQueryWithMeta, QuantizationTier, QuantizedData,
44    approximate_cosine_distance, approximate_cosine_distance_prepared,
45    approximate_cosine_distance_prepared_with_meta, approximate_dot_product,
46    approximate_dot_product_prepared, batch_approximate_cosine_distance_prepared,
47    batch_approximate_cosine_distance_prepared_into, is_unit_norm, prepare_query,
48    prepare_query_with_norm, try_approximate_cosine_distance_prepared,
49    try_approximate_dot_product_prepared,
50};
51
52use std::sync::OnceLock;
53
54/// **Unstable**: SIMD dispatch internals; fields may be added as new ISAs are supported.
55///
56/// SIMD configuration with runtime feature detection.
57#[derive(Debug, Clone, Copy)]
58pub struct SimdConfig {
59    /// **Unstable**: AVX-512F support available (x86_64).
60    pub avx512f_enabled: bool,
61    /// **Unstable**: AVX2 support available (x86_64).
62    pub avx2_enabled: bool,
63    /// **Unstable**: FMA (Fused Multiply-Add) support available (x86_64).
64    pub fma_enabled: bool,
65    /// **Unstable**: AVX-512F + AVX-512VNNI support available (x86_64).
66    pub avx512vnni_enabled: bool,
67    /// **Unstable**: NEON support available (aarch64/ARM64).
68    pub neon_enabled: bool,
69}
70
71impl Default for SimdConfig {
72    fn default() -> Self {
73        Self::detect()
74    }
75}
76
77impl SimdConfig {
78    /// **Unstable**: feature detection details may change as ISA support expands.
79    pub fn detect() -> Self {
80        #[cfg(target_arch = "x86_64")]
81        {
82            let avx512f_enabled = is_x86_feature_detected!("avx512f");
83
84            Self {
85                avx512f_enabled,
86                avx2_enabled: is_x86_feature_detected!("avx2"),
87                fma_enabled: is_x86_feature_detected!("fma"),
88                avx512vnni_enabled: avx512f_enabled
89                    && is_x86_feature_detected!("avx512bw")
90                    && is_x86_feature_detected!("avx512vnni"),
91                neon_enabled: false,
92            }
93        }
94        #[cfg(target_arch = "aarch64")]
95        {
96            // NEON is mandatory on aarch64, always available.
97            Self {
98                avx512f_enabled: false,
99                avx2_enabled: false,
100                fma_enabled: false,
101                avx512vnni_enabled: false,
102                neon_enabled: true,
103            }
104        }
105        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
106        {
107            Self {
108                avx512f_enabled: false,
109                avx2_enabled: false,
110                fma_enabled: false,
111                avx512vnni_enabled: false,
112                neon_enabled: false,
113            }
114        }
115    }
116
117    /// **Unstable**: check if any SIMD is available; logic may expand with new ISAs.
118    #[inline]
119    pub fn simd_available(&self) -> bool {
120        self.avx512f_enabled || self.avx512vnni_enabled || self.avx2_enabled || self.neon_enabled
121    }
122
123    /// Force scalar-only mode (useful for testing).
124    #[cfg(test)]
125    pub fn scalar_only() -> Self {
126        Self {
127            avx512f_enabled: false,
128            avx2_enabled: false,
129            fma_enabled: false,
130            avx512vnni_enabled: false,
131            neon_enabled: false,
132        }
133    }
134}
135
136// Process-wide SIMD configuration (detected once).
137static SIMD_CONFIG: OnceLock<SimdConfig> = OnceLock::new();
138
139/// **Unstable**: SIMD dispatch internal; shape may change as new backends are added.
140///
141/// The config is detected once per process and cached.
142#[inline]
143pub fn simd_config() -> SimdConfig {
144    *SIMD_CONFIG.get_or_init(SimdConfig::detect)
145}