1pub mod cpu;
2pub use cpu::CPUBackend;
3
4#[cfg(feature = "gpu")]
5pub mod gpu;
6#[cfg(feature = "gpu")]
7pub use gpu::GPUBackend;
8
9#[cfg(not(feature = "gpu"))]
10pub type DefaultBackend = CPUBackend;
11
12#[cfg(feature = "gpu")]
13pub type DefaultBackend = GPUBackend;
14
15use ndarray::{Dimension, RemoveAxis};
16
17pub trait Backend: Clone + 'static {
18 type Tensor<D: Dimension>: Clone;
19
20 fn zeros<D: Dimension>(shape: D) -> Self::Tensor<D>;
21 fn random_uniform<D: Dimension>(shape: D, low: f32, high: f32) -> Self::Tensor<D>;
22 fn random_normal<D: Dimension>(shape: D, mean: f32, std: f32) -> Self::Tensor<D>;
23 fn from_array<D: Dimension>(array: ndarray::Array<f32, D>) -> Self::Tensor<D>;
24 fn to_array<D: Dimension>(tensor: &Self::Tensor<D>) -> ndarray::Array<f32, D>;
25
26 fn unary<D: Dimension>(tensor: &Self::Tensor<D>, op: u32) -> Self::Tensor<D>;
27
28 fn add<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
29 fn sub<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
30 fn mul<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
31 fn div<D: Dimension>(a: &Self::Tensor<D>, b: &Self::Tensor<D>) -> Self::Tensor<D>;
32 fn scale<D: Dimension>(tensor: &Self::Tensor<D>, scalar: f32) -> Self::Tensor<D>;
33 fn scalar_sub<D: Dimension>(scalar: f32, tensor: &Self::Tensor<D>) -> Self::Tensor<D>;
34 fn scalar_max<D: Dimension>(tensor: &Self::Tensor<D>, s: f32) -> Self::Tensor<D>;
35 fn scalar_min<D: Dimension>(tensor: &Self::Tensor<D>, s: f32) -> Self::Tensor<D>;
36 fn clamp<D: Dimension>(tensor: &Self::Tensor<D>, low: f32, high: f32) -> Self::Tensor<D> {
37 Self::scalar_min(&Self::scalar_max(tensor, low), high)
38 }
39
40 fn mean<D: Dimension>(tensor: &Self::Tensor<D>) -> Option<f32>;
41 fn sum_axis<D: Dimension + RemoveAxis>(
42 tensor: &Self::Tensor<D>,
43 axis: usize,
44 ) -> Self::Tensor<D::Smaller>;
45
46 fn matmul<D1: Dimension, D2: Dimension>(
47 a: &Self::Tensor<D1>,
48 b: &Self::Tensor<D2>,
49 ) -> Self::Tensor<D1>;
50 fn transpose<D: Dimension>(
51 tensor: &Self::Tensor<D>,
52 axis1: usize,
53 axis2: usize,
54 ) -> Self::Tensor<D>;
55 fn broadcast_add<D1: Dimension, D2: Dimension>(
56 a: &Self::Tensor<D1>,
57 b: &Self::Tensor<D2>,
58 ) -> Self::Tensor<D1>;
59
60 fn softmax<D: Dimension>(tensor: &Self::Tensor<D>) -> Self::Tensor<D>;
61 fn softmax_vjp<D: Dimension>(
62 z: &Self::Tensor<D>,
63 grad_output: &Self::Tensor<D>,
64 ) -> Self::Tensor<D>;
65
66 fn assign<D: Dimension>(dst: &mut Self::Tensor<D>, src: Self::Tensor<D>);
67 fn shape<D: Dimension>(tensor: &Self::Tensor<D>) -> Vec<usize>;
68 fn len_of<D: Dimension>(tensor: &Self::Tensor<D>, axis: usize) -> usize;
69 fn select<D: Dimension + RemoveAxis>(
70 tensor: &Self::Tensor<D>,
71 axis: usize,
72 indices: &[usize],
73 ) -> Self::Tensor<D>;
74
75 fn flush() {}
76}
77
78pub mod unary_ops {
79 pub const TANH: u32 = 0;
80 pub const SIGMOID: u32 = 1;
81 pub const RELU: u32 = 2;
82 pub const TANH_DERIV: u32 = 3;
83 pub const SIGMOID_DERIV: u32 = 4;
84 pub const RELU_DERIV: u32 = 5;
85 pub const EXP: u32 = 6;
86 pub const LN: u32 = 7;
87 pub const ABS: u32 = 8;
88 pub const NEG: u32 = 9;
89 pub const SQRT: u32 = 10;
90}