Skip to main content

ariadnetor_native/
performance.rs

1//! Hardware-aware parallelism threshold tables for `NativeBackend`.
2//!
3//! A `ThresholdTable` stores the minimum problem-size key at which each
4//! linear-algebra op is worth running in parallel on this machine. The
5//! sentinel `usize::MAX` means "no finite parallel threshold" — either
6//! the op is unmeasured on this profile, or calibration showed no
7//! regime where parallel beats sequential (e.g. `ThresholdTable::laptop().transpose`).
8//! `PerformanceManager` pairs a table with the comparison logic that
9//! `NativeBackend::par_for_*` methods call.
10
11use ariadnetor_core::backend::ExecPolicy;
12
13/// Per-op parallelism thresholds.
14///
15/// Each field is the smallest problem-size key at which the op should
16/// dispatch as `ExecPolicy::Parallel(0)`. Keys are op-specific and
17/// produced by the corresponding `NativeBackend::par_for_*` method:
18/// `svd`/`qr`/`lq` and `gemm` use `cbrt(m*n*min(m,n))` and `cbrt(m*n*k)`
19/// respectively, `eigh`/`eig`/`solve` use `n`, `transpose` uses total
20/// element count.
21///
22/// `usize::MAX` marks "no finite parallel threshold": either unmeasured
23/// on this profile, or a calibrated decision that parallel never wins
24/// (e.g. `ThresholdTable::laptop().transpose`). `policy_by_n` treats it
25/// as `ExecPolicy::Sequential` in both cases.
26#[derive(Clone, Debug)]
27pub struct ThresholdTable {
28    /// SVD threshold; key is `cbrt(m*n*min(m,n))`.
29    pub svd: usize,
30    /// QR threshold; key is `cbrt(m*n*min(m,n))`.
31    pub qr: usize,
32    /// LQ threshold; key is `cbrt(m*n*min(m,n))`.
33    pub lq: usize,
34    /// Hermitian-eigendecomposition threshold; key is the dimension `n`.
35    pub eigh: usize,
36    /// General-eigendecomposition threshold; key is the dimension `n`.
37    pub eig: usize,
38    /// GEMM threshold; key is `cbrt(m*n*k)`.
39    pub gemm: usize,
40    /// Linear-solve threshold; key is the dimension `n`.
41    pub solve: usize,
42    /// Transpose threshold; key is the total element count.
43    pub transpose: usize,
44}
45
46impl ThresholdTable {
47    /// Thresholds calibrated for laptop-class CPUs (Apple M2 8-core).
48    ///
49    /// Values come from `crates/ariadnetor-linalg/benches/sweep_{decomp,
50    /// decomp_rect,gemm,solve,transpose}_par.rs` run in a single session.
51    ///
52    /// `transpose` is calibrated per backend at compile time. Under the
53    /// `hptt` feature the sweep showed no regime where Rayon-style
54    /// parallel can beat HPTT's tiled sequential on laptop, so the
55    /// sentinel `usize::MAX` is retained. Without it (the default build),
56    /// the naive fallback's simpler sequential loses to the parallel
57    /// kernel above ~65k total elements.
58    pub fn laptop() -> Self {
59        Self {
60            svd: 384,
61            qr: 384,
62            lq: 512,
63            eigh: 256,
64            eig: 256,
65            gemm: 192,
66            solve: 768,
67            transpose: if cfg!(feature = "hptt") {
68                usize::MAX
69            } else {
70                65_536
71            },
72        }
73    }
74
75    /// Thresholds calibrated for workstation-class CPUs (Xeon NUMA, 112 cores).
76    ///
77    /// Calibrated with the same five sweeps listed for `laptop()`. Most
78    /// ops carry the `usize::MAX` sentinel: at workstation scale parallel
79    /// sync cost is high enough that `svd`/`qr`/`lq`/`eigh`/`eig`/`solve`
80    /// never beat sequential at any `n ≤ 1024` tested. Only large GEMMs
81    /// (`cbrt(m*n*k) ≥ 768`) and transposes benefit from parallel
82    /// dispatch.
83    ///
84    /// `transpose` is calibrated per backend at compile time. Under
85    /// `hptt` the tiled kernel only crosses over at total element count
86    /// ≥ 4_194_304. Without it (the default build), the naive fallback
87    /// crosses over much earlier — its parallel kernel beats its own
88    /// sequential above ~262_144 total elements. Calibration was
89    /// performed on 2D `[n, n]` inputs; the dispatch key is total
90    /// elements for any rank.
91    pub fn workstation() -> Self {
92        Self {
93            svd: usize::MAX,
94            qr: usize::MAX,
95            lq: usize::MAX,
96            eigh: usize::MAX,
97            eig: usize::MAX,
98            gemm: 768,
99            solve: usize::MAX,
100            transpose: if cfg!(feature = "hptt") {
101                4_194_304
102            } else {
103                262_144
104            },
105        }
106    }
107
108    /// Pick a profile based on `std::thread::available_parallelism()`.
109    ///
110    /// Reads the logical-core count (falling back to the conservative `1`
111    /// when the query fails) and delegates the profile choice to
112    /// [`Self::profile_for_parallelism`].
113    pub fn detect() -> Self {
114        let n = std::thread::available_parallelism()
115            .map(|v| v.get())
116            .unwrap_or(1);
117        Self::profile_for_parallelism(n)
118    }
119
120    /// Map a logical-core count to a profile: `> 16` cores → `workstation`,
121    /// otherwise `laptop`. Kept pure (no environment read) so the boundary
122    /// is testable independently of the host's actual core count.
123    fn profile_for_parallelism(n: usize) -> Self {
124        if n > 16 {
125            Self::workstation()
126        } else {
127            Self::laptop()
128        }
129    }
130}
131
132/// Pairs a `ThresholdTable` with the comparison logic used by
133/// `NativeBackend::par_for_*` to translate a problem-size key into an
134/// `ExecPolicy`.
135#[derive(Clone, Debug)]
136pub struct PerformanceManager {
137    thresholds: ThresholdTable,
138}
139
140impl PerformanceManager {
141    /// Wrap a calibrated threshold table in a performance manager.
142    pub fn new(thresholds: ThresholdTable) -> Self {
143        Self { thresholds }
144    }
145
146    /// Borrow the underlying per-op threshold table.
147    pub fn thresholds(&self) -> &ThresholdTable {
148        &self.thresholds
149    }
150
151    /// Map a problem-size key to an `ExecPolicy`.
152    ///
153    /// Returns `Parallel(0)` iff the threshold is non-sentinel
154    /// (`!= usize::MAX`) and the key meets or exceeds it; otherwise
155    /// `Sequential`. The explicit `usize::MAX` check covers both
156    /// "unmeasured" thresholds and calibrated-no-win sentinels (see
157    /// the crate-level note on `usize::MAX` semantics) and prevents
158    /// either from ever tripping Parallel, even if `n` were also
159    /// `usize::MAX`.
160    pub(crate) fn policy_by_n(threshold: usize, n: usize) -> ExecPolicy {
161        if threshold != usize::MAX && n >= threshold {
162            ExecPolicy::Parallel(0)
163        } else {
164            ExecPolicy::Sequential
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn laptop_constants_pinned() {
175        let t = ThresholdTable::laptop();
176        assert_eq!(t.svd, 384);
177        assert_eq!(t.qr, 384);
178        assert_eq!(t.lq, 512);
179        assert_eq!(t.eigh, 256);
180        assert_eq!(t.eig, 256);
181        assert_eq!(t.gemm, 192);
182        assert_eq!(t.solve, 768);
183        #[cfg(feature = "hptt")]
184        assert_eq!(t.transpose, usize::MAX);
185        #[cfg(not(feature = "hptt"))]
186        assert_eq!(t.transpose, 65_536);
187    }
188
189    #[test]
190    fn workstation_constants_pinned() {
191        let t = ThresholdTable::workstation();
192        assert_eq!(t.svd, usize::MAX);
193        assert_eq!(t.qr, usize::MAX);
194        assert_eq!(t.lq, usize::MAX);
195        assert_eq!(t.eigh, usize::MAX);
196        assert_eq!(t.eig, usize::MAX);
197        assert_eq!(t.gemm, 768);
198        assert_eq!(t.solve, usize::MAX);
199        #[cfg(feature = "hptt")]
200        assert_eq!(t.transpose, 4_194_304);
201        #[cfg(not(feature = "hptt"))]
202        assert_eq!(t.transpose, 262_144);
203    }
204
205    #[test]
206    fn policy_by_n_below_threshold_is_sequential() {
207        assert_eq!(
208            PerformanceManager::policy_by_n(256, 255),
209            ExecPolicy::Sequential
210        );
211    }
212
213    #[test]
214    fn policy_by_n_at_threshold_is_parallel() {
215        assert_eq!(
216            PerformanceManager::policy_by_n(256, 256),
217            ExecPolicy::Parallel(0)
218        );
219    }
220
221    #[test]
222    fn policy_by_n_above_threshold_is_parallel() {
223        assert_eq!(
224            PerformanceManager::policy_by_n(256, 1024),
225            ExecPolicy::Parallel(0)
226        );
227    }
228
229    #[test]
230    fn profile_for_parallelism_pins_core_count_boundary() {
231        // The boundary is `n > 16`: 16 stays on `laptop`, 17 crosses to
232        // `workstation`. `gemm` differs between the two profiles, so it
233        // witnesses which branch was taken without needing PartialEq;
234        // comparing against the named profiles keeps the test on the boundary
235        // rather than duplicating the calibration constants.
236        assert_eq!(
237            ThresholdTable::profile_for_parallelism(16).gemm,
238            ThresholdTable::laptop().gemm
239        );
240        assert_eq!(
241            ThresholdTable::profile_for_parallelism(17).gemm,
242            ThresholdTable::workstation().gemm
243        );
244    }
245
246    #[test]
247    fn policy_by_n_sentinel_is_always_sequential() {
248        assert_eq!(
249            PerformanceManager::policy_by_n(usize::MAX, 0),
250            ExecPolicy::Sequential
251        );
252        assert_eq!(
253            PerformanceManager::policy_by_n(usize::MAX, usize::MAX),
254            ExecPolicy::Sequential
255        );
256    }
257}