ariadnetor_core/backend.rs
1//! Pluggable compute backend trait.
2//!
3//! [`ComputeBackend`] unifies the numerical primitives the algorithm
4//! layer needs (GEMM, transpose, SVD / QR / LQ / eigh / eig / solve)
5//! behind a single trait so the algorithm layer never names a
6//! concrete backend. Each backend declares its identity through the
7//! `name` / `device_type` / `preferred_order` accessors and then
8//! overrides only the operations it actually supports — the default
9//! implementations return [`BackendError::NotSupported`], so a
10//! partial backend still compiles. Per-call parallelism is selected
11//! by the caller through [`ExecPolicy`] and shaped by the
12//! per-operation `par_for_*` hooks; see those docstrings for how a
13//! given backend interprets `Parallel(n)`.
14
15use crate::scalar::Scalar;
16use num_complex::Complex;
17
18/// Device type for backend selection
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum DeviceType {
21 /// Host CPU.
22 Cpu,
23 /// NVIDIA GPU via CUDA.
24 Cuda,
25 /// Apple GPU via Metal.
26 Metal,
27}
28
29/// Memory layout order for tensor data.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum MemoryOrder {
32 /// Row-major (C order): last axis varies fastest.
33 RowMajor,
34 /// Column-major (Fortran order): first axis varies fastest.
35 ColumnMajor,
36}
37
38/// Per-call execution policy for a compute backend operation.
39///
40/// `Parallel(0)` means "backend auto" — faer uses rayon's
41/// `current_num_threads`, while HPTT resolves `0` via
42/// `std::thread::available_parallelism()` before crossing the FFI
43/// boundary (HPTT 0.4 rejects a literal `0`). `Parallel(n)` with
44/// `n > 0` is a target worker count whose strictness depends on the
45/// backend: HPTT spawns exactly `n` OpenMP threads, while faer and
46/// the naive Rayon kernel treat `n` as a partitioning hint dispatched
47/// on the global Rayon pool. `Sequential` forces single-threaded
48/// execution.
49#[derive(Copy, Clone, Debug, PartialEq, Eq)]
50pub enum ExecPolicy {
51 /// Force single-threaded execution.
52 Sequential,
53 /// Run in parallel with the given target worker count; `0` means
54 /// "backend auto" (see the type-level note for per-backend semantics).
55 Parallel(usize),
56}
57
58/// GEMM operation descriptor
59///
60/// Data layout (A, B, C slices) is specified by the `order` field.
61pub struct GemmDescriptor<'a, T> {
62 /// Rows of `op(A)` and of `C`.
63 pub m: usize,
64 /// Columns of `op(B)` and of `C`.
65 pub n: usize,
66 /// Contracted dimension: columns of `op(A)` and rows of `op(B)`.
67 pub k: usize,
68 /// Scalar applied to the `op(A) * op(B)` product.
69 pub alpha: T,
70 /// Operand `A` (`m×k`, or `k×m` when `trans_a`).
71 pub a: &'a [T],
72 /// Operand `B` (`k×n`, or `n×k` when `trans_b`).
73 pub b: &'a [T],
74 /// Scalar applied to the existing `C` before accumulation.
75 pub beta: T,
76 /// Operand / output `C` (`m×n`), overwritten with the result.
77 pub c: &'a mut [T],
78 /// Whether `A` is transposed, i.e. `op(A) = Aᵀ`.
79 pub trans_a: bool,
80 /// Whether `B` is transposed, i.e. `op(B) = Bᵀ`.
81 pub trans_b: bool,
82 /// Memory layout of the `A` / `B` / `C` slices.
83 pub order: MemoryOrder,
84 /// Per-call execution policy.
85 pub policy: ExecPolicy,
86}
87
88/// Transpose operation descriptor
89pub struct TransposeDescriptor<'a, T> {
90 /// Input tensor in `shape` order.
91 pub input: &'a [T],
92 /// Output buffer receiving the permuted tensor.
93 pub output: &'a mut [T],
94 /// Shape of the input tensor.
95 pub shape: &'a [usize],
96 /// Axis permutation: output axis `i` is input axis `perm[i]`.
97 pub perm: &'a [usize],
98 /// Memory layout of the `input` / `output` slices.
99 pub order: MemoryOrder,
100 /// Apply element-wise complex conjugation during transpose.
101 /// No-op for real types.
102 pub conj: bool,
103 /// Per-call execution policy.
104 pub policy: ExecPolicy,
105}
106
107/// Thin SVD operation descriptor: A = U * diag(S) * Vt
108///
109/// Computes the thin SVD of an m×n matrix A.
110/// Data layout (A, U, Vt slices) is specified by the `order` field;
111/// a backend that does not support a given order returns
112/// [`BackendError::InvalidArgument`].
113/// Outputs: U (m×k), S (k singular values), Vt (k×n)
114/// where k = min(m, n).
115pub struct SvdDescriptor<'a, T: Scalar> {
116 /// Rows of `A`.
117 pub m: usize,
118 /// Columns of `A`.
119 pub n: usize,
120 /// Input matrix `A` (`m×n`).
121 pub a: &'a [T],
122 /// Output left singular vectors `U` (`m×k`, `k = min(m, n)`).
123 pub u: &'a mut [T],
124 /// Output singular values `S` (`k` real values).
125 pub s: &'a mut [T::Real],
126 /// Output right singular vectors `Vᴴ` (`k×n`).
127 pub vt: &'a mut [T],
128 /// Memory layout of the matrix slices.
129 pub order: MemoryOrder,
130 /// Per-call execution policy.
131 pub policy: ExecPolicy,
132}
133
134/// Thin QR decomposition descriptor: A = Q * R
135///
136/// Computes the thin QR of an m×n matrix A.
137/// Data layout (A, Q, R slices) is specified by the `order` field;
138/// a backend that does not support a given order returns
139/// [`BackendError::InvalidArgument`].
140/// Outputs: Q (m×k), R (k×n)
141/// where k = min(m, n).
142pub struct QrDescriptor<'a, T> {
143 /// Rows of `A`.
144 pub m: usize,
145 /// Columns of `A`.
146 pub n: usize,
147 /// Input matrix `A` (`m×n`).
148 pub a: &'a [T],
149 /// Output orthonormal factor `Q` (`m×k`, `k = min(m, n)`).
150 pub q: &'a mut [T],
151 /// Output upper-triangular factor `R` (`k×n`).
152 pub r: &'a mut [T],
153 /// Memory layout of the matrix slices.
154 pub order: MemoryOrder,
155 /// Per-call execution policy.
156 pub policy: ExecPolicy,
157}
158
159/// Thin LQ decomposition descriptor: A = L * Q
160///
161/// Computes the thin LQ of an m×n matrix A.
162/// Data layout (A, L, Q slices) is specified by the `order` field;
163/// a backend that does not support a given order returns
164/// [`BackendError::InvalidArgument`].
165/// Outputs: L (m×k), Q (k×n)
166/// where k = min(m, n).
167pub struct LqDescriptor<'a, T> {
168 /// Rows of `A`.
169 pub m: usize,
170 /// Columns of `A`.
171 pub n: usize,
172 /// Input matrix `A` (`m×n`).
173 pub a: &'a [T],
174 /// Output lower-triangular factor `L` (`m×k`, `k = min(m, n)`).
175 pub l: &'a mut [T],
176 /// Output orthonormal factor `Q` (`k×n`).
177 pub q: &'a mut [T],
178 /// Memory layout of the matrix slices.
179 pub order: MemoryOrder,
180 /// Per-call execution policy.
181 pub policy: ExecPolicy,
182}
183
184/// Self-adjoint eigenvalue decomposition descriptor: A = V * diag(W) * V^H
185///
186/// Computes eigenvalues and eigenvectors of an n×n self-adjoint matrix A.
187/// Data layout (A, V slices) is specified by the `order` field;
188/// a backend that does not support a given order returns
189/// [`BackendError::InvalidArgument`].
190/// Outputs: W (n real eigenvalues, ascending), V (n×n eigenvectors)
191pub struct EighDescriptor<'a, T: Scalar> {
192 /// Dimension of the square matrix `A`.
193 pub n: usize,
194 /// Input self-adjoint matrix `A` (`n×n`).
195 pub a: &'a [T],
196 /// Output eigenvalues `W` (`n` real values, ascending).
197 pub w: &'a mut [T::Real],
198 /// Output eigenvectors `V` (`n×n`).
199 pub v: &'a mut [T],
200 /// Memory layout of the matrix slices.
201 pub order: MemoryOrder,
202 /// Per-call execution policy.
203 pub policy: ExecPolicy,
204}
205
206/// General eigenvalue decomposition descriptor
207///
208/// Computes eigenvalues and right eigenvectors of an n×n matrix A.
209/// Data layout (A, V slices) is specified by the `order` field;
210/// a backend that does not support a given order returns
211/// [`BackendError::InvalidArgument`].
212/// Outputs are always complex: W (n complex eigenvalues), V (n×n eigenvectors)
213pub struct EigDescriptor<'a, T: Scalar> {
214 /// Dimension of the square matrix `A`.
215 pub n: usize,
216 /// Input matrix `A` (`n×n`).
217 pub a: &'a [T],
218 /// Output complex eigenvalues `W` (`n`).
219 pub w: &'a mut [T::Complex],
220 /// Output complex right eigenvectors `V` (`n×n`).
221 pub v: &'a mut [T::Complex],
222 /// Memory layout of the matrix slices.
223 pub order: MemoryOrder,
224 /// Per-call execution policy.
225 pub policy: ExecPolicy,
226}
227
228/// Linear solve descriptor: AX = B via LU decomposition
229///
230/// Solves the system AX = B where A is an n×n matrix and B is n×nrhs.
231/// Data layout (A, B, X slices) is specified by the `order` field;
232/// a backend that does not support a given order returns
233/// [`BackendError::InvalidArgument`].
234/// Output X is written to `x` (n×nrhs).
235pub struct SolveDescriptor<'a, T> {
236 /// Dimension of the square coefficient matrix `A`.
237 pub n: usize,
238 /// Number of right-hand-side columns.
239 pub nrhs: usize,
240 /// Coefficient matrix `A` (`n×n`).
241 pub a: &'a [T],
242 /// Right-hand side `B` (`n×nrhs`).
243 pub b: &'a [T],
244 /// Output solution `X` (`n×nrhs`).
245 pub x: &'a mut [T],
246 /// Memory layout of the matrix slices.
247 pub order: MemoryOrder,
248 /// Per-call execution policy.
249 pub policy: ExecPolicy,
250}
251
252/// One generic-descriptor backend operation, tagged by which op it is.
253///
254/// This is the unit that [`DispatchScalar::dispatch_op`] carries from a generic
255/// `T: Scalar` context down to a concrete per-type kernel. Bundling every op
256/// into one enum lets a backend expose a single typed entry point per scalar
257/// (see [`ScalarKernels`]) instead of one per `(op, type)` pair, which is what
258/// makes type-directed dispatch possible without reinterpreting a
259/// `Descriptor<T>` into a `Descriptor<concrete>` through `unsafe`.
260pub enum OpDesc<'a, T: Scalar> {
261 /// GEMM operation.
262 Gemm(GemmDescriptor<'a, T>),
263 /// Thin SVD operation.
264 Svd(SvdDescriptor<'a, T>),
265 /// Thin QR operation.
266 Qr(QrDescriptor<'a, T>),
267 /// Thin LQ operation.
268 Lq(LqDescriptor<'a, T>),
269 /// Self-adjoint eigendecomposition operation.
270 Eigh(EighDescriptor<'a, T>),
271 /// General eigendecomposition operation.
272 Eig(EigDescriptor<'a, T>),
273 /// Linear-solve operation.
274 Solve(SolveDescriptor<'a, T>),
275 /// Tensor-transpose operation.
276 Transpose(TransposeDescriptor<'a, T>),
277}
278
279/// A backend's concrete per-scalar kernels, one entry point per supported type.
280///
281/// [`DispatchScalar::dispatch_op`] resolves a generic `OpDesc<'_, T>` to exactly
282/// one of these methods, so inside each method the scalar is concrete and the op
283/// match dispatches to a monomorphic kernel directly. A backend implements this
284/// on a local kernel-set type; the four methods mirror the four sealed [`Scalar`]
285/// types.
286pub trait ScalarKernels {
287 /// Run an operation with `f64` scalars.
288 fn run_f64(&self, op: OpDesc<'_, f64>) -> Result<(), BackendError>;
289 /// Run an operation with `f32` scalars.
290 fn run_f32(&self, op: OpDesc<'_, f32>) -> Result<(), BackendError>;
291 /// Run an operation with `Complex<f64>` scalars.
292 fn run_c64(&self, op: OpDesc<'_, Complex<f64>>) -> Result<(), BackendError>;
293 /// Run an operation with `Complex<f32>` scalars.
294 fn run_c32(&self, op: OpDesc<'_, Complex<f32>>) -> Result<(), BackendError>;
295}
296
297/// Type-directed dispatch hook: reach a concrete per-type kernel from a generic
298/// `T: Scalar`.
299///
300/// A backend method bounded only by `T: Scalar` cannot name a per-type kernel
301/// directly. This supertrait of [`Scalar`] lets it call `T::dispatch_op(...)`,
302/// which forwards to the matching [`ScalarKernels`] method where the scalar is
303/// concrete — a type-level branch in place of an `unsafe`
304/// `Descriptor<T>` -> `Descriptor<concrete>` reinterpretation. It is dispatch
305/// plumbing between [`ComputeBackend`] and a backend's [`ScalarKernels`], not a
306/// user entry point.
307///
308/// It is kept separate from [`Scalar`] so that `Scalar`'s own method list carries
309/// no backend descriptor / error / kernel types; the supertrait bound still makes
310/// every `Scalar` a `DispatchScalar`. The `where Self: Scalar` bound (rather than
311/// `trait DispatchScalar: Scalar`) avoids a cycle with that supertrait while still
312/// admitting `OpDesc<'_, Self>`, which requires `Self: Scalar`. Sealed: only the
313/// four built-in scalar types implement it.
314pub trait DispatchScalar: sealed::Sealed {
315 /// Forward a generic `OpDesc<'_, Self>` to the concrete [`ScalarKernels`]
316 /// method matching `Self`, turning a type parameter into a type-level branch.
317 fn dispatch_op<K: ScalarKernels>(kernels: &K, op: OpDesc<'_, Self>) -> Result<(), BackendError>
318 where
319 Self: Scalar;
320}
321
322mod sealed {
323 pub trait Sealed {}
324 impl Sealed for f32 {}
325 impl Sealed for f64 {}
326 impl Sealed for num_complex::Complex<f32> {}
327 impl Sealed for num_complex::Complex<f64> {}
328}
329
330impl DispatchScalar for f64 {
331 #[inline]
332 fn dispatch_op<K: ScalarKernels>(
333 kernels: &K,
334 op: OpDesc<'_, Self>,
335 ) -> Result<(), BackendError> {
336 kernels.run_f64(op)
337 }
338}
339
340impl DispatchScalar for f32 {
341 #[inline]
342 fn dispatch_op<K: ScalarKernels>(
343 kernels: &K,
344 op: OpDesc<'_, Self>,
345 ) -> Result<(), BackendError> {
346 kernels.run_f32(op)
347 }
348}
349
350impl DispatchScalar for Complex<f64> {
351 #[inline]
352 fn dispatch_op<K: ScalarKernels>(
353 kernels: &K,
354 op: OpDesc<'_, Self>,
355 ) -> Result<(), BackendError> {
356 kernels.run_c64(op)
357 }
358}
359
360impl DispatchScalar for Complex<f32> {
361 #[inline]
362 fn dispatch_op<K: ScalarKernels>(
363 kernels: &K,
364 op: OpDesc<'_, Self>,
365 ) -> Result<(), BackendError> {
366 kernels.run_c32(op)
367 }
368}
369
370/// Pluggable compute backend trait
371pub trait ComputeBackend: Send + Sync {
372 /// Backend name
373 fn name(&self) -> &'static str;
374
375 /// Device type
376 fn device_type(&self) -> DeviceType;
377
378 /// Preferred memory order for this backend's data layout.
379 ///
380 /// Descriptor data (input/output slices) is expected in this order.
381 /// The linalg layer converts tensors to this order before constructing descriptors.
382 ///
383 /// This is an **implementor-facing contract**, not a user entry point:
384 /// backend implementors must report the layout their kernels assume so
385 /// the linalg / algorithm layers can normalize to it. End users never
386 /// call it — the public `Tensor` surface hides memory layout entirely.
387 fn preferred_order(&self) -> MemoryOrder;
388
389 /// Check if backend is available
390 fn is_available(&self) -> bool {
391 true
392 }
393
394 /// GEMM: C = alpha * A * B + beta * C
395 fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError>;
396
397 /// Transpose tensor
398 fn transpose<T: Scalar>(&self, desc: TransposeDescriptor<'_, T>) -> Result<(), BackendError>;
399
400 /// Thin SVD: A = U * diag(S) * Vt
401 fn svd<T: Scalar>(&self, _desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
402 Err(BackendError::NotSupported("svd".into()))
403 }
404
405 /// Thin QR: A = Q * R
406 fn qr<T: Scalar>(&self, _desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
407 Err(BackendError::NotSupported("qr".into()))
408 }
409
410 /// Thin LQ: A = L * Q
411 fn lq<T: Scalar>(&self, _desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
412 Err(BackendError::NotSupported("lq".into()))
413 }
414
415 /// Self-adjoint eigenvalue decomposition: A = V * diag(W) * V^H
416 fn eigh<T: Scalar>(&self, _desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
417 Err(BackendError::NotSupported("eigh".into()))
418 }
419
420 /// General eigenvalue decomposition
421 fn eig<T: Scalar>(&self, _desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
422 Err(BackendError::NotSupported("eig".into()))
423 }
424
425 /// Linear solve: AX = B via LU decomposition
426 fn solve<T: Scalar>(&self, _desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
427 Err(BackendError::NotSupported("solve".into()))
428 }
429
430 /// Recommended execution policy for SVD at the given problem size.
431 ///
432 /// Default returns `Sequential`; performance-oriented backends (e.g. `NativeBackend`)
433 /// override this with a hardware-aware threshold table.
434 fn par_for_svd(&self, _m: usize, _n: usize) -> ExecPolicy {
435 ExecPolicy::Sequential
436 }
437
438 /// Recommended execution policy for QR at the given problem size.
439 fn par_for_qr(&self, _m: usize, _n: usize) -> ExecPolicy {
440 ExecPolicy::Sequential
441 }
442
443 /// Recommended execution policy for LQ at the given problem size.
444 fn par_for_lq(&self, _m: usize, _n: usize) -> ExecPolicy {
445 ExecPolicy::Sequential
446 }
447
448 /// Recommended execution policy for self-adjoint eigendecomposition.
449 fn par_for_eigh(&self, _n: usize) -> ExecPolicy {
450 ExecPolicy::Sequential
451 }
452
453 /// Recommended execution policy for general eigendecomposition.
454 fn par_for_eig(&self, _n: usize) -> ExecPolicy {
455 ExecPolicy::Sequential
456 }
457
458 /// Recommended execution policy for GEMM at the given problem size.
459 fn par_for_gemm(&self, _m: usize, _n: usize, _k: usize) -> ExecPolicy {
460 ExecPolicy::Sequential
461 }
462
463 /// Recommended execution policy for linear solve.
464 fn par_for_solve(&self, _n: usize, _nrhs: usize) -> ExecPolicy {
465 ExecPolicy::Sequential
466 }
467
468 /// Recommended execution policy for tensor transpose.
469 fn par_for_transpose(&self, _shape: &[usize]) -> ExecPolicy {
470 ExecPolicy::Sequential
471 }
472}
473
474/// Error originating from a compute backend.
475///
476/// All variants represent conditions detected by or attributed to the backend.
477/// Linalg-layer validation (nrow range, square matrix checks, etc.) should use
478/// a separate error mechanism, not `BackendError`.
479///
480/// Every variant carries its full context in its own `Display` message; none
481/// wraps a structured inner error. `BackendError` is therefore a leaf in the
482/// error chain — its `source()` is always `None`.
483#[derive(Debug, thiserror::Error)]
484pub enum BackendError {
485 /// The backend does not support this operation.
486 ///
487 /// Returned when an operation is fundamentally unavailable on this backend
488 /// (e.g., a GPU backend that lacks an eigenvalue solver). Upper layers
489 /// should consider fallback strategies or alternative computation paths.
490 #[error("Not supported: {0}")]
491 NotSupported(String),
492
493 /// The descriptor passed to the backend violates its contract.
494 ///
495 /// This indicates a bug in the calling layer (typically linalg), not a user
496 /// error. For example, buffer sizes inconsistent with declared dimensions.
497 /// Callers should treat this as a panic-worthy condition in debug builds.
498 #[error("Invalid argument: {0}")]
499 InvalidArgument(String),
500
501 /// The computation failed at runtime.
502 ///
503 /// The operation was supported and the arguments were valid, but execution
504 /// failed due to numerical issues, resource exhaustion, or other runtime
505 /// conditions (e.g., a matrix factorization that fails to converge).
506 #[error("Execution failed: {0}")]
507 ExecutionFailed(String),
508}