use crate::scalar::Scalar;
use num_complex::Complex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceType {
Cpu,
Cuda,
Metal,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryOrder {
RowMajor,
ColumnMajor,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ExecPolicy {
Sequential,
Parallel(usize),
}
pub struct GemmDescriptor<'a, T> {
pub m: usize,
pub n: usize,
pub k: usize,
pub alpha: T,
pub a: &'a [T],
pub b: &'a [T],
pub beta: T,
pub c: &'a mut [T],
pub trans_a: bool,
pub trans_b: bool,
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub struct TransposeDescriptor<'a, T> {
pub input: &'a [T],
pub output: &'a mut [T],
pub shape: &'a [usize],
pub perm: &'a [usize],
pub order: MemoryOrder,
pub conj: bool,
pub policy: ExecPolicy,
}
pub struct SvdDescriptor<'a, T: Scalar> {
pub m: usize,
pub n: usize,
pub a: &'a [T],
pub u: &'a mut [T],
pub s: &'a mut [T::Real],
pub vt: &'a mut [T],
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub struct QrDescriptor<'a, T> {
pub m: usize,
pub n: usize,
pub a: &'a [T],
pub q: &'a mut [T],
pub r: &'a mut [T],
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub struct LqDescriptor<'a, T> {
pub m: usize,
pub n: usize,
pub a: &'a [T],
pub l: &'a mut [T],
pub q: &'a mut [T],
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub struct EighDescriptor<'a, T: Scalar> {
pub n: usize,
pub a: &'a [T],
pub w: &'a mut [T::Real],
pub v: &'a mut [T],
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub struct EigDescriptor<'a, T: Scalar> {
pub n: usize,
pub a: &'a [T],
pub w: &'a mut [T::Complex],
pub v: &'a mut [T::Complex],
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub struct SolveDescriptor<'a, T> {
pub n: usize,
pub nrhs: usize,
pub a: &'a [T],
pub b: &'a [T],
pub x: &'a mut [T],
pub order: MemoryOrder,
pub policy: ExecPolicy,
}
pub enum OpDesc<'a, T: Scalar> {
Gemm(GemmDescriptor<'a, T>),
Svd(SvdDescriptor<'a, T>),
Qr(QrDescriptor<'a, T>),
Lq(LqDescriptor<'a, T>),
Eigh(EighDescriptor<'a, T>),
Eig(EigDescriptor<'a, T>),
Solve(SolveDescriptor<'a, T>),
Transpose(TransposeDescriptor<'a, T>),
}
pub trait ScalarKernels {
fn run_f64(&self, op: OpDesc<'_, f64>) -> Result<(), BackendError>;
fn run_f32(&self, op: OpDesc<'_, f32>) -> Result<(), BackendError>;
fn run_c64(&self, op: OpDesc<'_, Complex<f64>>) -> Result<(), BackendError>;
fn run_c32(&self, op: OpDesc<'_, Complex<f32>>) -> Result<(), BackendError>;
}
pub trait DispatchScalar: sealed::Sealed {
fn dispatch_op<K: ScalarKernels>(kernels: &K, op: OpDesc<'_, Self>) -> Result<(), BackendError>
where
Self: Scalar;
}
mod sealed {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for num_complex::Complex<f32> {}
impl Sealed for num_complex::Complex<f64> {}
}
impl DispatchScalar for f64 {
#[inline]
fn dispatch_op<K: ScalarKernels>(
kernels: &K,
op: OpDesc<'_, Self>,
) -> Result<(), BackendError> {
kernels.run_f64(op)
}
}
impl DispatchScalar for f32 {
#[inline]
fn dispatch_op<K: ScalarKernels>(
kernels: &K,
op: OpDesc<'_, Self>,
) -> Result<(), BackendError> {
kernels.run_f32(op)
}
}
impl DispatchScalar for Complex<f64> {
#[inline]
fn dispatch_op<K: ScalarKernels>(
kernels: &K,
op: OpDesc<'_, Self>,
) -> Result<(), BackendError> {
kernels.run_c64(op)
}
}
impl DispatchScalar for Complex<f32> {
#[inline]
fn dispatch_op<K: ScalarKernels>(
kernels: &K,
op: OpDesc<'_, Self>,
) -> Result<(), BackendError> {
kernels.run_c32(op)
}
}
pub trait ComputeBackend: Send + Sync {
fn name(&self) -> &'static str;
fn device_type(&self) -> DeviceType;
fn preferred_order(&self) -> MemoryOrder;
fn is_available(&self) -> bool {
true
}
fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError>;
fn transpose<T: Scalar>(&self, desc: TransposeDescriptor<'_, T>) -> Result<(), BackendError>;
fn svd<T: Scalar>(&self, _desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
Err(BackendError::NotSupported("svd".into()))
}
fn qr<T: Scalar>(&self, _desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
Err(BackendError::NotSupported("qr".into()))
}
fn lq<T: Scalar>(&self, _desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
Err(BackendError::NotSupported("lq".into()))
}
fn eigh<T: Scalar>(&self, _desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
Err(BackendError::NotSupported("eigh".into()))
}
fn eig<T: Scalar>(&self, _desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
Err(BackendError::NotSupported("eig".into()))
}
fn solve<T: Scalar>(&self, _desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
Err(BackendError::NotSupported("solve".into()))
}
fn par_for_svd(&self, _m: usize, _n: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_qr(&self, _m: usize, _n: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_lq(&self, _m: usize, _n: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_eigh(&self, _n: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_eig(&self, _n: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_gemm(&self, _m: usize, _n: usize, _k: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_solve(&self, _n: usize, _nrhs: usize) -> ExecPolicy {
ExecPolicy::Sequential
}
fn par_for_transpose(&self, _shape: &[usize]) -> ExecPolicy {
ExecPolicy::Sequential
}
}
#[derive(Debug, thiserror::Error)]
pub enum BackendError {
#[error("Not supported: {0}")]
NotSupported(String),
#[error("Invalid argument: {0}")]
InvalidArgument(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
}