gloss_burn_multibackend/ops/
float_tensor.rs

1#![allow(unreachable_patterns)]
2
3use std::ops::Range;
4
5use burn::tensor::{ops::FloatTensorOps, Distribution, FloatDType, Shape, TensorData};
6
7#[cfg(feature = "burn-candle")]
8use crate::backend::CandleBackend;
9#[cfg(feature = "burn-ndarray")]
10use crate::backend::NdArrayBackend;
11#[cfg(feature = "burn-wgpu")]
12use crate::backend::WgpuBackend;
13use crate::{
14    backend::{MultiBackend, MultiDevice},
15    tensor::{MultiBoolTensor, MultiFloatTensor, MultiIntTensor},
16};
17
18#[allow(unused_variables)]
19impl FloatTensorOps<Self> for MultiBackend {
20    fn float_from_data(data: TensorData, device: &MultiDevice) -> MultiFloatTensor {
21        ops_rest_device!(float(data ; device) => float_from_data)
22    }
23    fn float_random(shape: Shape, distribution: Distribution, device: &MultiDevice) -> MultiFloatTensor {
24        ops_rest_device!(float(shape, distribution ; device) => float_random)
25    }
26    fn float_repeat_dim(tensor: MultiFloatTensor, dim: usize, times: usize) -> MultiFloatTensor {
27        ops_tensor_rest!(float(tensor, dim, times) => float_repeat_dim)
28    }
29    fn float_zeros(shape: Shape, device: &MultiDevice) -> MultiFloatTensor {
30        ops_rest_device!(float(shape ; device) => float_zeros)
31    }
32    fn float_ones(shape: Shape, device: &MultiDevice) -> MultiFloatTensor {
33        ops_rest_device!(float(shape ; device) => float_ones)
34    }
35    async fn float_into_data(tensor: MultiFloatTensor) -> TensorData {
36        // ops_tensor!(float(tensor) => float_into_data)
37        match tensor {
38            #[cfg(feature = "burn-candle")]
39            MultiFloatTensor::Candle(t) => <CandleBackend as FloatTensorOps<CandleBackend>>::float_into_data(t).await,
40            #[cfg(feature = "burn-ndarray")]
41            MultiFloatTensor::NdArray(t) => <NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_into_data(t).await,
42            #[cfg(feature = "burn-wgpu")]
43            MultiFloatTensor::Wgpu(t) => <WgpuBackend as FloatTensorOps<WgpuBackend>>::float_into_data(t).await,
44        }
45    }
46    fn float_device(tensor: &MultiFloatTensor) -> MultiDevice {
47        match tensor {
48            #[cfg(feature = "burn-candle")]
49            MultiFloatTensor::Candle(t) => MultiDevice::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_device(t)),
50            #[cfg(feature = "burn-ndarray")]
51            MultiFloatTensor::NdArray(t) => MultiDevice::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_device(t)),
52            #[cfg(feature = "burn-wgpu")]
53            MultiFloatTensor::Wgpu(t) => MultiDevice::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_device(t)),
54        }
55    }
56    fn float_to_device(tensor: MultiFloatTensor, device: &MultiDevice) -> MultiFloatTensor {
57        match tensor {
58            //current tensor is on candle
59            #[cfg(feature = "burn-candle")]
60            MultiFloatTensor::Candle(ref t) => match device {
61                MultiDevice::Candle(_) => {
62                    // No need to move anything
63                    tensor.clone()
64                }
65                #[cfg(feature = "burn-wgpu")]
66                MultiDevice::Wgpu(d) => {
67                    //need to move ndarray to wgpu
68                    let data = burn::tensor::try_read_sync(<CandleBackend as FloatTensorOps<CandleBackend>>::float_into_data(t.clone())).expect(
69                        "Failed to read tensor data synchronously.
70        This can happen on platforms that don't support blocking futures like WASM.
71        If possible, try using into_data_async instead.",
72                    );
73                    MultiFloatTensor::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_from_data(data, d))
74                }
75                #[cfg(feature = "burn-ndarray")]
76                MultiDevice::NdArray(d) => {
77                    //need to move candle to ndarray
78                    let data = burn::tensor::try_read_sync(<CandleBackend as FloatTensorOps<CandleBackend>>::float_into_data(t.clone())).expect(
79                        "Failed to read tensor data synchronously.
80        This can happen on platforms that don't support blocking futures like WASM.
81        If possible, try using into_data_async instead.",
82                    );
83                    MultiFloatTensor::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_from_data(data, d))
84                }
85            },
86
87            //current tensor is on ndarray
88            #[cfg(feature = "burn-ndarray")]
89            MultiFloatTensor::NdArray(ref t) => match device {
90                MultiDevice::NdArray(_) => {
91                    // No need to move anything
92                    tensor.clone()
93                }
94                #[cfg(feature = "burn-wgpu")]
95                MultiDevice::Wgpu(d) => {
96                    //need to move ndarray to wgpu
97                    let data = burn::tensor::try_read_sync(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_into_data(t.clone())).expect(
98                        "Failed to read tensor data synchronously.
99        This can happen on platforms that don't support blocking futures like WASM.
100        If possible, try using into_data_async instead.",
101                    );
102                    MultiFloatTensor::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_from_data(data, d))
103                }
104                #[cfg(feature = "burn-candle")]
105                MultiDevice::Candle(d) => {
106                    //need to move ndarray to candle
107                    let data = burn::tensor::try_read_sync(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_into_data(t.clone())).expect(
108                        "Failed to read tensor data synchronously.
109        This can happen on platforms that don't support blocking futures like WASM.
110        If possible, try using into_data_async instead.",
111                    );
112                    MultiFloatTensor::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_from_data(data, d))
113                }
114            },
115
116            //current tensor is on wgpu
117            #[cfg(feature = "burn-wgpu")]
118            MultiFloatTensor::Wgpu(ref t) => match device {
119                MultiDevice::Wgpu(_) => {
120                    // No need to move anything
121                    tensor.clone()
122                }
123                #[cfg(feature = "burn-ndarray")]
124                MultiDevice::NdArray(d) => {
125                    //need to move wgpu to ndarray
126                    let data = burn::tensor::try_read_sync(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_into_data(t.clone())).expect(
127                        "Failed to read tensor data synchronously.
128        This can happen on platforms that don't support blocking futures like WASM.
129        If possible, try using into_data_async instead.",
130                    );
131                    MultiFloatTensor::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_from_data(data, d))
132                }
133                #[cfg(feature = "burn-candle")]
134                MultiDevice::Candle(d) => {
135                    //need to move wgpu to candle
136                    let data = burn::tensor::try_read_sync(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_into_data(t.clone())).expect(
137                        "Failed to read tensor data synchronously.
138        This can happen on platforms that don't support blocking futures like WASM.
139        If possible, try using into_data_async instead.",
140                    );
141                    MultiFloatTensor::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_from_data(data, d))
142                }
143            },
144        }
145    }
146    fn float_empty(shape: Shape, device: &MultiDevice) -> MultiFloatTensor {
147        ops_rest_device!(float(shape ; device) => float_empty)
148    }
149    fn float_add(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
150        ops_tensor_tensor!(float(lhs, rhs) => float_add)
151    }
152    fn float_add_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
153        ops_tensor_scalar!(float(lhs, rhs) => float_add_scalar)
154    }
155    fn float_sub(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
156        ops_tensor_tensor!(float(lhs, rhs) => float_sub)
157    }
158    fn float_sub_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
159        ops_tensor_scalar!(float(lhs, rhs) => float_sub_scalar)
160    }
161    fn float_mul(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
162        ops_tensor_tensor!(float(lhs, rhs) => float_mul)
163    }
164    fn float_mul_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
165        ops_tensor_scalar!(float(lhs, rhs) => float_mul_scalar)
166    }
167    fn float_div(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
168        ops_tensor_tensor!(float(lhs, rhs) => float_div)
169    }
170    fn float_div_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
171        ops_tensor_scalar!(float(lhs, rhs) => float_div_scalar)
172    }
173    fn float_remainder(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
174        ops_tensor_tensor!(float(lhs, rhs) => float_remainder)
175    }
176    fn float_remainder_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
177        ops_tensor_scalar!(float(lhs, rhs) => float_remainder_scalar)
178    }
179    fn float_matmul(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
180        ops_tensor_tensor!(float(lhs, rhs) => float_matmul)
181    }
182    fn float_neg(tensor: MultiFloatTensor) -> MultiFloatTensor {
183        ops_tensor!(float(tensor) => float_neg)
184    }
185    fn float_recip(tensor: MultiFloatTensor) -> MultiFloatTensor {
186        ops_tensor!(float(tensor) => float_recip)
187    }
188    fn float_swap_dims(tensor: MultiFloatTensor, dim1: usize, dim2: usize) -> MultiFloatTensor {
189        ops_tensor_rest!(float(tensor, dim1, dim2) => float_swap_dims)
190    }
191    fn float_reshape(tensor: MultiFloatTensor, shape: Shape) -> MultiFloatTensor {
192        ops_tensor_rest!(float(tensor, shape) => float_reshape)
193    }
194    fn float_gather(dim: usize, tensor: MultiFloatTensor, indices: MultiIntTensor) -> MultiFloatTensor {
195        ops_dim_tensor_indices!(float(dim, tensor, indices) => float_gather)
196    }
197    fn float_scatter(dim: usize, tensor: MultiFloatTensor, indices: MultiIntTensor, value: MultiFloatTensor) -> MultiFloatTensor {
198        ops_dim_tensor_indices_values!(float(dim, tensor, indices, value) => float_scatter)
199    }
200    fn float_select(tensor: MultiFloatTensor, dim: usize, indices: MultiIntTensor) -> MultiFloatTensor {
201        ops_tensor_dim_indices!(float(tensor, dim, indices) => float_select)
202    }
203    fn float_select_assign(tensor: MultiFloatTensor, dim: usize, indices: MultiIntTensor, value: MultiFloatTensor) -> MultiFloatTensor {
204        ops_tensor_dim_indices_values!(float(tensor, dim, indices, value) => float_select_assign)
205    }
206    fn float_slice(tensor: MultiFloatTensor, ranges: &[Range<usize>]) -> MultiFloatTensor {
207        ops_tensor_rest!(float(tensor, ranges) => float_slice)
208    }
209    fn float_slice_assign(tensor: MultiFloatTensor, ranges: &[Range<usize>], value: MultiFloatTensor) -> MultiFloatTensor {
210        ops_tensor_other_values!(float(tensor, ranges, value) => float_slice_assign)
211    }
212    fn float_mask_where(tensor: MultiFloatTensor, mask: MultiBoolTensor, value: MultiFloatTensor) -> MultiFloatTensor {
213        unimplemented!()
214    }
215    fn float_mask_fill(tensor: MultiFloatTensor, mask: MultiBoolTensor, value: f32) -> MultiFloatTensor {
216        unimplemented!()
217    }
218    fn float_equal(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
219        unimplemented!()
220    }
221    fn float_equal_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
222        unimplemented!()
223    }
224    fn float_greater(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
225        unimplemented!()
226    }
227    fn float_greater_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
228        ops_tensor_rest_ret_bool!(float(lhs, rhs) => float_greater_elem)
229    }
230    fn float_greater_equal(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
231        unimplemented!()
232    }
233    fn float_greater_equal_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
234        unimplemented!()
235    }
236    fn float_lower(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
237        unimplemented!()
238    }
239    fn float_lower_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
240        ops_tensor_rest_ret_bool!(float(lhs, rhs) => float_lower_elem)
241    }
242
243    fn float_lower_equal(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
244        unimplemented!()
245    }
246    fn float_lower_equal_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
247        unimplemented!()
248    }
249    fn float_mean(tensor: MultiFloatTensor) -> MultiFloatTensor {
250        ops_tensor!(float(tensor) => float_mean)
251    }
252    fn float_sum(tensor: MultiFloatTensor) -> MultiFloatTensor {
253        ops_tensor!(float(tensor) => float_sum)
254    }
255    fn float_sum_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
256        ops_tensor_rest!(float(tensor, dim) => float_sum_dim)
257    }
258    fn float_mean_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
259        ops_tensor_rest!(float(tensor, dim) => float_mean_dim)
260    }
261    fn float_prod(tensor: MultiFloatTensor) -> MultiFloatTensor {
262        ops_tensor!(float(tensor) => float_prod)
263    }
264    fn float_prod_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
265        ops_tensor_rest!(float(tensor, dim) => float_prod_dim)
266    }
267    fn float_argmax(tensor: MultiFloatTensor, dim: usize) -> MultiIntTensor {
268        unimplemented!()
269    }
270    fn float_argmin(tensor: MultiFloatTensor, dim: usize) -> MultiIntTensor {
271        unimplemented!()
272    }
273    fn float_max_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
274        ops_tensor_rest!(float(tensor, dim) => float_max_dim)
275    }
276    fn float_max_dim_with_indices(tensor: MultiFloatTensor, dim: usize) -> (MultiFloatTensor, MultiIntTensor) {
277        unimplemented!()
278    }
279    fn float_min_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
280        ops_tensor_rest!(float(tensor, dim) => float_min_dim)
281    }
282    fn float_min_dim_with_indices(tensor: MultiFloatTensor, dim: usize) -> (MultiFloatTensor, MultiIntTensor) {
283        unimplemented!()
284    }
285    fn float_exp(tensor: MultiFloatTensor) -> MultiFloatTensor {
286        ops_tensor!(float(tensor) => float_exp)
287    }
288    fn float_log(tensor: MultiFloatTensor) -> MultiFloatTensor {
289        ops_tensor!(float(tensor) => float_log)
290    }
291    fn float_log1p(tensor: MultiFloatTensor) -> MultiFloatTensor {
292        ops_tensor!(float(tensor) => float_log1p)
293    }
294    fn float_powf_scalar(tensor: MultiFloatTensor, value: f32) -> MultiFloatTensor {
295        ops_tensor_rest!(float(tensor, value) => float_powf_scalar)
296    }
297    fn float_sqrt(tensor: MultiFloatTensor) -> MultiFloatTensor {
298        ops_tensor!(float(tensor) => float_sqrt)
299    }
300    fn float_abs(tensor: MultiFloatTensor) -> MultiFloatTensor {
301        ops_tensor!(float(tensor) => float_abs)
302    }
303    fn float_cos(tensor: MultiFloatTensor) -> MultiFloatTensor {
304        ops_tensor!(float(tensor) => float_cos)
305    }
306    fn float_sin(tensor: MultiFloatTensor) -> MultiFloatTensor {
307        ops_tensor!(float(tensor) => float_sin)
308    }
309    fn float_tanh(tensor: MultiFloatTensor) -> MultiFloatTensor {
310        ops_tensor!(float(tensor) => float_tanh)
311    }
312    fn float_round(tensor: MultiFloatTensor) -> MultiFloatTensor {
313        ops_tensor!(float(tensor) => float_round)
314    }
315    fn float_floor(tensor: MultiFloatTensor) -> MultiFloatTensor {
316        ops_tensor!(float(tensor) => float_floor)
317    }
318    fn float_ceil(tensor: MultiFloatTensor) -> MultiFloatTensor {
319        ops_tensor!(float(tensor) => float_ceil)
320    }
321    fn float_erf(tensor: MultiFloatTensor) -> MultiFloatTensor {
322        ops_tensor!(float(tensor) => float_erf)
323    }
324    #[allow(clippy::match_wildcard_for_single_variants)]
325    fn float_cat(tensors: Vec<MultiFloatTensor>, dim: usize) -> MultiFloatTensor {
326        assert!(!tensors.is_empty(), "Cannot concatenate an empty list of tensors");
327        match &tensors[0] {
328            #[cfg(feature = "burn-candle")]
329            MultiFloatTensor::Candle(_) => {
330                use crate::backend::CandleBackend;
331                let inner: Vec<_> = tensors
332                    .into_iter()
333                    .map(|t| match t {
334                        MultiFloatTensor::Candle(inner) => inner,
335                        _ => panic!("Mismatched tensor backends in float_cat: expected Candle"),
336                    })
337                    .collect();
338                MultiFloatTensor::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_cat(inner, dim))
339            }
340
341            #[cfg(feature = "burn-ndarray")]
342            MultiFloatTensor::NdArray(_) => {
343                use crate::backend::NdArrayBackend;
344                let inner: Vec<_> = tensors
345                    .into_iter()
346                    .map(|t| match t {
347                        MultiFloatTensor::NdArray(inner) => inner,
348                        _ => panic!("Mismatched tensor backends in float_cat: expected NdArray"),
349                    })
350                    .collect();
351                MultiFloatTensor::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_cat(inner, dim))
352            }
353
354            #[cfg(feature = "burn-wgpu")]
355            MultiFloatTensor::Wgpu(_) => {
356                use crate::backend::WgpuBackend;
357                let inner: Vec<_> = tensors
358                    .into_iter()
359                    .map(|t| match t {
360                        MultiFloatTensor::Wgpu(inner) => inner,
361                        _ => panic!("Mismatched tensor backends in float_cat: expected Wgpu"),
362                    })
363                    .collect();
364                MultiFloatTensor::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_cat(inner, dim))
365            }
366        }
367    }
368    fn float_clamp_min(tensor: MultiFloatTensor, min: f32) -> MultiFloatTensor {
369        ops_tensor_rest!(float(tensor, min) => float_clamp_min)
370    }
371    fn float_clamp_max(tensor: MultiFloatTensor, max: f32) -> MultiFloatTensor {
372        ops_tensor_rest!(float(tensor, max) => float_clamp_max)
373    }
374    fn float_clamp(tensor: MultiFloatTensor, min: f32, max: f32) -> MultiFloatTensor {
375        ops_tensor_rest!(float(tensor, min, max) => float_clamp)
376    }
377    fn float_into_int(tensor: MultiFloatTensor) -> MultiIntTensor {
378        unimplemented!()
379    }
380    fn float_powf(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
381        ops_tensor_tensor!(float(lhs, rhs) => float_powf)
382    }
383    fn float_permute(tensor: MultiFloatTensor, axes: &[usize]) -> MultiFloatTensor {
384        ops_tensor_rest!(float(tensor, axes) => float_permute)
385    }
386    fn float_flip(tensor: MultiFloatTensor, axes: &[usize]) -> MultiFloatTensor {
387        ops_tensor_rest!(float(tensor, axes) => float_flip)
388    }
389    fn float_sign(tensor: MultiFloatTensor) -> MultiFloatTensor {
390        ops_tensor!(float(tensor) => float_sign)
391    }
392    fn float_expand(tensor: MultiFloatTensor, shape: Shape) -> MultiFloatTensor {
393        ops_tensor_rest!(float(tensor, shape) => float_expand)
394    }
395    fn float_sort(tensor: MultiFloatTensor, dim: usize, descending: bool) -> MultiFloatTensor {
396        ops_tensor_rest!(float(tensor, dim, descending) => float_sort)
397    }
398    fn float_sort_with_indices(tensor: MultiFloatTensor, dim: usize, descending: bool) -> (MultiFloatTensor, MultiIntTensor) {
399        unimplemented!()
400    }
401    fn float_argsort(tensor: MultiFloatTensor, dim: usize, descending: bool) -> MultiIntTensor {
402        unimplemented!()
403    }
404    fn float_cast(tensor: MultiFloatTensor, dtype: FloatDType) -> MultiFloatTensor {
405        ops_tensor_rest!(float(tensor, dtype) => float_cast)
406    }
407}