use ariadnetor_core::Scalar;
use ariadnetor_core::backend::ComputeBackend;
use ariadnetor_native::NativeBackend;
use crate::{BlockSparseStorage, DenseStorage};
pub trait OpsFor<St>: ComputeBackend {}
impl<T: Scalar> OpsFor<DenseStorage<T>> for NativeBackend {}
impl<T: Scalar> OpsFor<BlockSparseStorage<T>> for NativeBackend {}
#[cfg(not(feature = "pluggability-litmus"))]
pub type Host = NativeBackend;
#[cfg(feature = "pluggability-litmus")]
pub type Host = alt_host::AltHostBackend;
#[cfg(feature = "pluggability-litmus")]
mod alt_host {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use ariadnetor_core::Scalar;
use ariadnetor_core::backend::{
BackendError, ComputeBackend, DeviceType, EigDescriptor, EighDescriptor, ExecPolicy,
GemmDescriptor, LqDescriptor, MemoryOrder, QrDescriptor, SolveDescriptor, SvdDescriptor,
TransposeDescriptor,
};
use ariadnetor_native::NativeBackend;
use crate::{BlockSparseStorage, DenseStorage, OpsFor};
pub struct AltHostBackend {
inner: NativeBackend,
kernel_calls: AtomicUsize,
}
impl AltHostBackend {
fn new() -> Self {
Self {
inner: NativeBackend::new(),
kernel_calls: AtomicUsize::new(0),
}
}
pub fn shared() -> Arc<AltHostBackend> {
static INSTANCE: OnceLock<Arc<AltHostBackend>> = OnceLock::new();
INSTANCE
.get_or_init(|| Arc::new(AltHostBackend::new()))
.clone()
}
pub fn count(&self) -> usize {
self.kernel_calls.load(Ordering::SeqCst)
}
fn bump(&self) {
self.kernel_calls.fetch_add(1, Ordering::SeqCst);
}
}
impl ComputeBackend for AltHostBackend {
fn name(&self) -> &'static str {
"alt-host"
}
fn device_type(&self) -> DeviceType {
self.inner.device_type()
}
fn preferred_order(&self) -> MemoryOrder {
self.inner.preferred_order()
}
fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.gemm(desc)
}
fn transpose<T: Scalar>(
&self,
desc: TransposeDescriptor<'_, T>,
) -> Result<(), BackendError> {
self.bump();
self.inner.transpose(desc)
}
fn svd<T: Scalar>(&self, desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.svd(desc)
}
fn qr<T: Scalar>(&self, desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.qr(desc)
}
fn lq<T: Scalar>(&self, desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.lq(desc)
}
fn eigh<T: Scalar>(&self, desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.eigh(desc)
}
fn eig<T: Scalar>(&self, desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.eig(desc)
}
fn solve<T: Scalar>(&self, desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
self.bump();
self.inner.solve(desc)
}
fn par_for_svd(&self, m: usize, n: usize) -> ExecPolicy {
self.inner.par_for_svd(m, n)
}
fn par_for_qr(&self, m: usize, n: usize) -> ExecPolicy {
self.inner.par_for_qr(m, n)
}
fn par_for_lq(&self, m: usize, n: usize) -> ExecPolicy {
self.inner.par_for_lq(m, n)
}
fn par_for_eigh(&self, n: usize) -> ExecPolicy {
self.inner.par_for_eigh(n)
}
fn par_for_eig(&self, n: usize) -> ExecPolicy {
self.inner.par_for_eig(n)
}
fn par_for_gemm(&self, m: usize, n: usize, k: usize) -> ExecPolicy {
self.inner.par_for_gemm(m, n, k)
}
fn par_for_solve(&self, n: usize, nrhs: usize) -> ExecPolicy {
self.inner.par_for_solve(n, nrhs)
}
fn par_for_transpose(&self, shape: &[usize]) -> ExecPolicy {
self.inner.par_for_transpose(shape)
}
}
impl<T: Scalar> OpsFor<DenseStorage<T>> for AltHostBackend {}
impl<T: Scalar> OpsFor<BlockSparseStorage<T>> for AltHostBackend {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn native_and_host_declare_ops_for_both_storage_flavors() {
fn assert_ops_for<St, B: OpsFor<St>>() {}
assert_ops_for::<DenseStorage<f64>, NativeBackend>();
assert_ops_for::<BlockSparseStorage<f64>, NativeBackend>();
assert_ops_for::<DenseStorage<f64>, Host>();
assert_ops_for::<BlockSparseStorage<f64>, Host>();
}
}