use crate::traits::tensor_data::TensorData;
use crate::types::backend::Device;
use crate::types::cpu_tensor::EinSumAST;
use core::ops::Range;
pub trait TensorBackend: Clone + Send + Sync + 'static {
type Tensor<T>;
fn device() -> Device;
fn create<T: Clone>(data: &[T], shape: &[usize]) -> Self::Tensor<T>;
fn create_from_vec<T>(data: Vec<T>, shape: &[usize]) -> Self::Tensor<T>;
fn zeros<T: TensorData>(shape: &[usize]) -> Self::Tensor<T>;
fn ones<T: TensorData>(shape: &[usize]) -> Self::Tensor<T>;
fn from_shape_fn<T: Clone, F>(shape: &[usize], f: F) -> Self::Tensor<T>
where
F: FnMut(&[usize]) -> T;
fn to_vec<T: Clone>(tensor: &Self::Tensor<T>) -> Vec<T>;
fn into_vec<T>(tensor: Self::Tensor<T>) -> Vec<T>;
fn shape<T>(tensor: &Self::Tensor<T>) -> Vec<usize>;
fn strides<T>(tensor: &Self::Tensor<T>) -> Vec<usize>;
fn get<T: Clone>(tensor: &Self::Tensor<T>, index: &[usize]) -> Option<T>;
fn reshape<T: Clone>(tensor: &Self::Tensor<T>, shape: &[usize]) -> Self::Tensor<T>;
fn permute<T: Clone>(tensor: &Self::Tensor<T>, axes: &[usize]) -> Self::Tensor<T>;
fn slice<T: Clone>(tensor: &Self::Tensor<T>, ranges: &[Range<usize>]) -> Self::Tensor<T>;
fn stack<T: TensorData>(
tensors: &[Self::Tensor<T>],
axis: usize,
) -> Result<Self::Tensor<T>, crate::CausalTensorError>;
fn add<T: TensorData>(a: &Self::Tensor<T>, b: &Self::Tensor<T>) -> Self::Tensor<T>;
fn sub<T: TensorData>(a: &Self::Tensor<T>, b: &Self::Tensor<T>) -> Self::Tensor<T>;
fn mul<T: TensorData>(a: &Self::Tensor<T>, b: &Self::Tensor<T>) -> Self::Tensor<T>;
fn div<T: TensorData>(a: &Self::Tensor<T>, b: &Self::Tensor<T>) -> Self::Tensor<T>;
fn broadcast_op<T: Clone, F>(
lhs: &Self::Tensor<T>,
rhs: &Self::Tensor<T>,
f: F,
) -> Result<Self::Tensor<T>, crate::CausalTensorError>
where
F: Fn(T, T) -> Result<T, crate::CausalTensorError>;
fn sum<T: TensorData>(tensor: &Self::Tensor<T>, axes: &[usize]) -> Self::Tensor<T>;
fn max<T: TensorData>(tensor: &Self::Tensor<T>, axes: &[usize]) -> Self::Tensor<T>;
fn mean<T: TensorData + From<u32>>(tensor: &Self::Tensor<T>, axes: &[usize])
-> Self::Tensor<T>;
fn ravel<T: Clone>(tensor: &Self::Tensor<T>) -> Self::Tensor<T>;
fn arg_sort<T: TensorData>(
tensor: &Self::Tensor<T>,
) -> Result<Vec<usize>, crate::CausalTensorError>;
fn shifted_view<T: Clone>(tensor: &Self::Tensor<T>, flat_index: usize) -> Self::Tensor<T>;
fn ein_sum<T: TensorData>(
ast: &EinSumAST<Self::Tensor<T>>,
) -> Result<Self::Tensor<T>, crate::CausalTensorError>;
}