use super::types::*;
use crate::category::core::{Dtype, Shape};
use std::fmt::Debug;
pub mod shape_only;
#[cfg(feature = "ndarray-backend")]
pub mod ndarray;
#[cfg(feature = "candle-backend")]
pub mod candle;
pub trait Backend: Send + Sync + Clone + Debug {
type NdArray<D: HasDtype>: NdArray<D, Backend = Self>;
fn scalar<D: HasDtype>(&self, d: D) -> Self::NdArray<D>;
fn zeros<D: HasDtype + Default>(&self, shape: Shape) -> Self::NdArray<D>;
fn ndarray_from_slice<D: HasDtype>(
&self,
data: &[D],
shape: Shape,
) -> Result<Self::NdArray<D>, BackendError>;
fn cast(&self, x: TaggedTensor<Self>, target_dtype: Dtype) -> TaggedTensor<Self>;
fn matmul(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn add(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn sub(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn mul(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn div(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn pow(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn lt(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn eq(&self, lhs: TaggedTensorTuple<Self, 2>) -> TaggedTensor<Self>;
fn sin(&self, x: TaggedTensor<Self>) -> TaggedTensor<Self>;
fn cos(&self, x: TaggedTensor<Self>) -> TaggedTensor<Self>;
fn neg(&self, x: TaggedTensor<Self>) -> TaggedTensor<Self>;
fn broadcast(&self, x: TaggedTensor<Self>, shape: Shape) -> TaggedTensor<Self>;
fn reshape(&self, x: TaggedTensor<Self>, new_shape: Shape) -> TaggedTensor<Self>;
fn transpose(&self, x: TaggedTensor<Self>, dim0: usize, dim1: usize) -> TaggedTensor<Self>;
fn max(&self, x: TaggedTensor<Self>) -> TaggedTensor<Self>;
fn sum(&self, x: TaggedTensor<Self>) -> TaggedTensor<Self>;
fn argmax(&self, x: TaggedTensor<Self>) -> TaggedTensor<Self>;
fn compare(&self, x: TaggedTensorTuple<Self, 2>) -> bool;
fn concat(
&self,
x: TaggedTensor<Self>,
y: TaggedTensor<Self>,
dim: usize,
) -> TaggedTensor<Self>;
fn index(
&self,
x: TaggedTensor<Self>,
dim: usize,
indices: TaggedTensor<Self>,
) -> TaggedTensor<Self>;
fn slice(
&self,
x: TaggedTensor<Self>,
dim: usize,
start: usize,
len: usize,
) -> TaggedTensor<Self>;
fn arange(&self, end: usize) -> TaggedTensor<Self>;
}
pub trait NdArray<D: HasDtype>: Send + Sync + Clone + Debug {
type Backend: Backend;
fn shape(&self) -> Shape;
}
#[derive(Debug, Clone)]
pub enum BackendError {
ShapeError,
}