1#![deny(missing_docs)]
9
10mod eig;
11mod eigh;
12mod gemm;
13mod lq;
14mod performance;
15mod qr;
16mod solve;
17mod svd;
18mod transpose;
19
20use std::sync::{Arc, OnceLock};
21
22use ariadnetor_core::Scalar;
23use ariadnetor_core::backend::{
24 BackendError, ComputeBackend, DeviceType, EigDescriptor, EighDescriptor, ExecPolicy,
25 GemmDescriptor, LqDescriptor, MemoryOrder, OpDesc, QrDescriptor, ScalarKernels,
26 SolveDescriptor, SvdDescriptor, TransposeDescriptor,
27};
28use num_complex::Complex;
29
30pub use performance::{PerformanceManager, ThresholdTable};
31
32pub(crate) fn to_faer_par(policy: ExecPolicy) -> faer::Par {
41 match policy {
42 ExecPolicy::Sequential => faer::Par::Seq,
43 ExecPolicy::Parallel(0) => faer::Par::rayon(0),
44 ExecPolicy::Parallel(n) => faer::Par::rayon(n),
45 }
46}
47
48#[derive(Debug, Clone)]
56pub struct NativeBackend {
57 perf: PerformanceManager,
58}
59
60impl NativeBackend {
61 pub fn new() -> Self {
64 Self {
65 perf: PerformanceManager::new(ThresholdTable::detect()),
66 }
67 }
68
69 pub fn with_perf(perf: PerformanceManager) -> Self {
74 Self { perf }
75 }
76
77 pub fn perf(&self) -> &PerformanceManager {
79 &self.perf
80 }
81
82 pub fn shared() -> Arc<NativeBackend> {
87 static INSTANCE: OnceLock<Arc<NativeBackend>> = OnceLock::new();
88 INSTANCE
89 .get_or_init(|| Arc::new(NativeBackend::new()))
90 .clone()
91 }
92}
93
94impl Default for NativeBackend {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100fn require_column_major(op: &str, order: MemoryOrder) -> Result<(), BackendError> {
104 if order != MemoryOrder::ColumnMajor {
105 return Err(BackendError::InvalidArgument(format!(
106 "NativeBackend::{op} supports ColumnMajor only, got {order:?}"
107 )));
108 }
109 Ok(())
110}
111
112impl ComputeBackend for NativeBackend {
113 fn name(&self) -> &'static str {
114 "cpu"
115 }
116
117 fn device_type(&self) -> DeviceType {
118 DeviceType::Cpu
119 }
120
121 fn preferred_order(&self) -> MemoryOrder {
122 MemoryOrder::ColumnMajor
123 }
124
125 fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError> {
129 T::dispatch_op(&NativeKernels, OpDesc::Gemm(desc))
130 }
131
132 fn transpose<T: Scalar>(&self, desc: TransposeDescriptor<'_, T>) -> Result<(), BackendError> {
137 T::dispatch_op(&NativeKernels, OpDesc::Transpose(desc))
138 }
139
140 fn svd<T: Scalar>(&self, desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
147 require_column_major("svd", desc.order)?;
148 T::dispatch_op(&NativeKernels, OpDesc::Svd(desc))
149 }
150
151 fn qr<T: Scalar>(&self, desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
157 require_column_major("qr", desc.order)?;
158 T::dispatch_op(&NativeKernels, OpDesc::Qr(desc))
159 }
160
161 fn lq<T: Scalar>(&self, desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
168 require_column_major("lq", desc.order)?;
169 T::dispatch_op(&NativeKernels, OpDesc::Lq(desc))
170 }
171
172 fn eigh<T: Scalar>(&self, desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
178 require_column_major("eigh", desc.order)?;
179 T::dispatch_op(&NativeKernels, OpDesc::Eigh(desc))
180 }
181
182 fn eig<T: Scalar>(&self, desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
188 require_column_major("eig", desc.order)?;
189 T::dispatch_op(&NativeKernels, OpDesc::Eig(desc))
190 }
191
192 fn solve<T: Scalar>(&self, desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
198 require_column_major("solve", desc.order)?;
199 T::dispatch_op(&NativeKernels, OpDesc::Solve(desc))
200 }
201
202 fn par_for_svd(&self, m: usize, n: usize) -> ExecPolicy {
203 let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
204 PerformanceManager::policy_by_n(self.perf.thresholds().svd, work_proxy)
205 }
206
207 fn par_for_qr(&self, m: usize, n: usize) -> ExecPolicy {
208 let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
209 PerformanceManager::policy_by_n(self.perf.thresholds().qr, work_proxy)
210 }
211
212 fn par_for_lq(&self, m: usize, n: usize) -> ExecPolicy {
213 let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
214 PerformanceManager::policy_by_n(self.perf.thresholds().lq, work_proxy)
215 }
216
217 fn par_for_eigh(&self, n: usize) -> ExecPolicy {
218 PerformanceManager::policy_by_n(self.perf.thresholds().eigh, n)
219 }
220
221 fn par_for_eig(&self, n: usize) -> ExecPolicy {
222 PerformanceManager::policy_by_n(self.perf.thresholds().eig, n)
223 }
224
225 fn par_for_gemm(&self, m: usize, n: usize, k: usize) -> ExecPolicy {
226 let work_proxy = (m as f64 * n as f64 * k as f64).cbrt() as usize;
227 PerformanceManager::policy_by_n(self.perf.thresholds().gemm, work_proxy)
228 }
229
230 fn par_for_solve(&self, n: usize, _nrhs: usize) -> ExecPolicy {
231 PerformanceManager::policy_by_n(self.perf.thresholds().solve, n)
232 }
233
234 fn par_for_transpose(&self, shape: &[usize]) -> ExecPolicy {
235 let total: usize = shape.iter().copied().fold(1usize, usize::saturating_mul);
238 PerformanceManager::policy_by_n(self.perf.thresholds().transpose, total)
239 }
240}
241
242struct NativeKernels;
250
251impl ScalarKernels for NativeKernels {
252 fn run_f64(&self, op: OpDesc<'_, f64>) -> Result<(), BackendError> {
253 match op {
254 OpDesc::Gemm(d) => gemm::gemm_f64(d),
255 OpDesc::Svd(d) => svd::svd_f64(d),
256 OpDesc::Qr(d) => qr::qr_f64(d),
257 OpDesc::Lq(d) => lq::lq_f64(d),
258 OpDesc::Eigh(d) => eigh::eigh_f64(d),
259 OpDesc::Eig(d) => eig::eig_f64(d),
260 OpDesc::Solve(d) => solve::solve_f64(d),
261 OpDesc::Transpose(d) => transpose::transpose_f64(d),
262 }
263 }
264
265 fn run_f32(&self, op: OpDesc<'_, f32>) -> Result<(), BackendError> {
266 match op {
267 OpDesc::Gemm(d) => gemm::gemm_f32(d),
268 OpDesc::Svd(d) => svd::svd_f32(d),
269 OpDesc::Qr(d) => qr::qr_f32(d),
270 OpDesc::Lq(d) => lq::lq_f32(d),
271 OpDesc::Eigh(d) => eigh::eigh_f32(d),
272 OpDesc::Eig(d) => eig::eig_f32(d),
273 OpDesc::Solve(d) => solve::solve_f32(d),
274 OpDesc::Transpose(d) => transpose::transpose_f32(d),
275 }
276 }
277
278 fn run_c64(&self, op: OpDesc<'_, Complex<f64>>) -> Result<(), BackendError> {
279 match op {
280 OpDesc::Gemm(d) => gemm::gemm_c64(d),
281 OpDesc::Svd(d) => svd::svd_c64(d),
282 OpDesc::Qr(d) => qr::qr_c64(d),
283 OpDesc::Lq(d) => lq::lq_c64(d),
284 OpDesc::Eigh(d) => eigh::eigh_c64(d),
285 OpDesc::Eig(d) => eig::eig_c64(d),
286 OpDesc::Solve(d) => solve::solve_c64(d),
287 OpDesc::Transpose(d) => transpose::transpose_c64(d),
288 }
289 }
290
291 fn run_c32(&self, op: OpDesc<'_, Complex<f32>>) -> Result<(), BackendError> {
292 match op {
293 OpDesc::Gemm(d) => gemm::gemm_c32(d),
294 OpDesc::Svd(d) => svd::svd_c32(d),
295 OpDesc::Qr(d) => qr::qr_c32(d),
296 OpDesc::Lq(d) => lq::lq_c32(d),
297 OpDesc::Eigh(d) => eigh::eigh_c32(d),
298 OpDesc::Eig(d) => eig::eig_c32(d),
299 OpDesc::Solve(d) => solve::solve_c32(d),
300 OpDesc::Transpose(d) => transpose::transpose_c32(d),
301 }
302 }
303}