candle_core/
backend.rs

1//! Traits to Define Backend Behavior
2//!
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{CpuStorage, DType, Layout, Result, Shape};
5
6pub trait BackendStorage: Sized {
7    type Device: BackendDevice;
8
9    fn try_clone(&self, _: &Layout) -> Result<Self>;
10
11    fn dtype(&self) -> DType;
12
13    fn device(&self) -> &Self::Device;
14
15    // Maybe this should return a Cow instead so that no copy is done on the cpu case.
16    fn to_cpu_storage(&self) -> Result<CpuStorage>;
17
18    fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
19
20    fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
21
22    fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
23
24    fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
25
26    fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
27
28    fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
29
30    fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
31
32    fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
33
34    fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
35
36    fn conv1d(
37        &self,
38        _l: &Layout,
39        _kernel: &Self,
40        _kernel_l: &Layout,
41        _params: &crate::conv::ParamsConv1D,
42    ) -> Result<Self>;
43
44    fn conv_transpose1d(
45        &self,
46        _l: &Layout,
47        _kernel: &Self,
48        _kernel_l: &Layout,
49        _params: &crate::conv::ParamsConvTranspose1D,
50    ) -> Result<Self>;
51
52    fn conv2d(
53        &self,
54        _l: &Layout,
55        _kernel: &Self,
56        _kernel_l: &Layout,
57        _params: &crate::conv::ParamsConv2D,
58    ) -> Result<Self>;
59
60    fn conv_transpose2d(
61        &self,
62        _l: &Layout,
63        _kernel: &Self,
64        _kernel_l: &Layout,
65        _params: &crate::conv::ParamsConvTranspose2D,
66    ) -> Result<Self>;
67
68    fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
69    fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
70    fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
71    fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
72
73    fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
74
75    fn scatter_set(
76        &mut self,
77        _: &Layout,
78        _: &Self,
79        _: &Layout,
80        _: &Self,
81        _: &Layout,
82        _: usize,
83    ) -> Result<()>;
84
85    fn scatter_add_set(
86        &mut self,
87        _: &Layout,
88        _: &Self,
89        _: &Layout,
90        _: &Self,
91        _: &Layout,
92        _: usize,
93    ) -> Result<()>;
94
95    fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
96    fn index_add(
97        &self,
98        _: &Layout,
99        _: &Self,
100        _: &Layout,
101        _: &Self,
102        _: &Layout,
103        _: usize,
104    ) -> Result<Self>;
105
106    fn matmul(
107        &self,
108        _: &Self,
109        _: (usize, usize, usize, usize),
110        _: &Layout,
111        _: &Layout,
112    ) -> Result<Self>;
113
114    fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
115
116    #[allow(clippy::too_many_arguments)]
117    // Similar to cudaMemcpy2D, though values are in elements and not in bytes.
118    fn copy2d(
119        &self,
120        _: &mut Self,
121        _d1: usize,
122        _d2: usize,
123        _src_stride1: usize,
124        _dst_stride1: usize,
125        _src_offset: usize,
126        _dst_offset: usize,
127    ) -> Result<()>;
128
129    fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
130}
131
132pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
133    type Storage: BackendStorage;
134
135    // TODO: Make the usize generic and part of a generic DeviceLocation.
136    fn new(_: usize) -> Result<Self>;
137
138    fn location(&self) -> crate::DeviceLocation;
139
140    fn same_device(&self, _: &Self) -> bool;
141
142    fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
143
144    /// # Safety
145    /// This function is unsafe as it doesn't initialize the underlying data store.
146    /// The caller should ensure that the data is properly initialized as early as possible
147    /// after this call.
148    unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
149
150    fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
151
152    fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
153
154    fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
155
156    fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
157
158    fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
159
160    fn set_seed(&self, _: u64) -> Result<()>;
161
162    /// Synchronize should block until all the operations on the device are completed.
163    fn synchronize(&self) -> Result<()>;
164}