ariadnetor_native/performance.rs
1//! Hardware-aware parallelism threshold tables for `NativeBackend`.
2//!
3//! A `ThresholdTable` stores the minimum problem-size key at which each
4//! linear-algebra op is worth running in parallel on this machine. The
5//! sentinel `usize::MAX` means "no finite parallel threshold" — either
6//! the op is unmeasured on this profile, or calibration showed no
7//! regime where parallel beats sequential (e.g. `ThresholdTable::laptop().transpose`).
8//! `PerformanceManager` pairs a table with the comparison logic that
9//! `NativeBackend::par_for_*` methods call.
10
11use ariadnetor_core::backend::ExecPolicy;
12
13/// Per-op parallelism thresholds.
14///
15/// Each field is the smallest problem-size key at which the op should
16/// dispatch as `ExecPolicy::Parallel(0)`. Keys are op-specific and
17/// produced by the corresponding `NativeBackend::par_for_*` method:
18/// `svd`/`qr`/`lq` and `gemm` use `cbrt(m*n*min(m,n))` and `cbrt(m*n*k)`
19/// respectively, `eigh`/`eig`/`solve` use `n`, `transpose` uses total
20/// element count.
21///
22/// `usize::MAX` marks "no finite parallel threshold": either unmeasured
23/// on this profile, or a calibrated decision that parallel never wins
24/// (e.g. `ThresholdTable::laptop().transpose`). `policy_by_n` treats it
25/// as `ExecPolicy::Sequential` in both cases.
26#[derive(Clone, Debug)]
27pub struct ThresholdTable {
28 /// SVD threshold; key is `cbrt(m*n*min(m,n))`.
29 pub svd: usize,
30 /// QR threshold; key is `cbrt(m*n*min(m,n))`.
31 pub qr: usize,
32 /// LQ threshold; key is `cbrt(m*n*min(m,n))`.
33 pub lq: usize,
34 /// Hermitian-eigendecomposition threshold; key is the dimension `n`.
35 pub eigh: usize,
36 /// General-eigendecomposition threshold; key is the dimension `n`.
37 pub eig: usize,
38 /// GEMM threshold; key is `cbrt(m*n*k)`.
39 pub gemm: usize,
40 /// Linear-solve threshold; key is the dimension `n`.
41 pub solve: usize,
42 /// Transpose threshold; key is the total element count.
43 pub transpose: usize,
44}
45
46impl ThresholdTable {
47 /// Thresholds calibrated for laptop-class CPUs (Apple M2 8-core).
48 ///
49 /// Values come from `crates/ariadnetor-linalg/benches/sweep_{decomp,
50 /// decomp_rect,gemm,solve,transpose}_par.rs` run in a single session.
51 ///
52 /// `transpose` is calibrated per backend at compile time. Under the
53 /// `hptt` feature the sweep showed no regime where Rayon-style
54 /// parallel can beat HPTT's tiled sequential on laptop, so the
55 /// sentinel `usize::MAX` is retained. Without it (the default build),
56 /// the naive fallback's simpler sequential loses to the parallel
57 /// kernel above ~65k total elements.
58 pub fn laptop() -> Self {
59 Self {
60 svd: 384,
61 qr: 384,
62 lq: 512,
63 eigh: 256,
64 eig: 256,
65 gemm: 192,
66 solve: 768,
67 transpose: if cfg!(feature = "hptt") {
68 usize::MAX
69 } else {
70 65_536
71 },
72 }
73 }
74
75 /// Thresholds calibrated for workstation-class CPUs (Xeon NUMA, 112 cores).
76 ///
77 /// Calibrated with the same five sweeps listed for `laptop()`. Most
78 /// ops carry the `usize::MAX` sentinel: at workstation scale parallel
79 /// sync cost is high enough that `svd`/`qr`/`lq`/`eigh`/`eig`/`solve`
80 /// never beat sequential at any `n ≤ 1024` tested. Only large GEMMs
81 /// (`cbrt(m*n*k) ≥ 768`) and transposes benefit from parallel
82 /// dispatch.
83 ///
84 /// `transpose` is calibrated per backend at compile time. Under
85 /// `hptt` the tiled kernel only crosses over at total element count
86 /// ≥ 4_194_304. Without it (the default build), the naive fallback
87 /// crosses over much earlier — its parallel kernel beats its own
88 /// sequential above ~262_144 total elements. Calibration was
89 /// performed on 2D `[n, n]` inputs; the dispatch key is total
90 /// elements for any rank.
91 pub fn workstation() -> Self {
92 Self {
93 svd: usize::MAX,
94 qr: usize::MAX,
95 lq: usize::MAX,
96 eigh: usize::MAX,
97 eig: usize::MAX,
98 gemm: 768,
99 solve: usize::MAX,
100 transpose: if cfg!(feature = "hptt") {
101 4_194_304
102 } else {
103 262_144
104 },
105 }
106 }
107
108 /// Pick a profile based on `std::thread::available_parallelism()`.
109 ///
110 /// Reads the logical-core count (falling back to the conservative `1`
111 /// when the query fails) and delegates the profile choice to
112 /// [`Self::profile_for_parallelism`].
113 pub fn detect() -> Self {
114 let n = std::thread::available_parallelism()
115 .map(|v| v.get())
116 .unwrap_or(1);
117 Self::profile_for_parallelism(n)
118 }
119
120 /// Map a logical-core count to a profile: `> 16` cores → `workstation`,
121 /// otherwise `laptop`. Kept pure (no environment read) so the boundary
122 /// is testable independently of the host's actual core count.
123 fn profile_for_parallelism(n: usize) -> Self {
124 if n > 16 {
125 Self::workstation()
126 } else {
127 Self::laptop()
128 }
129 }
130}
131
132/// Pairs a `ThresholdTable` with the comparison logic used by
133/// `NativeBackend::par_for_*` to translate a problem-size key into an
134/// `ExecPolicy`.
135#[derive(Clone, Debug)]
136pub struct PerformanceManager {
137 thresholds: ThresholdTable,
138}
139
140impl PerformanceManager {
141 /// Wrap a calibrated threshold table in a performance manager.
142 pub fn new(thresholds: ThresholdTable) -> Self {
143 Self { thresholds }
144 }
145
146 /// Borrow the underlying per-op threshold table.
147 pub fn thresholds(&self) -> &ThresholdTable {
148 &self.thresholds
149 }
150
151 /// Map a problem-size key to an `ExecPolicy`.
152 ///
153 /// Returns `Parallel(0)` iff the threshold is non-sentinel
154 /// (`!= usize::MAX`) and the key meets or exceeds it; otherwise
155 /// `Sequential`. The explicit `usize::MAX` check covers both
156 /// "unmeasured" thresholds and calibrated-no-win sentinels (see
157 /// the crate-level note on `usize::MAX` semantics) and prevents
158 /// either from ever tripping Parallel, even if `n` were also
159 /// `usize::MAX`.
160 pub(crate) fn policy_by_n(threshold: usize, n: usize) -> ExecPolicy {
161 if threshold != usize::MAX && n >= threshold {
162 ExecPolicy::Parallel(0)
163 } else {
164 ExecPolicy::Sequential
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn laptop_constants_pinned() {
175 let t = ThresholdTable::laptop();
176 assert_eq!(t.svd, 384);
177 assert_eq!(t.qr, 384);
178 assert_eq!(t.lq, 512);
179 assert_eq!(t.eigh, 256);
180 assert_eq!(t.eig, 256);
181 assert_eq!(t.gemm, 192);
182 assert_eq!(t.solve, 768);
183 #[cfg(feature = "hptt")]
184 assert_eq!(t.transpose, usize::MAX);
185 #[cfg(not(feature = "hptt"))]
186 assert_eq!(t.transpose, 65_536);
187 }
188
189 #[test]
190 fn workstation_constants_pinned() {
191 let t = ThresholdTable::workstation();
192 assert_eq!(t.svd, usize::MAX);
193 assert_eq!(t.qr, usize::MAX);
194 assert_eq!(t.lq, usize::MAX);
195 assert_eq!(t.eigh, usize::MAX);
196 assert_eq!(t.eig, usize::MAX);
197 assert_eq!(t.gemm, 768);
198 assert_eq!(t.solve, usize::MAX);
199 #[cfg(feature = "hptt")]
200 assert_eq!(t.transpose, 4_194_304);
201 #[cfg(not(feature = "hptt"))]
202 assert_eq!(t.transpose, 262_144);
203 }
204
205 #[test]
206 fn policy_by_n_below_threshold_is_sequential() {
207 assert_eq!(
208 PerformanceManager::policy_by_n(256, 255),
209 ExecPolicy::Sequential
210 );
211 }
212
213 #[test]
214 fn policy_by_n_at_threshold_is_parallel() {
215 assert_eq!(
216 PerformanceManager::policy_by_n(256, 256),
217 ExecPolicy::Parallel(0)
218 );
219 }
220
221 #[test]
222 fn policy_by_n_above_threshold_is_parallel() {
223 assert_eq!(
224 PerformanceManager::policy_by_n(256, 1024),
225 ExecPolicy::Parallel(0)
226 );
227 }
228
229 #[test]
230 fn profile_for_parallelism_pins_core_count_boundary() {
231 // The boundary is `n > 16`: 16 stays on `laptop`, 17 crosses to
232 // `workstation`. `gemm` differs between the two profiles, so it
233 // witnesses which branch was taken without needing PartialEq;
234 // comparing against the named profiles keeps the test on the boundary
235 // rather than duplicating the calibration constants.
236 assert_eq!(
237 ThresholdTable::profile_for_parallelism(16).gemm,
238 ThresholdTable::laptop().gemm
239 );
240 assert_eq!(
241 ThresholdTable::profile_for_parallelism(17).gemm,
242 ThresholdTable::workstation().gemm
243 );
244 }
245
246 #[test]
247 fn policy_by_n_sentinel_is_always_sequential() {
248 assert_eq!(
249 PerformanceManager::policy_by_n(usize::MAX, 0),
250 ExecPolicy::Sequential
251 );
252 assert_eq!(
253 PerformanceManager::policy_by_n(usize::MAX, usize::MAX),
254 ExecPolicy::Sequential
255 );
256 }
257}