burn_tensor/tensor/ops/
tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_bilinear;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor};
5use crate::ops::InterpolateMode;
6use crate::{Distribution, ElementConversion, Float, TensorData, backend::Backend, tensor::Shape};
7use crate::{FloatDType, TensorMetadata, TensorPrimitive};
8use crate::{argsort, sort, sort_with_indices};
9use alloc::vec::Vec;
10
11/// Operations on float tensors.
12pub trait FloatTensorOps<B: Backend> {
13    /// Creates a new tensor from the data structure.
14    ///
15    /// # Arguments
16    ///
17    /// * `data` - The data structure.
18    /// * `device` - The device to create the tensor on.
19    ///
20    /// # Returns
21    ///
22    /// The tensor with the given data.
23    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
24
25    /// Creates a new tensor with random values.
26    ///
27    /// # Arguments
28    ///
29    /// * `shape` - The shape of the tensor.
30    /// * `distribution` - The distribution to sample from.
31    /// * `device` - The device to create the tensor on.
32    ///
33    /// # Returns
34    ///
35    /// The tensor with the given shape and random values.
36    fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
37    -> FloatTensor<B>;
38
39    /// Creates a new tensor with zeros.
40    ///
41    /// # Arguments
42    ///
43    /// * `shape` - The shape of the tensor.
44    /// * `device` - The device to create the tensor on.
45    /// * `dtype` - The target data type.
46    ///
47    /// # Returns
48    ///
49    /// The tensor with the given shape and zeros.
50    fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
51        Self::float_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
52    }
53
54    /// Creates a new tensor with ones.
55    ///
56    /// # Arguments
57    ///
58    /// * `shape` - The shape of the tensor.
59    /// * `device` - The device to create the tensor on.
60    /// * `dtype` - The target data type.
61    ///
62    /// # Returns
63    ///
64    /// The tensor with the given shape and ones.
65    fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
66        Self::float_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
67    }
68
69    /// Creates a tensor filled with given value.
70    ///
71    /// # Arguments
72    ///
73    /// * `shape` - The shape of the tensor.
74    /// * `fill_value` - The value with which to fill the tensor.
75    /// * `device` - The device to create the tensor on.
76    /// * `dtype` - The target data type.
77    ///
78    /// # Returns
79    ///
80    /// The tensor filled with given value
81    fn float_full(
82        shape: Shape,
83        fill_value: FloatElem<B>,
84        device: &Device<B>,
85        dtype: FloatDType,
86    ) -> FloatTensor<B> {
87        Self::float_from_data(
88            TensorData::full_dtype(shape, fill_value, dtype.into()),
89            device,
90        )
91    }
92
93    /// Converts the tensor to a data structure.
94    ///
95    /// # Arguments
96    ///
97    /// * `tensor` - The tensor.
98    ///
99    /// # Returns
100    ///
101    /// The data structure with the tensor's data.
102    fn float_into_data(tensor: FloatTensor<B>) -> impl Future<Output = TensorData> + Send;
103
104    /// Gets the device of the tensor.
105    ///
106    /// # Arguments
107    ///
108    /// * `tensor` - The tensor.
109    ///
110    /// # Returns
111    ///
112    /// The device of the tensor.
113    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
114
115    /// Moves the tensor to the given device.
116    ///
117    /// # Arguments
118    ///
119    /// * `tensor` - The tensor.
120    /// * `device` - The device to move the tensor to.
121    ///
122    /// # Returns
123    ///
124    /// The tensor on the given device.
125    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
126
127    /// Converts float tensor to int tensor.
128    ///
129    /// # Arguments
130    ///
131    /// * `tensor` - The tensor.
132    ///
133    /// # Returns
134    ///
135    /// The int tensor with the same data as the float tensor.
136    fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
137
138    /// Creates an empty tensor with the given shape.
139    ///
140    /// # Arguments
141    ///
142    /// * `shape` - The shape of the tensor.
143    /// * `device` - The device to create the tensor on.
144    /// * `dtype` - The target data type.
145    ///
146    /// # Returns
147    ///
148    /// The empty tensor with the given shape.
149    fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
150
151    /// Repeat the tensor along the given dimension.
152    ///
153    /// # Arguments
154    ///
155    /// * `tensor` - The tensor.
156    /// * `dim` - The dimension to repeat.
157    /// * `times` - The number of times to repeat the dimension.
158    ///
159    /// # Returns
160    ///
161    /// The tensor with the given dimension repeated.
162    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
163        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
164    }
165
166    /// Adds two tensors together.
167    ///
168    /// # Arguments
169    ///
170    /// * `lhs` - The left-hand side tensor.
171    /// * `rhs` - The right-hand side tensor.
172    ///
173    /// # Returns
174    ///
175    /// The result of adding the two tensors together.
176    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
177
178    /// Adds a scalar to a tensor.
179    ///
180    /// # Arguments
181    ///
182    /// * `lhs` - The left-hand side tensor.
183    /// * `rhs` - The right-hand side scalar.
184    ///
185    /// # Returns
186    ///
187    /// The result of adding the scalar to the tensor.
188    fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
189
190    /// Clamps a tensor under a minimum value.
191    ///
192    /// # Arguments
193    ///
194    /// * `tensor` - The tensor to clamp.
195    /// * `min` - The minimum value.
196    ///
197    /// # Returns
198    ///
199    /// The clamped tensor.
200    fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
201        // Default implementation
202        let mask = Self::float_lower_elem(tensor.clone(), min);
203        B::float_mask_fill(tensor, mask, min)
204    }
205
206    /// Clamps a tensor over a maximum value.
207    ///
208    /// # Arguments
209    ///
210    /// * `tensor` - The tensor to clamp.
211    /// * `max` - The maximum value.
212    ///
213    /// # Returns
214    ///
215    /// The clamped tensor.
216    fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
217        // Default implementation
218        let mask = Self::float_greater_elem(tensor.clone(), max);
219        B::float_mask_fill(tensor, mask, max)
220    }
221
222    /// Clamps a tensor between a minimum and maximum value.
223    ///
224    /// # Arguments
225    ///
226    /// * `tensor` - The tensor to clamp.
227    /// * `min` - The minimum value.
228    /// * `max` - The maximum value.
229    ///
230    /// # Returns
231    ///
232    /// The clamped tensor.
233    fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
234        // Default implementation
235        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
236    }
237
238    /// Subtracts two tensors.
239    ///
240    /// # Arguments
241    ///
242    /// * `lhs` - The left-hand side tensor.
243    /// * `rhs` - The right-hand side tensor.
244    ///
245    /// # Returns
246    ///
247    /// The result of subtracting the two tensors.
248    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
249
250    /// Subtracts a scalar from a tensor.
251    ///
252    /// # Arguments
253    ///
254    /// * `lhs` - The left-hand side tensor.
255    /// * `rhs` - The right-hand side scalar.
256    ///
257    /// # Returns
258    ///
259    /// The result of subtracting the scalar from the tensor.
260    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
261
262    /// Multiplies two tensors together element-wise.
263    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
264
265    /// Multiplies a tensor by a scalar.
266    ///
267    /// # Arguments
268    ///
269    /// * `lhs` - The left-hand side tensor.
270    /// * `rhs` - The right-hand side scalar.
271    ///
272    /// # Returns
273    ///
274    /// The result of multiplying the tensor by the scalar.
275    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
276
277    /// Divides two tensors element-wise.
278    ///
279    /// # Arguments
280    ///
281    /// * `lhs` - The left-hand side tensor.
282    /// * `rhs` - The right-hand side tensor.
283    ///
284    /// # Returns
285    ///
286    /// The result of dividing the two tensors.
287    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
288
289    /// Divides a tensor by a scalar.
290    ///
291    /// # Arguments
292    ///
293    /// * `lhs` - The left-hand side tensor.
294    /// * `rhs` - The right-hand side scalar.
295    ///
296    /// # Returns
297    ///
298    /// The result of dividing the tensor by the scalar.
299    fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
300
301    /// Computes the remainder of division between two tensors element-wise.
302    ///
303    /// # Arguments
304    ///
305    /// * `lhs` - The left-hand side tensor.
306    /// * `rhs` - The right-hand side tensor.
307    ///
308    /// # Returns
309    ///
310    /// The element-wise remainder when dividing `lhs` by `rhs`.
311    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
312
313    /// Computes the modulus of a tensor given a scalar.
314    ///
315    /// # Arguments
316    /// * `lhs` - The left-hand side tensor.
317    /// * `rhs` - The right-hand side scalar.
318    ///
319    /// # Returns
320    ///
321    /// The result of applying the modulus of the scalar to the tensor.
322    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
323
324    /// Multiplies two tensors together using matrix multiplication.
325    ///
326    /// # Arguments
327    ///
328    /// * `lhs` - The left-hand side tensor.
329    /// * `rhs` - The right-hand side tensor.
330    ///
331    /// # Returns
332    ///
333    /// The result of multiplying the two tensors together using matrix multiplication.
334    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
335
336    /// Computes the cross product of two tensors along a given dimension.
337    ///
338    /// # Arguments
339    ///
340    /// * `lhs` - The left-hand side tensor.
341    /// * `rhs` - The right-hand side tensor.
342    /// * `dim` - The dimension to compute the cross product along.
343    ///
344    /// # Returns
345    ///
346    /// The cross product of the two tensors.
347    fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
348
349    /// Negates a tensor element-wise.
350    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
351        Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
352    }
353
354    /// Calculates the reciprocals element-wise
355    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
356
357    /// Transposes a tensor.
358    ///
359    /// # Arguments
360    ///
361    /// * `tensor` - The tensor to transpose.
362    ///
363    /// # Returns
364    ///
365    /// The transposed tensor.
366    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
367        let ndims = tensor.shape().num_dims();
368        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
369    }
370
371    /// Swaps two dimensions of a tensor.
372    ///
373    /// # Arguments
374    ///
375    /// * `tensor` - The tensor to swap the dimensions of.
376    /// * `dim1` - The first dimension to swap.
377    /// * `dim2` - The second dimension to swap.
378    ///
379    /// # Returns
380    ///
381    /// The tensor with the dimensions swapped.
382    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
383
384    /// Permutes the dimensions of a tensor.
385    ///
386    /// # Arguments
387    ///
388    /// * `tensor` - The tensor to permute the dimensions of.
389    /// * `axes` - The new order of the dimensions.
390    /// # Returns
391    ///
392    /// The tensor with the dimensions permuted.
393    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
394
395    /// Reverse the order of elements in a tensor along the given axes.
396    ///
397    /// # Arguments
398    ///
399    /// * `tensor` - The tensor to reverse.
400    /// * `axes` - The axes to reverse.
401    ///
402    /// The tensor with the elements reversed.
403    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
404
405    /// Reshapes a tensor.
406    ///
407    /// # Arguments
408    ///
409    /// * `tensor` - The tensor to reshape.
410    /// * `shape` - The new shape of the tensor.
411    ///
412    /// # Returns
413    ///
414    /// The tensor with the new shape.
415    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
416
417    /// Gather elements from a tensor.
418    ///
419    /// # Arguments
420    ///
421    /// * `dim` - The dimension to gather from.
422    /// * `tensor` - The tensor to gather from.
423    /// * `indices` - The indices to gather.
424    ///
425    /// # Returns
426    ///
427    /// The gathered elements.
428    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
429
430    /// Scatter elements into a tensor.
431    ///
432    /// # Arguments
433    ///
434    /// * `dim` - The dimension to scatter into.
435    /// * `tensor` - The tensor to scatter into.
436    /// * `indices` - The indices to scatter into.
437    /// * `value` - The value to scatter.
438    ///
439    /// # Returns
440    ///
441    /// The tensor with the scattered elements.
442    fn float_scatter(
443        dim: usize,
444        tensor: FloatTensor<B>,
445        indices: IntTensor<B>,
446        value: FloatTensor<B>,
447    ) -> FloatTensor<B>;
448
449    /// Select tensor elements along the given dimension corresponding for the given indices.
450    ///
451    /// # Arguments
452    ///
453    /// * `tensor` - The tensor to select from.
454    /// * `dim` - The dimension to select from.
455    /// * `indices` - The indices to select.
456    ///
457    /// # Returns
458    ///
459    /// The selected elements.
460    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
461
462    /// Assign the selected elements along the given dimension corresponding for the given indices
463    /// to the given value.
464    ///
465    /// # Arguments
466    ///
467    /// * `tensor` - The tensor to select from.
468    /// * `dim` - The dimension to select from.
469    /// * `indices` - The indices to select.
470    /// * `value` - The value to assign.
471    ///
472    /// # Returns
473    ///
474    /// The tensor with the selected elements assigned to the given value.
475    fn float_select_assign(
476        tensor: FloatTensor<B>,
477        dim: usize,
478        indices: IntTensor<B>,
479        value: FloatTensor<B>,
480    ) -> FloatTensor<B>;
481
482    /// Select tensor elements corresponding to the given slices.
483    ///
484    /// # Arguments
485    ///
486    /// * `tensor` - The tensor to select from.
487    /// * `slices` - The slices specifying ranges and steps for each dimension.
488    ///
489    /// # Returns
490    ///
491    /// The selected elements in a new tensor.
492    fn float_slice(tensor: FloatTensor<B>, slices: &[crate::Slice]) -> FloatTensor<B>;
493
494    /// Assign the selected elements corresponding to the given slices to the given value.
495    ///
496    /// # Arguments
497    ///
498    /// * `tensor` - The tensor to select from.
499    /// * `ranges` - The ranges to select.
500    /// * `value` - The value to assign.
501    ///
502    /// # Returns
503    ///
504    /// The tensor with the selected elements assigned to the given value.
505    fn float_slice_assign(
506        tensor: FloatTensor<B>,
507        slices: &[crate::Slice],
508        value: FloatTensor<B>,
509    ) -> FloatTensor<B>;
510
511    /// Update the given tensor with the value tensor where the mask is true.
512    ///
513    /// # Arguments
514    ///
515    /// * `tensor` - The tensor to select from.
516    /// * `mask` - The boolean mask to select with.
517    /// * `value` - The value to assign to the selected elements from the value tensor.
518    ///
519    /// # Returns
520    ///
521    /// The tensor with the selected elements assigned to the given value.
522    fn float_mask_where(
523        tensor: FloatTensor<B>,
524        mask: BoolTensor<B>,
525        value: FloatTensor<B>,
526    ) -> FloatTensor<B>;
527
528    /// Update the given tensor with the value where the mask is true.
529    ///
530    /// # Arguments
531    ///
532    /// * `tensor` - The tensor to select from.
533    /// * `mask` - The boolean mask to select with.
534    /// * `value` - The value to assign to the selected elements.
535    ///
536    /// # Returns
537    ///
538    /// The tensor with the selected elements assigned to the given value.
539    fn float_mask_fill(
540        tensor: FloatTensor<B>,
541        mask: BoolTensor<B>,
542        value: FloatElem<B>,
543    ) -> FloatTensor<B>;
544
545    /// Equal comparison of two tensors.
546    ///
547    /// # Arguments
548    ///
549    /// * `lhs` - The left-hand side tensor.
550    /// * `rhs` - The right-hand side tensor.
551    ///
552    /// # Returns
553    ///
554    /// A boolean tensor with the result of the comparison.
555    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
556
557    /// Element-wise non-equality comparison.
558    ///
559    /// # Arguments
560    ///
561    /// * `lhs` - The left-hand side tensor.
562    /// * `rhs` - The right-hand side tensor.
563    ///
564    /// # Returns
565    ///
566    /// A boolean tensor with the result of the comparison.
567    fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
568        let equal_tensor = B::float_equal(lhs, rhs);
569        B::bool_not(equal_tensor)
570    }
571
572    /// Equal comparison of a tensor and a scalar.
573    ///
574    /// # Arguments
575    ///
576    /// * `lhs` - The left-hand side tensor.
577    /// * `rhs` - The right-hand side scalar.
578    ///
579    /// # Returns
580    ///
581    /// A boolean tensor with the result of the comparison.
582    fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
583
584    /// Element-wise non-equality comparison with a scalar.
585    ///
586    /// # Arguments
587    ///
588    /// * `lhs` - The left-hand side tensor.
589    /// * `rhs` - The right-hand side scalar.
590    ///
591    /// # Returns
592    ///
593    /// A boolean tensor with the result of the comparison.
594    fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
595        let equal_tensor = B::float_equal_elem(lhs, rhs);
596        B::bool_not(equal_tensor)
597    }
598
599    /// Greater than comparison of two tensors.
600    ///
601    /// # Arguments
602    ///
603    /// * `lhs` - The left-hand side tensor.
604    /// * `rhs` - The right-hand side tensor.
605    ///
606    /// # Returns
607    ///
608    /// A boolean tensor with the result of the comparison.
609    fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
610
611    /// Greater than comparison of a tensor and a scalar.
612    ///
613    /// # Arguments
614    ///
615    /// * `lhs` - The left-hand side tensor.
616    /// * `rhs` - The right-hand side scalar.
617    ///
618    /// # Returns
619    ///
620    /// A boolean tensor with the result of the comparison.
621    fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
622
623    /// Greater than or equal comparison of two tensors.
624    ///
625    /// # Arguments
626    ///
627    /// * `lhs` - The left-hand side tensor.
628    /// * `rhs` - The right-hand side tensor.
629    ///
630    /// # Returns
631    ///
632    /// A boolean tensor with the result of the comparison.
633    fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
634
635    /// Greater than or equal comparison of a tensor and a scalar.
636    ///
637    /// # Arguments
638    ///
639    /// * `lhs` - The left-hand side tensor.
640    /// * `rhs` - The right-hand side scalar.
641    ///
642    /// # Returns
643    ///
644    /// A boolean tensor with the result of the comparison.
645    fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
646
647    /// Less than comparison of two tensors.
648    ///
649    /// # Arguments
650    ///
651    /// * `lhs` - The left-hand side tensor.
652    /// * `rhs` - The right-hand side tensor.
653    ///
654    /// # Returns
655    ///
656    /// A boolean tensor with the result of the comparison.
657    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
658
659    /// Less than comparison of a tensor and a scalar.
660    ///
661    /// # Arguments
662    ///
663    /// * `lhs` - The left-hand side tensor.
664    /// * `rhs` - The right-hand side scalar.
665    ///
666    /// # Returns
667    ///
668    /// A boolean tensor with the result of the comparison.
669    fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
670
671    /// Less than or equal comparison of two tensors.
672    ///
673    /// # Arguments
674    ///
675    /// * `lhs` - The left-hand side tensor.
676    /// * `rhs` - The right-hand side tensor.
677    ///
678    /// # Returns
679    ///
680    /// A boolean tensor with the result of the comparison.
681    fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
682
683    /// Less than or equal comparison of a tensor and a scalar.
684    ///
685    /// # Arguments
686    ///
687    /// * `lhs` - The left-hand side tensor.
688    /// * `rhs` - The right-hand side scalar.
689    ///
690    /// # Returns
691    ///
692    /// A boolean tensor with the result of the comparison.
693    fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
694
695    /// Detaches a tensor from the computation graph.
696    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
697        // Should only be overridden by autodiff backends.
698        tensor
699    }
700
701    /// Sets the `require_grad` flag of a tensor.
702    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
703        // Should only be overridden by autodiff backends.
704        tensor
705    }
706
707    /// Returns the `require_grad` flag of a tensor.
708    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
709        // Should only be overridden by autodiff backends.
710        false
711    }
712
713    /// Sum of all elements in a tensor.
714    ///
715    /// # Arguments
716    ///
717    /// * `tensor` - The tensor to sum.
718    ///
719    /// # Returns
720    ///
721    /// A scalar tensor with the sum of all elements in `tensor`.
722    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
723
724    /// Sum of all elements in a tensor along a dimension.
725    ///
726    /// # Arguments
727    ///
728    /// * `tensor` - The tensor to sum.
729    /// * `dim` - The dimension along which to sum.
730    ///
731    /// # Returns
732    ///
733    /// A tensor with the sum of all elements in `tensor` along `dim`.
734    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
735
736    /// Product of all elements in a tensor.
737    ///
738    /// # Arguments
739    ///
740    /// * `tensor` - The tensor to product.
741    ///
742    /// # Returns
743    ///
744    /// A scalar tensor with the product of all elements in `tensor`.
745    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
746        // Product of all elements in a tensor
747        B::float_exp(B::float_sum(B::float_log(tensor)))
748    }
749
750    /// Product of all elements in a tensor along a dimension.
751    ///
752    /// # Arguments
753    ///
754    /// * `tensor` - The tensor to product.
755    ///
756    /// # Returns
757    ///
758    /// A tensor with the product of all elements in `tensor` along `dim`.
759    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
760        // Product of all elements in a tensor along a dimension
761        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
762    }
763
764    /// Mean of all elements in a tensor.
765    ///
766    /// # Arguments
767    ///
768    /// * `tensor` - The tensor to mean.
769    ///
770    /// # Returns
771    ///
772    /// A scalar tensor with the mean of all elements in `tensor`.
773    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
774        let num_elems = tensor.shape().num_elements();
775        B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
776    }
777
778    /// Mean of all elements in a tensor along a dimension.
779    ///
780    /// # Arguments
781    ///
782    /// * `tensor` - The tensor to mean.
783    /// * `dim` - The dimension along which to mean.
784    ///
785    /// # Returns
786    ///
787    /// A tensor with the mean of all elements in `tensor` along `dim`.
788    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
789
790    /// Computes the cumulative sum of elements along a dimension.
791    ///
792    /// # Arguments
793    ///
794    /// * `tensor` - The tensor to compute the cumulative sum of.
795    /// * `dim` - The dimension along which to compute the cumulative sum.
796    ///
797    /// # Returns
798    ///
799    /// A tensor with the same shape where each element is the cumulative sum
800    /// of all elements up to and including that position along the dimension.
801    fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
802
803    /// Computes the cumulative product of elements along a dimension.
804    ///
805    /// # Arguments
806    ///
807    /// * `tensor` - The tensor to compute the cumulative product of.
808    /// * `dim` - The dimension along which to compute the cumulative product.
809    ///
810    /// # Returns
811    ///
812    /// A tensor with the same shape where each element is the cumulative product
813    /// of all elements up to and including that position along the dimension.
814    fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
815
816    /// Computes the cumulative minimum of elements along a dimension.
817    ///
818    /// # Arguments
819    ///
820    /// * `tensor` - The tensor to compute the cumulative minimum of.
821    /// * `dim` - The dimension along which to compute the cumulative minimum.
822    ///
823    /// # Returns
824    ///
825    /// A tensor with the same shape where each element is the minimum
826    /// of all elements up to and including that position along the dimension.
827    fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
828
829    /// Computes the cumulative maximum of elements along a dimension.
830    ///
831    /// # Arguments
832    ///
833    /// * `tensor` - The tensor to compute the cumulative maximum of.
834    /// * `dim` - The dimension along which to compute the cumulative maximum.
835    ///
836    /// # Returns
837    ///
838    /// A tensor with the same shape where each element is the maximum
839    /// of all elements up to and including that position along the dimension.
840    fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
841
842    /// Converts a tensor to another floating point data type.
843    ///
844    /// # Arguments
845    ///
846    /// * `tensor` - The tensor to convert.
847    /// * `dtype` - The target data type.
848    ///
849    /// # Returns
850    ///
851    /// A tensor with the same values as `tensor` but in the target floating point data type.
852    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
853
854    /// Returns a new tensor with exponential values.
855    ///
856    /// # Arguments
857    ///
858    /// * `tensor` - The tensor to exponentiate.
859    ///
860    /// # Returns
861    ///
862    /// A tensor with the same shape as `tensor` with exponential values.
863    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
864
865    /// Returns a new tensor with natural logarithm values.
866    ///
867    /// # Arguments
868    ///
869    /// * `tensor` - The tensor to take the logarithm of.
870    ///
871    /// # Returns
872    ///
873    /// A tensor with the same shape as `tensor` with natural logarithm values.
874    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
875
876    /// Returns a new tensor with logarithm values of (1 + Xi).
877    ///
878    /// # Arguments
879    ///
880    /// * `tensor` - The tensor to take the logarithm of.
881    ///
882    /// # Returns
883    ///
884    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
885    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
886
887    /// Element-wise power with a FloatTensor.
888    ///
889    /// # Arguments
890    ///
891    /// * `lhs` - The left-hand side tensor.
892    /// * `rhs` - The right-hand side tensor.
893    ///
894    /// # Returns
895    ///
896    /// The elements of `lhs` raised to the power of the elements of `rhs`.
897    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
898
899    /// Element-wise power with an IntTensor.
900    ///
901    /// # Arguments
902    ///
903    /// * `lhs` - The left-hand side tensor.
904    /// * `rhs` - The right-hand side floatTensor.
905    ///
906    /// # Returns
907    ///
908    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
909    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
910        Self::float_powf(lhs, B::int_into_float(rhs))
911    }
912
913    /// Raises a tensor to the power of an int scalar.
914    ///
915    /// # Backend Implementors Note
916    ///
917    /// A number of common exponent cases can be implemented with operations
918    /// which are much cheaper than generic exponentiation.
919    ///
920    /// This (`Backend` impl overridable) operation handles generic optimizations
921    /// for several common integer exponent cases; and then dispatches to
922    /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
923    /// operation to handle the generic case.
924    ///
925    /// # Arguments
926    ///
927    /// * `lhs` - The left-hand side tensor.
928    /// * `rhs` - The right-hand side scalar.
929    ///
930    /// # Returns
931    ///
932    /// The elements of `lhs` raised to the value of `rhs`.
933    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
934        let exp = rhs.elem::<i32>();
935        match exp {
936            0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
937            1 => lhs,
938            2 => B::float_mul(lhs.clone(), lhs),
939            -1 => Self::float_recip(lhs),
940            -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
941            _ => Self::float_powi_scalar_impl(lhs, rhs),
942        }
943    }
944
945    /// Raises a tensor to the power of an int scalar.
946    ///
947    /// # Backend Implementors Note
948    ///
949    /// This is the generic implementation of integer exponentiation
950    /// called by [`Self::float_powi_scalar`] in the fallback case.
951    ///
952    /// As a general rule, this should not be called directly.
953    ///
954    /// # Arguments
955    ///
956    /// * `lhs` - The left-hand side tensor.
957    /// * `rhs` - The right-hand side scalar.
958    ///
959    /// # Returns
960    ///
961    /// The elements of `lhs` raised to the value of `rhs`.
962    fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
963        // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
964        Self::float_powf_scalar_impl(lhs, rhs.elem::<f32>())
965    }
966
967    /// Returns a new tensor with values raised to the power of float `value`.
968    ///
969    /// # Backend Implementors Note
970    ///
971    /// This (`Backend` impl overridable) operation dispatches integer exponentiation
972    /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
973    /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
974    /// operation to handle the generic case.
975    ///
976    /// # Arguments
977    ///
978    /// * `tensor` - The tensor to exponentiate.
979    /// * `value` - The exponent.
980    ///
981    /// # Returns
982    ///
983    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
984    fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B> {
985        if num_traits::Float::floor(value) == value {
986            // When the exponent is an integer, use the integer exponentiation implementation.
987            let exp = B::IntElem::from_elem(value as i32);
988            Self::float_powi_scalar(tensor, exp)
989        } else {
990            Self::float_powf_scalar_impl(tensor, value)
991        }
992    }
993
994    /// Returns a new tensor with values raised to the power of float `value`.
995    ///
996    /// # Backend Implementors Note
997    ///
998    /// This is the generic implementation of integer exponentiation
999    /// called by [`Self::float_powf_scalar`] in the fallback case.
1000    ///
1001    /// This is the minimal required support a `Backend` must implement
1002    /// for exponentiation.
1003    ///
1004    /// As a general rule, this should not be called directly.
1005    ///
1006    /// # Arguments
1007    ///
1008    /// * `tensor` - The tensor to exponentiate.
1009    /// * `value` - The exponent.
1010    ///
1011    /// # Returns
1012    ///
1013    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1014    fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
1015
1016    /// Returns a new tensor with square root values.
1017    ///
1018    /// # Arguments
1019    ///
1020    /// * `tensor` - The tensor to take the square root of.
1021    ///
1022    /// # Returns
1023    ///
1024    /// A tensor with the same shape as `tensor` with square root values.
1025    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1026
1027    /// Returns a new tensor with absolute values.
1028    ///
1029    /// # Arguments
1030    ///
1031    /// * `tensor` - The tensor to take absolute value of.
1032    ///
1033    /// # Returns
1034    ///
1035    /// A tensor with the same shape as `tensor` with absolute values.
1036    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1037
1038    /// Returns a new tensor with cosine values.
1039    ///
1040    /// # Arguments
1041    ///
1042    /// * `tensor` - The tensor to take the cosine of.
1043    ///
1044    /// # Returns
1045    ///
1046    /// A tensor with the same shape as `tensor` with cosine values.
1047    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1048
1049    /// Returns a new tensor with sine values.
1050    ///
1051    /// # Arguments
1052    ///
1053    /// * `tensor` - The tensor to take the sine of.
1054    ///
1055    /// # Returns
1056    ///
1057    /// A tensor with the same shape as `tensor` with sine values.
1058    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1059
1060    /// Returns a new tensor with tangent values.
1061    ///
1062    /// # Arguments
1063    ///
1064    /// * `tensor` - The tensor to take the tangent of.
1065    ///
1066    /// # Returns
1067    ///
1068    /// A tensor with the same shape as `tensor` with tangent values.
1069    fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> {
1070        let sin = B::float_sin(tensor.clone());
1071        let cos = B::float_cos(tensor);
1072        B::float_div(sin, cos)
1073    }
1074
1075    /// Returns a new tensor with hyperbolic cosine values.
1076    ///
1077    /// # Arguments
1078    ///
1079    /// * `tensor` - The tensor to take the hyperbolic cosine of.
1080    ///
1081    /// # Returns
1082    ///
1083    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1084    fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1085        // cosh = ( e^x + e^(-x) ) / 2
1086        let e_x = B::float_exp(tensor.clone());
1087        let e_neg_x = B::float_exp(B::float_neg(tensor));
1088        let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x)
1089        B::float_div_scalar(num, 2.0.elem())
1090    }
1091
1092    /// Returns a new tensor with hyperbolic sine values.
1093    ///
1094    /// # Arguments
1095    ///
1096    /// * `tensor` - The tensor to take the hyperbolic sine of.
1097    ///
1098    /// # Returns
1099    ///
1100    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1101    fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1102        // sinh = ( e^x - e^(-x) ) / 2
1103        let e_x = B::float_exp(tensor.clone());
1104        let e_neg_x = B::float_exp(B::float_neg(tensor));
1105        let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x)
1106        B::float_div_scalar(num, 2.0.elem())
1107    }
1108
1109    /// Returns a new tensor with hyperbolic tangent values.
1110    ///
1111    /// # Arguments
1112    ///
1113    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1114    ///
1115    /// # Returns
1116    ///
1117    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1118    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1119        let sinh = B::float_sinh(tensor.clone());
1120        let cosh = B::float_cosh(tensor);
1121        B::float_div(sinh, cosh)
1122    }
1123
1124    /// Returns a new tensor with rounded values.
1125    ///
1126    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1127    /// strategy, with halfway cases rounded to the nearest even integer value.
1128    ///
1129    /// # Arguments
1130    ///
1131    /// * `tensor` - The tensor to be rounded.
1132    ///
1133    /// # Returns
1134    ///
1135    /// A tensor with the same shape as `tensor` with rounded values.
1136    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1137
1138    /// Returns a new tensor with floored values.
1139    ///
1140    /// # Arguments
1141    ///
1142    /// * `tensor` - The tensor to be floored.
1143    ///
1144    /// # Returns
1145    ///
1146    /// A tensor with the same shape as `tensor` with floored values.
1147    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1148
1149    /// Returns a new tensor with ceiled values.
1150    ///
1151    /// # Arguments
1152    ///
1153    /// * `tensor` - The tensor to be ceiled.
1154    ///
1155    /// # Returns
1156    ///
1157    /// A tensor with the same shape as `tensor` with ceiled values.
1158    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1159
1160    /// Returns a new tensor with truncated values.
1161    ///
1162    /// # Arguments
1163    ///
1164    /// * `tensor` - The tensor to be truncated.
1165    ///
1166    /// # Returns
1167    ///
1168    /// A tensor with the same shape as `tensor` with truncated values.
1169    fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1170
1171    /// Returns a new tensor with the error function values.
1172    ///
1173    /// # Arguments
1174    ///
1175    /// * `tensor` - The tensor to take the error function of.
1176    ///
1177    /// # Returns
1178    ///
1179    /// A tensor with the same shape as `tensor` with error function values.
1180    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1181
1182    /// Concatenates tensors along a dimension.
1183    ///
1184    /// # Arguments
1185    ///
1186    /// * `tensors` - The tensors to concatenate.
1187    /// * `dim` - The dimension along which to concatenate.
1188    ///
1189    /// # Returns
1190    ///
1191    /// A tensor with the concatenated tensors along `dim`.
1192    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1193        cat_with_slice_assign::<B, Float>(
1194            tensors.into_iter().map(TensorPrimitive::Float).collect(),
1195            dim,
1196        )
1197        .tensor()
1198    }
1199
1200    /// Gets the indices of the maximum elements of a tensor along an axis.
1201    ///
1202    /// # Arguments
1203    ///
1204    /// * `tensor` - The tensor to get the maximum elements of.
1205    /// * `dim` - The dimension along which to get the maximum elements.
1206    ///
1207    /// # Returns
1208    ///
1209    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1210    fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1211
1212    /// Gets the indices of the minimum elements of a tensor along an axis.
1213    ///
1214    /// # Arguments
1215    ///
1216    /// * `tensor` - The tensor to get the minimum elements of.
1217    /// * `dim` - The dimension along which to get the minimum elements.
1218    ///
1219    /// # Returns
1220    ///
1221    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1222    fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1223
1224    /// Gets the maximum element of a tensor.
1225    ///
1226    /// # Arguments
1227    ///
1228    /// * `tensor` - The tensor to get the maximum elements of.
1229    ///
1230    /// # Returns
1231    ///
1232    /// A tensor with the maximum element of `tensor`.
1233    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1234        let shape = tensor.shape();
1235        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1236
1237        B::float_max_dim(tensor, 0)
1238    }
1239
1240    /// Gets the maximum elements of a tensor along an axis.
1241    ///
1242    /// # Arguments
1243    ///
1244    /// * `tensor` - The tensor to get the maximum elements of.
1245    /// * `dim` - The dimension along which to get the maximum elements.
1246    ///
1247    /// # Returns
1248    ///
1249    /// A tensor with the maximum elements of `tensor` along `dim`.
1250    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1251        let index = B::float_argmax(tensor.clone(), dim);
1252
1253        B::float_gather(dim, tensor, index)
1254    }
1255
1256    /// Gets the maximum elements of a tensor along an axis and their indices.
1257    ///
1258    /// # Arguments
1259    ///
1260    /// * `tensor` - The tensor to get the maximum elements of.
1261    /// * `dim` - The dimension along which to get the maximum elements.
1262    ///
1263    /// # Returns
1264    ///
1265    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1266    fn float_max_dim_with_indices(
1267        tensor: FloatTensor<B>,
1268        dim: usize,
1269    ) -> (FloatTensor<B>, IntTensor<B>) {
1270        let index = B::float_argmax(tensor.clone(), dim);
1271        let values = B::float_gather(dim, tensor, index.clone());
1272
1273        (values, index)
1274    }
1275
1276    /// Gets the minimum element of a tensor.
1277    ///
1278    /// # Arguments
1279    ///
1280    /// * `tensor` - The tensor to get the minimum elements of.
1281    ///
1282    /// # Returns
1283    ///
1284    /// A tensor with the minimum element of `tensor`.
1285    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1286        let shape = tensor.shape();
1287        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1288
1289        B::float_min_dim(tensor, 0)
1290    }
1291
1292    /// Gets the minimum elements of a tensor along an axis.
1293    ///
1294    /// # Arguments
1295    ///
1296    /// * `tensor` - The tensor to get the minimum elements of.
1297    /// * `dim` - The dimension along which to get the minimum elements.
1298    ///
1299    /// # Returns
1300    ///
1301    /// A tensor with the minimum elements of `tensor` along `dim`.
1302    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1303        let index = B::float_argmin(tensor.clone(), dim);
1304
1305        B::float_gather(dim, tensor, index)
1306    }
1307
1308    /// Gets the minimum elements of a tensor along an axis and their indices.
1309    ///
1310    /// # Arguments
1311    ///
1312    /// * `tensor` - The tensor to get the minimum elements of.
1313    /// * `dim` - The dimension along which to get the minimum elements.
1314    ///
1315    /// # Returns
1316    ///
1317    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1318    fn float_min_dim_with_indices(
1319        tensor: FloatTensor<B>,
1320        dim: usize,
1321    ) -> (FloatTensor<B>, IntTensor<B>) {
1322        let index = B::float_argmin(tensor.clone(), dim);
1323        let values = B::float_gather(dim, tensor, index.clone());
1324
1325        (values, index)
1326    }
1327
1328    /// Gets the maximum absolute element of a tensor.
1329    ///
1330    /// # Arguments
1331    ///
1332    /// * `tensor` - The tensor to get the maximum elements of.
1333    ///
1334    /// # Returns
1335    ///
1336    /// A tensor with the maximum element of `tensor`.
1337    fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1338        let shape = tensor.shape();
1339        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1340
1341        B::float_max_abs_dim(tensor, 0)
1342    }
1343
1344    /// Gets the maximum absolute elements of a tensor along an axis.
1345    ///
1346    /// # Arguments
1347    ///
1348    /// * `tensor` - The tensor to get the maximum elements of.
1349    /// * `dim` - The dimension along which to get the maximum elements.
1350    ///
1351    /// # Returns
1352    ///
1353    /// A tensor with the maximum elements of `tensor` along `dim`.
1354    fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1355        B::float_max_dim(B::float_abs(tensor), dim)
1356    }
1357
1358    /// Tests if any element in the float `tensor` evaluates to True.
1359    ///
1360    /// # Arguments
1361    ///
1362    /// * `tensor` - The tensor to test.
1363    ///
1364    /// # Returns
1365    ///
1366    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1367    fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1368        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1369        let bool_tensor = B::bool_not(bool_tensor);
1370        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1371        B::float_greater_elem(sum, 0.0f32.elem())
1372    }
1373
1374    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1375    ///
1376    /// # Arguments
1377    ///
1378    /// * `tensor` - The tensor to test.
1379    /// * `dim` - The axis along which to test.
1380    ///
1381    /// # Returns
1382    ///
1383    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1384    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1385    /// input evaluates to True, False otherwise.
1386    fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1387        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1388        let bool_tensor = B::bool_not(bool_tensor);
1389        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1390        B::float_greater_elem(sum, 0.0f32.elem())
1391    }
1392
1393    /// Tests if all elements in the float `tensor` evaluate to True.
1394    ///
1395    /// # Arguments
1396    ///
1397    /// * `tensor` - The tensor to test.
1398    ///
1399    /// # Returns
1400    ///
1401    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1402    /// evaluate to True, False otherwise.
1403    fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1404        let num_elems = tensor.shape().num_elements();
1405        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1406        let bool_tensor = B::bool_not(bool_tensor);
1407        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1408        B::float_equal_elem(sum, (num_elems as f32).elem())
1409    }
1410
1411    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1412    ///
1413    /// # Arguments
1414    ///
1415    /// * `tensor` - The tensor to test.
1416    /// * `dim` - The axis along which to test.
1417    ///
1418    /// # Returns
1419    ///
1420    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1421    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1422    /// evaluates to True, False otherwise.
1423    fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1424        let num_elems = tensor.shape().dims[dim];
1425        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1426        let bool_tensor = B::bool_not(bool_tensor);
1427        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1428        B::float_equal_elem(sum, (num_elems as f32).elem())
1429    }
1430
1431    /// Returns the signs of the float `tensor`.
1432    ///
1433    /// # Arguments
1434    ///
1435    /// * `tensor` - The tensor to extract the signs from.
1436    ///
1437    /// # Returns
1438    ///
1439    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1440    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1441        let zeros = B::float_zeros(
1442            tensor.shape(),
1443            &B::float_device(&tensor),
1444            tensor.dtype().into(),
1445        );
1446        let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1447        let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1448
1449        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1450        result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1451        result
1452    }
1453
1454    /// Broadcasts the float `tensor` to the given `shape`.
1455    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1456
1457    /// Sort the elements of the input `tensor` by value in along a given dimension.
1458    ///
1459    /// This sort is unstable (i.e., may reorder equal elements).
1460    ///
1461    /// # Arguments
1462    ///
1463    /// * `tensor` - The input tensor.
1464    /// * `dim` - The axis along which to sort.
1465    /// * `descending` - The sorting order.
1466    ///
1467    /// # Returns
1468    ///
1469    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1470    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1471        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1472    }
1473
1474    /// Sort the elements of the input `tensor` by value in along a given dimension.
1475    ///
1476    /// This sort is unstable (i.e., may reorder equal elements).
1477    ///
1478    /// # Arguments
1479    ///
1480    /// * `tensor` - The input tensor.
1481    /// * `dim` - The axis along which to sort.
1482    /// * `descending` - The sorting order.
1483    ///
1484    /// # Returns
1485    ///
1486    /// A tensor with the same shape as the input tensor and corresponding indices, where
1487    /// the elements are sorted by value and the indices map back to the original input tensor.
1488    fn float_sort_with_indices(
1489        tensor: FloatTensor<B>,
1490        dim: usize,
1491        descending: bool,
1492    ) -> (FloatTensor<B>, IntTensor<B>) {
1493        let (values, indices) =
1494            sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1495        (values.tensor(), indices)
1496    }
1497
1498    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1499    ///
1500    /// This sort is unstable (i.e., may reorder equal elements).
1501    ///
1502    /// # Arguments
1503    ///
1504    /// * `tensor` - The input tensor.
1505    /// * `dim` - The axis along which to sort.
1506    /// * `descending` - The sorting order.
1507    ///
1508    /// # Returns
1509    ///
1510    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1511    fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1512        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1513    }
1514
1515    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1516    /// using the given locations in [-1, 1].
1517    ///
1518    /// Interpolation is bilinear.
1519    /// Padding is border: out of bounds locations will be clamped to the nearest border
1520    ///
1521    /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
1522    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1523    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1524    /// * `method` - How to interpolate between samples
1525    ///
1526    /// # Returns
1527    ///
1528    /// A tensor with shape (N, C, H_out, W_out)
1529    fn float_grid_sample_2d(
1530        tensor: FloatTensor<B>,
1531        grid: FloatTensor<B>,
1532        method: InterpolateMode,
1533    ) -> FloatTensor<B> {
1534        match method {
1535            InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(tensor, grid),
1536            _ => todo!("Default implementation for grid_sample_2d with {method:?} unimplemented"),
1537        }
1538    }
1539
1540    /// Unfold windows along a dimension.
1541    ///
1542    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1543    /// where windows are advanced by `step` at each index.
1544    ///
1545    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1546    ///
1547    /// # Arguments
1548    ///
1549    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1550    /// * `dim` - the selected dim.
1551    /// * `size` - the size of each unfolded window.
1552    /// * `step` - the step between each window.
1553    ///
1554    /// # Returns
1555    ///
1556    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1557    fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1558    -> FloatTensor<B>;
1559
1560    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1561    ///
1562    /// # Returns
1563    ///
1564    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1565    fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> {
1566        // Check if the input tensor is NaN by comparing it to itself
1567        // NaN is the only value that is not equal to itself
1568        B::float_not_equal(tensor.clone(), tensor)
1569    }
1570
1571    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1572    ///
1573    /// # Returns
1574    ///
1575    /// A boolean tensor where `true` indicates that the value is infinite
1576    fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> {
1577        B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.elem())
1578    }
1579}