gloss_burn_multibackend/ops/
int_tensor.rs

1#![allow(unreachable_patterns)]
2
3use std::ops::Range;
4
5use burn::tensor::{ops::IntTensorOps, Distribution, Shape, TensorData};
6
7#[cfg(feature = "burn-candle")]
8use crate::backend::CandleBackend;
9#[cfg(feature = "burn-ndarray")]
10use crate::backend::NdArrayBackend;
11#[cfg(feature = "burn-wgpu")]
12use crate::backend::WgpuBackend;
13use crate::{
14    backend::{MultiBackend, MultiDevice},
15    tensor::{MultiBoolTensor, MultiFloatTensor, MultiIntTensor},
16};
17
18#[allow(unused_variables)]
19impl IntTensorOps<Self> for MultiBackend {
20    fn int_from_data(data: TensorData, device: &MultiDevice) -> MultiIntTensor {
21        let data = match device {
22            #[cfg(feature = "burn-candle")]
23            MultiDevice::Candle(dev) => data.convert_dtype(burn::tensor::DType::I64),
24            #[cfg(feature = "burn-ndarray")]
25            MultiDevice::NdArray(d) => data.convert_dtype(burn::tensor::DType::I32),
26            #[cfg(feature = "burn-wgpu")]
27            MultiDevice::Wgpu(d) => data.convert_dtype(burn::tensor::DType::I32),
28        };
29        ops_rest_device!(int(data ; device) => int_from_data)
30    }
31    fn int_repeat_dim(tensor: MultiIntTensor, dim: usize, times: usize) -> MultiIntTensor {
32        ops_tensor_rest!(int(tensor, dim, times) => int_repeat_dim)
33    }
34    async fn int_into_data(tensor: MultiIntTensor) -> TensorData {
35        match tensor {
36            #[cfg(feature = "burn-candle")]
37            MultiIntTensor::Candle(t) => <CandleBackend as IntTensorOps<CandleBackend>>::int_into_data(t).await,
38            #[cfg(feature = "burn-ndarray")]
39            MultiIntTensor::NdArray(t) => <NdArrayBackend as IntTensorOps<NdArrayBackend>>::int_into_data(t).await,
40            #[cfg(feature = "burn-wgpu")]
41            MultiIntTensor::Wgpu(t) => <WgpuBackend as IntTensorOps<WgpuBackend>>::int_into_data(t).await,
42        }
43    }
44    fn int_to_device(tensor: MultiIntTensor, device: &MultiDevice) -> MultiIntTensor {
45        unimplemented!()
46    }
47    fn int_reshape(tensor: MultiIntTensor, shape: Shape) -> MultiIntTensor {
48        ops_tensor_rest!(int(tensor, shape) => int_reshape)
49    }
50    fn int_device(tensor: &MultiIntTensor) -> MultiDevice {
51        match tensor {
52            #[cfg(feature = "burn-candle")]
53            MultiIntTensor::Candle(t) => MultiDevice::Candle(<CandleBackend as IntTensorOps<CandleBackend>>::int_device(t)),
54            #[cfg(feature = "burn-ndarray")]
55            MultiIntTensor::NdArray(t) => MultiDevice::NdArray(<NdArrayBackend as IntTensorOps<NdArrayBackend>>::int_device(t)),
56            #[cfg(feature = "burn-wgpu")]
57            MultiIntTensor::Wgpu(t) => MultiDevice::Wgpu(<WgpuBackend as IntTensorOps<WgpuBackend>>::int_device(t)),
58        }
59    }
60    fn int_empty(shape: Shape, device: &MultiDevice) -> MultiIntTensor {
61        unimplemented!()
62    }
63    fn int_slice(tensor: MultiIntTensor, ranges: &[Range<usize>]) -> MultiIntTensor {
64        ops_tensor_rest!(int(tensor, ranges) => int_slice)
65    }
66    fn int_slice_assign(tensor: MultiIntTensor, ranges: &[Range<usize>], value: MultiIntTensor) -> MultiIntTensor {
67        ops_tensor_other_values!(int(tensor, ranges, value) => int_slice_assign)
68    }
69    fn int_cat(tensors: Vec<MultiIntTensor>, dim: usize) -> MultiIntTensor {
70        unimplemented!()
71    }
72    // fn int_matmul(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
73    //     unimplemented!()
74    // }
75    fn int_equal(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
76        unimplemented!()
77    }
78    fn int_equal_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
79        unimplemented!()
80    }
81    fn int_greater(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
82        unimplemented!()
83    }
84    fn int_greater_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
85        unimplemented!()
86    }
87    fn int_greater_equal(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
88        unimplemented!()
89    }
90    fn int_greater_equal_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
91        unimplemented!()
92    }
93    fn int_lower(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
94        unimplemented!()
95    }
96    fn int_lower_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
97        unimplemented!()
98    }
99    fn int_lower_equal(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiBoolTensor {
100        unimplemented!()
101    }
102    fn int_lower_equal_elem(lhs: MultiIntTensor, rhs: i32) -> MultiBoolTensor {
103        unimplemented!()
104    }
105    fn int_add(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
106        unimplemented!()
107    }
108    fn int_add_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
109        unimplemented!()
110    }
111    fn int_sub(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
112        unimplemented!()
113    }
114    fn int_sub_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
115        ops_tensor_scalar!(int(lhs, rhs) => int_sub_scalar)
116        // unimplemented!()
117    }
118    fn int_mul(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
119        unimplemented!()
120    }
121    fn int_mul_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
122        unimplemented!()
123    }
124    fn int_div(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
125        unimplemented!()
126    }
127    fn int_div_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
128        unimplemented!()
129    }
130    fn int_remainder(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
131        unimplemented!()
132    }
133    fn int_remainder_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
134        unimplemented!()
135    }
136    fn int_neg(tensor: MultiIntTensor) -> MultiIntTensor {
137        unimplemented!()
138    }
139    fn int_zeros(shape: Shape, device: &MultiDevice) -> MultiIntTensor {
140        unimplemented!()
141    }
142    fn int_ones(shape: Shape, device: &MultiDevice) -> MultiIntTensor {
143        unimplemented!()
144    }
145    fn int_full(shape: Shape, fill_value: i32, device: &MultiDevice) -> MultiIntTensor {
146        unimplemented!()
147    }
148    fn int_sum(tensor: MultiIntTensor) -> MultiIntTensor {
149        unimplemented!()
150    }
151    fn int_sum_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
152        unimplemented!()
153    }
154    fn int_prod(tensor: MultiIntTensor) -> MultiIntTensor {
155        unimplemented!()
156    }
157    fn int_prod_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
158        unimplemented!()
159    }
160    fn int_mean(tensor: MultiIntTensor) -> MultiIntTensor {
161        unimplemented!()
162    }
163    fn int_mean_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
164        unimplemented!()
165    }
166    fn int_gather(dim: usize, tensor: MultiIntTensor, indices: MultiIntTensor) -> MultiIntTensor {
167        unimplemented!()
168    }
169    fn int_scatter(dim: usize, tensor: MultiIntTensor, indices: MultiIntTensor, value: MultiIntTensor) -> MultiIntTensor {
170        unimplemented!()
171    }
172    fn int_select(tensor: MultiIntTensor, dim: usize, indices: MultiIntTensor) -> MultiIntTensor {
173        ops_tensor_dim_indices!(int(tensor, dim, indices) => int_select)
174    }
175    fn int_select_assign(tensor: MultiIntTensor, dim: usize, indices: MultiIntTensor, value: MultiIntTensor) -> MultiIntTensor {
176        unimplemented!()
177    }
178    fn int_mask_where(tensor: MultiIntTensor, mask: MultiBoolTensor, source: MultiIntTensor) -> MultiIntTensor {
179        unimplemented!()
180    }
181    fn int_mask_fill(tensor: MultiIntTensor, mask: MultiBoolTensor, value: i32) -> MultiIntTensor {
182        unimplemented!()
183    }
184    fn int_argmax(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
185        unimplemented!()
186    }
187    fn int_argmin(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
188        unimplemented!()
189    }
190    fn int_max_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
191        unimplemented!()
192    }
193    fn int_max_dim_with_indices(tensor: MultiIntTensor, dim: usize) -> (MultiIntTensor, MultiIntTensor) {
194        unimplemented!()
195    }
196    fn int_min_dim(tensor: MultiIntTensor, dim: usize) -> MultiIntTensor {
197        unimplemented!()
198    }
199    fn int_min_dim_with_indices(tensor: MultiIntTensor, dim: usize) -> (MultiIntTensor, MultiIntTensor) {
200        unimplemented!()
201    }
202    fn int_clamp_min(tensor: MultiIntTensor, min: i32) -> MultiIntTensor {
203        unimplemented!()
204    }
205    fn int_clamp_max(tensor: MultiIntTensor, max: i32) -> MultiIntTensor {
206        unimplemented!()
207    }
208    fn int_clamp(tensor: MultiIntTensor, min: i32, max: i32) -> MultiIntTensor {
209        unimplemented!()
210    }
211    fn int_abs(tensor: MultiIntTensor) -> MultiIntTensor {
212        unimplemented!()
213    }
214    fn int_into_float(tensor: MultiIntTensor) -> MultiFloatTensor {
215        unimplemented!()
216    }
217    fn int_swap_dims(tensor: MultiIntTensor, dim1: usize, dim2: usize) -> MultiIntTensor {
218        unimplemented!()
219    }
220    fn int_random(shape: Shape, distribution: Distribution, device: &MultiDevice) -> MultiIntTensor {
221        unimplemented!()
222    }
223    fn int_arange(range: Range<i64>, device: &MultiDevice) -> MultiIntTensor {
224        ops_rest_device!(int(range ; device) => int_arange)
225    }
226    fn int_permute(tensor: MultiIntTensor, axes: &[usize]) -> MultiIntTensor {
227        unimplemented!()
228    }
229    fn int_flip(tensor: MultiIntTensor, axes: &[usize]) -> MultiIntTensor {
230        unimplemented!()
231    }
232    fn int_sign(tensor: MultiIntTensor) -> MultiIntTensor {
233        unimplemented!()
234    }
235    fn int_expand(tensor: MultiIntTensor, shape: Shape) -> MultiIntTensor {
236        ops_tensor_rest!(int(tensor, shape) => int_expand)
237    }
238    fn int_sort(tensor: MultiIntTensor, dim: usize, descending: bool) -> MultiIntTensor {
239        unimplemented!()
240    }
241    fn int_argsort(tensor: MultiIntTensor, dim: usize, descending: bool) -> MultiIntTensor {
242        unimplemented!()
243    }
244    fn bitwise_and(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
245        unimplemented!()
246    }
247
248    fn bitwise_or(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
249        unimplemented!()
250    }
251
252    fn bitwise_xor(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
253        unimplemented!()
254    }
255
256    fn bitwise_not(tensor: MultiIntTensor) -> MultiIntTensor {
257        unimplemented!()
258    }
259
260    fn bitwise_and_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
261        unimplemented!()
262    }
263
264    fn bitwise_or_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
265        unimplemented!()
266    }
267
268    fn bitwise_xor_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
269        unimplemented!()
270    }
271
272    fn bitwise_left_shift(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
273        unimplemented!()
274    }
275
276    fn bitwise_right_shift(lhs: MultiIntTensor, rhs: MultiIntTensor) -> MultiIntTensor {
277        unimplemented!()
278    }
279
280    fn bitwise_left_shift_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
281        unimplemented!()
282    }
283
284    fn bitwise_right_shift_scalar(lhs: MultiIntTensor, rhs: i32) -> MultiIntTensor {
285        unimplemented!()
286    }
287}