1#![allow(unreachable_patterns)]
2
3use std::ops::Range;
4
5use burn::tensor::{ops::FloatTensorOps, Distribution, FloatDType, Shape, TensorData};
6
7#[cfg(feature = "burn-candle")]
8use crate::backend::CandleBackend;
9#[cfg(feature = "burn-ndarray")]
10use crate::backend::NdArrayBackend;
11#[cfg(feature = "burn-wgpu")]
12use crate::backend::WgpuBackend;
13use crate::{
14 backend::{MultiBackend, MultiDevice},
15 tensor::{MultiBoolTensor, MultiFloatTensor, MultiIntTensor},
16};
17
18#[allow(unused_variables)]
19impl FloatTensorOps<Self> for MultiBackend {
20 fn float_from_data(data: TensorData, device: &MultiDevice) -> MultiFloatTensor {
21 ops_rest_device!(float(data ; device) => float_from_data)
22 }
23 fn float_random(shape: Shape, distribution: Distribution, device: &MultiDevice) -> MultiFloatTensor {
24 ops_rest_device!(float(shape, distribution ; device) => float_random)
25 }
26 fn float_repeat_dim(tensor: MultiFloatTensor, dim: usize, times: usize) -> MultiFloatTensor {
27 ops_tensor_rest!(float(tensor, dim, times) => float_repeat_dim)
28 }
29 fn float_zeros(shape: Shape, device: &MultiDevice) -> MultiFloatTensor {
30 ops_rest_device!(float(shape ; device) => float_zeros)
31 }
32 fn float_ones(shape: Shape, device: &MultiDevice) -> MultiFloatTensor {
33 ops_rest_device!(float(shape ; device) => float_ones)
34 }
35 async fn float_into_data(tensor: MultiFloatTensor) -> TensorData {
36 match tensor {
38 #[cfg(feature = "burn-candle")]
39 MultiFloatTensor::Candle(t) => <CandleBackend as FloatTensorOps<CandleBackend>>::float_into_data(t).await,
40 #[cfg(feature = "burn-ndarray")]
41 MultiFloatTensor::NdArray(t) => <NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_into_data(t).await,
42 #[cfg(feature = "burn-wgpu")]
43 MultiFloatTensor::Wgpu(t) => <WgpuBackend as FloatTensorOps<WgpuBackend>>::float_into_data(t).await,
44 }
45 }
46 fn float_device(tensor: &MultiFloatTensor) -> MultiDevice {
47 match tensor {
48 #[cfg(feature = "burn-candle")]
49 MultiFloatTensor::Candle(t) => MultiDevice::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_device(t)),
50 #[cfg(feature = "burn-ndarray")]
51 MultiFloatTensor::NdArray(t) => MultiDevice::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_device(t)),
52 #[cfg(feature = "burn-wgpu")]
53 MultiFloatTensor::Wgpu(t) => MultiDevice::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_device(t)),
54 }
55 }
56 fn float_to_device(tensor: MultiFloatTensor, device: &MultiDevice) -> MultiFloatTensor {
57 match tensor {
58 #[cfg(feature = "burn-candle")]
60 MultiFloatTensor::Candle(ref t) => match device {
61 MultiDevice::Candle(_) => {
62 tensor.clone()
64 }
65 #[cfg(feature = "burn-wgpu")]
66 MultiDevice::Wgpu(d) => {
67 let data = burn::tensor::try_read_sync(<CandleBackend as FloatTensorOps<CandleBackend>>::float_into_data(t.clone())).expect(
69 "Failed to read tensor data synchronously.
70 This can happen on platforms that don't support blocking futures like WASM.
71 If possible, try using into_data_async instead.",
72 );
73 MultiFloatTensor::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_from_data(data, d))
74 }
75 #[cfg(feature = "burn-ndarray")]
76 MultiDevice::NdArray(d) => {
77 let data = burn::tensor::try_read_sync(<CandleBackend as FloatTensorOps<CandleBackend>>::float_into_data(t.clone())).expect(
79 "Failed to read tensor data synchronously.
80 This can happen on platforms that don't support blocking futures like WASM.
81 If possible, try using into_data_async instead.",
82 );
83 MultiFloatTensor::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_from_data(data, d))
84 }
85 },
86
87 #[cfg(feature = "burn-ndarray")]
89 MultiFloatTensor::NdArray(ref t) => match device {
90 MultiDevice::NdArray(_) => {
91 tensor.clone()
93 }
94 #[cfg(feature = "burn-wgpu")]
95 MultiDevice::Wgpu(d) => {
96 let data = burn::tensor::try_read_sync(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_into_data(t.clone())).expect(
98 "Failed to read tensor data synchronously.
99 This can happen on platforms that don't support blocking futures like WASM.
100 If possible, try using into_data_async instead.",
101 );
102 MultiFloatTensor::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_from_data(data, d))
103 }
104 #[cfg(feature = "burn-candle")]
105 MultiDevice::Candle(d) => {
106 let data = burn::tensor::try_read_sync(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_into_data(t.clone())).expect(
108 "Failed to read tensor data synchronously.
109 This can happen on platforms that don't support blocking futures like WASM.
110 If possible, try using into_data_async instead.",
111 );
112 MultiFloatTensor::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_from_data(data, d))
113 }
114 },
115
116 #[cfg(feature = "burn-wgpu")]
118 MultiFloatTensor::Wgpu(ref t) => match device {
119 MultiDevice::Wgpu(_) => {
120 tensor.clone()
122 }
123 #[cfg(feature = "burn-ndarray")]
124 MultiDevice::NdArray(d) => {
125 let data = burn::tensor::try_read_sync(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_into_data(t.clone())).expect(
127 "Failed to read tensor data synchronously.
128 This can happen on platforms that don't support blocking futures like WASM.
129 If possible, try using into_data_async instead.",
130 );
131 MultiFloatTensor::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_from_data(data, d))
132 }
133 #[cfg(feature = "burn-candle")]
134 MultiDevice::Candle(d) => {
135 let data = burn::tensor::try_read_sync(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_into_data(t.clone())).expect(
137 "Failed to read tensor data synchronously.
138 This can happen on platforms that don't support blocking futures like WASM.
139 If possible, try using into_data_async instead.",
140 );
141 MultiFloatTensor::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_from_data(data, d))
142 }
143 },
144 }
145 }
146 fn float_empty(shape: Shape, device: &MultiDevice) -> MultiFloatTensor {
147 ops_rest_device!(float(shape ; device) => float_empty)
148 }
149 fn float_add(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
150 ops_tensor_tensor!(float(lhs, rhs) => float_add)
151 }
152 fn float_add_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
153 ops_tensor_scalar!(float(lhs, rhs) => float_add_scalar)
154 }
155 fn float_sub(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
156 ops_tensor_tensor!(float(lhs, rhs) => float_sub)
157 }
158 fn float_sub_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
159 ops_tensor_scalar!(float(lhs, rhs) => float_sub_scalar)
160 }
161 fn float_mul(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
162 ops_tensor_tensor!(float(lhs, rhs) => float_mul)
163 }
164 fn float_mul_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
165 ops_tensor_scalar!(float(lhs, rhs) => float_mul_scalar)
166 }
167 fn float_div(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
168 ops_tensor_tensor!(float(lhs, rhs) => float_div)
169 }
170 fn float_div_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
171 ops_tensor_scalar!(float(lhs, rhs) => float_div_scalar)
172 }
173 fn float_remainder(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
174 ops_tensor_tensor!(float(lhs, rhs) => float_remainder)
175 }
176 fn float_remainder_scalar(lhs: MultiFloatTensor, rhs: f32) -> MultiFloatTensor {
177 ops_tensor_scalar!(float(lhs, rhs) => float_remainder_scalar)
178 }
179 fn float_matmul(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
180 ops_tensor_tensor!(float(lhs, rhs) => float_matmul)
181 }
182 fn float_neg(tensor: MultiFloatTensor) -> MultiFloatTensor {
183 ops_tensor!(float(tensor) => float_neg)
184 }
185 fn float_recip(tensor: MultiFloatTensor) -> MultiFloatTensor {
186 ops_tensor!(float(tensor) => float_recip)
187 }
188 fn float_swap_dims(tensor: MultiFloatTensor, dim1: usize, dim2: usize) -> MultiFloatTensor {
189 ops_tensor_rest!(float(tensor, dim1, dim2) => float_swap_dims)
190 }
191 fn float_reshape(tensor: MultiFloatTensor, shape: Shape) -> MultiFloatTensor {
192 ops_tensor_rest!(float(tensor, shape) => float_reshape)
193 }
194 fn float_gather(dim: usize, tensor: MultiFloatTensor, indices: MultiIntTensor) -> MultiFloatTensor {
195 ops_dim_tensor_indices!(float(dim, tensor, indices) => float_gather)
196 }
197 fn float_scatter(dim: usize, tensor: MultiFloatTensor, indices: MultiIntTensor, value: MultiFloatTensor) -> MultiFloatTensor {
198 ops_dim_tensor_indices_values!(float(dim, tensor, indices, value) => float_scatter)
199 }
200 fn float_select(tensor: MultiFloatTensor, dim: usize, indices: MultiIntTensor) -> MultiFloatTensor {
201 ops_tensor_dim_indices!(float(tensor, dim, indices) => float_select)
202 }
203 fn float_select_assign(tensor: MultiFloatTensor, dim: usize, indices: MultiIntTensor, value: MultiFloatTensor) -> MultiFloatTensor {
204 ops_tensor_dim_indices_values!(float(tensor, dim, indices, value) => float_select_assign)
205 }
206 fn float_slice(tensor: MultiFloatTensor, ranges: &[Range<usize>]) -> MultiFloatTensor {
207 ops_tensor_rest!(float(tensor, ranges) => float_slice)
208 }
209 fn float_slice_assign(tensor: MultiFloatTensor, ranges: &[Range<usize>], value: MultiFloatTensor) -> MultiFloatTensor {
210 ops_tensor_other_values!(float(tensor, ranges, value) => float_slice_assign)
211 }
212 fn float_mask_where(tensor: MultiFloatTensor, mask: MultiBoolTensor, value: MultiFloatTensor) -> MultiFloatTensor {
213 unimplemented!()
214 }
215 fn float_mask_fill(tensor: MultiFloatTensor, mask: MultiBoolTensor, value: f32) -> MultiFloatTensor {
216 unimplemented!()
217 }
218 fn float_equal(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
219 unimplemented!()
220 }
221 fn float_equal_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
222 unimplemented!()
223 }
224 fn float_greater(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
225 unimplemented!()
226 }
227 fn float_greater_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
228 ops_tensor_rest_ret_bool!(float(lhs, rhs) => float_greater_elem)
229 }
230 fn float_greater_equal(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
231 unimplemented!()
232 }
233 fn float_greater_equal_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
234 unimplemented!()
235 }
236 fn float_lower(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
237 unimplemented!()
238 }
239 fn float_lower_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
240 ops_tensor_rest_ret_bool!(float(lhs, rhs) => float_lower_elem)
241 }
242
243 fn float_lower_equal(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiBoolTensor {
244 unimplemented!()
245 }
246 fn float_lower_equal_elem(lhs: MultiFloatTensor, rhs: f32) -> MultiBoolTensor {
247 unimplemented!()
248 }
249 fn float_mean(tensor: MultiFloatTensor) -> MultiFloatTensor {
250 ops_tensor!(float(tensor) => float_mean)
251 }
252 fn float_sum(tensor: MultiFloatTensor) -> MultiFloatTensor {
253 ops_tensor!(float(tensor) => float_sum)
254 }
255 fn float_sum_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
256 ops_tensor_rest!(float(tensor, dim) => float_sum_dim)
257 }
258 fn float_mean_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
259 ops_tensor_rest!(float(tensor, dim) => float_mean_dim)
260 }
261 fn float_prod(tensor: MultiFloatTensor) -> MultiFloatTensor {
262 ops_tensor!(float(tensor) => float_prod)
263 }
264 fn float_prod_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
265 ops_tensor_rest!(float(tensor, dim) => float_prod_dim)
266 }
267 fn float_argmax(tensor: MultiFloatTensor, dim: usize) -> MultiIntTensor {
268 unimplemented!()
269 }
270 fn float_argmin(tensor: MultiFloatTensor, dim: usize) -> MultiIntTensor {
271 unimplemented!()
272 }
273 fn float_max_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
274 ops_tensor_rest!(float(tensor, dim) => float_max_dim)
275 }
276 fn float_max_dim_with_indices(tensor: MultiFloatTensor, dim: usize) -> (MultiFloatTensor, MultiIntTensor) {
277 unimplemented!()
278 }
279 fn float_min_dim(tensor: MultiFloatTensor, dim: usize) -> MultiFloatTensor {
280 ops_tensor_rest!(float(tensor, dim) => float_min_dim)
281 }
282 fn float_min_dim_with_indices(tensor: MultiFloatTensor, dim: usize) -> (MultiFloatTensor, MultiIntTensor) {
283 unimplemented!()
284 }
285 fn float_exp(tensor: MultiFloatTensor) -> MultiFloatTensor {
286 ops_tensor!(float(tensor) => float_exp)
287 }
288 fn float_log(tensor: MultiFloatTensor) -> MultiFloatTensor {
289 ops_tensor!(float(tensor) => float_log)
290 }
291 fn float_log1p(tensor: MultiFloatTensor) -> MultiFloatTensor {
292 ops_tensor!(float(tensor) => float_log1p)
293 }
294 fn float_powf_scalar(tensor: MultiFloatTensor, value: f32) -> MultiFloatTensor {
295 ops_tensor_rest!(float(tensor, value) => float_powf_scalar)
296 }
297 fn float_sqrt(tensor: MultiFloatTensor) -> MultiFloatTensor {
298 ops_tensor!(float(tensor) => float_sqrt)
299 }
300 fn float_abs(tensor: MultiFloatTensor) -> MultiFloatTensor {
301 ops_tensor!(float(tensor) => float_abs)
302 }
303 fn float_cos(tensor: MultiFloatTensor) -> MultiFloatTensor {
304 ops_tensor!(float(tensor) => float_cos)
305 }
306 fn float_sin(tensor: MultiFloatTensor) -> MultiFloatTensor {
307 ops_tensor!(float(tensor) => float_sin)
308 }
309 fn float_tanh(tensor: MultiFloatTensor) -> MultiFloatTensor {
310 ops_tensor!(float(tensor) => float_tanh)
311 }
312 fn float_round(tensor: MultiFloatTensor) -> MultiFloatTensor {
313 ops_tensor!(float(tensor) => float_round)
314 }
315 fn float_floor(tensor: MultiFloatTensor) -> MultiFloatTensor {
316 ops_tensor!(float(tensor) => float_floor)
317 }
318 fn float_ceil(tensor: MultiFloatTensor) -> MultiFloatTensor {
319 ops_tensor!(float(tensor) => float_ceil)
320 }
321 fn float_erf(tensor: MultiFloatTensor) -> MultiFloatTensor {
322 ops_tensor!(float(tensor) => float_erf)
323 }
324 #[allow(clippy::match_wildcard_for_single_variants)]
325 fn float_cat(tensors: Vec<MultiFloatTensor>, dim: usize) -> MultiFloatTensor {
326 assert!(!tensors.is_empty(), "Cannot concatenate an empty list of tensors");
327 match &tensors[0] {
328 #[cfg(feature = "burn-candle")]
329 MultiFloatTensor::Candle(_) => {
330 use crate::backend::CandleBackend;
331 let inner: Vec<_> = tensors
332 .into_iter()
333 .map(|t| match t {
334 MultiFloatTensor::Candle(inner) => inner,
335 _ => panic!("Mismatched tensor backends in float_cat: expected Candle"),
336 })
337 .collect();
338 MultiFloatTensor::Candle(<CandleBackend as FloatTensorOps<CandleBackend>>::float_cat(inner, dim))
339 }
340
341 #[cfg(feature = "burn-ndarray")]
342 MultiFloatTensor::NdArray(_) => {
343 use crate::backend::NdArrayBackend;
344 let inner: Vec<_> = tensors
345 .into_iter()
346 .map(|t| match t {
347 MultiFloatTensor::NdArray(inner) => inner,
348 _ => panic!("Mismatched tensor backends in float_cat: expected NdArray"),
349 })
350 .collect();
351 MultiFloatTensor::NdArray(<NdArrayBackend as FloatTensorOps<NdArrayBackend>>::float_cat(inner, dim))
352 }
353
354 #[cfg(feature = "burn-wgpu")]
355 MultiFloatTensor::Wgpu(_) => {
356 use crate::backend::WgpuBackend;
357 let inner: Vec<_> = tensors
358 .into_iter()
359 .map(|t| match t {
360 MultiFloatTensor::Wgpu(inner) => inner,
361 _ => panic!("Mismatched tensor backends in float_cat: expected Wgpu"),
362 })
363 .collect();
364 MultiFloatTensor::Wgpu(<WgpuBackend as FloatTensorOps<WgpuBackend>>::float_cat(inner, dim))
365 }
366 }
367 }
368 fn float_clamp_min(tensor: MultiFloatTensor, min: f32) -> MultiFloatTensor {
369 ops_tensor_rest!(float(tensor, min) => float_clamp_min)
370 }
371 fn float_clamp_max(tensor: MultiFloatTensor, max: f32) -> MultiFloatTensor {
372 ops_tensor_rest!(float(tensor, max) => float_clamp_max)
373 }
374 fn float_clamp(tensor: MultiFloatTensor, min: f32, max: f32) -> MultiFloatTensor {
375 ops_tensor_rest!(float(tensor, min, max) => float_clamp)
376 }
377 fn float_into_int(tensor: MultiFloatTensor) -> MultiIntTensor {
378 unimplemented!()
379 }
380 fn float_powf(lhs: MultiFloatTensor, rhs: MultiFloatTensor) -> MultiFloatTensor {
381 ops_tensor_tensor!(float(lhs, rhs) => float_powf)
382 }
383 fn float_permute(tensor: MultiFloatTensor, axes: &[usize]) -> MultiFloatTensor {
384 ops_tensor_rest!(float(tensor, axes) => float_permute)
385 }
386 fn float_flip(tensor: MultiFloatTensor, axes: &[usize]) -> MultiFloatTensor {
387 ops_tensor_rest!(float(tensor, axes) => float_flip)
388 }
389 fn float_sign(tensor: MultiFloatTensor) -> MultiFloatTensor {
390 ops_tensor!(float(tensor) => float_sign)
391 }
392 fn float_expand(tensor: MultiFloatTensor, shape: Shape) -> MultiFloatTensor {
393 ops_tensor_rest!(float(tensor, shape) => float_expand)
394 }
395 fn float_sort(tensor: MultiFloatTensor, dim: usize, descending: bool) -> MultiFloatTensor {
396 ops_tensor_rest!(float(tensor, dim, descending) => float_sort)
397 }
398 fn float_sort_with_indices(tensor: MultiFloatTensor, dim: usize, descending: bool) -> (MultiFloatTensor, MultiIntTensor) {
399 unimplemented!()
400 }
401 fn float_argsort(tensor: MultiFloatTensor, dim: usize, descending: bool) -> MultiIntTensor {
402 unimplemented!()
403 }
404 fn float_cast(tensor: MultiFloatTensor, dtype: FloatDType) -> MultiFloatTensor {
405 ops_tensor_rest!(float(tensor, dtype) => float_cast)
406 }
407}