Skip to main content

ariadnetor_core/
backend.rs

1//! Pluggable compute backend trait.
2//!
3//! [`ComputeBackend`] unifies the numerical primitives the algorithm
4//! layer needs (GEMM, transpose, SVD / QR / LQ / eigh / eig / solve)
5//! behind a single trait so the algorithm layer never names a
6//! concrete backend. Each backend declares its identity through the
7//! `name` / `device_type` / `preferred_order` accessors and then
8//! overrides only the operations it actually supports — the default
9//! implementations return [`BackendError::NotSupported`], so a
10//! partial backend still compiles. Per-call parallelism is selected
11//! by the caller through [`ExecPolicy`] and shaped by the
12//! per-operation `par_for_*` hooks; see those docstrings for how a
13//! given backend interprets `Parallel(n)`.
14
15use crate::scalar::Scalar;
16use num_complex::Complex;
17
18/// Device type for backend selection
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum DeviceType {
21    /// Host CPU.
22    Cpu,
23    /// NVIDIA GPU via CUDA.
24    Cuda,
25    /// Apple GPU via Metal.
26    Metal,
27}
28
29/// Memory layout order for tensor data.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum MemoryOrder {
32    /// Row-major (C order): last axis varies fastest.
33    RowMajor,
34    /// Column-major (Fortran order): first axis varies fastest.
35    ColumnMajor,
36}
37
38/// Per-call execution policy for a compute backend operation.
39///
40/// `Parallel(0)` means "backend auto" — faer uses rayon's
41/// `current_num_threads`, while HPTT resolves `0` via
42/// `std::thread::available_parallelism()` before crossing the FFI
43/// boundary (HPTT 0.4 rejects a literal `0`). `Parallel(n)` with
44/// `n > 0` is a target worker count whose strictness depends on the
45/// backend: HPTT spawns exactly `n` OpenMP threads, while faer and
46/// the naive Rayon kernel treat `n` as a partitioning hint dispatched
47/// on the global Rayon pool. `Sequential` forces single-threaded
48/// execution.
49#[derive(Copy, Clone, Debug, PartialEq, Eq)]
50pub enum ExecPolicy {
51    /// Force single-threaded execution.
52    Sequential,
53    /// Run in parallel with the given target worker count; `0` means
54    /// "backend auto" (see the type-level note for per-backend semantics).
55    Parallel(usize),
56}
57
58/// GEMM operation descriptor
59///
60/// Data layout (A, B, C slices) is specified by the `order` field.
61pub struct GemmDescriptor<'a, T> {
62    /// Rows of `op(A)` and of `C`.
63    pub m: usize,
64    /// Columns of `op(B)` and of `C`.
65    pub n: usize,
66    /// Contracted dimension: columns of `op(A)` and rows of `op(B)`.
67    pub k: usize,
68    /// Scalar applied to the `op(A) * op(B)` product.
69    pub alpha: T,
70    /// Operand `A` (`m×k`, or `k×m` when `trans_a`).
71    pub a: &'a [T],
72    /// Operand `B` (`k×n`, or `n×k` when `trans_b`).
73    pub b: &'a [T],
74    /// Scalar applied to the existing `C` before accumulation.
75    pub beta: T,
76    /// Operand / output `C` (`m×n`), overwritten with the result.
77    pub c: &'a mut [T],
78    /// Whether `A` is transposed, i.e. `op(A) = Aᵀ`.
79    pub trans_a: bool,
80    /// Whether `B` is transposed, i.e. `op(B) = Bᵀ`.
81    pub trans_b: bool,
82    /// Memory layout of the `A` / `B` / `C` slices.
83    pub order: MemoryOrder,
84    /// Per-call execution policy.
85    pub policy: ExecPolicy,
86}
87
88/// Transpose operation descriptor
89pub struct TransposeDescriptor<'a, T> {
90    /// Input tensor in `shape` order.
91    pub input: &'a [T],
92    /// Output buffer receiving the permuted tensor.
93    pub output: &'a mut [T],
94    /// Shape of the input tensor.
95    pub shape: &'a [usize],
96    /// Axis permutation: output axis `i` is input axis `perm[i]`.
97    pub perm: &'a [usize],
98    /// Memory layout of the `input` / `output` slices.
99    pub order: MemoryOrder,
100    /// Apply element-wise complex conjugation during transpose.
101    /// No-op for real types.
102    pub conj: bool,
103    /// Per-call execution policy.
104    pub policy: ExecPolicy,
105}
106
107/// Thin SVD operation descriptor: A = U * diag(S) * Vt
108///
109/// Computes the thin SVD of an m×n matrix A.
110/// Data layout (A, U, Vt slices) is specified by the `order` field;
111/// a backend that does not support a given order returns
112/// [`BackendError::InvalidArgument`].
113/// Outputs: U (m×k), S (k singular values), Vt (k×n)
114/// where k = min(m, n).
115pub struct SvdDescriptor<'a, T: Scalar> {
116    /// Rows of `A`.
117    pub m: usize,
118    /// Columns of `A`.
119    pub n: usize,
120    /// Input matrix `A` (`m×n`).
121    pub a: &'a [T],
122    /// Output left singular vectors `U` (`m×k`, `k = min(m, n)`).
123    pub u: &'a mut [T],
124    /// Output singular values `S` (`k` real values).
125    pub s: &'a mut [T::Real],
126    /// Output right singular vectors `Vᴴ` (`k×n`).
127    pub vt: &'a mut [T],
128    /// Memory layout of the matrix slices.
129    pub order: MemoryOrder,
130    /// Per-call execution policy.
131    pub policy: ExecPolicy,
132}
133
134/// Thin QR decomposition descriptor: A = Q * R
135///
136/// Computes the thin QR of an m×n matrix A.
137/// Data layout (A, Q, R slices) is specified by the `order` field;
138/// a backend that does not support a given order returns
139/// [`BackendError::InvalidArgument`].
140/// Outputs: Q (m×k), R (k×n)
141/// where k = min(m, n).
142pub struct QrDescriptor<'a, T> {
143    /// Rows of `A`.
144    pub m: usize,
145    /// Columns of `A`.
146    pub n: usize,
147    /// Input matrix `A` (`m×n`).
148    pub a: &'a [T],
149    /// Output orthonormal factor `Q` (`m×k`, `k = min(m, n)`).
150    pub q: &'a mut [T],
151    /// Output upper-triangular factor `R` (`k×n`).
152    pub r: &'a mut [T],
153    /// Memory layout of the matrix slices.
154    pub order: MemoryOrder,
155    /// Per-call execution policy.
156    pub policy: ExecPolicy,
157}
158
159/// Thin LQ decomposition descriptor: A = L * Q
160///
161/// Computes the thin LQ of an m×n matrix A.
162/// Data layout (A, L, Q slices) is specified by the `order` field;
163/// a backend that does not support a given order returns
164/// [`BackendError::InvalidArgument`].
165/// Outputs: L (m×k), Q (k×n)
166/// where k = min(m, n).
167pub struct LqDescriptor<'a, T> {
168    /// Rows of `A`.
169    pub m: usize,
170    /// Columns of `A`.
171    pub n: usize,
172    /// Input matrix `A` (`m×n`).
173    pub a: &'a [T],
174    /// Output lower-triangular factor `L` (`m×k`, `k = min(m, n)`).
175    pub l: &'a mut [T],
176    /// Output orthonormal factor `Q` (`k×n`).
177    pub q: &'a mut [T],
178    /// Memory layout of the matrix slices.
179    pub order: MemoryOrder,
180    /// Per-call execution policy.
181    pub policy: ExecPolicy,
182}
183
184/// Self-adjoint eigenvalue decomposition descriptor: A = V * diag(W) * V^H
185///
186/// Computes eigenvalues and eigenvectors of an n×n self-adjoint matrix A.
187/// Data layout (A, V slices) is specified by the `order` field;
188/// a backend that does not support a given order returns
189/// [`BackendError::InvalidArgument`].
190/// Outputs: W (n real eigenvalues, ascending), V (n×n eigenvectors)
191pub struct EighDescriptor<'a, T: Scalar> {
192    /// Dimension of the square matrix `A`.
193    pub n: usize,
194    /// Input self-adjoint matrix `A` (`n×n`).
195    pub a: &'a [T],
196    /// Output eigenvalues `W` (`n` real values, ascending).
197    pub w: &'a mut [T::Real],
198    /// Output eigenvectors `V` (`n×n`).
199    pub v: &'a mut [T],
200    /// Memory layout of the matrix slices.
201    pub order: MemoryOrder,
202    /// Per-call execution policy.
203    pub policy: ExecPolicy,
204}
205
206/// General eigenvalue decomposition descriptor
207///
208/// Computes eigenvalues and right eigenvectors of an n×n matrix A.
209/// Data layout (A, V slices) is specified by the `order` field;
210/// a backend that does not support a given order returns
211/// [`BackendError::InvalidArgument`].
212/// Outputs are always complex: W (n complex eigenvalues), V (n×n eigenvectors)
213pub struct EigDescriptor<'a, T: Scalar> {
214    /// Dimension of the square matrix `A`.
215    pub n: usize,
216    /// Input matrix `A` (`n×n`).
217    pub a: &'a [T],
218    /// Output complex eigenvalues `W` (`n`).
219    pub w: &'a mut [T::Complex],
220    /// Output complex right eigenvectors `V` (`n×n`).
221    pub v: &'a mut [T::Complex],
222    /// Memory layout of the matrix slices.
223    pub order: MemoryOrder,
224    /// Per-call execution policy.
225    pub policy: ExecPolicy,
226}
227
228/// Linear solve descriptor: AX = B via LU decomposition
229///
230/// Solves the system AX = B where A is an n×n matrix and B is n×nrhs.
231/// Data layout (A, B, X slices) is specified by the `order` field;
232/// a backend that does not support a given order returns
233/// [`BackendError::InvalidArgument`].
234/// Output X is written to `x` (n×nrhs).
235pub struct SolveDescriptor<'a, T> {
236    /// Dimension of the square coefficient matrix `A`.
237    pub n: usize,
238    /// Number of right-hand-side columns.
239    pub nrhs: usize,
240    /// Coefficient matrix `A` (`n×n`).
241    pub a: &'a [T],
242    /// Right-hand side `B` (`n×nrhs`).
243    pub b: &'a [T],
244    /// Output solution `X` (`n×nrhs`).
245    pub x: &'a mut [T],
246    /// Memory layout of the matrix slices.
247    pub order: MemoryOrder,
248    /// Per-call execution policy.
249    pub policy: ExecPolicy,
250}
251
252/// One generic-descriptor backend operation, tagged by which op it is.
253///
254/// This is the unit that [`DispatchScalar::dispatch_op`] carries from a generic
255/// `T: Scalar` context down to a concrete per-type kernel. Bundling every op
256/// into one enum lets a backend expose a single typed entry point per scalar
257/// (see [`ScalarKernels`]) instead of one per `(op, type)` pair, which is what
258/// makes type-directed dispatch possible without reinterpreting a
259/// `Descriptor<T>` into a `Descriptor<concrete>` through `unsafe`.
260pub enum OpDesc<'a, T: Scalar> {
261    /// GEMM operation.
262    Gemm(GemmDescriptor<'a, T>),
263    /// Thin SVD operation.
264    Svd(SvdDescriptor<'a, T>),
265    /// Thin QR operation.
266    Qr(QrDescriptor<'a, T>),
267    /// Thin LQ operation.
268    Lq(LqDescriptor<'a, T>),
269    /// Self-adjoint eigendecomposition operation.
270    Eigh(EighDescriptor<'a, T>),
271    /// General eigendecomposition operation.
272    Eig(EigDescriptor<'a, T>),
273    /// Linear-solve operation.
274    Solve(SolveDescriptor<'a, T>),
275    /// Tensor-transpose operation.
276    Transpose(TransposeDescriptor<'a, T>),
277}
278
279/// A backend's concrete per-scalar kernels, one entry point per supported type.
280///
281/// [`DispatchScalar::dispatch_op`] resolves a generic `OpDesc<'_, T>` to exactly
282/// one of these methods, so inside each method the scalar is concrete and the op
283/// match dispatches to a monomorphic kernel directly. A backend implements this
284/// on a local kernel-set type; the four methods mirror the four sealed [`Scalar`]
285/// types.
286pub trait ScalarKernels {
287    /// Run an operation with `f64` scalars.
288    fn run_f64(&self, op: OpDesc<'_, f64>) -> Result<(), BackendError>;
289    /// Run an operation with `f32` scalars.
290    fn run_f32(&self, op: OpDesc<'_, f32>) -> Result<(), BackendError>;
291    /// Run an operation with `Complex<f64>` scalars.
292    fn run_c64(&self, op: OpDesc<'_, Complex<f64>>) -> Result<(), BackendError>;
293    /// Run an operation with `Complex<f32>` scalars.
294    fn run_c32(&self, op: OpDesc<'_, Complex<f32>>) -> Result<(), BackendError>;
295}
296
297/// Type-directed dispatch hook: reach a concrete per-type kernel from a generic
298/// `T: Scalar`.
299///
300/// A backend method bounded only by `T: Scalar` cannot name a per-type kernel
301/// directly. This supertrait of [`Scalar`] lets it call `T::dispatch_op(...)`,
302/// which forwards to the matching [`ScalarKernels`] method where the scalar is
303/// concrete — a type-level branch in place of an `unsafe`
304/// `Descriptor<T>` -> `Descriptor<concrete>` reinterpretation. It is dispatch
305/// plumbing between [`ComputeBackend`] and a backend's [`ScalarKernels`], not a
306/// user entry point.
307///
308/// It is kept separate from [`Scalar`] so that `Scalar`'s own method list carries
309/// no backend descriptor / error / kernel types; the supertrait bound still makes
310/// every `Scalar` a `DispatchScalar`. The `where Self: Scalar` bound (rather than
311/// `trait DispatchScalar: Scalar`) avoids a cycle with that supertrait while still
312/// admitting `OpDesc<'_, Self>`, which requires `Self: Scalar`. Sealed: only the
313/// four built-in scalar types implement it.
314pub trait DispatchScalar: sealed::Sealed {
315    /// Forward a generic `OpDesc<'_, Self>` to the concrete [`ScalarKernels`]
316    /// method matching `Self`, turning a type parameter into a type-level branch.
317    fn dispatch_op<K: ScalarKernels>(kernels: &K, op: OpDesc<'_, Self>) -> Result<(), BackendError>
318    where
319        Self: Scalar;
320}
321
322mod sealed {
323    pub trait Sealed {}
324    impl Sealed for f32 {}
325    impl Sealed for f64 {}
326    impl Sealed for num_complex::Complex<f32> {}
327    impl Sealed for num_complex::Complex<f64> {}
328}
329
330impl DispatchScalar for f64 {
331    #[inline]
332    fn dispatch_op<K: ScalarKernels>(
333        kernels: &K,
334        op: OpDesc<'_, Self>,
335    ) -> Result<(), BackendError> {
336        kernels.run_f64(op)
337    }
338}
339
340impl DispatchScalar for f32 {
341    #[inline]
342    fn dispatch_op<K: ScalarKernels>(
343        kernels: &K,
344        op: OpDesc<'_, Self>,
345    ) -> Result<(), BackendError> {
346        kernels.run_f32(op)
347    }
348}
349
350impl DispatchScalar for Complex<f64> {
351    #[inline]
352    fn dispatch_op<K: ScalarKernels>(
353        kernels: &K,
354        op: OpDesc<'_, Self>,
355    ) -> Result<(), BackendError> {
356        kernels.run_c64(op)
357    }
358}
359
360impl DispatchScalar for Complex<f32> {
361    #[inline]
362    fn dispatch_op<K: ScalarKernels>(
363        kernels: &K,
364        op: OpDesc<'_, Self>,
365    ) -> Result<(), BackendError> {
366        kernels.run_c32(op)
367    }
368}
369
370/// Pluggable compute backend trait
371pub trait ComputeBackend: Send + Sync {
372    /// Backend name
373    fn name(&self) -> &'static str;
374
375    /// Device type
376    fn device_type(&self) -> DeviceType;
377
378    /// Preferred memory order for this backend's data layout.
379    ///
380    /// Descriptor data (input/output slices) is expected in this order.
381    /// The linalg layer converts tensors to this order before constructing descriptors.
382    ///
383    /// This is an **implementor-facing contract**, not a user entry point:
384    /// backend implementors must report the layout their kernels assume so
385    /// the linalg / algorithm layers can normalize to it. End users never
386    /// call it — the public `Tensor` surface hides memory layout entirely.
387    fn preferred_order(&self) -> MemoryOrder;
388
389    /// Check if backend is available
390    fn is_available(&self) -> bool {
391        true
392    }
393
394    /// GEMM: C = alpha * A * B + beta * C
395    fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError>;
396
397    /// Transpose tensor
398    fn transpose<T: Scalar>(&self, desc: TransposeDescriptor<'_, T>) -> Result<(), BackendError>;
399
400    /// Thin SVD: A = U * diag(S) * Vt
401    fn svd<T: Scalar>(&self, _desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
402        Err(BackendError::NotSupported("svd".into()))
403    }
404
405    /// Thin QR: A = Q * R
406    fn qr<T: Scalar>(&self, _desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
407        Err(BackendError::NotSupported("qr".into()))
408    }
409
410    /// Thin LQ: A = L * Q
411    fn lq<T: Scalar>(&self, _desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
412        Err(BackendError::NotSupported("lq".into()))
413    }
414
415    /// Self-adjoint eigenvalue decomposition: A = V * diag(W) * V^H
416    fn eigh<T: Scalar>(&self, _desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
417        Err(BackendError::NotSupported("eigh".into()))
418    }
419
420    /// General eigenvalue decomposition
421    fn eig<T: Scalar>(&self, _desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
422        Err(BackendError::NotSupported("eig".into()))
423    }
424
425    /// Linear solve: AX = B via LU decomposition
426    fn solve<T: Scalar>(&self, _desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
427        Err(BackendError::NotSupported("solve".into()))
428    }
429
430    /// Recommended execution policy for SVD at the given problem size.
431    ///
432    /// Default returns `Sequential`; performance-oriented backends (e.g. `NativeBackend`)
433    /// override this with a hardware-aware threshold table.
434    fn par_for_svd(&self, _m: usize, _n: usize) -> ExecPolicy {
435        ExecPolicy::Sequential
436    }
437
438    /// Recommended execution policy for QR at the given problem size.
439    fn par_for_qr(&self, _m: usize, _n: usize) -> ExecPolicy {
440        ExecPolicy::Sequential
441    }
442
443    /// Recommended execution policy for LQ at the given problem size.
444    fn par_for_lq(&self, _m: usize, _n: usize) -> ExecPolicy {
445        ExecPolicy::Sequential
446    }
447
448    /// Recommended execution policy for self-adjoint eigendecomposition.
449    fn par_for_eigh(&self, _n: usize) -> ExecPolicy {
450        ExecPolicy::Sequential
451    }
452
453    /// Recommended execution policy for general eigendecomposition.
454    fn par_for_eig(&self, _n: usize) -> ExecPolicy {
455        ExecPolicy::Sequential
456    }
457
458    /// Recommended execution policy for GEMM at the given problem size.
459    fn par_for_gemm(&self, _m: usize, _n: usize, _k: usize) -> ExecPolicy {
460        ExecPolicy::Sequential
461    }
462
463    /// Recommended execution policy for linear solve.
464    fn par_for_solve(&self, _n: usize, _nrhs: usize) -> ExecPolicy {
465        ExecPolicy::Sequential
466    }
467
468    /// Recommended execution policy for tensor transpose.
469    fn par_for_transpose(&self, _shape: &[usize]) -> ExecPolicy {
470        ExecPolicy::Sequential
471    }
472}
473
474/// Error originating from a compute backend.
475///
476/// All variants represent conditions detected by or attributed to the backend.
477/// Linalg-layer validation (nrow range, square matrix checks, etc.) should use
478/// a separate error mechanism, not `BackendError`.
479///
480/// Every variant carries its full context in its own `Display` message; none
481/// wraps a structured inner error. `BackendError` is therefore a leaf in the
482/// error chain — its `source()` is always `None`.
483#[derive(Debug, thiserror::Error)]
484pub enum BackendError {
485    /// The backend does not support this operation.
486    ///
487    /// Returned when an operation is fundamentally unavailable on this backend
488    /// (e.g., a GPU backend that lacks an eigenvalue solver). Upper layers
489    /// should consider fallback strategies or alternative computation paths.
490    #[error("Not supported: {0}")]
491    NotSupported(String),
492
493    /// The descriptor passed to the backend violates its contract.
494    ///
495    /// This indicates a bug in the calling layer (typically linalg), not a user
496    /// error. For example, buffer sizes inconsistent with declared dimensions.
497    /// Callers should treat this as a panic-worthy condition in debug builds.
498    #[error("Invalid argument: {0}")]
499    InvalidArgument(String),
500
501    /// The computation failed at runtime.
502    ///
503    /// The operation was supported and the arguments were valid, but execution
504    /// failed due to numerical issues, resource exhaustion, or other runtime
505    /// conditions (e.g., a matrix factorization that fails to converge).
506    #[error("Execution failed: {0}")]
507    ExecutionFailed(String),
508}