Skip to main content

nodedb_query/simd_filter/
runtime.rs

1// SPDX-License-Identifier: Apache-2.0
2
3// ---------------------------------------------------------------------------
4// Runtime dispatch
5// ---------------------------------------------------------------------------
6
7use super::scalar::{
8    CmpOp, scalar_cmp_f64, scalar_cmp_i64, scalar_eq_u32, scalar_ne_u32, scalar_range_i64,
9};
10
11/// SIMD runtime for filter-to-bitmask operations.
12pub struct FilterSimdRuntime {
13    /// `values[i] == target` → bit i set.
14    pub eq_u32: fn(&[u32], u32) -> Vec<u64>,
15    /// `values[i] != target` → bit i set.
16    pub ne_u32: fn(&[u32], u32) -> Vec<u64>,
17    /// `values[i] > threshold` → bit i set.
18    pub gt_f64: fn(&[f64], f64) -> Vec<u64>,
19    /// `values[i] >= threshold` → bit i set.
20    pub gte_f64: fn(&[f64], f64) -> Vec<u64>,
21    /// `values[i] < threshold` → bit i set.
22    pub lt_f64: fn(&[f64], f64) -> Vec<u64>,
23    /// `values[i] <= threshold` → bit i set.
24    pub lte_f64: fn(&[f64], f64) -> Vec<u64>,
25    /// `values[i] > threshold` → bit i set (i64).
26    pub gt_i64: fn(&[i64], i64) -> Vec<u64>,
27    /// `values[i] >= threshold` → bit i set (i64).
28    pub gte_i64: fn(&[i64], i64) -> Vec<u64>,
29    /// `values[i] < threshold` → bit i set (i64).
30    pub lt_i64: fn(&[i64], i64) -> Vec<u64>,
31    /// `values[i] <= threshold` → bit i set (i64).
32    pub lte_i64: fn(&[i64], i64) -> Vec<u64>,
33    /// `min <= values[i] <= max` → bit i set.
34    pub range_i64: fn(&[i64], i64, i64) -> Vec<u64>,
35    pub name: &'static str,
36}
37
38impl FilterSimdRuntime {
39    pub fn detect() -> Self {
40        #[cfg(target_arch = "x86_64")]
41        {
42            if std::is_x86_feature_detected!("avx512f") {
43                return Self {
44                    eq_u32: super::avx512::avx512_eq_u32,
45                    ne_u32: super::avx512::avx512_ne_u32,
46                    gt_f64: super::avx512::avx512_gt_f64,
47                    gte_f64: super::avx512::avx512_gte_f64,
48                    lt_f64: super::avx512::avx512_lt_f64,
49                    lte_f64: super::avx512::avx512_lte_f64,
50                    gt_i64: super::avx512::avx512_gt_i64,
51                    gte_i64: super::avx512::avx512_gte_i64,
52                    lt_i64: super::avx512::avx512_lt_i64,
53                    lte_i64: super::avx512::avx512_lte_i64,
54                    range_i64: super::avx512::avx512_range_i64,
55                    name: "avx512",
56                };
57            }
58            if std::is_x86_feature_detected!("avx2") {
59                return Self {
60                    eq_u32: super::avx2::avx2_eq_u32,
61                    ne_u32: super::avx2::avx2_ne_u32,
62                    gt_f64: super::avx2::avx2_gt_f64,
63                    gte_f64: super::avx2::avx2_gte_f64,
64                    lt_f64: super::avx2::avx2_lt_f64,
65                    lte_f64: super::avx2::avx2_lte_f64,
66                    gt_i64: super::avx2::avx2_gt_i64,
67                    gte_i64: super::avx2::avx2_gte_i64,
68                    lt_i64: super::avx2::avx2_lt_i64,
69                    lte_i64: super::avx2::avx2_lte_i64,
70                    range_i64: super::avx2::avx2_range_i64,
71                    name: "avx2",
72                };
73            }
74        }
75        #[cfg(target_arch = "aarch64")]
76        {
77            return Self {
78                eq_u32: super::neon::neon_eq_u32,
79                ne_u32: super::neon::neon_ne_u32,
80                gt_f64: super::neon::neon_gt_f64,
81                gte_f64: super::neon::neon_gte_f64,
82                lt_f64: super::neon::neon_lt_f64,
83                lte_f64: super::neon::neon_lte_f64,
84                gt_i64: super::neon::neon_gt_i64,
85                gte_i64: super::neon::neon_gte_i64,
86                lt_i64: super::neon::neon_lt_i64,
87                lte_i64: super::neon::neon_lte_i64,
88                range_i64: super::neon::neon_range_i64,
89                name: "neon",
90            };
91        }
92        #[cfg(target_arch = "wasm32")]
93        {
94            return Self {
95                eq_u32: super::wasm::wasm_eq_u32,
96                ne_u32: super::wasm::wasm_ne_u32,
97                gt_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Gt),
98                gte_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Gte),
99                lt_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Lt),
100                lte_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Lte),
101                gt_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Gt),
102                gte_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Gte),
103                lt_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Lt),
104                lte_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Lte),
105                range_i64: scalar_range_i64,
106                name: "wasm-simd128",
107            };
108        }
109        #[allow(unreachable_code)]
110        Self {
111            eq_u32: scalar_eq_u32,
112            ne_u32: scalar_ne_u32,
113            gt_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Gt),
114            gte_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Gte),
115            lt_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Lt),
116            lte_f64: |v, t| scalar_cmp_f64(v, t, CmpOp::Lte),
117            gt_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Gt),
118            gte_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Gte),
119            lt_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Lt),
120            lte_i64: |v, t| scalar_cmp_i64(v, t, CmpOp::Lte),
121            range_i64: scalar_range_i64,
122            name: "scalar",
123        }
124    }
125}
126
127static FILTER_RUNTIME: std::sync::OnceLock<FilterSimdRuntime> = std::sync::OnceLock::new();
128
129/// Get the global filter SIMD runtime.
130pub fn filter_runtime() -> &'static FilterSimdRuntime {
131    FILTER_RUNTIME.get_or_init(FilterSimdRuntime::detect)
132}