Skip to main content

ariadnetor_native/
lib.rs

1//! CPU compute backend for ariadnetor
2//!
3//! Provides [`NativeBackend`] implementing `ComputeBackend` via:
4//! - **GEMM**: faer (f64, f32, `Complex<f64>`, `Complex<f32>`)
5//! - **SVD/QR/LQ/EIGH**: faer (f64, f32, `Complex<f64>`, `Complex<f32>`)
6//! - **Transpose**: HPTT (f64, f32, Complex) when the `hptt` feature is on, a naive kernel otherwise
7
8#![deny(missing_docs)]
9
10mod eig;
11mod eigh;
12mod gemm;
13mod lq;
14mod performance;
15mod qr;
16mod solve;
17mod svd;
18mod transpose;
19
20use std::sync::{Arc, OnceLock};
21
22use ariadnetor_core::Scalar;
23use ariadnetor_core::backend::{
24    BackendError, ComputeBackend, DeviceType, EigDescriptor, EighDescriptor, ExecPolicy,
25    GemmDescriptor, LqDescriptor, MemoryOrder, OpDesc, QrDescriptor, ScalarKernels,
26    SolveDescriptor, SvdDescriptor, TransposeDescriptor,
27};
28use num_complex::Complex;
29
30pub use performance::{PerformanceManager, ThresholdTable};
31
32/// Map an [`ExecPolicy`] to faer's per-call parallelism selector.
33///
34/// `Parallel(0)` defers to faer's Rayon default (current thread pool size).
35/// `Parallel(n)` for `n > 0` is an advisory thread-count hint passed to
36/// `faer::Par::rayon(n)`; faer dispatches on the global Rayon pool, so
37/// `n` influences work partitioning rather than guaranteeing exactly
38/// `n` OS threads. The naive transpose kernel honors `n` with the same
39/// semantics.
40pub(crate) fn to_faer_par(policy: ExecPolicy) -> faer::Par {
41    match policy {
42        ExecPolicy::Sequential => faer::Par::Seq,
43        ExecPolicy::Parallel(0) => faer::Par::rayon(0),
44        ExecPolicy::Parallel(n) => faer::Par::rayon(n),
45    }
46}
47
48/// Native backend using faer for GEMM and, with the `hptt` feature, HPTT for
49/// transpose (a naive kernel otherwise).
50///
51/// This is the sole owner of faer and hptt-rs dependencies in the workspace.
52/// Other crates access these capabilities through the `ComputeBackend` trait.
53/// Holds a `PerformanceManager` that drives the `par_for_*` dispatch
54/// decisions for each op based on a hardware-aware threshold table.
55#[derive(Debug, Clone)]
56pub struct NativeBackend {
57    perf: PerformanceManager,
58}
59
60impl NativeBackend {
61    /// Construct a `NativeBackend` with thresholds auto-detected from the
62    /// current machine via `ThresholdTable::detect()`.
63    pub fn new() -> Self {
64        Self {
65            perf: PerformanceManager::new(ThresholdTable::detect()),
66        }
67    }
68
69    /// Construct a `NativeBackend` with a user-supplied `PerformanceManager`.
70    ///
71    /// Use this to override the auto-detected threshold table, e.g. to pin
72    /// the laptop profile on a workstation for reproducible benchmarks.
73    pub fn with_perf(perf: PerformanceManager) -> Self {
74        Self { perf }
75    }
76
77    /// Borrow the `PerformanceManager` driving this backend's dispatch.
78    pub fn perf(&self) -> &PerformanceManager {
79        &self.perf
80    }
81
82    /// Get a shared singleton instance.
83    ///
84    /// All tensors using the default backend share this single Arc,
85    /// avoiding per-tensor allocation.
86    pub fn shared() -> Arc<NativeBackend> {
87        static INSTANCE: OnceLock<Arc<NativeBackend>> = OnceLock::new();
88        INSTANCE
89            .get_or_init(|| Arc::new(NativeBackend::new()))
90            .clone()
91    }
92}
93
94impl Default for NativeBackend {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100/// faer-backed decomposition / solve kernels accept column-major slices only.
101/// Reject any other order at the dispatcher boundary so per-type kernels
102/// never see a layout they cannot interpret.
103fn require_column_major(op: &str, order: MemoryOrder) -> Result<(), BackendError> {
104    if order != MemoryOrder::ColumnMajor {
105        return Err(BackendError::InvalidArgument(format!(
106            "NativeBackend::{op} supports ColumnMajor only, got {order:?}"
107        )));
108    }
109    Ok(())
110}
111
112impl ComputeBackend for NativeBackend {
113    fn name(&self) -> &'static str {
114        "cpu"
115    }
116
117    fn device_type(&self) -> DeviceType {
118        DeviceType::Cpu
119    }
120
121    fn preferred_order(&self) -> MemoryOrder {
122        MemoryOrder::ColumnMajor
123    }
124
125    /// GEMM: C = alpha * A * B + beta * C
126    ///
127    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
128    fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError> {
129        T::dispatch_op(&NativeKernels, OpDesc::Gemm(desc))
130    }
131
132    /// Transpose tensor axes according to permutation.
133    ///
134    /// Uses HPTT for f64/f32/Complex when the `hptt` feature is enabled,
135    /// otherwise a naive output-driven kernel.
136    fn transpose<T: Scalar>(&self, desc: TransposeDescriptor<'_, T>) -> Result<(), BackendError> {
137        T::dispatch_op(&NativeKernels, OpDesc::Transpose(desc))
138    }
139
140    /// Thin SVD via faer: A = U * diag(S) * Vt
141    ///
142    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
143    /// For complex types, Vt stores V^H (conjugate transpose).
144    /// faer's SVD is column-major only; descriptors with any other
145    /// order are rejected with `BackendError::InvalidArgument`.
146    fn svd<T: Scalar>(&self, desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
147        require_column_major("svd", desc.order)?;
148        T::dispatch_op(&NativeKernels, OpDesc::Svd(desc))
149    }
150
151    /// Thin QR via faer: A = Q * R
152    ///
153    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
154    /// faer's QR is column-major only; descriptors with any other
155    /// order are rejected with `BackendError::InvalidArgument`.
156    fn qr<T: Scalar>(&self, desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
157        require_column_major("qr", desc.order)?;
158        T::dispatch_op(&NativeKernels, OpDesc::Qr(desc))
159    }
160
161    /// Thin LQ via faer: A = L * Q
162    ///
163    /// Internally computes QR of A^H (adjoint), then takes conjugate transposes.
164    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
165    /// faer's QR (and hence this LQ) is column-major only; descriptors
166    /// with any other order are rejected with `BackendError::InvalidArgument`.
167    fn lq<T: Scalar>(&self, desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
168        require_column_major("lq", desc.order)?;
169        T::dispatch_op(&NativeKernels, OpDesc::Lq(desc))
170    }
171
172    /// Self-adjoint eigenvalue decomposition via faer
173    ///
174    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
175    /// faer's eigendecomposition is column-major only; descriptors with
176    /// any other order are rejected with `BackendError::InvalidArgument`.
177    fn eigh<T: Scalar>(&self, desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
178        require_column_major("eigh", desc.order)?;
179        T::dispatch_op(&NativeKernels, OpDesc::Eigh(desc))
180    }
181
182    /// General eigenvalue decomposition via faer
183    ///
184    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
185    /// faer's eigendecomposition is column-major only; descriptors with
186    /// any other order are rejected with `BackendError::InvalidArgument`.
187    fn eig<T: Scalar>(&self, desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
188        require_column_major("eig", desc.order)?;
189        T::dispatch_op(&NativeKernels, OpDesc::Eig(desc))
190    }
191
192    /// Linear solve via faer LU decomposition with partial pivoting
193    ///
194    /// Dispatches to faer for f64/f32/`Complex<f64>`/`Complex<f32>`.
195    /// faer's LU is column-major only; descriptors with any other
196    /// order are rejected with `BackendError::InvalidArgument`.
197    fn solve<T: Scalar>(&self, desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
198        require_column_major("solve", desc.order)?;
199        T::dispatch_op(&NativeKernels, OpDesc::Solve(desc))
200    }
201
202    fn par_for_svd(&self, m: usize, n: usize) -> ExecPolicy {
203        let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
204        PerformanceManager::policy_by_n(self.perf.thresholds().svd, work_proxy)
205    }
206
207    fn par_for_qr(&self, m: usize, n: usize) -> ExecPolicy {
208        let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
209        PerformanceManager::policy_by_n(self.perf.thresholds().qr, work_proxy)
210    }
211
212    fn par_for_lq(&self, m: usize, n: usize) -> ExecPolicy {
213        let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
214        PerformanceManager::policy_by_n(self.perf.thresholds().lq, work_proxy)
215    }
216
217    fn par_for_eigh(&self, n: usize) -> ExecPolicy {
218        PerformanceManager::policy_by_n(self.perf.thresholds().eigh, n)
219    }
220
221    fn par_for_eig(&self, n: usize) -> ExecPolicy {
222        PerformanceManager::policy_by_n(self.perf.thresholds().eig, n)
223    }
224
225    fn par_for_gemm(&self, m: usize, n: usize, k: usize) -> ExecPolicy {
226        let work_proxy = (m as f64 * n as f64 * k as f64).cbrt() as usize;
227        PerformanceManager::policy_by_n(self.perf.thresholds().gemm, work_proxy)
228    }
229
230    fn par_for_solve(&self, n: usize, _nrhs: usize) -> ExecPolicy {
231        PerformanceManager::policy_by_n(self.perf.thresholds().solve, n)
232    }
233
234    fn par_for_transpose(&self, shape: &[usize]) -> ExecPolicy {
235        // Saturate on overflow so very large shapes don't wrap below the
236        // threshold and silently dispatch Sequential.
237        let total: usize = shape.iter().copied().fold(1usize, usize::saturating_mul);
238        PerformanceManager::policy_by_n(self.perf.thresholds().transpose, total)
239    }
240}
241
242/// faer / HPTT kernel set the call-site dispatcher routes to.
243///
244/// `DispatchScalar::dispatch_op` resolves a generic `OpDesc<'_, T>` to one of
245/// these four methods, where the scalar is concrete; each method then matches the op
246/// and calls the corresponding monomorphic kernel directly. This is what lets
247/// the generic `ComputeBackend` methods reach the per-type kernels without an
248/// `unsafe` `Descriptor<T>` -> `Descriptor<concrete>` reinterpretation.
249struct NativeKernels;
250
251impl ScalarKernels for NativeKernels {
252    fn run_f64(&self, op: OpDesc<'_, f64>) -> Result<(), BackendError> {
253        match op {
254            OpDesc::Gemm(d) => gemm::gemm_f64(d),
255            OpDesc::Svd(d) => svd::svd_f64(d),
256            OpDesc::Qr(d) => qr::qr_f64(d),
257            OpDesc::Lq(d) => lq::lq_f64(d),
258            OpDesc::Eigh(d) => eigh::eigh_f64(d),
259            OpDesc::Eig(d) => eig::eig_f64(d),
260            OpDesc::Solve(d) => solve::solve_f64(d),
261            OpDesc::Transpose(d) => transpose::transpose_f64(d),
262        }
263    }
264
265    fn run_f32(&self, op: OpDesc<'_, f32>) -> Result<(), BackendError> {
266        match op {
267            OpDesc::Gemm(d) => gemm::gemm_f32(d),
268            OpDesc::Svd(d) => svd::svd_f32(d),
269            OpDesc::Qr(d) => qr::qr_f32(d),
270            OpDesc::Lq(d) => lq::lq_f32(d),
271            OpDesc::Eigh(d) => eigh::eigh_f32(d),
272            OpDesc::Eig(d) => eig::eig_f32(d),
273            OpDesc::Solve(d) => solve::solve_f32(d),
274            OpDesc::Transpose(d) => transpose::transpose_f32(d),
275        }
276    }
277
278    fn run_c64(&self, op: OpDesc<'_, Complex<f64>>) -> Result<(), BackendError> {
279        match op {
280            OpDesc::Gemm(d) => gemm::gemm_c64(d),
281            OpDesc::Svd(d) => svd::svd_c64(d),
282            OpDesc::Qr(d) => qr::qr_c64(d),
283            OpDesc::Lq(d) => lq::lq_c64(d),
284            OpDesc::Eigh(d) => eigh::eigh_c64(d),
285            OpDesc::Eig(d) => eig::eig_c64(d),
286            OpDesc::Solve(d) => solve::solve_c64(d),
287            OpDesc::Transpose(d) => transpose::transpose_c64(d),
288        }
289    }
290
291    fn run_c32(&self, op: OpDesc<'_, Complex<f32>>) -> Result<(), BackendError> {
292        match op {
293            OpDesc::Gemm(d) => gemm::gemm_c32(d),
294            OpDesc::Svd(d) => svd::svd_c32(d),
295            OpDesc::Qr(d) => qr::qr_c32(d),
296            OpDesc::Lq(d) => lq::lq_c32(d),
297            OpDesc::Eigh(d) => eigh::eigh_c32(d),
298            OpDesc::Eig(d) => eig::eig_c32(d),
299            OpDesc::Solve(d) => solve::solve_c32(d),
300            OpDesc::Transpose(d) => transpose::transpose_c32(d),
301        }
302    }
303}