Skip to main content

meuron/backend/
mod.rs

1pub mod cpu;
2pub use cpu::CPUBackend;
3
4#[cfg(feature = "gpu")]
5pub mod gpu;
6#[cfg(feature = "gpu")]
7pub use gpu::GPUBackend;
8
9#[cfg(not(feature = "gpu"))]
10pub type DefaultBackend = CPUBackend;
11
12#[cfg(feature = "gpu")]
13pub type DefaultBackend = GPUBackend;
14
15use ndarray::{Dimension, RemoveAxis};
16
17pub trait Backend: Clone + 'static {
18    type Tensor<D: Dimension>: Clone;
19
20    fn zeros<D: Dimension>(shape: D) -> Self::Tensor<D>;
21    fn random_uniform<D: Dimension>(shape: D, low: f32, high: f32) -> Self::Tensor<D>;
22    fn random_normal<D: Dimension>(shape: D, mean: f32, std: f32) -> Self::Tensor<D>;
23    fn from_array<D: Dimension>(array: ndarray::Array<f32, D>) -> Self::Tensor<D>;
24    fn to_array<D: Dimension>(tensor: &Self::Tensor<D>) -> ndarray::Array<f32, D>;
25
26    fn unary<D: Dimension>(tensor: &Self::Tensor<D>, op: u32) -> Self::Tensor<D>;
27
28    fn add<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
29    fn sub<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
30    fn mul<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
31    fn div<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
32    fn scale<D: Dimension>(tensor: &Self::Tensor<D>, scalar: f32) -> Self::Tensor<D>;
33    fn scalar_sub<D: Dimension>(scalar: f32, tensor: &Self::Tensor<D>) -> Self::Tensor<D>;
34    fn scalar_max<D: Dimension>(tensor: &Self::Tensor<D>, s: f32) -> Self::Tensor<D>;
35    fn scalar_min<D: Dimension>(tensor: &Self::Tensor<D>, s: f32) -> Self::Tensor<D>;
36    fn clamp<D: Dimension>(tensor: &Self::Tensor<D>, low: f32, high: f32) -> Self::Tensor<D> {
37        Self::scalar_min(&Self::scalar_max(tensor, low), high)
38    }
39
40    fn mean<D: Dimension>(tensor: &Self::Tensor<D>) -> Option<f32>;
41    fn sum_axis<D: Dimension + RemoveAxis>(
42        tensor: &Self::Tensor<D>,
43        axis: usize,
44    ) -> Self::Tensor<D::Smaller>;
45
46    fn matmul<D1: Dimension, D2: Dimension>(
47        a: &Self::Tensor<D1>,
48        b: &Self::Tensor<D2>,
49    ) -> Self::Tensor<D1>;
50    fn transpose<D: Dimension>(
51        tensor: &Self::Tensor<D>,
52        axis1: usize,
53        axis2: usize,
54    ) -> Self::Tensor<D>;
55    fn broadcast_add<D1: Dimension, D2: Dimension>(
56        a: &Self::Tensor<D1>,
57        b: &Self::Tensor<D2>,
58    ) -> Self::Tensor<D1>;
59
60    fn softmax<D: Dimension>(tensor: &Self::Tensor<D>) -> Self::Tensor<D>;
61    fn softmax_vjp<D: Dimension>(
62        z: &Self::Tensor<D>,
63        grad_output: &Self::Tensor<D>,
64    ) -> Self::Tensor<D>;
65
66    fn assign<D: Dimension>(dst: &mut Self::Tensor<D>, src: Self::Tensor<D>);
67    fn shape<D: Dimension>(tensor: &Self::Tensor<D>) -> Vec<usize>;
68    fn len_of<D: Dimension>(tensor: &Self::Tensor<D>, axis: usize) -> usize;
69    fn select<D: Dimension + RemoveAxis>(
70        tensor: &Self::Tensor<D>,
71        axis: usize,
72        indices: &[usize],
73    ) -> Self::Tensor<D>;
74
75    fn flush() {}
76}
77
78pub mod unary_ops {
79    pub const TANH: u32 = 0;
80    pub const SIGMOID: u32 = 1;
81    pub const RELU: u32 = 2;
82    pub const TANH_DERIV: u32 = 3;
83    pub const SIGMOID_DERIV: u32 = 4;
84    pub const RELU_DERIV: u32 = 5;
85    pub const EXP: u32 = 6;
86    pub const LN: u32 = 7;
87    pub const ABS: u32 = 8;
88    pub const NEG: u32 = 9;
89    pub const SQRT: u32 = 10;
90}