Skip to main content

lattice_embed/simd/
mod.rs

1//! SIMD-accelerated vector operations for embedding similarity.
2//!
3//! Provides optimized implementations with automatic fallback:
4//! - **x86_64 (float32)**: AVX-512F > AVX2 + FMA > scalar
5//! - **x86_64 (int8)**: AVX-512 VNNI > AVX2 > scalar
6//! - **aarch64**: ARM NEON with multiple accumulators and loop unrolling
7//! - **Other**: Scalar fallback
8//!
9//! ## Optimizations
10//!
11//! - **Multiple accumulators**: 4 parallel accumulators to break dependency chains
12//! - **Loop unrolling**: Process 16/32/64 elements per iteration depending on ISA
13//! - **AVX-512F**: Wide float32 kernels for dot, cosine, normalize, and distance
14//! - **AVX-512 VNNI**: Integer VNNI instructions when available (quantized int8 path)
15
16mod binary;
17mod cosine;
18mod distance;
19mod dot_product;
20mod int4;
21mod normalize;
22mod quantized;
23mod tier;
24
25#[cfg(test)]
26mod tests;
27
28// Re-export public API
29pub use binary::BinaryVector;
30pub use cosine::{
31    batch_cosine_one_vs_many, batch_cosine_similarity, cosine_similarity, cosine_similarity_fused,
32};
33pub use distance::{euclidean_distance, squared_euclidean_distance};
34pub use dot_product::{
35    DotBatch4Kernel, DotKernel, batch_dot_product, dot_product, dot_product_batch4,
36    resolved_dot_product_batch4_kernel, resolved_dot_product_kernel,
37};
38pub use int4::{Int4Params, Int4Vector};
39pub use normalize::normalize;
40pub use quantized::{
41    I8DotKernel, QuantizationParams, QuantizedVector, cosine_similarity_i8, dot_product_i8,
42    dot_product_i8_raw, resolved_i8_dot_kernel,
43};
44pub use tier::{
45    NormalizationHint, PreparedQuery, PreparedQueryWithMeta, QuantizationTier, QuantizedData,
46    approximate_cosine_distance, approximate_cosine_distance_prepared,
47    approximate_cosine_distance_prepared_with_meta, approximate_dot_product,
48    approximate_dot_product_prepared, approximate_int4_batch_prepared,
49    approximate_int4_batch_prepared_into, approximate_int8_batch_prepared,
50    approximate_int8_batch_prepared_into, batch_approximate_cosine_distance_prepared,
51    batch_approximate_cosine_distance_prepared_into, is_unit_norm, prepare_query,
52    prepare_query_with_norm, try_approximate_cosine_distance_prepared,
53    try_approximate_dot_product_prepared,
54};
55
56use std::sync::OnceLock;
57
58/// **Unstable**: SIMD dispatch internals; fields may be added as new ISAs are supported.
59///
60/// SIMD configuration with runtime feature detection.
61#[derive(Debug, Clone, Copy)]
62pub struct SimdConfig {
63    /// **Unstable**: AVX-512F support available (x86_64).
64    pub avx512f_enabled: bool,
65    /// **Unstable**: AVX2 support available (x86_64).
66    pub avx2_enabled: bool,
67    /// **Unstable**: FMA (Fused Multiply-Add) support available (x86_64).
68    pub fma_enabled: bool,
69    /// **Unstable**: AVX-512F + AVX-512VNNI support available (x86_64).
70    pub avx512vnni_enabled: bool,
71    /// **Unstable**: NEON support available (aarch64/ARM64).
72    pub neon_enabled: bool,
73    /// **Unstable**: ARM FEAT_DotProd (SDOT/UDOT instructions) available (aarch64).
74    ///
75    /// Mandatory on Armv8.4+; optional on Armv8.2/v8.3. Always false on non-aarch64.
76    /// SDOT kernels must only be dispatched when this is `true`.
77    pub dotprod_enabled: bool,
78}
79
80impl Default for SimdConfig {
81    fn default() -> Self {
82        Self::detect()
83    }
84}
85
86impl SimdConfig {
87    /// **Unstable**: feature detection details may change as ISA support expands.
88    pub fn detect() -> Self {
89        #[cfg(target_arch = "x86_64")]
90        {
91            let avx512f_enabled = is_x86_feature_detected!("avx512f");
92
93            Self {
94                avx512f_enabled,
95                avx2_enabled: is_x86_feature_detected!("avx2"),
96                fma_enabled: is_x86_feature_detected!("fma"),
97                avx512vnni_enabled: avx512f_enabled
98                    && is_x86_feature_detected!("avx512bw")
99                    && is_x86_feature_detected!("avx512vnni"),
100                neon_enabled: false,
101                dotprod_enabled: false,
102            }
103        }
104        #[cfg(target_arch = "aarch64")]
105        {
106            // NEON is mandatory on aarch64, always available.
107            // FEAT_DotProd (dotprod) is optional: required on Armv8.4+,
108            // optional on Armv8.2/v8.3. Detect at runtime.
109            Self {
110                avx512f_enabled: false,
111                avx2_enabled: false,
112                fma_enabled: false,
113                avx512vnni_enabled: false,
114                neon_enabled: true,
115                dotprod_enabled: std::arch::is_aarch64_feature_detected!("dotprod"),
116            }
117        }
118        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
119        {
120            Self {
121                avx512f_enabled: false,
122                avx2_enabled: false,
123                fma_enabled: false,
124                avx512vnni_enabled: false,
125                neon_enabled: false,
126                dotprod_enabled: false,
127            }
128        }
129    }
130
131    /// **Unstable**: check if any SIMD is available; logic may expand with new ISAs.
132    #[inline]
133    pub fn simd_available(&self) -> bool {
134        self.avx512f_enabled || self.avx512vnni_enabled || self.avx2_enabled || self.neon_enabled
135    }
136
137    /// Force scalar-only mode (useful for testing).
138    #[cfg(test)]
139    pub fn scalar_only() -> Self {
140        Self {
141            avx512f_enabled: false,
142            avx2_enabled: false,
143            fma_enabled: false,
144            avx512vnni_enabled: false,
145            neon_enabled: false,
146            dotprod_enabled: false,
147        }
148    }
149}
150
151// Process-wide SIMD configuration (detected once).
152static SIMD_CONFIG: OnceLock<SimdConfig> = OnceLock::new();
153
154/// **Unstable**: SIMD dispatch internal; shape may change as new backends are added.
155///
156/// The config is detected once per process and cached.
157#[inline]
158pub fn simd_config() -> SimdConfig {
159    *SIMD_CONFIG.get_or_init(SimdConfig::detect)
160}