candle_core/
backend.rs

1//! Traits to Define Backend Behavior
2//!
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{CpuStorage, DType, Layout, Result, Shape};
5
6pub trait BackendStorage: Sized {
7    type Device: BackendDevice;
8
9    fn try_clone(&self, _: &Layout) -> Result<Self>;
10
11    fn dtype(&self) -> DType;
12
13    fn device(&self) -> &Self::Device;
14
15    // Maybe this should return a Cow instead so that no copy is done on the cpu case.
16    fn to_cpu_storage(&self) -> Result<CpuStorage>;
17
18    fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
19
20    fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
21
22    fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
23
24    fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
25
26    fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
27
28    fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
29
30    fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
31
32    fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
33
34    fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
35
36    fn conv1d(
37        &self,
38        _l: &Layout,
39        _kernel: &Self,
40        _kernel_l: &Layout,
41        _params: &crate::conv::ParamsConv1D,
42    ) -> Result<Self>;
43
44    fn conv_transpose1d(
45        &self,
46        _l: &Layout,
47        _kernel: &Self,
48        _kernel_l: &Layout,
49        _params: &crate::conv::ParamsConvTranspose1D,
50    ) -> Result<Self>;
51
52    fn conv2d(
53        &self,
54        _l: &Layout,
55        _kernel: &Self,
56        _kernel_l: &Layout,
57        _params: &crate::conv::ParamsConv2D,
58    ) -> Result<Self>;
59
60    fn conv_transpose2d(
61        &self,
62        _l: &Layout,
63        _kernel: &Self,
64        _kernel_l: &Layout,
65        _params: &crate::conv::ParamsConvTranspose2D,
66    ) -> Result<Self>;
67
68    fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
69    fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
70    fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
71    fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
72    fn upsample_bilinear2d(
73        &self,
74        _: &Layout,
75        _: usize,
76        _: usize,
77        _: bool,
78        _: Option<f64>,
79        _: Option<f64>,
80    ) -> Result<Self>;
81
82    fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
83
84    fn scatter_set(
85        &mut self,
86        _: &Layout,
87        _: &Self,
88        _: &Layout,
89        _: &Self,
90        _: &Layout,
91        _: usize,
92    ) -> Result<()>;
93
94    fn scatter_add_set(
95        &mut self,
96        _: &Layout,
97        _: &Self,
98        _: &Layout,
99        _: &Self,
100        _: &Layout,
101        _: usize,
102    ) -> Result<()>;
103
104    fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
105    fn index_add(
106        &self,
107        _: &Layout,
108        _: &Self,
109        _: &Layout,
110        _: &Self,
111        _: &Layout,
112        _: usize,
113    ) -> Result<Self>;
114
115    fn matmul(
116        &self,
117        _: &Self,
118        _: (usize, usize, usize, usize),
119        _: &Layout,
120        _: &Layout,
121    ) -> Result<Self>;
122
123    fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
124
125    #[allow(clippy::too_many_arguments)]
126    // Similar to cudaMemcpy2D, though values are in elements and not in bytes.
127    fn copy2d(
128        &self,
129        _: &mut Self,
130        _d1: usize,
131        _d2: usize,
132        _src_stride1: usize,
133        _dst_stride1: usize,
134        _src_offset: usize,
135        _dst_offset: usize,
136    ) -> Result<()>;
137
138    fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
139}
140
141pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
142    type Storage: BackendStorage;
143
144    // TODO: Make the usize generic and part of a generic DeviceLocation.
145    fn new(_: usize) -> Result<Self>;
146
147    fn location(&self) -> crate::DeviceLocation;
148
149    fn same_device(&self, _: &Self) -> bool;
150
151    fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
152
153    /// # Safety
154    /// This function is unsafe as it doesn't initialize the underlying data store.
155    /// The caller should ensure that the data is properly initialized as early as possible
156    /// after this call.
157    unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
158
159    fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
160
161    fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
162
163    fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
164
165    fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
166
167    fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
168
169    fn set_seed(&self, _: u64) -> Result<()>;
170    fn get_current_seed(&self) -> Result<u64>;
171
172    /// Synchronize should block until all the operations on the device are completed.
173    fn synchronize(&self) -> Result<()>;
174}