1use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{CpuStorage, DType, Layout, Result, Shape};
5
6pub trait BackendStorage: Sized {
7 type Device: BackendDevice;
8
9 fn try_clone(&self, _: &Layout) -> Result<Self>;
10
11 fn dtype(&self) -> DType;
12
13 fn device(&self) -> &Self::Device;
14
15 fn to_cpu_storage(&self) -> Result<CpuStorage>;
17
18 fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
19
20 fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
21
22 fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
23
24 fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
25
26 fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
27
28 fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
29
30 fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
31
32 fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
33
34 fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
35
36 fn conv1d(
37 &self,
38 _l: &Layout,
39 _kernel: &Self,
40 _kernel_l: &Layout,
41 _params: &crate::conv::ParamsConv1D,
42 ) -> Result<Self>;
43
44 fn conv_transpose1d(
45 &self,
46 _l: &Layout,
47 _kernel: &Self,
48 _kernel_l: &Layout,
49 _params: &crate::conv::ParamsConvTranspose1D,
50 ) -> Result<Self>;
51
52 fn conv2d(
53 &self,
54 _l: &Layout,
55 _kernel: &Self,
56 _kernel_l: &Layout,
57 _params: &crate::conv::ParamsConv2D,
58 ) -> Result<Self>;
59
60 fn conv_transpose2d(
61 &self,
62 _l: &Layout,
63 _kernel: &Self,
64 _kernel_l: &Layout,
65 _params: &crate::conv::ParamsConvTranspose2D,
66 ) -> Result<Self>;
67
68 fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
69 fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
70 fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
71 fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
72 fn upsample_bilinear2d(
73 &self,
74 _: &Layout,
75 _: usize,
76 _: usize,
77 _: bool,
78 _: Option<f64>,
79 _: Option<f64>,
80 ) -> Result<Self>;
81
82 fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
83
84 fn scatter_set(
85 &mut self,
86 _: &Layout,
87 _: &Self,
88 _: &Layout,
89 _: &Self,
90 _: &Layout,
91 _: usize,
92 ) -> Result<()>;
93
94 fn scatter_add_set(
95 &mut self,
96 _: &Layout,
97 _: &Self,
98 _: &Layout,
99 _: &Self,
100 _: &Layout,
101 _: usize,
102 ) -> Result<()>;
103
104 fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
105 fn index_add(
106 &self,
107 _: &Layout,
108 _: &Self,
109 _: &Layout,
110 _: &Self,
111 _: &Layout,
112 _: usize,
113 ) -> Result<Self>;
114
115 fn matmul(
116 &self,
117 _: &Self,
118 _: (usize, usize, usize, usize),
119 _: &Layout,
120 _: &Layout,
121 ) -> Result<Self>;
122
123 fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
124
125 #[allow(clippy::too_many_arguments)]
126 fn copy2d(
128 &self,
129 _: &mut Self,
130 _d1: usize,
131 _d2: usize,
132 _src_stride1: usize,
133 _dst_stride1: usize,
134 _src_offset: usize,
135 _dst_offset: usize,
136 ) -> Result<()>;
137
138 fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
139}
140
141pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
142 type Storage: BackendStorage;
143
144 fn new(_: usize) -> Result<Self>;
146
147 fn location(&self) -> crate::DeviceLocation;
148
149 fn same_device(&self, _: &Self) -> bool;
150
151 fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
152
153 unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
158
159 fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
160
161 fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
162
163 fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
164
165 fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
166
167 fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
168
169 fn set_seed(&self, _: u64) -> Result<()>;
170 fn get_current_seed(&self) -> Result<u64>;
171
172 fn synchronize(&self) -> Result<()>;
174}