#![deny(missing_docs)]
mod eig;
mod eigh;
mod gemm;
mod lq;
mod performance;
mod qr;
mod solve;
mod svd;
mod transpose;
use std::sync::{Arc, OnceLock};
use ariadnetor_core::Scalar;
use ariadnetor_core::backend::{
BackendError, ComputeBackend, DeviceType, EigDescriptor, EighDescriptor, ExecPolicy,
GemmDescriptor, LqDescriptor, MemoryOrder, OpDesc, QrDescriptor, ScalarKernels,
SolveDescriptor, SvdDescriptor, TransposeDescriptor,
};
use num_complex::Complex;
pub use performance::{PerformanceManager, ThresholdTable};
pub(crate) fn to_faer_par(policy: ExecPolicy) -> faer::Par {
match policy {
ExecPolicy::Sequential => faer::Par::Seq,
ExecPolicy::Parallel(0) => faer::Par::rayon(0),
ExecPolicy::Parallel(n) => faer::Par::rayon(n),
}
}
#[derive(Debug, Clone)]
pub struct NativeBackend {
perf: PerformanceManager,
}
impl NativeBackend {
pub fn new() -> Self {
Self {
perf: PerformanceManager::new(ThresholdTable::detect()),
}
}
pub fn with_perf(perf: PerformanceManager) -> Self {
Self { perf }
}
pub fn perf(&self) -> &PerformanceManager {
&self.perf
}
pub fn shared() -> Arc<NativeBackend> {
static INSTANCE: OnceLock<Arc<NativeBackend>> = OnceLock::new();
INSTANCE
.get_or_init(|| Arc::new(NativeBackend::new()))
.clone()
}
}
impl Default for NativeBackend {
fn default() -> Self {
Self::new()
}
}
fn require_column_major(op: &str, order: MemoryOrder) -> Result<(), BackendError> {
if order != MemoryOrder::ColumnMajor {
return Err(BackendError::InvalidArgument(format!(
"NativeBackend::{op} supports ColumnMajor only, got {order:?}"
)));
}
Ok(())
}
impl ComputeBackend for NativeBackend {
fn name(&self) -> &'static str {
"cpu"
}
fn device_type(&self) -> DeviceType {
DeviceType::Cpu
}
fn preferred_order(&self) -> MemoryOrder {
MemoryOrder::ColumnMajor
}
fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError> {
T::dispatch_op(&NativeKernels, OpDesc::Gemm(desc))
}
fn transpose<T: Scalar>(&self, desc: TransposeDescriptor<'_, T>) -> Result<(), BackendError> {
T::dispatch_op(&NativeKernels, OpDesc::Transpose(desc))
}
fn svd<T: Scalar>(&self, desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
require_column_major("svd", desc.order)?;
T::dispatch_op(&NativeKernels, OpDesc::Svd(desc))
}
fn qr<T: Scalar>(&self, desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
require_column_major("qr", desc.order)?;
T::dispatch_op(&NativeKernels, OpDesc::Qr(desc))
}
fn lq<T: Scalar>(&self, desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
require_column_major("lq", desc.order)?;
T::dispatch_op(&NativeKernels, OpDesc::Lq(desc))
}
fn eigh<T: Scalar>(&self, desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
require_column_major("eigh", desc.order)?;
T::dispatch_op(&NativeKernels, OpDesc::Eigh(desc))
}
fn eig<T: Scalar>(&self, desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
require_column_major("eig", desc.order)?;
T::dispatch_op(&NativeKernels, OpDesc::Eig(desc))
}
fn solve<T: Scalar>(&self, desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
require_column_major("solve", desc.order)?;
T::dispatch_op(&NativeKernels, OpDesc::Solve(desc))
}
fn par_for_svd(&self, m: usize, n: usize) -> ExecPolicy {
let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
PerformanceManager::policy_by_n(self.perf.thresholds().svd, work_proxy)
}
fn par_for_qr(&self, m: usize, n: usize) -> ExecPolicy {
let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
PerformanceManager::policy_by_n(self.perf.thresholds().qr, work_proxy)
}
fn par_for_lq(&self, m: usize, n: usize) -> ExecPolicy {
let work_proxy = (m as f64 * n as f64 * m.min(n) as f64).cbrt() as usize;
PerformanceManager::policy_by_n(self.perf.thresholds().lq, work_proxy)
}
fn par_for_eigh(&self, n: usize) -> ExecPolicy {
PerformanceManager::policy_by_n(self.perf.thresholds().eigh, n)
}
fn par_for_eig(&self, n: usize) -> ExecPolicy {
PerformanceManager::policy_by_n(self.perf.thresholds().eig, n)
}
fn par_for_gemm(&self, m: usize, n: usize, k: usize) -> ExecPolicy {
let work_proxy = (m as f64 * n as f64 * k as f64).cbrt() as usize;
PerformanceManager::policy_by_n(self.perf.thresholds().gemm, work_proxy)
}
fn par_for_solve(&self, n: usize, _nrhs: usize) -> ExecPolicy {
PerformanceManager::policy_by_n(self.perf.thresholds().solve, n)
}
fn par_for_transpose(&self, shape: &[usize]) -> ExecPolicy {
let total: usize = shape.iter().copied().fold(1usize, usize::saturating_mul);
PerformanceManager::policy_by_n(self.perf.thresholds().transpose, total)
}
}
struct NativeKernels;
impl ScalarKernels for NativeKernels {
fn run_f64(&self, op: OpDesc<'_, f64>) -> Result<(), BackendError> {
match op {
OpDesc::Gemm(d) => gemm::gemm_f64(d),
OpDesc::Svd(d) => svd::svd_f64(d),
OpDesc::Qr(d) => qr::qr_f64(d),
OpDesc::Lq(d) => lq::lq_f64(d),
OpDesc::Eigh(d) => eigh::eigh_f64(d),
OpDesc::Eig(d) => eig::eig_f64(d),
OpDesc::Solve(d) => solve::solve_f64(d),
OpDesc::Transpose(d) => transpose::transpose_f64(d),
}
}
fn run_f32(&self, op: OpDesc<'_, f32>) -> Result<(), BackendError> {
match op {
OpDesc::Gemm(d) => gemm::gemm_f32(d),
OpDesc::Svd(d) => svd::svd_f32(d),
OpDesc::Qr(d) => qr::qr_f32(d),
OpDesc::Lq(d) => lq::lq_f32(d),
OpDesc::Eigh(d) => eigh::eigh_f32(d),
OpDesc::Eig(d) => eig::eig_f32(d),
OpDesc::Solve(d) => solve::solve_f32(d),
OpDesc::Transpose(d) => transpose::transpose_f32(d),
}
}
fn run_c64(&self, op: OpDesc<'_, Complex<f64>>) -> Result<(), BackendError> {
match op {
OpDesc::Gemm(d) => gemm::gemm_c64(d),
OpDesc::Svd(d) => svd::svd_c64(d),
OpDesc::Qr(d) => qr::qr_c64(d),
OpDesc::Lq(d) => lq::lq_c64(d),
OpDesc::Eigh(d) => eigh::eigh_c64(d),
OpDesc::Eig(d) => eig::eig_c64(d),
OpDesc::Solve(d) => solve::solve_c64(d),
OpDesc::Transpose(d) => transpose::transpose_c64(d),
}
}
fn run_c32(&self, op: OpDesc<'_, Complex<f32>>) -> Result<(), BackendError> {
match op {
OpDesc::Gemm(d) => gemm::gemm_c32(d),
OpDesc::Svd(d) => svd::svd_c32(d),
OpDesc::Qr(d) => qr::qr_c32(d),
OpDesc::Lq(d) => lq::lq_c32(d),
OpDesc::Eigh(d) => eigh::eigh_c32(d),
OpDesc::Eig(d) => eig::eig_c32(d),
OpDesc::Solve(d) => solve::solve_c32(d),
OpDesc::Transpose(d) => transpose::transpose_c32(d),
}
}
}