use crate::Result;
use crate::{BinaryOp, UnaryOp};
pub trait Backend: Sized + Clone + 'static + Sync + Send + std::fmt::Debug {
type Storage<T: crate::WithDType>: Sized + Sync + Send + 'static;
fn name(&self) -> String;
fn synchronize(&self) -> Result<()>;
fn storage_len<T: crate::WithDType>(storage: &Self::Storage<T>) -> usize;
fn storage_is_empty<T: crate::WithDType>(storage: &Self::Storage<T>) -> bool {
Self::storage_len::<T>(storage) == 0
}
unsafe fn alloc_uninit<T: crate::WithDType>(len: usize, dev: &Self)
-> Result<Self::Storage<T>>;
fn from_vec<T: crate::WithDType>(v: Vec<T>, dev: &Self) -> Result<Self::Storage<T>>;
fn cst<T: crate::WithDType>(v: T, len: usize, dev: &Self) -> Result<Self::Storage<T>> {
let mut res = unsafe { Self::alloc_uninit(len, dev)? };
Self::fill(&mut res, v, len)?;
Ok(res)
}
fn fill<T: crate::WithDType>(dst: &mut Self::Storage<T>, elem: T, len: usize) -> Result<()>;
fn rand_uniform(dst: &mut Self::Storage<f32>, len: usize, lo: f32, up: f32) -> Result<()>;
fn randn(dst: &mut Self::Storage<f32>, len: usize, mean: f32, std: f32) -> Result<()>;
fn copy<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
len: usize,
) -> Result<()>;
fn to_dtype<T: crate::WithDType, U: crate::WithDType>(
dst: &mut Self::Storage<U>,
src: &Self::Storage<T>,
len: usize,
) -> Result<()>;
fn data<T: crate::WithDType>(
src: &Self::Storage<T>,
len: usize,
) -> Result<std::borrow::Cow<'_, [T]>>;
fn inplace_unary<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
len: usize,
op: UnaryOp,
) -> Result<()>;
fn bin_assign<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
s: &Self::Storage<T>,
len: usize,
op: BinaryOp,
) -> Result<()>;
fn unary<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
len: usize,
op: UnaryOp,
) -> Result<()>;
fn binary<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
lhs: &Self::Storage<T>,
rhs: &Self::Storage<T>,
len: usize,
op: BinaryOp,
) -> Result<()>;
fn scale_add<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
scale: T,
add: T,
len: usize,
) -> Result<()>;
fn transpose<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
s: &Self::Storage<T>,
dim1: usize,
dim2: usize,
dims: &[usize],
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn copy2d<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
d1: usize,
d2: usize,
dst_s: usize,
src_s: usize,
dst_o: usize,
src_o: usize,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn rope<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
cos: &Self::Storage<T>,
sin: &Self::Storage<T>,
b: usize,
h: usize,
t: usize,
d: usize,
pos: usize,
unbatched_rope: bool,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn rope_i<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
cos: &Self::Storage<T>,
sin: &Self::Storage<T>,
b: usize,
h: usize,
t: usize,
d: usize,
pos: usize,
unbatched_rope: bool,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn gemm<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
lhs: (&Self::Storage<T>, usize),
rhs: (&Self::Storage<T>, usize),
m: usize,
n: usize,
k: usize,
lhs_b: usize,
lhs_b_stride: usize,
rhs_b_stride: usize,
dst_strides: (usize, usize),
lhs_strides: (usize, usize),
rhs_strides: (usize, usize),
) -> Result<()>;
fn index_select<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
ids: &Self::Storage<i64>,
num_ids: usize,
dim: usize,
dims: &[usize],
) -> Result<()>;
fn apply_causality_mask<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
bh: usize,
t1: usize,
t2: usize,
offset: usize,
) -> Result<()>;
fn softmax<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
dim_m1: usize,
d: usize,
) -> Result<()>;
fn rms_norm<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
alpha: &Self::Storage<T>,
dim_m1: usize,
d: usize,
eps: f32,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn layer_norm<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
weight: &Self::Storage<T>,
bias: &Self::Storage<T>,
dim_m1: usize,
d: usize,
eps: f32,
remove_mean: bool,
) -> Result<()>;
fn reduce_max<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
dim_size: usize,
outer_size: usize,
inner_size: usize,
) -> Result<()>;
fn reduce_min<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
dim_size: usize,
outer_size: usize,
inner_size: usize,
) -> Result<()>;
fn reduce_argmin<T: crate::WithDTypeF>(
dst: &mut Self::Storage<i64>,
src: &Self::Storage<T>,
dim_size: usize,
outer_size: usize,
inner_size: usize,
) -> Result<()>;
fn reduce_argmax<T: crate::WithDTypeF>(
dst: &mut Self::Storage<i64>,
src: &Self::Storage<T>,
dim_size: usize,
outer_size: usize,
inner_size: usize,
) -> Result<()>;
fn reduce_sum<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
dim_size: usize,
outer_size: usize,
inner_size: usize,
) -> Result<()>;
fn copy_strided<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
src_offset: usize,
dims: &[usize],
src_strides: &[usize],
) -> Result<()>;
fn scatter_set<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
ids: &Self::Storage<i64>,
dim: usize,
dst_dims: &[usize],
src_dims: &[usize],
) -> Result<()>;
fn broadcast_binary<T: crate::WithDType>(
dst: &mut Self::Storage<T>,
lhs: &Self::Storage<T>,
rhs: &Self::Storage<T>,
dst_shape: &[usize],
lhs_strides: &[usize],
rhs_strides: &[usize],
op: BinaryOp,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn conv1d<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
kernel: &Self::Storage<T>,
batch: usize,
in_channels: usize,
out_channels: usize,
length: usize,
out_length: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
groups: usize,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn conv_transpose1d<T: crate::WithDTypeF>(
dst: &mut Self::Storage<T>,
src: &Self::Storage<T>,
kernel: &Self::Storage<T>,
batch: usize,
in_channels: usize,
out_channels: usize,
length: usize,
out_length: usize,
kernel_size: usize,
stride: usize,
padding: usize,
output_padding: usize,
groups: usize,
) -> Result<()>;
}