1use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{CpuStorage, DType, Layout, Result, Shape};
5
6pub trait BackendStorage: Sized {
7 type Device: BackendDevice;
8
9 fn try_clone(&self, _: &Layout) -> Result<Self>;
10
11 fn dtype(&self) -> DType;
12
13 fn device(&self) -> &Self::Device;
14
15 fn to_cpu_storage(&self) -> Result<CpuStorage>;
17
18 fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
19
20 fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
21
22 fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
23
24 fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
25
26 fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
27
28 fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
29
30 fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
31
32 fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
33
34 fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
35
36 fn conv1d(
37 &self,
38 _l: &Layout,
39 _kernel: &Self,
40 _kernel_l: &Layout,
41 _params: &crate::conv::ParamsConv1D,
42 ) -> Result<Self>;
43
44 fn conv_transpose1d(
45 &self,
46 _l: &Layout,
47 _kernel: &Self,
48 _kernel_l: &Layout,
49 _params: &crate::conv::ParamsConvTranspose1D,
50 ) -> Result<Self>;
51
52 fn conv2d(
53 &self,
54 _l: &Layout,
55 _kernel: &Self,
56 _kernel_l: &Layout,
57 _params: &crate::conv::ParamsConv2D,
58 ) -> Result<Self>;
59
60 fn conv_transpose2d(
61 &self,
62 _l: &Layout,
63 _kernel: &Self,
64 _kernel_l: &Layout,
65 _params: &crate::conv::ParamsConvTranspose2D,
66 ) -> Result<Self>;
67
68 fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
69 fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
70 fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
71 fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
72
73 fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
74
75 fn scatter_set(
76 &mut self,
77 _: &Layout,
78 _: &Self,
79 _: &Layout,
80 _: &Self,
81 _: &Layout,
82 _: usize,
83 ) -> Result<()>;
84
85 fn scatter_add_set(
86 &mut self,
87 _: &Layout,
88 _: &Self,
89 _: &Layout,
90 _: &Self,
91 _: &Layout,
92 _: usize,
93 ) -> Result<()>;
94
95 fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
96 fn index_add(
97 &self,
98 _: &Layout,
99 _: &Self,
100 _: &Layout,
101 _: &Self,
102 _: &Layout,
103 _: usize,
104 ) -> Result<Self>;
105
106 fn matmul(
107 &self,
108 _: &Self,
109 _: (usize, usize, usize, usize),
110 _: &Layout,
111 _: &Layout,
112 ) -> Result<Self>;
113
114 fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
115
116 #[allow(clippy::too_many_arguments)]
117 fn copy2d(
119 &self,
120 _: &mut Self,
121 _d1: usize,
122 _d2: usize,
123 _src_stride1: usize,
124 _dst_stride1: usize,
125 _src_offset: usize,
126 _dst_offset: usize,
127 ) -> Result<()>;
128
129 fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
130}
131
132pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
133 type Storage: BackendStorage;
134
135 fn new(_: usize) -> Result<Self>;
137
138 fn location(&self) -> crate::DeviceLocation;
139
140 fn same_device(&self, _: &Self) -> bool;
141
142 fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
143
144 unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
149
150 fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
151
152 fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
153
154 fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
155
156 fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
157
158 fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
159
160 fn set_seed(&self, _: u64) -> Result<()>;
161
162 fn synchronize(&self) -> Result<()>;
164}