gloss_burn_multibackend/ops/
int_tensor.rs1#![allow(unreachable_patterns)]
2
3use std::ops::Range;
4
5use burn::tensor::{ops::IntTensorOps, Distribution, Shape, TensorData};
6
7#[cfg(feature = "burn-candle")]
8use crate::backend::CandleBackend;
9#[cfg(feature = "burn-ndarray")]
10use crate::backend::NdArrayBackend;
11#[cfg(feature = "burn-wgpu")]
12use crate::backend::WgpuBackend;
13use crate::{
14 backend::{MultiBackend, MultiDevice},
15 tensor::{MultiBoolTensor, MultiFloatTensor, MultiIntTensor},
16};
17
18#[allow(unused_variables)]
19impl IntTensorOps<Self> for MultiBackend {
20 fn int_from_data(data: TensorData, device: &MultiDevice) -> MultiIntTensor {
21 let data = match device {
22 #[cfg(feature = "burn-candle")]
23 MultiDevice::Candle(dev) => data.convert_dtype(burn::tensor::DType::I64),
24 #[cfg(feature = "burn-ndarray")]
25 MultiDevice::NdArray(d) => data.convert_dtype(burn::tensor::DType::I32),
26 #[cfg(feature = "burn-wgpu")]
27 MultiDevice::Wgpu(d) => data.convert_dtype(burn::tensor::DType::I32),
28 };
29 ops_rest_device!(int(data ; device) => int_from_data)
30 }
31 fn int_repeat_dim(tensor: MultiIntTensor, dim: usize, times: usize) -> MultiIntTensor {
32 ops_tensor_rest!(int(tensor, dim, times) => int_repeat_dim)
33 }
34 async fn int_into_data(tensor: MultiIntTensor) -> TensorData {
35 match tensor {
36 #[cfg(feature = "burn-candle")]
37 MultiIntTensor::Candle(t) => <CandleBackend as IntTensorOps<CandleBackend>>::int_into_data(t).await,
38 #[cfg(feature = "burn-ndarray")]
39 MultiIntTensor::NdArray(t) => <NdArrayBackend as IntTensorOps<NdArrayBackend>>::int_into_data(t).await,
40 #[cfg(feature = "burn-wgpu")]
41 MultiIntTensor::Wgpu(t) => <WgpuBackend as IntTensorOps<WgpuBackend>>::int_into_data(t).await,
42 }
43 }
44 fn int_to_device(tensor: MultiIntTensor, device: &MultiDevice) -> MultiIntTensor {
45 unimplemented!()
46 }
47 fn int_reshape(tensor: MultiIntTensor, shape: Shape) -> MultiIntTensor {
48 ops_tensor_rest!(int(tensor, shape) => int_reshape)
49 }
50 fn int_device(tensor: &MultiIntTensor) -> MultiDevice {
51 match tensor {
52 #[cfg(feature = "burn-candle")]
53 MultiIntTensor::Candle(t) => MultiDevice::Candle(<CandleBackend as IntTensorOps<CandleBackend>>::int_device(t)),
54 #[cfg(feature = "burn-ndarray")]
55 MultiIntTensor::NdArray(t) => MultiDevice::NdArray(<NdArrayBackend as IntTensorOps<NdArrayBackend>>::int_device(t)),
56 #[cfg(feature = "burn-wgpu")]
57 MultiIntTensor::Wgpu(t) => MultiDevice::Wgpu(<WgpuBackend as IntTensorOps<WgpuBackend>>::int_device(t)),
58 }
59 }
60 fn int_empty(shape: Shape, device: &MultiDevice) -> MultiIntTensor {
61 unimplemented!()
62 }
63 fn int_slice(tensor: MultiIntTensor, ranges: &[Range<usize>]) -> MultiIntTensor {
64 ops_tensor_rest!(int(tensor, ranges) => int_slice)
65 }
66 fn int_slice_assign(tensor: MultiIntTensor, ranges: &[Range<usize>], value: MultiIntTensor) -> MultiIntTensor {
67 ops_tensor_other_values!(int(tensor, ranges, value) => int_slice_assign)
68 }
69 fn int_cat(tensors: Vec<MultiIntTensor>, dim: usize) -> MultiIntTensor {
70 unimplemented!()
71 }
72 fn int_equal(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
76 unimplemented!()
77 }
78 fn int_equal_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
79 unimplemented!()
80 }
81 fn int_greater(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
82 unimplemented!()
83 }
84 fn int_greater_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
85 unimplemented!()
86 }
87 fn int_greater_equal(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
88 unimplemented!()
89 }
90 fn int_greater_equal_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
91 unimplemented!()
92 }
93 fn int_lower(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
94 unimplemented!()
95 }
96 fn int_lower_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
97 unimplemented!()
98 }
99 fn int_lower_equal(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
100 unimplemented!()
101 }
102 fn int_lower_equal_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
103 unimplemented!()
104 }
105 fn int_add(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
106 unimplemented!()
107 }
108 fn int_add_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
109 unimplemented!()
110 }
111 fn int_sub(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
112 unimplemented!()
113 }
114 fn int_sub_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
115 ops_tensor_scalar!(int(lhs, rhs) => int_sub_scalar)
116 }
118 fn int_mul(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
119 unimplemented!()
120 }
121 fn int_mul_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
122 unimplemented!()
123 }
124 fn int_div(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
125 unimplemented!()
126 }
127 fn int_div_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
128 unimplemented!()
129 }
130 fn int_remainder(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
131 unimplemented!()
132 }
133 fn int_remainder_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
134 unimplemented!()
135 }
136 fn int_neg(tensor: MultiIntTensor) -> MultiIntTensor {
137 unimplemented!()
138 }
139 fn int_zeros(shape: Shape, device: &MultiDevice) -> MultiIntTensor {
140 unimplemented!()
141 }
142 fn int_ones(shape: Shape, device: &MultiDevice) -> MultiIntTensor {
143 unimplemented!()
144 }
145 fn int_full(shape: Shape, fill_value: i32, device: &MultiDevice) -> MultiIntTensor {
146 unimplemented!()
147 }
148 fn int_sum(tensor: MultiIntTensor) -> MultiIntTensor {
149 unimplemented!()
150 }
151 fn int_sum_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
152 unimplemented!()
153 }
154 fn int_prod(tensor: MultiIntTensor) -> MultiIntTensor {
155 unimplemented!()
156 }
157 fn int_prod_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
158 unimplemented!()
159 }
160 fn int_mean(tensor: MultiIntTensor) -> MultiIntTensor {
161 unimplemented!()
162 }
163 fn int_mean_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
164 unimplemented!()
165 }
166 fn int_gather(dim: usize, tensor: MultiIntTensor, indices: MultiIntTensor) -> MultiIntTensor {
167 unimplemented!()
168 }
169 fn int_scatter(dim: usize, tensor: MultiIntTensor, indices: MultiIntTensor, value: MultiIntTensor) -> MultiIntTensor {
170 unimplemented!()
171 }
172 fn int_select(tensor: MultiIntTensor, dim: usize, indices: MultiIntTensor) -> MultiIntTensor {
173 ops_tensor_dim_indices!(int(tensor, dim, indices) => int_select)
174 }
175 fn int_select_assign(tensor: MultiIntTensor, dim: usize, indices: MultiIntTensor, value: MultiIntTensor) -> MultiIntTensor {
176 unimplemented!()
177 }
178 fn int_mask_where(tensor: MultiIntTensor, mask: MultiBoolTensor, source: MultiIntTensor) -> MultiIntTensor {
179 unimplemented!()
180 }
181 fn int_mask_fill(tensor: MultiIntTensor, mask: MultiBoolTensor, value: i32) -> MultiIntTensor {
182 unimplemented!()
183 }
184 fn int_argmax(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
185 unimplemented!()
186 }
187 fn int_argmin(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
188 unimplemented!()
189 }
190 fn int_max_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
191 unimplemented!()
192 }
193 fn int_max_dim_with_indices(tensor: MultiIntTensor, dim: usize) -> (MultiIntTensor, MultiIntTensor) {
194 unimplemented!()
195 }
196 fn int_min_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
197 unimplemented!()
198 }
199 fn int_min_dim_with_indices(tensor: MultiIntTensor, dim: usize) -> (MultiIntTensor, MultiIntTensor) {
200 unimplemented!()
201 }
202 fn int_clamp_min(tensor: MultiIntTensor, min: i32) -> MultiIntTensor {
203 unimplemented!()
204 }
205 fn int_clamp_max(tensor: MultiIntTensor, max: i32) -> MultiIntTensor {
206 unimplemented!()
207 }
208 fn int_clamp(tensor: MultiIntTensor, min: i32, max: i32) -> MultiIntTensor {
209 unimplemented!()
210 }
211 fn int_abs(tensor: MultiIntTensor) -> MultiIntTensor {
212 unimplemented!()
213 }
214 fn int_into_float(tensor: MultiIntTensor) -> MultiFloatTensor {
215 unimplemented!()
216 }
217 fn int_swap_dims(tensor: MultiIntTensor, dim1: usize, dim2: usize) -> MultiIntTensor {
218 unimplemented!()
219 }
220 fn int_random(shape: Shape, distribution: Distribution, device: &MultiDevice) -> MultiIntTensor {
221 unimplemented!()
222 }
223 fn int_arange(range: Range<i64>, device: &MultiDevice) -> MultiIntTensor {
224 ops_rest_device!(int(range ; device) => int_arange)
225 }
226 fn int_permute(tensor: MultiIntTensor, axes: &[usize]) -> MultiIntTensor {
227 unimplemented!()
228 }
229 fn int_flip(tensor: MultiIntTensor, axes: &[usize]) -> MultiIntTensor {
230 unimplemented!()
231 }
232 fn int_sign(tensor: MultiIntTensor) -> MultiIntTensor {
233 unimplemented!()
234 }
235 fn int_expand(tensor: MultiIntTensor, shape: Shape) -> MultiIntTensor {
236 ops_tensor_rest!(int(tensor, shape) => int_expand)
237 }
238 fn int_sort(tensor: MultiIntTensor, dim: usize, descending: bool) -> MultiIntTensor {
239 unimplemented!()
240 }
241 fn int_argsort(tensor: MultiIntTensor, dim: usize, descending: bool) -> MultiIntTensor {
242 unimplemented!()
243 }
244 fn bitwise_and(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
245 unimplemented!()
246 }
247
248 fn bitwise_or(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
249 unimplemented!()
250 }
251
252 fn bitwise_xor(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
253 unimplemented!()
254 }
255
256 fn bitwise_not(tensor: MultiIntTensor) -> MultiIntTensor {
257 unimplemented!()
258 }
259
260 fn bitwise_and_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
261 unimplemented!()
262 }
263
264 fn bitwise_or_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
265 unimplemented!()
266 }
267
268 fn bitwise_xor_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
269 unimplemented!()
270 }
271
272 fn bitwise_left_shift(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
273 unimplemented!()
274 }
275
276 fn bitwise_right_shift(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
277 unimplemented!()
278 }
279
280 fn bitwise_left_shift_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
281 unimplemented!()
282 }
283
284 fn bitwise_right_shift_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
285 unimplemented!()
286 }
287}