Skip to main content

burn_backend/backend/ops/
tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};
7use crate::{Backend, Distribution, TensorData, get_device_settings};
8use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14    /// Creates a new tensor from the data structure.
15    ///
16    /// # Arguments
17    ///
18    /// * `data` - The data structure.
19    /// * `device` - The device to create the tensor on.
20    ///
21    /// # Returns
22    ///
23    /// The tensor with the given data.
24    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26    /// Creates a new tensor with random values.
27    ///
28    /// # Arguments
29    ///
30    /// * `shape` - The shape of the tensor.
31    /// * `distribution` - The distribution to sample from.
32    /// * `device` - The device to create the tensor on.
33    /// * `dtype` - The target data type.
34    ///
35    /// # Returns
36    ///
37    /// The tensor with the given shape and random values.
38    fn float_random(
39        shape: Shape,
40        distribution: Distribution,
41        device: &Device<B>,
42        dtype: FloatDType,
43    ) -> FloatTensor<B>;
44
45    /// Creates a new tensor with zeros.
46    ///
47    /// # Arguments
48    ///
49    /// * `shape` - The shape of the tensor.
50    /// * `device` - The device to create the tensor on.
51    /// * `dtype` - The target data type.
52    ///
53    /// # Returns
54    ///
55    /// The tensor with the given shape and zeros.
56    fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
57        Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)
58    }
59
60    /// Creates a new tensor with ones.
61    ///
62    /// # Arguments
63    ///
64    /// * `shape` - The shape of the tensor.
65    /// * `device` - The device to create the tensor on.
66    /// * `dtype` - The target data type.
67    ///
68    /// # Returns
69    ///
70    /// The tensor with the given shape and ones.
71    fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
72        Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)
73    }
74
75    /// Creates a tensor filled with given value.
76    ///
77    /// # Arguments
78    ///
79    /// * `shape` - The shape of the tensor.
80    /// * `fill_value` - The value with which to fill the tensor.
81    /// * `device` - The device to create the tensor on.
82    /// * `dtype` - The target data type.
83    ///
84    /// # Returns
85    ///
86    /// The tensor filled with given value
87    fn float_full(
88        shape: Shape,
89        fill_value: Scalar,
90        device: &Device<B>,
91        dtype: FloatDType,
92    ) -> FloatTensor<B> {
93        Self::float_from_data(
94            TensorData::full_dtype(shape, fill_value, dtype.into()),
95            device,
96        )
97    }
98
99    /// Converts the tensor to a data structure.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The data structure with the tensor's data.
108    fn float_into_data(
109        tensor: FloatTensor<B>,
110    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
111
112    /// Gets the device of the tensor.
113    ///
114    /// # Arguments
115    ///
116    /// * `tensor` - The tensor.
117    ///
118    /// # Returns
119    ///
120    /// The device of the tensor.
121    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
122
123    /// Moves the tensor to the given device.
124    ///
125    /// # Arguments
126    ///
127    /// * `tensor` - The tensor.
128    /// * `device` - The device to move the tensor to.
129    ///
130    /// # Returns
131    ///
132    /// The tensor on the given device.
133    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
134
135    /// Converts float tensor to int tensor.
136    ///
137    /// # Arguments
138    ///
139    /// * `tensor` - The tensor.
140    /// * `out_dtype` - The output tensor dtype.
141    ///
142    /// # Returns
143    ///
144    /// The int tensor with the same data as the float tensor.
145    fn float_into_int(tensor: FloatTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
146
147    /// Creates an empty tensor with the given shape.
148    ///
149    /// # Arguments
150    ///
151    /// * `shape` - The shape of the tensor.
152    /// * `device` - The device to create the tensor on.
153    /// * `dtype` - The target data type.
154    ///
155    /// # Returns
156    ///
157    /// The empty tensor with the given shape.
158    fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
159
160    /// Repeat the tensor along the given dimension.
161    ///
162    /// # Arguments
163    ///
164    /// * `tensor` - The tensor.
165    /// * `dim` - The dimension to repeat.
166    /// * `times` - The number of times to repeat the dimension.
167    ///
168    /// # Returns
169    ///
170    /// The tensor with the given dimension repeated.
171    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
172        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
173    }
174
175    /// Adds two tensors together.
176    ///
177    /// # Arguments
178    ///
179    /// * `lhs` - The left-hand side tensor.
180    /// * `rhs` - The right-hand side tensor.
181    ///
182    /// # Returns
183    ///
184    /// The result of adding the two tensors together.
185    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
186
187    /// Adds a scalar to a tensor.
188    ///
189    /// # Arguments
190    ///
191    /// * `lhs` - The left-hand side tensor.
192    /// * `rhs` - The right-hand side scalar.
193    ///
194    /// # Returns
195    ///
196    /// The result of adding the scalar to the tensor.
197    fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
198
199    /// Clamps a tensor under a minimum value.
200    ///
201    /// # Arguments
202    ///
203    /// * `tensor` - The tensor to clamp.
204    /// * `min` - The minimum value.
205    ///
206    /// # Returns
207    ///
208    /// The clamped tensor.
209    fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {
210        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
211        let mask = Self::float_lower_elem(tensor.clone(), min, dtype);
212        B::float_mask_fill(tensor, mask, min)
213    }
214
215    /// Clamps a tensor over a maximum value.
216    ///
217    /// # Arguments
218    ///
219    /// * `tensor` - The tensor to clamp.
220    /// * `max` - The maximum value.
221    ///
222    /// # Returns
223    ///
224    /// The clamped tensor.
225    fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {
226        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
227        let mask = Self::float_greater_elem(tensor.clone(), max, dtype);
228        B::float_mask_fill(tensor, mask, max)
229    }
230
231    /// Clamps a tensor between a minimum and maximum value.
232    ///
233    /// # Arguments
234    ///
235    /// * `tensor` - The tensor to clamp.
236    /// * `min` - The minimum value.
237    /// * `max` - The maximum value.
238    ///
239    /// # Returns
240    ///
241    /// The clamped tensor.
242    fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {
243        // Default implementation
244        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
245    }
246
247    /// Subtracts two tensors.
248    ///
249    /// # Arguments
250    ///
251    /// * `lhs` - The left-hand side tensor.
252    /// * `rhs` - The right-hand side tensor.
253    ///
254    /// # Returns
255    ///
256    /// The result of subtracting the two tensors.
257    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259    /// Subtracts a scalar from a tensor.
260    ///
261    /// # Arguments
262    ///
263    /// * `lhs` - The left-hand side tensor.
264    /// * `rhs` - The right-hand side scalar.
265    ///
266    /// # Returns
267    ///
268    /// The result of subtracting the scalar from the tensor.
269    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
270
271    /// Multiplies two tensors together element-wise.
272    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
273
274    /// Multiplies a tensor by a scalar.
275    ///
276    /// # Arguments
277    ///
278    /// * `lhs` - The left-hand side tensor.
279    /// * `rhs` - The right-hand side scalar.
280    ///
281    /// # Returns
282    ///
283    /// The result of multiplying the tensor by the scalar.
284    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
285
286    /// Divides two tensors element-wise.
287    ///
288    /// # Arguments
289    ///
290    /// * `lhs` - The left-hand side tensor.
291    /// * `rhs` - The right-hand side tensor.
292    ///
293    /// # Returns
294    ///
295    /// The result of dividing the two tensors.
296    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
297
298    /// Divides a tensor by a scalar.
299    ///
300    /// # Arguments
301    ///
302    /// * `lhs` - The left-hand side tensor.
303    /// * `rhs` - The right-hand side scalar.
304    ///
305    /// # Returns
306    ///
307    /// The result of dividing the tensor by the scalar.
308    fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
309
310    /// Computes the remainder of division between two tensors element-wise.
311    ///
312    /// # Arguments
313    ///
314    /// * `lhs` - The left-hand side tensor.
315    /// * `rhs` - The right-hand side tensor.
316    ///
317    /// # Returns
318    ///
319    /// The element-wise remainder when dividing `lhs` by `rhs`.
320    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
321
322    /// Computes the modulus of a tensor given a scalar.
323    ///
324    /// # Arguments
325    /// * `lhs` - The left-hand side tensor.
326    /// * `rhs` - The right-hand side scalar.
327    ///
328    /// # Returns
329    ///
330    /// The result of applying the modulus of the scalar to the tensor.
331    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
332
333    /// Multiplies two tensors together using matrix multiplication.
334    ///
335    /// # Arguments
336    ///
337    /// * `lhs` - The left-hand side tensor.
338    /// * `rhs` - The right-hand side tensor.
339    ///
340    /// # Returns
341    ///
342    /// The result of multiplying the two tensors together using matrix multiplication.
343    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
344
345    /// Computes the cross product of two tensors along a given dimension.
346    ///
347    /// # Arguments
348    ///
349    /// * `lhs` - The left-hand side tensor.
350    /// * `rhs` - The right-hand side tensor.
351    /// * `dim` - The dimension to compute the cross product along.
352    ///
353    /// # Returns
354    ///
355    /// The cross product of the two tensors.
356    fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
357
358    /// Negates a tensor element-wise.
359    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
360        Self::float_mul_scalar(tensor, (-1f32).into())
361    }
362
363    /// Calculates the reciprocals element-wise
364    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
365
366    /// Transposes a tensor.
367    ///
368    /// # Arguments
369    ///
370    /// * `tensor` - The tensor to transpose.
371    ///
372    /// # Returns
373    ///
374    /// The transposed tensor.
375    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
376        let ndims = tensor.shape().num_dims();
377        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
378    }
379
380    /// Swaps two dimensions of a tensor.
381    ///
382    /// # Arguments
383    ///
384    /// * `tensor` - The tensor to swap the dimensions of.
385    /// * `dim1` - The first dimension to swap.
386    /// * `dim2` - The second dimension to swap.
387    ///
388    /// # Returns
389    ///
390    /// The tensor with the dimensions swapped.
391    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
392
393    /// Permutes the dimensions of a tensor.
394    ///
395    /// # Arguments
396    ///
397    /// * `tensor` - The tensor to permute the dimensions of.
398    /// * `axes` - The new order of the dimensions.
399    /// # Returns
400    ///
401    /// The tensor with the dimensions permuted.
402    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
403
404    /// Reverse the order of elements in a tensor along the given axes.
405    ///
406    /// # Arguments
407    ///
408    /// * `tensor` - The tensor to reverse.
409    /// * `axes` - The axes to reverse.
410    ///
411    /// The tensor with the elements reversed.
412    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
413
414    /// Reshapes a tensor.
415    ///
416    /// # Arguments
417    ///
418    /// * `tensor` - The tensor to reshape.
419    /// * `shape` - The new shape of the tensor.
420    ///
421    /// # Returns
422    ///
423    /// The tensor with the new shape.
424    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
425
426    /// Gather elements from a tensor.
427    ///
428    /// # Arguments
429    ///
430    /// * `dim` - The dimension to gather from.
431    /// * `tensor` - The tensor to gather from.
432    /// * `indices` - The indices to gather.
433    ///
434    /// # Returns
435    ///
436    /// The gathered elements.
437    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
438
439    /// Scatter elements into a tensor using sum reduction.
440    ///
441    /// # Arguments
442    ///
443    /// * `dim` - The dimension to scatter into.
444    /// * `tensor` - The tensor to scatter into.
445    /// * `indices` - The indices to scatter into.
446    /// * `value` - The value to scatter.
447    ///
448    /// # Returns
449    ///
450    /// The tensor with the scattered elements.
451    fn float_scatter_add(
452        dim: usize,
453        tensor: FloatTensor<B>,
454        indices: IntTensor<B>,
455        value: FloatTensor<B>,
456    ) -> FloatTensor<B>;
457
458    /// Select tensor elements along the given dimension corresponding for the given indices.
459    ///
460    /// # Arguments
461    ///
462    /// * `tensor` - The tensor to select from.
463    /// * `dim` - The dimension to select from.
464    /// * `indices` - The indices to select.
465    ///
466    /// # Returns
467    ///
468    /// The selected elements.
469    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
470
471    /// Assign the selected elements along the given dimension corresponding for the given indices
472    /// to the given value using sum reduction.
473    ///
474    /// # Arguments
475    ///
476    /// * `tensor` - The tensor to select from.
477    /// * `dim` - The dimension to select from.
478    /// * `indices` - The indices to select.
479    /// * `value` - The value to assign.
480    ///
481    /// # Returns
482    ///
483    /// The tensor with the selected elements assigned to the given value.
484    fn float_select_add(
485        tensor: FloatTensor<B>,
486        dim: usize,
487        indices: IntTensor<B>,
488        value: FloatTensor<B>,
489    ) -> FloatTensor<B>;
490
491    /// Select tensor elements corresponding to the given slices.
492    ///
493    /// # Arguments
494    ///
495    /// * `tensor` - The tensor to select from.
496    /// * `slices` - The slices specifying ranges and steps for each dimension.
497    ///
498    /// # Returns
499    ///
500    /// The selected elements in a new tensor.
501    ///
502    /// # Note
503    ///
504    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
505    /// be passed to this method. Backend implementations do not need to handle empty slices.
506    fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
507
508    /// Assign the selected elements corresponding to the given slices to the given value.
509    ///
510    /// # Arguments
511    ///
512    /// * `tensor` - The tensor to select from.
513    /// * `ranges` - The ranges to select.
514    /// * `value` - The value to assign.
515    ///
516    /// # Returns
517    ///
518    /// The tensor with the selected elements assigned to the given value.
519    ///
520    /// # Note
521    ///
522    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
523    /// high-level tensor API and will not be passed to this method. Backend implementations do
524    /// not need to handle empty slice assignments.
525    fn float_slice_assign(
526        tensor: FloatTensor<B>,
527        slices: &[Slice],
528        value: FloatTensor<B>,
529    ) -> FloatTensor<B>;
530
531    /// Update the given tensor with the value tensor where the mask is true.
532    ///
533    /// # Arguments
534    ///
535    /// * `tensor` - The tensor to select from.
536    /// * `mask` - The boolean mask to select with.
537    /// * `value` - The value to assign to the selected elements from the value tensor.
538    ///
539    /// # Returns
540    ///
541    /// The tensor with the selected elements assigned to the given value.
542    fn float_mask_where(
543        tensor: FloatTensor<B>,
544        mask: BoolTensor<B>,
545        value: FloatTensor<B>,
546    ) -> FloatTensor<B>;
547
548    /// Update the given tensor with the value where the mask is true.
549    ///
550    /// # Arguments
551    ///
552    /// * `tensor` - The tensor to select from.
553    /// * `mask` - The boolean mask to select with.
554    /// * `value` - The value to assign to the selected elements.
555    ///
556    /// # Returns
557    ///
558    /// The tensor with the selected elements assigned to the given value.
559    fn float_mask_fill(
560        tensor: FloatTensor<B>,
561        mask: BoolTensor<B>,
562        value: Scalar,
563    ) -> FloatTensor<B>;
564
565    /// Equal comparison of two tensors.
566    ///
567    /// # Arguments
568    ///
569    /// * `lhs` - The left-hand side tensor.
570    /// * `rhs` - The right-hand side tensor.
571    /// * `out_dtype` - The output tensor dtype.
572    ///
573    /// # Returns
574    ///
575    /// A boolean tensor with the result of the comparison.
576    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
577    -> BoolTensor<B>;
578
579    /// Element-wise non-equality comparison.
580    ///
581    /// # Arguments
582    ///
583    /// * `lhs` - The left-hand side tensor.
584    /// * `rhs` - The right-hand side tensor.
585    /// * `out_dtype` - The output tensor dtype.
586    ///
587    /// # Returns
588    ///
589    /// A boolean tensor with the result of the comparison.
590    fn float_not_equal(
591        lhs: FloatTensor<B>,
592        rhs: FloatTensor<B>,
593        out_dtype: BoolDType,
594    ) -> BoolTensor<B> {
595        let equal_tensor = B::float_equal(lhs, rhs, out_dtype);
596        B::bool_not(equal_tensor)
597    }
598
599    /// Equal comparison of a tensor and a scalar.
600    ///
601    /// # Arguments
602    ///
603    /// * `lhs` - The left-hand side tensor.
604    /// * `rhs` - The right-hand side scalar.
605    /// * `out_dtype` - The output tensor dtype.
606    ///
607    /// # Returns
608    ///
609    /// A boolean tensor with the result of the comparison.
610    fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
611
612    /// Element-wise non-equality comparison with a scalar.
613    ///
614    /// # Arguments
615    ///
616    /// * `lhs` - The left-hand side tensor.
617    /// * `rhs` - The right-hand side scalar.
618    /// * `out_dtype` - The output tensor dtype.
619    ///
620    /// # Returns
621    ///
622    /// A boolean tensor with the result of the comparison.
623    fn float_not_equal_elem(
624        lhs: FloatTensor<B>,
625        rhs: Scalar,
626        out_dtype: BoolDType,
627    ) -> BoolTensor<B> {
628        let equal_tensor = B::float_equal_elem(lhs, rhs, out_dtype);
629        B::bool_not(equal_tensor)
630    }
631
632    /// Greater than comparison of two tensors.
633    ///
634    /// # Arguments
635    ///
636    /// * `lhs` - The left-hand side tensor.
637    /// * `rhs` - The right-hand side tensor.
638    /// * `out_dtype` - The output tensor dtype.
639    ///
640    /// # Returns
641    ///
642    /// A boolean tensor with the result of the comparison.
643    fn float_greater(
644        lhs: FloatTensor<B>,
645        rhs: FloatTensor<B>,
646        out_dtype: BoolDType,
647    ) -> BoolTensor<B>;
648
649    /// Greater than comparison of a tensor and a scalar.
650    ///
651    /// # Arguments
652    ///
653    /// * `lhs` - The left-hand side tensor.
654    /// * `rhs` - The right-hand side scalar.
655    /// * `out_dtype` - The output tensor dtype.
656    ///
657    /// # Returns
658    ///
659    /// A boolean tensor with the result of the comparison.
660    fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
661
662    /// Greater than or equal comparison of two tensors.
663    ///
664    /// # Arguments
665    ///
666    /// * `lhs` - The left-hand side tensor.
667    /// * `rhs` - The right-hand side tensor.
668    /// * `out_dtype` - The output tensor dtype.
669    ///
670    /// # Returns
671    ///
672    /// A boolean tensor with the result of the comparison.
673    fn float_greater_equal(
674        lhs: FloatTensor<B>,
675        rhs: FloatTensor<B>,
676        out_dtype: BoolDType,
677    ) -> BoolTensor<B>;
678
679    /// Greater than or equal comparison of a tensor and a scalar.
680    ///
681    /// # Arguments
682    ///
683    /// * `lhs` - The left-hand side tensor.
684    /// * `rhs` - The right-hand side scalar.
685    /// * `out_dtype` - The output tensor dtype.
686    ///
687    /// # Returns
688    ///
689    /// A boolean tensor with the result of the comparison.
690    fn float_greater_equal_elem(
691        lhs: FloatTensor<B>,
692        rhs: Scalar,
693        out_dtype: BoolDType,
694    ) -> BoolTensor<B>;
695
696    /// Less than comparison of two tensors.
697    ///
698    /// # Arguments
699    ///
700    /// * `lhs` - The left-hand side tensor.
701    /// * `rhs` - The right-hand side tensor.
702    /// * `out_dtype` - The output tensor dtype.
703    ///
704    /// # Returns
705    ///
706    /// A boolean tensor with the result of the comparison.
707    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
708    -> BoolTensor<B>;
709
710    /// Less than comparison of a tensor and a scalar.
711    ///
712    /// # Arguments
713    ///
714    /// * `lhs` - The left-hand side tensor.
715    /// * `rhs` - The right-hand side scalar.
716    /// * `out_dtype` - The output tensor dtype.
717    ///
718    /// # Returns
719    ///
720    /// A boolean tensor with the result of the comparison.
721    fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
722
723    /// Less than or equal comparison of two tensors.
724    ///
725    /// # Arguments
726    ///
727    /// * `lhs` - The left-hand side tensor.
728    /// * `rhs` - The right-hand side tensor.
729    /// * `out_dtype` - The output tensor dtype.
730    ///
731    /// # Returns
732    ///
733    /// A boolean tensor with the result of the comparison.
734    fn float_lower_equal(
735        lhs: FloatTensor<B>,
736        rhs: FloatTensor<B>,
737        out_dtype: BoolDType,
738    ) -> BoolTensor<B>;
739
740    /// Less than or equal comparison of a tensor and a scalar.
741    ///
742    /// # Arguments
743    ///
744    /// * `lhs` - The left-hand side tensor.
745    /// * `rhs` - The right-hand side scalar.
746    /// * `out_dtype` - The output tensor dtype.
747    ///
748    /// # Returns
749    ///
750    /// A boolean tensor with the result of the comparison.
751    fn float_lower_equal_elem(
752        lhs: FloatTensor<B>,
753        rhs: Scalar,
754        out_dtype: BoolDType,
755    ) -> BoolTensor<B>;
756
757    /// Detaches a tensor from the computation graph.
758    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
759        // Should only be overridden by autodiff backends.
760        tensor
761    }
762
763    /// Sets the `require_grad` flag of a tensor.
764    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
765        // Should only be overridden by autodiff backends.
766        tensor
767    }
768
769    /// Returns the `require_grad` flag of a tensor.
770    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
771        // Should only be overridden by autodiff backends.
772        false
773    }
774
775    /// Sum of all elements in a tensor.
776    ///
777    /// # Arguments
778    ///
779    /// * `tensor` - The tensor to sum.
780    ///
781    /// # Returns
782    ///
783    /// A scalar tensor with the sum of all elements in `tensor`.
784    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
785
786    /// Sum of all elements in a tensor along a dimension.
787    ///
788    /// # Arguments
789    ///
790    /// * `tensor` - The tensor to sum.
791    /// * `dim` - The dimension along which to sum.
792    ///
793    /// # Returns
794    ///
795    /// A tensor with the sum of all elements in `tensor` along `dim`.
796    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
797
798    /// Product of all elements in a tensor.
799    ///
800    /// # Arguments
801    ///
802    /// * `tensor` - The tensor to product.
803    ///
804    /// # Returns
805    ///
806    /// A scalar tensor with the product of all elements in `tensor`.
807    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
808        // Product of all elements in a tensor
809        B::float_exp(B::float_sum(B::float_log(tensor)))
810    }
811
812    /// Product of all elements in a tensor along a dimension.
813    ///
814    /// # Arguments
815    ///
816    /// * `tensor` - The tensor to product.
817    ///
818    /// # Returns
819    ///
820    /// A tensor with the product of all elements in `tensor` along `dim`.
821    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
822        // Product of all elements in a tensor along a dimension
823        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
824    }
825
826    /// Mean of all elements in a tensor.
827    ///
828    /// # Arguments
829    ///
830    /// * `tensor` - The tensor to mean.
831    ///
832    /// # Returns
833    ///
834    /// A scalar tensor with the mean of all elements in `tensor`.
835    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
836        let num_elems = tensor.shape().num_elements() as f32;
837        B::float_div_scalar(B::float_sum(tensor), num_elems.into())
838    }
839
840    /// Mean of all elements in a tensor along a dimension.
841    ///
842    /// # Arguments
843    ///
844    /// * `tensor` - The tensor to mean.
845    /// * `dim` - The dimension along which to mean.
846    ///
847    /// # Returns
848    ///
849    /// A tensor with the mean of all elements in `tensor` along `dim`.
850    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
851
852    /// Computes the cumulative sum of elements along a dimension.
853    ///
854    /// # Arguments
855    ///
856    /// * `tensor` - The tensor to compute the cumulative sum of.
857    /// * `dim` - The dimension along which to compute the cumulative sum.
858    ///
859    /// # Returns
860    ///
861    /// A tensor with the same shape where each element is the cumulative sum
862    /// of all elements up to and including that position along the dimension.
863    fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
864
865    /// Computes the cumulative product of elements along a dimension.
866    ///
867    /// # Arguments
868    ///
869    /// * `tensor` - The tensor to compute the cumulative product of.
870    /// * `dim` - The dimension along which to compute the cumulative product.
871    ///
872    /// # Returns
873    ///
874    /// A tensor with the same shape where each element is the cumulative product
875    /// of all elements up to and including that position along the dimension.
876    fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
877
878    /// Computes the cumulative minimum of elements along a dimension.
879    ///
880    /// # Arguments
881    ///
882    /// * `tensor` - The tensor to compute the cumulative minimum of.
883    /// * `dim` - The dimension along which to compute the cumulative minimum.
884    ///
885    /// # Returns
886    ///
887    /// A tensor with the same shape where each element is the minimum
888    /// of all elements up to and including that position along the dimension.
889    fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
890
891    /// Computes the cumulative maximum of elements along a dimension.
892    ///
893    /// # Arguments
894    ///
895    /// * `tensor` - The tensor to compute the cumulative maximum of.
896    /// * `dim` - The dimension along which to compute the cumulative maximum.
897    ///
898    /// # Returns
899    ///
900    /// A tensor with the same shape where each element is the maximum
901    /// of all elements up to and including that position along the dimension.
902    fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
903
904    /// Converts a tensor to another floating point data type.
905    ///
906    /// # Arguments
907    ///
908    /// * `tensor` - The tensor to convert.
909    /// * `dtype` - The target data type.
910    ///
911    /// # Returns
912    ///
913    /// A tensor with the same values as `tensor` but in the target floating point data type.
914    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
915
916    /// Returns a new tensor with exponential values.
917    ///
918    /// # Arguments
919    ///
920    /// * `tensor` - The tensor to exponentiate.
921    ///
922    /// # Returns
923    ///
924    /// A tensor with the same shape as `tensor` with exponential values.
925    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
926
927    /// Returns a new tensor with natural logarithm values.
928    ///
929    /// # Arguments
930    ///
931    /// * `tensor` - The tensor to take the logarithm of.
932    ///
933    /// # Returns
934    ///
935    /// A tensor with the same shape as `tensor` with natural logarithm values.
936    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
937
938    /// Returns a new tensor with logarithm values of (1 + Xi).
939    ///
940    /// # Arguments
941    ///
942    /// * `tensor` - The tensor to take the logarithm of.
943    ///
944    /// # Returns
945    ///
946    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
947    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
948
949    /// Element-wise power with a FloatTensor.
950    ///
951    /// # Arguments
952    ///
953    /// * `lhs` - The left-hand side tensor.
954    /// * `rhs` - The right-hand side tensor.
955    ///
956    /// # Returns
957    ///
958    /// The elements of `lhs` raised to the power of the elements of `rhs`.
959    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
960
961    /// Element-wise power with an IntTensor.
962    ///
963    /// # Arguments
964    ///
965    /// * `lhs` - The left-hand side tensor.
966    /// * `rhs` - The right-hand side floatTensor.
967    ///
968    /// # Returns
969    ///
970    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
971    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
972        let dtype = lhs.dtype();
973        Self::float_powf(lhs, B::int_into_float(rhs, dtype.into()))
974    }
975
976    /// Raises a tensor to the power of an int scalar.
977    ///
978    /// # Backend Implementors Note
979    ///
980    /// A number of common exponent cases can be implemented with operations
981    /// which are much cheaper than generic exponentiation.
982    ///
983    /// This (`Backend` impl overridable) operation handles generic optimizations
984    /// for several common integer exponent cases; and then dispatches to
985    /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
986    /// operation to handle the generic case.
987    ///
988    /// # Arguments
989    ///
990    /// * `lhs` - The left-hand side tensor.
991    /// * `rhs` - The right-hand side scalar.
992    ///
993    /// # Returns
994    ///
995    /// The elements of `lhs` raised to the value of `rhs`.
996    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
997        match rhs.elem::<i64>() {
998            0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
999            1 => lhs,
1000            2 => B::float_mul(lhs.clone(), lhs),
1001            -1 => Self::float_recip(lhs),
1002            -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
1003            _ => Self::float_powi_scalar_impl(lhs, rhs),
1004        }
1005    }
1006
1007    /// Raises a tensor to the power of an int scalar.
1008    ///
1009    /// # Backend Implementors Note
1010    ///
1011    /// This is the generic implementation of integer exponentiation
1012    /// called by [`Self::float_powi_scalar`] in the fallback case.
1013    ///
1014    /// As a general rule, this should not be called directly.
1015    ///
1016    /// # Arguments
1017    ///
1018    /// * `lhs` - The left-hand side tensor.
1019    /// * `rhs` - The right-hand side scalar.
1020    ///
1021    /// # Returns
1022    ///
1023    /// The elements of `lhs` raised to the value of `rhs`.
1024    fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1025        // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
1026        Self::float_powf_scalar_impl(lhs, rhs)
1027    }
1028
1029    /// Returns a new tensor with values raised to the power of float `value`.
1030    ///
1031    /// # Backend Implementors Note
1032    ///
1033    /// This (`Backend` impl overridable) operation dispatches integer exponentiation
1034    /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
1035    /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
1036    /// operation to handle the generic case.
1037    ///
1038    /// # Arguments
1039    ///
1040    /// * `tensor` - The tensor to exponentiate.
1041    /// * `value` - The exponent.
1042    ///
1043    /// # Returns
1044    ///
1045    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1046    fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {
1047        if let Some(exp) = value.try_as_integer() {
1048            Self::float_powi_scalar(tensor, exp)
1049        } else {
1050            Self::float_powf_scalar_impl(tensor, value)
1051        }
1052    }
1053
1054    /// Returns a new tensor with values raised to the power of float `value`.
1055    ///
1056    /// # Backend Implementors Note
1057    ///
1058    /// This is the generic implementation of integer exponentiation
1059    /// called by [`Self::float_powf_scalar`] in the fallback case.
1060    ///
1061    /// This is the minimal required support a `Backend` must implement
1062    /// for exponentiation.
1063    ///
1064    /// As a general rule, this should not be called directly.
1065    ///
1066    /// # Arguments
1067    ///
1068    /// * `tensor` - The tensor to exponentiate.
1069    /// * `value` - The exponent.
1070    ///
1071    /// # Returns
1072    ///
1073    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1074    fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;
1075
1076    /// Returns a new tensor with square root values.
1077    ///
1078    /// # Arguments
1079    ///
1080    /// * `tensor` - The tensor to take the square root of.
1081    ///
1082    /// # Returns
1083    ///
1084    /// A tensor with the same shape as `tensor` with square root values.
1085    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1086
1087    /// Returns a new tensor with absolute values.
1088    ///
1089    /// # Arguments
1090    ///
1091    /// * `tensor` - The tensor to take absolute value of.
1092    ///
1093    /// # Returns
1094    ///
1095    /// A tensor with the same shape as `tensor` with absolute values.
1096    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1097
1098    /// Returns a new tensor with cosine values.
1099    ///
1100    /// # Arguments
1101    ///
1102    /// * `tensor` - The tensor to take the cosine of.
1103    ///
1104    /// # Returns
1105    ///
1106    /// A tensor with the same shape as `tensor` with cosine values.
1107    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1108
1109    /// Returns a new tensor with sine values.
1110    ///
1111    /// # Arguments
1112    ///
1113    /// * `tensor` - The tensor to take the sine of.
1114    ///
1115    /// # Returns
1116    ///
1117    /// A tensor with the same shape as `tensor` with sine values.
1118    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1119
1120    /// Returns a new tensor with tangent values.
1121    ///
1122    /// # Arguments
1123    ///
1124    /// * `tensor` - The tensor to take the tangent of.
1125    ///
1126    /// # Returns
1127    ///
1128    /// A tensor with the same shape as `tensor` with tangent values.
1129    fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1130
1131    /// Returns a new tensor with hyperbolic cosine values.
1132    ///
1133    /// # Arguments
1134    ///
1135    /// * `tensor` - The tensor to take the hyperbolic cosine of.
1136    ///
1137    /// # Returns
1138    ///
1139    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1140    fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1141
1142    /// Returns a new tensor with hyperbolic sine values.
1143    ///
1144    /// # Arguments
1145    ///
1146    /// * `tensor` - The tensor to take the hyperbolic sine of.
1147    ///
1148    /// # Returns
1149    ///
1150    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1151    fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1152
1153    /// Returns a new tensor with hyperbolic tangent values.
1154    ///
1155    /// # Arguments
1156    ///
1157    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1158    ///
1159    /// # Returns
1160    ///
1161    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1162    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1163
1164    /// Returns a new tensor with inverse cosine values.
1165    ///
1166    /// # Arguments
1167    ///
1168    /// * `tensor` - The input tensor.
1169    ///
1170    /// # Returns
1171    ///
1172    /// A tensor with the same shape as `tensor` with inverse cosine values.
1173    fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1174
1175    /// Returns a new tensor with inverse hyperbolic cosine values.
1176    ///
1177    /// # Arguments
1178    ///
1179    /// * `tensor` - The input tensor.
1180    ///
1181    /// # Returns
1182    ///
1183    /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
1184    fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1185
1186    /// Returns a new tensor with inverse sine values.
1187    ///
1188    /// # Arguments
1189    ///
1190    /// * `tensor` - The input tensor.
1191    ///
1192    /// # Returns
1193    ///
1194    /// A tensor with the same shape as `tensor` with inverse sine values.
1195    fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1196
1197    /// Returns a new tensor with inverse hyperbolic sine values.
1198    ///
1199    /// # Arguments
1200    ///
1201    /// * `tensor` - The input tensor.
1202    ///
1203    /// # Returns
1204    ///
1205    /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1206    fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1207
1208    /// Returns a new tensor with the inverse tangent values.
1209    ///
1210    /// # Arguments
1211    ///
1212    /// * `tensor` - The input tensor.
1213    ///
1214    /// # Returns
1215    ///
1216    /// A tensor with the same shape as `tensor` with the inverse tangent values.
1217    fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1218
1219    /// Returns a new tensor with the inverse hyperbolic tangent values.
1220    ///
1221    /// # Arguments
1222    ///
1223    /// * `tensor` - The input tensor.
1224    ///
1225    /// # Returns
1226    ///
1227    /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.
1228    fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1229
1230    /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1231    ///
1232    /// # Arguments
1233    ///
1234    /// * `lhs` - The tensor with y coordinates.
1235    /// * `rhs` - The tensor with x coordinates.
1236    ///
1237    /// # Returns
1238    ///
1239    /// A tensor with the four-quadrant inverse tangent values.
1240    fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
1241
1242    /// Returns a new tensor with rounded values.
1243    ///
1244    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1245    /// strategy, with halfway cases rounded to the nearest even integer value.
1246    ///
1247    /// # Arguments
1248    ///
1249    /// * `tensor` - The tensor to be rounded.
1250    ///
1251    /// # Returns
1252    ///
1253    /// A tensor with the same shape as `tensor` with rounded values.
1254    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1255
1256    /// Returns a new tensor with floored values.
1257    ///
1258    /// # Arguments
1259    ///
1260    /// * `tensor` - The tensor to be floored.
1261    ///
1262    /// # Returns
1263    ///
1264    /// A tensor with the same shape as `tensor` with floored values.
1265    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1266
1267    /// Returns a new tensor with ceiled values.
1268    ///
1269    /// # Arguments
1270    ///
1271    /// * `tensor` - The tensor to be ceiled.
1272    ///
1273    /// # Returns
1274    ///
1275    /// A tensor with the same shape as `tensor` with ceiled values.
1276    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1277
1278    /// Returns a new tensor with truncated values.
1279    ///
1280    /// # Arguments
1281    ///
1282    /// * `tensor` - The tensor to be truncated.
1283    ///
1284    /// # Returns
1285    ///
1286    /// A tensor with the same shape as `tensor` with truncated values.
1287    fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1288
1289    /// Returns a new tensor with the error function values.
1290    ///
1291    /// # Arguments
1292    ///
1293    /// * `tensor` - The tensor to take the error function of.
1294    ///
1295    /// # Returns
1296    ///
1297    /// A tensor with the same shape as `tensor` with error function values.
1298    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1299
1300    /// Concatenates tensors along a dimension.
1301    ///
1302    /// # Arguments
1303    ///
1304    /// * `tensors` - The tensors to concatenate.
1305    /// * `dim` - The dimension along which to concatenate.
1306    ///
1307    /// # Returns
1308    ///
1309    /// A tensor with the concatenated tensors along `dim`.
1310    ///
1311    /// # Note
1312    ///
1313    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1314    /// high-level tensor API and will not be passed to this method. Backend implementations do
1315    /// not need to handle empty tensors.
1316    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1317        cat_with_slice_assign::<B, Float>(
1318            tensors.into_iter().map(TensorPrimitive::Float).collect(),
1319            dim,
1320        )
1321        .tensor()
1322    }
1323
1324    /// Gets the indices of the maximum elements of a tensor along an axis.
1325    ///
1326    /// # Arguments
1327    ///
1328    /// * `tensor` - The tensor to get the maximum elements of.
1329    /// * `dim` - The dimension along which to get the maximum elements.
1330    /// * `out_dtype` - The output tensor dtype.
1331    ///
1332    /// # Returns
1333    ///
1334    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1335    fn float_argmax(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1336
1337    /// Gets the indices of the minimum elements of a tensor along an axis.
1338    ///
1339    /// # Arguments
1340    ///
1341    /// * `tensor` - The tensor to get the minimum elements of.
1342    /// * `dim` - The dimension along which to get the minimum elements.
1343    /// * `out_dtype` - The output tensor dtype.
1344    ///
1345    /// # Returns
1346    ///
1347    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1348    fn float_argmin(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1349
1350    /// Gets the maximum element of a tensor.
1351    ///
1352    /// # Arguments
1353    ///
1354    /// * `tensor` - The tensor to get the maximum elements of.
1355    ///
1356    /// # Returns
1357    ///
1358    /// A tensor with the maximum element of `tensor`.
1359    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1360        let shape = tensor.shape();
1361        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1362
1363        B::float_max_dim(tensor, 0)
1364    }
1365
1366    /// Gets the maximum elements of a tensor along an axis.
1367    ///
1368    /// # Arguments
1369    ///
1370    /// * `tensor` - The tensor to get the maximum elements of.
1371    /// * `dim` - The dimension along which to get the maximum elements.
1372    ///
1373    /// # Returns
1374    ///
1375    /// A tensor with the maximum elements of `tensor` along `dim`.
1376    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1377        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1378        let index = B::float_argmax(tensor.clone(), dim, dtype);
1379
1380        B::float_gather(dim, tensor, index)
1381    }
1382
1383    /// Gets the maximum elements of a tensor along an axis and their indices.
1384    ///
1385    /// # Arguments
1386    ///
1387    /// * `tensor` - The tensor to get the maximum elements of.
1388    /// * `dim` - The dimension along which to get the maximum elements.
1389    /// * `indices_dtype` - The indices tensor dtype.
1390    ///
1391    /// # Returns
1392    ///
1393    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1394    fn float_max_dim_with_indices(
1395        tensor: FloatTensor<B>,
1396        dim: usize,
1397        indices_dtype: IntDType,
1398    ) -> (FloatTensor<B>, IntTensor<B>) {
1399        let index = B::float_argmax(tensor.clone(), dim, indices_dtype);
1400        let values = B::float_gather(dim, tensor, index.clone());
1401
1402        (values, index)
1403    }
1404
1405    /// Gets the minimum element of a tensor.
1406    ///
1407    /// # Arguments
1408    ///
1409    /// * `tensor` - The tensor to get the minimum elements of.
1410    ///
1411    /// # Returns
1412    ///
1413    /// A tensor with the minimum element of `tensor`.
1414    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1415        let shape = tensor.shape();
1416        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1417
1418        B::float_min_dim(tensor, 0)
1419    }
1420
1421    /// Gets the minimum elements of a tensor along an axis.
1422    ///
1423    /// # Arguments
1424    ///
1425    /// * `tensor` - The tensor to get the minimum elements of.
1426    /// * `dim` - The dimension along which to get the minimum elements.
1427    ///
1428    /// # Returns
1429    ///
1430    /// A tensor with the minimum elements of `tensor` along `dim`.
1431    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1432        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1433        let index = B::float_argmin(tensor.clone(), dim, dtype);
1434
1435        B::float_gather(dim, tensor, index)
1436    }
1437
1438    /// Gets the minimum elements of a tensor along an axis and their indices.
1439    ///
1440    /// # Arguments
1441    ///
1442    /// * `tensor` - The tensor to get the minimum elements of.
1443    /// * `dim` - The dimension along which to get the minimum elements.
1444    /// * `indices_dtype` - The indices tensor dtype.
1445    ///
1446    /// # Returns
1447    ///
1448    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1449    fn float_min_dim_with_indices(
1450        tensor: FloatTensor<B>,
1451        dim: usize,
1452        indices_dtype: IntDType,
1453    ) -> (FloatTensor<B>, IntTensor<B>) {
1454        let index = B::float_argmin(tensor.clone(), dim, indices_dtype);
1455        let values = B::float_gather(dim, tensor, index.clone());
1456
1457        (values, index)
1458    }
1459
1460    /// Gets the maximum absolute element of a tensor.
1461    ///
1462    /// # Arguments
1463    ///
1464    /// * `tensor` - The tensor to get the maximum elements of.
1465    ///
1466    /// # Returns
1467    ///
1468    /// A tensor with the maximum element of `tensor`.
1469    fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1470        let shape = tensor.shape();
1471        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1472
1473        B::float_max_abs_dim(tensor, 0)
1474    }
1475
1476    /// Gets the maximum absolute elements of a tensor along an axis.
1477    ///
1478    /// # Arguments
1479    ///
1480    /// * `tensor` - The tensor to get the maximum elements of.
1481    /// * `dim` - The dimension along which to get the maximum elements.
1482    ///
1483    /// # Returns
1484    ///
1485    /// A tensor with the maximum elements of `tensor` along `dim`.
1486    fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1487        B::float_max_dim(B::float_abs(tensor), dim)
1488    }
1489
1490    /// Tests if any element in the float `tensor` evaluates to True.
1491    ///
1492    /// # Arguments
1493    ///
1494    /// * `tensor` - The tensor to test.
1495    /// * `out_dtype` - The output tensor dtype.
1496    ///
1497    /// # Returns
1498    ///
1499    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1500    fn float_any(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1501        let float_dtype = tensor.dtype();
1502        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1503        let bool_tensor = B::bool_not(bool_tensor);
1504        let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1505        B::float_greater_elem(sum, 0f32.into(), out_dtype)
1506    }
1507
1508    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1509    ///
1510    /// # Arguments
1511    ///
1512    /// * `tensor` - The tensor to test.
1513    /// * `dim` - The axis along which to test.
1514    /// * `out_dtype` - The output tensor dtype.
1515    ///
1516    /// # Returns
1517    ///
1518    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1519    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1520    /// input evaluates to True, False otherwise.
1521    fn float_any_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1522        let float_dtype = tensor.dtype();
1523        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1524        let bool_tensor = B::bool_not(bool_tensor);
1525        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1526        B::float_greater_elem(sum, 0f32.into(), out_dtype)
1527    }
1528
1529    /// Tests if all elements in the float `tensor` evaluate to True.
1530    ///
1531    /// # Arguments
1532    ///
1533    /// * `tensor` - The tensor to test.
1534    /// * `out_dtype` - The output tensor dtype.
1535    ///
1536    /// # Returns
1537    ///
1538    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1539    /// evaluate to True, False otherwise.
1540    fn float_all(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1541        let float_dtype = tensor.dtype();
1542        let num_elems = tensor.shape().num_elements() as f32;
1543        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1544        let bool_tensor = B::bool_not(bool_tensor);
1545        let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1546        B::float_equal_elem(sum, num_elems.into(), out_dtype)
1547    }
1548
1549    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1550    ///
1551    /// # Arguments
1552    ///
1553    /// * `tensor` - The tensor to test.
1554    /// * `dim` - The axis along which to test.
1555    /// * `out_dtype` - The output tensor dtype.
1556    ///
1557    /// # Returns
1558    ///
1559    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1560    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1561    /// evaluates to True, False otherwise.
1562    fn float_all_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1563        let float_dtype = tensor.dtype();
1564        let num_elems = tensor.shape()[dim] as f32;
1565        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1566        let bool_tensor = B::bool_not(bool_tensor);
1567        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1568        B::float_equal_elem(sum, num_elems.into(), out_dtype)
1569    }
1570
1571    /// Returns the signs of the float `tensor`.
1572    ///
1573    /// # Arguments
1574    ///
1575    /// * `tensor` - The tensor to extract the signs from.
1576    ///
1577    /// # Returns
1578    ///
1579    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1580    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1581        let device = B::float_device(&tensor);
1582        let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
1583        let zeros = B::float_zeros(tensor.shape(), &device, tensor.dtype().into());
1584        let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
1585        let greater_than_zero = B::float_greater_elem(tensor, 0f32.into(), bool_dtype);
1586
1587        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());
1588        result = B::float_mask_fill(result, greater_than_zero, 1f32.into());
1589        result
1590    }
1591
1592    /// Broadcasts the float `tensor` to the given `shape`.
1593    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1594
1595    /// Sort the elements of the input `tensor` by value in along a given dimension.
1596    ///
1597    /// This sort is unstable (i.e., may reorder equal elements).
1598    ///
1599    /// # Arguments
1600    ///
1601    /// * `tensor` - The input tensor.
1602    /// * `dim` - The axis along which to sort.
1603    /// * `descending` - The sorting order.
1604    ///
1605    /// # Returns
1606    ///
1607    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1608    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1609        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1610    }
1611
1612    /// Sort the elements of the input `tensor` by value in along a given dimension.
1613    ///
1614    /// This sort is unstable (i.e., may reorder equal elements).
1615    ///
1616    /// # Arguments
1617    ///
1618    /// * `tensor` - The input tensor.
1619    /// * `dim` - The axis along which to sort.
1620    /// * `descending` - The sorting order.
1621    /// * `indices_dtype` - The indices tensor dtype.
1622    ///
1623    /// # Returns
1624    ///
1625    /// A tensor with the same shape as the input tensor and corresponding indices, where
1626    /// the elements are sorted by value and the indices map back to the original input tensor.
1627    fn float_sort_with_indices(
1628        tensor: FloatTensor<B>,
1629        dim: usize,
1630        descending: bool,
1631        indices_dtype: IntDType,
1632    ) -> (FloatTensor<B>, IntTensor<B>) {
1633        let (values, indices) = sort_with_indices::<B, Float>(
1634            TensorPrimitive::Float(tensor),
1635            dim,
1636            descending,
1637            indices_dtype,
1638        );
1639        (values.tensor(), indices)
1640    }
1641
1642    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1643    ///
1644    /// This sort is unstable (i.e., may reorder equal elements).
1645    ///
1646    /// # Arguments
1647    ///
1648    /// * `tensor` - The input tensor.
1649    /// * `dim` - The axis along which to sort.
1650    /// * `descending` - The sorting order.
1651    /// * `out_dtype` - The output tensor dtype.
1652    ///
1653    /// # Returns
1654    ///
1655    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1656    fn float_argsort(
1657        tensor: FloatTensor<B>,
1658        dim: usize,
1659        descending: bool,
1660        out_dtype: IntDType,
1661    ) -> IntTensor<B> {
1662        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending, out_dtype)
1663    }
1664
1665    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1666    /// using the given locations in [-1, 1].
1667    ///
1668    /// # Arguments
1669    ///
1670    /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
1671    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1672    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1673    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1674    ///
1675    /// # Returns
1676    ///
1677    /// A tensor with shape (N, C, H_out, W_out)
1678    fn float_grid_sample_2d(
1679        tensor: FloatTensor<B>,
1680        grid: FloatTensor<B>,
1681        options: GridSampleOptions,
1682    ) -> FloatTensor<B> {
1683        // TODO: default impl should get int default dtype
1684        float_grid_sample_2d_ref::<B>(tensor, grid, options)
1685    }
1686
1687    /// Unfold windows along a dimension.
1688    ///
1689    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1690    /// where windows are advanced by `step` at each index.
1691    ///
1692    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1693    ///
1694    /// # Arguments
1695    ///
1696    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1697    /// * `dim` - the selected dim.
1698    /// * `size` - the size of each unfolded window.
1699    /// * `step` - the step between each window.
1700    ///
1701    /// # Returns
1702    ///
1703    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1704    fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1705    -> FloatTensor<B>;
1706
1707    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1708    ///
1709    /// # Returns
1710    ///
1711    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1712    fn float_is_nan(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1713        // Check if the input tensor is NaN by comparing it to itself
1714        // NaN is the only value that is not equal to itself
1715        B::float_not_equal(tensor.clone(), tensor, out_dtype)
1716    }
1717
1718    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1719    ///
1720    /// # Returns
1721    ///
1722    /// A boolean tensor where `true` indicates that the value is infinite
1723    fn float_is_inf(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1724        B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into(), out_dtype)
1725    }
1726}