Skip to main content

burn_backend/backend/ops/
tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};
7use crate::{Backend, Distribution, TensorData, get_device_settings};
8use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14    /// Creates a new tensor from the data structure.
15    ///
16    /// # Arguments
17    ///
18    /// * `data` - The data structure.
19    /// * `device` - The device to create the tensor on.
20    ///
21    /// # Returns
22    ///
23    /// The tensor with the given data.
24    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26    /// Creates a new tensor with random values.
27    ///
28    /// # Arguments
29    ///
30    /// * `shape` - The shape of the tensor.
31    /// * `distribution` - The distribution to sample from.
32    /// * `device` - The device to create the tensor on.
33    /// * `dtype` - The target data type.
34    ///
35    /// # Returns
36    ///
37    /// The tensor with the given shape and random values.
38    fn float_random(
39        shape: Shape,
40        distribution: Distribution,
41        device: &Device<B>,
42        dtype: FloatDType,
43    ) -> FloatTensor<B>;
44
45    /// Creates a new tensor with zeros.
46    ///
47    /// # Arguments
48    ///
49    /// * `shape` - The shape of the tensor.
50    /// * `device` - The device to create the tensor on.
51    /// * `dtype` - The target data type.
52    ///
53    /// # Returns
54    ///
55    /// The tensor with the given shape and zeros.
56    fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
57        Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)
58    }
59
60    /// Creates a new tensor with ones.
61    ///
62    /// # Arguments
63    ///
64    /// * `shape` - The shape of the tensor.
65    /// * `device` - The device to create the tensor on.
66    /// * `dtype` - The target data type.
67    ///
68    /// # Returns
69    ///
70    /// The tensor with the given shape and ones.
71    fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
72        Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)
73    }
74
75    /// Creates a tensor filled with given value.
76    ///
77    /// # Arguments
78    ///
79    /// * `shape` - The shape of the tensor.
80    /// * `fill_value` - The value with which to fill the tensor.
81    /// * `device` - The device to create the tensor on.
82    /// * `dtype` - The target data type.
83    ///
84    /// # Returns
85    ///
86    /// The tensor filled with given value
87    fn float_full(
88        shape: Shape,
89        fill_value: Scalar,
90        device: &Device<B>,
91        dtype: FloatDType,
92    ) -> FloatTensor<B> {
93        Self::float_from_data(
94            TensorData::full_dtype(shape, fill_value, dtype.into()),
95            device,
96        )
97    }
98
99    /// Converts the tensor to a data structure.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The data structure with the tensor's data.
108    fn float_into_data(
109        tensor: FloatTensor<B>,
110    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
111
112    /// Gets the device of the tensor.
113    ///
114    /// # Arguments
115    ///
116    /// * `tensor` - The tensor.
117    ///
118    /// # Returns
119    ///
120    /// The device of the tensor.
121    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
122
123    /// Moves the tensor to the given device.
124    ///
125    /// # Arguments
126    ///
127    /// * `tensor` - The tensor.
128    /// * `device` - The device to move the tensor to.
129    ///
130    /// # Returns
131    ///
132    /// The tensor on the given device.
133    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
134
135    /// Converts float tensor to int tensor.
136    ///
137    /// # Arguments
138    ///
139    /// * `tensor` - The tensor.
140    /// * `out_dtype` - The output tensor dtype.
141    ///
142    /// # Returns
143    ///
144    /// The int tensor with the same data as the float tensor.
145    fn float_into_int(tensor: FloatTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
146
147    /// Creates an empty tensor with the given shape.
148    ///
149    /// # Arguments
150    ///
151    /// * `shape` - The shape of the tensor.
152    /// * `device` - The device to create the tensor on.
153    /// * `dtype` - The target data type.
154    ///
155    /// # Returns
156    ///
157    /// The empty tensor with the given shape.
158    fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
159
160    /// Repeat the tensor along the given dimension.
161    ///
162    /// # Arguments
163    ///
164    /// * `tensor` - The tensor.
165    /// * `dim` - The dimension to repeat.
166    /// * `times` - The number of times to repeat the dimension.
167    ///
168    /// # Returns
169    ///
170    /// The tensor with the given dimension repeated.
171    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
172        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
173    }
174
175    /// Adds two tensors together.
176    ///
177    /// # Arguments
178    ///
179    /// * `lhs` - The left-hand side tensor.
180    /// * `rhs` - The right-hand side tensor.
181    ///
182    /// # Returns
183    ///
184    /// The result of adding the two tensors together.
185    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
186
187    /// Adds a scalar to a tensor.
188    ///
189    /// # Arguments
190    ///
191    /// * `lhs` - The left-hand side tensor.
192    /// * `rhs` - The right-hand side scalar.
193    ///
194    /// # Returns
195    ///
196    /// The result of adding the scalar to the tensor.
197    fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
198
199    /// Clamps a tensor under a minimum value.
200    ///
201    /// # Arguments
202    ///
203    /// * `tensor` - The tensor to clamp.
204    /// * `min` - The minimum value.
205    ///
206    /// # Returns
207    ///
208    /// The clamped tensor.
209    fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {
210        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
211        let mask = Self::float_lower_elem(tensor.clone(), min, dtype);
212        B::float_mask_fill(tensor, mask, min)
213    }
214
215    /// Clamps a tensor over a maximum value.
216    ///
217    /// # Arguments
218    ///
219    /// * `tensor` - The tensor to clamp.
220    /// * `max` - The maximum value.
221    ///
222    /// # Returns
223    ///
224    /// The clamped tensor.
225    fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {
226        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
227        let mask = Self::float_greater_elem(tensor.clone(), max, dtype);
228        B::float_mask_fill(tensor, mask, max)
229    }
230
231    /// Clamps a tensor between a minimum and maximum value.
232    ///
233    /// # Arguments
234    ///
235    /// * `tensor` - The tensor to clamp.
236    /// * `min` - The minimum value.
237    /// * `max` - The maximum value.
238    ///
239    /// # Returns
240    ///
241    /// The clamped tensor.
242    fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {
243        // Default implementation
244        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
245    }
246
247    /// Subtracts two tensors.
248    ///
249    /// # Arguments
250    ///
251    /// * `lhs` - The left-hand side tensor.
252    /// * `rhs` - The right-hand side tensor.
253    ///
254    /// # Returns
255    ///
256    /// The result of subtracting the two tensors.
257    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259    /// Subtracts a scalar from a tensor.
260    ///
261    /// # Arguments
262    ///
263    /// * `lhs` - The left-hand side tensor.
264    /// * `rhs` - The right-hand side scalar.
265    ///
266    /// # Returns
267    ///
268    /// The result of subtracting the scalar from the tensor.
269    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
270
271    /// Multiplies two tensors together element-wise.
272    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
273
274    /// Multiplies a tensor by a scalar.
275    ///
276    /// # Arguments
277    ///
278    /// * `lhs` - The left-hand side tensor.
279    /// * `rhs` - The right-hand side scalar.
280    ///
281    /// # Returns
282    ///
283    /// The result of multiplying the tensor by the scalar.
284    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
285
286    /// Divides two tensors element-wise.
287    ///
288    /// # Arguments
289    ///
290    /// * `lhs` - The left-hand side tensor.
291    /// * `rhs` - The right-hand side tensor.
292    ///
293    /// # Returns
294    ///
295    /// The result of dividing the two tensors.
296    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
297
298    /// Divides a tensor by a scalar.
299    ///
300    /// # Arguments
301    ///
302    /// * `lhs` - The left-hand side tensor.
303    /// * `rhs` - The right-hand side scalar.
304    ///
305    /// # Returns
306    ///
307    /// The result of dividing the tensor by the scalar.
308    fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
309
310    /// Computes the remainder of division between two tensors element-wise.
311    ///
312    /// # Arguments
313    ///
314    /// * `lhs` - The left-hand side tensor.
315    /// * `rhs` - The right-hand side tensor.
316    ///
317    /// # Returns
318    ///
319    /// The element-wise remainder when dividing `lhs` by `rhs`.
320    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
321
322    /// Computes the modulus of a tensor given a scalar.
323    ///
324    /// # Arguments
325    /// * `lhs` - The left-hand side tensor.
326    /// * `rhs` - The right-hand side scalar.
327    ///
328    /// # Returns
329    ///
330    /// The result of applying the modulus of the scalar to the tensor.
331    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
332
333    /// Multiplies two tensors together using matrix multiplication.
334    ///
335    /// # Arguments
336    ///
337    /// * `lhs` - The left-hand side tensor.
338    /// * `rhs` - The right-hand side tensor.
339    ///
340    /// # Returns
341    ///
342    /// The result of multiplying the two tensors together using matrix multiplication.
343    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
344
345    /// Computes the cross product of two tensors along a given dimension.
346    ///
347    /// # Arguments
348    ///
349    /// * `lhs` - The left-hand side tensor.
350    /// * `rhs` - The right-hand side tensor.
351    /// * `dim` - The dimension to compute the cross product along.
352    ///
353    /// # Returns
354    ///
355    /// The cross product of the two tensors.
356    fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
357
358    /// Negates a tensor element-wise.
359    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
360        Self::float_mul_scalar(tensor, (-1f32).into())
361    }
362
363    /// Calculates the reciprocals element-wise
364    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
365
366    /// Transposes a tensor.
367    ///
368    /// # Arguments
369    ///
370    /// * `tensor` - The tensor to transpose.
371    ///
372    /// # Returns
373    ///
374    /// The transposed tensor.
375    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
376        let ndims = tensor.shape().num_dims();
377        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
378    }
379
380    /// Swaps two dimensions of a tensor.
381    ///
382    /// # Arguments
383    ///
384    /// * `tensor` - The tensor to swap the dimensions of.
385    /// * `dim1` - The first dimension to swap.
386    /// * `dim2` - The second dimension to swap.
387    ///
388    /// # Returns
389    ///
390    /// The tensor with the dimensions swapped.
391    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
392
393    /// Permutes the dimensions of a tensor.
394    ///
395    /// # Arguments
396    ///
397    /// * `tensor` - The tensor to permute the dimensions of.
398    /// * `axes` - The new order of the dimensions.
399    /// # Returns
400    ///
401    /// The tensor with the dimensions permuted.
402    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
403
404    /// Reverse the order of elements in a tensor along the given axes.
405    ///
406    /// # Arguments
407    ///
408    /// * `tensor` - The tensor to reverse.
409    /// * `axes` - The axes to reverse.
410    ///
411    /// The tensor with the elements reversed.
412    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
413
414    /// Reshapes a tensor.
415    ///
416    /// # Arguments
417    ///
418    /// * `tensor` - The tensor to reshape.
419    /// * `shape` - The new shape of the tensor.
420    ///
421    /// # Returns
422    ///
423    /// The tensor with the new shape.
424    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
425
426    /// Gather elements from a tensor.
427    ///
428    /// # Arguments
429    ///
430    /// * `dim` - The dimension to gather from.
431    /// * `tensor` - The tensor to gather from.
432    /// * `indices` - The indices to gather.
433    ///
434    /// # Returns
435    ///
436    /// The gathered elements.
437    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
438
439    /// Scatter elements into a tensor using sum reduction.
440    ///
441    /// # Arguments
442    ///
443    /// * `dim` - The dimension to scatter into.
444    /// * `tensor` - The tensor to scatter into.
445    /// * `indices` - The indices to scatter into.
446    /// * `value` - The value to scatter.
447    ///
448    /// # Returns
449    ///
450    /// The tensor with the scattered elements.
451    fn float_scatter_add(
452        dim: usize,
453        tensor: FloatTensor<B>,
454        indices: IntTensor<B>,
455        value: FloatTensor<B>,
456    ) -> FloatTensor<B>;
457
458    /// Multi-dimensional scatter: update `data` at locations specified by `indices` with `values`.
459    ///
460    /// # Arguments
461    ///
462    /// * `data` - The tensor to scatter into.
463    /// * `indices` - An M-dimensional integer tensor whose last dimension indexes into `data`.
464    /// * `values` - The values to scatter.
465    /// * `reduction` - How to combine with existing values.
466    ///
467    /// # Returns
468    ///
469    /// The tensor with scattered values.
470    fn float_scatter_nd(
471        _data: FloatTensor<B>,
472        _indices: IntTensor<B>,
473        _values: FloatTensor<B>,
474        _reduction: crate::tensor::IndexingUpdateOp,
475    ) -> FloatTensor<B> {
476        unimplemented!("float_scatter_nd is not implemented for this backend")
477    }
478
479    /// Multi-dimensional gather: collect slices from `data` at locations specified by `indices`.
480    ///
481    /// # Arguments
482    ///
483    /// * `data` - The tensor to gather from.
484    /// * `indices` - An M-dimensional integer tensor whose last dimension indexes into `data`.
485    ///
486    /// # Returns
487    ///
488    /// The gathered tensor.
489    fn float_gather_nd(_data: FloatTensor<B>, _indices: IntTensor<B>) -> FloatTensor<B> {
490        unimplemented!("float_gather_nd is not implemented for this backend")
491    }
492
493    /// Select tensor elements along the given dimension corresponding for the given indices.
494    ///
495    /// # Arguments
496    ///
497    /// * `tensor` - The tensor to select from.
498    /// * `dim` - The dimension to select from.
499    /// * `indices` - The indices to select.
500    ///
501    /// # Returns
502    ///
503    /// The selected elements.
504    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
505
506    /// Assign the selected elements along the given dimension corresponding for the given indices
507    /// to the given value using sum reduction.
508    ///
509    /// # Arguments
510    ///
511    /// * `tensor` - The tensor to select from.
512    /// * `dim` - The dimension to select from.
513    /// * `indices` - The indices to select.
514    /// * `value` - The value to assign.
515    ///
516    /// # Returns
517    ///
518    /// The tensor with the selected elements assigned to the given value.
519    fn float_select_add(
520        tensor: FloatTensor<B>,
521        dim: usize,
522        indices: IntTensor<B>,
523        value: FloatTensor<B>,
524    ) -> FloatTensor<B>;
525
526    /// Select tensor elements corresponding to the given slices.
527    ///
528    /// # Arguments
529    ///
530    /// * `tensor` - The tensor to select from.
531    /// * `slices` - The slices specifying ranges and steps for each dimension.
532    ///
533    /// # Returns
534    ///
535    /// The selected elements in a new tensor.
536    ///
537    /// # Note
538    ///
539    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
540    /// be passed to this method. Backend implementations do not need to handle empty slices.
541    fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
542
543    /// Assign the selected elements corresponding to the given slices to the given value.
544    ///
545    /// # Arguments
546    ///
547    /// * `tensor` - The tensor to select from.
548    /// * `ranges` - The ranges to select.
549    /// * `value` - The value to assign.
550    ///
551    /// # Returns
552    ///
553    /// The tensor with the selected elements assigned to the given value.
554    ///
555    /// # Note
556    ///
557    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
558    /// high-level tensor API and will not be passed to this method. Backend implementations do
559    /// not need to handle empty slice assignments.
560    fn float_slice_assign(
561        tensor: FloatTensor<B>,
562        slices: &[Slice],
563        value: FloatTensor<B>,
564    ) -> FloatTensor<B>;
565
566    /// Update the given tensor with the value tensor where the mask is true.
567    ///
568    /// # Arguments
569    ///
570    /// * `tensor` - The tensor to select from.
571    /// * `mask` - The boolean mask to select with.
572    /// * `value` - The value to assign to the selected elements from the value tensor.
573    ///
574    /// # Returns
575    ///
576    /// The tensor with the selected elements assigned to the given value.
577    fn float_mask_where(
578        tensor: FloatTensor<B>,
579        mask: BoolTensor<B>,
580        value: FloatTensor<B>,
581    ) -> FloatTensor<B>;
582
583    /// Update the given tensor with the value where the mask is true.
584    ///
585    /// # Arguments
586    ///
587    /// * `tensor` - The tensor to select from.
588    /// * `mask` - The boolean mask to select with.
589    /// * `value` - The value to assign to the selected elements.
590    ///
591    /// # Returns
592    ///
593    /// The tensor with the selected elements assigned to the given value.
594    fn float_mask_fill(
595        tensor: FloatTensor<B>,
596        mask: BoolTensor<B>,
597        value: Scalar,
598    ) -> FloatTensor<B>;
599
600    /// Equal comparison of two tensors.
601    ///
602    /// # Arguments
603    ///
604    /// * `lhs` - The left-hand side tensor.
605    /// * `rhs` - The right-hand side tensor.
606    /// * `out_dtype` - The output tensor dtype.
607    ///
608    /// # Returns
609    ///
610    /// A boolean tensor with the result of the comparison.
611    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
612    -> BoolTensor<B>;
613
614    /// Element-wise non-equality comparison.
615    ///
616    /// # Arguments
617    ///
618    /// * `lhs` - The left-hand side tensor.
619    /// * `rhs` - The right-hand side tensor.
620    /// * `out_dtype` - The output tensor dtype.
621    ///
622    /// # Returns
623    ///
624    /// A boolean tensor with the result of the comparison.
625    fn float_not_equal(
626        lhs: FloatTensor<B>,
627        rhs: FloatTensor<B>,
628        out_dtype: BoolDType,
629    ) -> BoolTensor<B> {
630        let equal_tensor = B::float_equal(lhs, rhs, out_dtype);
631        B::bool_not(equal_tensor)
632    }
633
634    /// Equal comparison of a tensor and a scalar.
635    ///
636    /// # Arguments
637    ///
638    /// * `lhs` - The left-hand side tensor.
639    /// * `rhs` - The right-hand side scalar.
640    /// * `out_dtype` - The output tensor dtype.
641    ///
642    /// # Returns
643    ///
644    /// A boolean tensor with the result of the comparison.
645    fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
646
647    /// Element-wise non-equality comparison with a scalar.
648    ///
649    /// # Arguments
650    ///
651    /// * `lhs` - The left-hand side tensor.
652    /// * `rhs` - The right-hand side scalar.
653    /// * `out_dtype` - The output tensor dtype.
654    ///
655    /// # Returns
656    ///
657    /// A boolean tensor with the result of the comparison.
658    fn float_not_equal_elem(
659        lhs: FloatTensor<B>,
660        rhs: Scalar,
661        out_dtype: BoolDType,
662    ) -> BoolTensor<B> {
663        let equal_tensor = B::float_equal_elem(lhs, rhs, out_dtype);
664        B::bool_not(equal_tensor)
665    }
666
667    /// Greater than comparison of two tensors.
668    ///
669    /// # Arguments
670    ///
671    /// * `lhs` - The left-hand side tensor.
672    /// * `rhs` - The right-hand side tensor.
673    /// * `out_dtype` - The output tensor dtype.
674    ///
675    /// # Returns
676    ///
677    /// A boolean tensor with the result of the comparison.
678    fn float_greater(
679        lhs: FloatTensor<B>,
680        rhs: FloatTensor<B>,
681        out_dtype: BoolDType,
682    ) -> BoolTensor<B>;
683
684    /// Greater than comparison of a tensor and a scalar.
685    ///
686    /// # Arguments
687    ///
688    /// * `lhs` - The left-hand side tensor.
689    /// * `rhs` - The right-hand side scalar.
690    /// * `out_dtype` - The output tensor dtype.
691    ///
692    /// # Returns
693    ///
694    /// A boolean tensor with the result of the comparison.
695    fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
696
697    /// Greater than or equal comparison of two tensors.
698    ///
699    /// # Arguments
700    ///
701    /// * `lhs` - The left-hand side tensor.
702    /// * `rhs` - The right-hand side tensor.
703    /// * `out_dtype` - The output tensor dtype.
704    ///
705    /// # Returns
706    ///
707    /// A boolean tensor with the result of the comparison.
708    fn float_greater_equal(
709        lhs: FloatTensor<B>,
710        rhs: FloatTensor<B>,
711        out_dtype: BoolDType,
712    ) -> BoolTensor<B>;
713
714    /// Greater than or equal comparison of a tensor and a scalar.
715    ///
716    /// # Arguments
717    ///
718    /// * `lhs` - The left-hand side tensor.
719    /// * `rhs` - The right-hand side scalar.
720    /// * `out_dtype` - The output tensor dtype.
721    ///
722    /// # Returns
723    ///
724    /// A boolean tensor with the result of the comparison.
725    fn float_greater_equal_elem(
726        lhs: FloatTensor<B>,
727        rhs: Scalar,
728        out_dtype: BoolDType,
729    ) -> BoolTensor<B>;
730
731    /// Less than comparison of two tensors.
732    ///
733    /// # Arguments
734    ///
735    /// * `lhs` - The left-hand side tensor.
736    /// * `rhs` - The right-hand side tensor.
737    /// * `out_dtype` - The output tensor dtype.
738    ///
739    /// # Returns
740    ///
741    /// A boolean tensor with the result of the comparison.
742    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
743    -> BoolTensor<B>;
744
745    /// Less than comparison of a tensor and a scalar.
746    ///
747    /// # Arguments
748    ///
749    /// * `lhs` - The left-hand side tensor.
750    /// * `rhs` - The right-hand side scalar.
751    /// * `out_dtype` - The output tensor dtype.
752    ///
753    /// # Returns
754    ///
755    /// A boolean tensor with the result of the comparison.
756    fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
757
758    /// Less than or equal comparison of two tensors.
759    ///
760    /// # Arguments
761    ///
762    /// * `lhs` - The left-hand side tensor.
763    /// * `rhs` - The right-hand side tensor.
764    /// * `out_dtype` - The output tensor dtype.
765    ///
766    /// # Returns
767    ///
768    /// A boolean tensor with the result of the comparison.
769    fn float_lower_equal(
770        lhs: FloatTensor<B>,
771        rhs: FloatTensor<B>,
772        out_dtype: BoolDType,
773    ) -> BoolTensor<B>;
774
775    /// Less than or equal comparison of a tensor and a scalar.
776    ///
777    /// # Arguments
778    ///
779    /// * `lhs` - The left-hand side tensor.
780    /// * `rhs` - The right-hand side scalar.
781    /// * `out_dtype` - The output tensor dtype.
782    ///
783    /// # Returns
784    ///
785    /// A boolean tensor with the result of the comparison.
786    fn float_lower_equal_elem(
787        lhs: FloatTensor<B>,
788        rhs: Scalar,
789        out_dtype: BoolDType,
790    ) -> BoolTensor<B>;
791
792    /// Detaches a tensor from the computation graph.
793    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
794        // Should only be overridden by autodiff backends.
795        tensor
796    }
797
798    /// Sets the `require_grad` flag of a tensor.
799    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
800        // Should only be overridden by autodiff backends.
801        tensor
802    }
803
804    /// Returns the `require_grad` flag of a tensor.
805    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
806        // Should only be overridden by autodiff backends.
807        false
808    }
809
810    /// Sum of all elements in a tensor.
811    ///
812    /// # Arguments
813    ///
814    /// * `tensor` - The tensor to sum.
815    ///
816    /// # Returns
817    ///
818    /// A scalar tensor with the sum of all elements in `tensor`.
819    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
820
821    /// Sum of all elements in a tensor along a dimension.
822    ///
823    /// # Arguments
824    ///
825    /// * `tensor` - The tensor to sum.
826    /// * `dim` - The dimension along which to sum.
827    ///
828    /// # Returns
829    ///
830    /// A tensor with the sum of all elements in `tensor` along `dim`.
831    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
832
833    /// Product of all elements in a tensor.
834    ///
835    /// # Arguments
836    ///
837    /// * `tensor` - The tensor to product.
838    ///
839    /// # Returns
840    ///
841    /// A scalar tensor with the product of all elements in `tensor`.
842    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
843        // Product of all elements in a tensor
844        B::float_exp(B::float_sum(B::float_log(tensor)))
845    }
846
847    /// Product of all elements in a tensor along a dimension.
848    ///
849    /// # Arguments
850    ///
851    /// * `tensor` - The tensor to product.
852    ///
853    /// # Returns
854    ///
855    /// A tensor with the product of all elements in `tensor` along `dim`.
856    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
857        // Product of all elements in a tensor along a dimension
858        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
859    }
860
861    /// Mean of all elements in a tensor.
862    ///
863    /// # Arguments
864    ///
865    /// * `tensor` - The tensor to mean.
866    ///
867    /// # Returns
868    ///
869    /// A scalar tensor with the mean of all elements in `tensor`.
870    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
871        let num_elems = tensor.shape().num_elements() as f32;
872        B::float_div_scalar(B::float_sum(tensor), num_elems.into())
873    }
874
875    /// Mean of all elements in a tensor along a dimension.
876    ///
877    /// # Arguments
878    ///
879    /// * `tensor` - The tensor to mean.
880    /// * `dim` - The dimension along which to mean.
881    ///
882    /// # Returns
883    ///
884    /// A tensor with the mean of all elements in `tensor` along `dim`.
885    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
886
887    /// Computes the cumulative sum of elements along a dimension.
888    ///
889    /// # Arguments
890    ///
891    /// * `tensor` - The tensor to compute the cumulative sum of.
892    /// * `dim` - The dimension along which to compute the cumulative sum.
893    ///
894    /// # Returns
895    ///
896    /// A tensor with the same shape where each element is the cumulative sum
897    /// of all elements up to and including that position along the dimension.
898    fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
899
900    /// Computes the cumulative product of elements along a dimension.
901    ///
902    /// # Arguments
903    ///
904    /// * `tensor` - The tensor to compute the cumulative product of.
905    /// * `dim` - The dimension along which to compute the cumulative product.
906    ///
907    /// # Returns
908    ///
909    /// A tensor with the same shape where each element is the cumulative product
910    /// of all elements up to and including that position along the dimension.
911    fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
912
913    /// Computes the cumulative minimum of elements along a dimension.
914    ///
915    /// # Arguments
916    ///
917    /// * `tensor` - The tensor to compute the cumulative minimum of.
918    /// * `dim` - The dimension along which to compute the cumulative minimum.
919    ///
920    /// # Returns
921    ///
922    /// A tensor with the same shape where each element is the minimum
923    /// of all elements up to and including that position along the dimension.
924    fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
925
926    /// Computes the cumulative maximum of elements along a dimension.
927    ///
928    /// # Arguments
929    ///
930    /// * `tensor` - The tensor to compute the cumulative maximum of.
931    /// * `dim` - The dimension along which to compute the cumulative maximum.
932    ///
933    /// # Returns
934    ///
935    /// A tensor with the same shape where each element is the maximum
936    /// of all elements up to and including that position along the dimension.
937    fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
938
939    /// Converts a tensor to another floating point data type.
940    ///
941    /// # Arguments
942    ///
943    /// * `tensor` - The tensor to convert.
944    /// * `dtype` - The target data type.
945    ///
946    /// # Returns
947    ///
948    /// A tensor with the same values as `tensor` but in the target floating point data type.
949    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
950
951    /// Returns a new tensor with exponential values.
952    ///
953    /// # Arguments
954    ///
955    /// * `tensor` - The tensor to exponentiate.
956    ///
957    /// # Returns
958    ///
959    /// A tensor with the same shape as `tensor` with exponential values.
960    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
961
962    /// Returns a new tensor with natural logarithm values.
963    ///
964    /// # Arguments
965    ///
966    /// * `tensor` - The tensor to take the logarithm of.
967    ///
968    /// # Returns
969    ///
970    /// A tensor with the same shape as `tensor` with natural logarithm values.
971    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
972
973    /// Returns a new tensor with logarithm values of (1 + Xi).
974    ///
975    /// # Arguments
976    ///
977    /// * `tensor` - The tensor to take the logarithm of.
978    ///
979    /// # Returns
980    ///
981    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
982    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
983
984    /// Element-wise power with a FloatTensor.
985    ///
986    /// # Arguments
987    ///
988    /// * `lhs` - The left-hand side tensor.
989    /// * `rhs` - The right-hand side tensor.
990    ///
991    /// # Returns
992    ///
993    /// The elements of `lhs` raised to the power of the elements of `rhs`.
994    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
995
996    /// Element-wise power with an IntTensor.
997    ///
998    /// # Arguments
999    ///
1000    /// * `lhs` - The left-hand side tensor.
1001    /// * `rhs` - The right-hand side floatTensor.
1002    ///
1003    /// # Returns
1004    ///
1005    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
1006    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
1007        let dtype = lhs.dtype();
1008        Self::float_powf(lhs, B::int_into_float(rhs, dtype.into()))
1009    }
1010
1011    /// Raises a tensor to the power of an int scalar.
1012    ///
1013    /// # Backend Implementors Note
1014    ///
1015    /// A number of common exponent cases can be implemented with operations
1016    /// which are much cheaper than generic exponentiation.
1017    ///
1018    /// This (`Backend` impl overridable) operation handles generic optimizations
1019    /// for several common integer exponent cases; and then dispatches to
1020    /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
1021    /// operation to handle the generic case.
1022    ///
1023    /// # Arguments
1024    ///
1025    /// * `lhs` - The left-hand side tensor.
1026    /// * `rhs` - The right-hand side scalar.
1027    ///
1028    /// # Returns
1029    ///
1030    /// The elements of `lhs` raised to the value of `rhs`.
1031    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1032        match rhs.elem::<i64>() {
1033            0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
1034            1 => lhs,
1035            2 => B::float_mul(lhs.clone(), lhs),
1036            -1 => Self::float_recip(lhs),
1037            -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
1038            _ => Self::float_powi_scalar_impl(lhs, rhs),
1039        }
1040    }
1041
1042    /// Raises a tensor to the power of an int scalar.
1043    ///
1044    /// # Backend Implementors Note
1045    ///
1046    /// This is the generic implementation of integer exponentiation
1047    /// called by [`Self::float_powi_scalar`] in the fallback case.
1048    ///
1049    /// As a general rule, this should not be called directly.
1050    ///
1051    /// # Arguments
1052    ///
1053    /// * `lhs` - The left-hand side tensor.
1054    /// * `rhs` - The right-hand side scalar.
1055    ///
1056    /// # Returns
1057    ///
1058    /// The elements of `lhs` raised to the value of `rhs`.
1059    fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1060        // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
1061        Self::float_powf_scalar_impl(lhs, rhs)
1062    }
1063
1064    /// Returns a new tensor with values raised to the power of float `value`.
1065    ///
1066    /// # Backend Implementors Note
1067    ///
1068    /// This (`Backend` impl overridable) operation dispatches integer exponentiation
1069    /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
1070    /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
1071    /// operation to handle the generic case.
1072    ///
1073    /// # Arguments
1074    ///
1075    /// * `tensor` - The tensor to exponentiate.
1076    /// * `value` - The exponent.
1077    ///
1078    /// # Returns
1079    ///
1080    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1081    fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {
1082        if let Some(exp) = value.try_as_integer() {
1083            Self::float_powi_scalar(tensor, exp)
1084        } else {
1085            Self::float_powf_scalar_impl(tensor, value)
1086        }
1087    }
1088
1089    /// Returns a new tensor with values raised to the power of float `value`.
1090    ///
1091    /// # Backend Implementors Note
1092    ///
1093    /// This is the generic implementation of integer exponentiation
1094    /// called by [`Self::float_powf_scalar`] in the fallback case.
1095    ///
1096    /// This is the minimal required support a `Backend` must implement
1097    /// for exponentiation.
1098    ///
1099    /// As a general rule, this should not be called directly.
1100    ///
1101    /// # Arguments
1102    ///
1103    /// * `tensor` - The tensor to exponentiate.
1104    /// * `value` - The exponent.
1105    ///
1106    /// # Returns
1107    ///
1108    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1109    fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;
1110
1111    /// Returns a new tensor with square root values.
1112    ///
1113    /// # Arguments
1114    ///
1115    /// * `tensor` - The tensor to take the square root of.
1116    ///
1117    /// # Returns
1118    ///
1119    /// A tensor with the same shape as `tensor` with square root values.
1120    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1121
1122    /// Returns a new tensor with absolute values.
1123    ///
1124    /// # Arguments
1125    ///
1126    /// * `tensor` - The tensor to take absolute value of.
1127    ///
1128    /// # Returns
1129    ///
1130    /// A tensor with the same shape as `tensor` with absolute values.
1131    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1132
1133    /// Returns a new tensor with cosine values.
1134    ///
1135    /// # Arguments
1136    ///
1137    /// * `tensor` - The tensor to take the cosine of.
1138    ///
1139    /// # Returns
1140    ///
1141    /// A tensor with the same shape as `tensor` with cosine values.
1142    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1143
1144    /// Returns a new tensor with sine values.
1145    ///
1146    /// # Arguments
1147    ///
1148    /// * `tensor` - The tensor to take the sine of.
1149    ///
1150    /// # Returns
1151    ///
1152    /// A tensor with the same shape as `tensor` with sine values.
1153    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1154
1155    /// Returns a new tensor with tangent values.
1156    ///
1157    /// # Arguments
1158    ///
1159    /// * `tensor` - The tensor to take the tangent of.
1160    ///
1161    /// # Returns
1162    ///
1163    /// A tensor with the same shape as `tensor` with tangent values.
1164    fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1165
1166    /// Returns a new tensor with hyperbolic cosine values.
1167    ///
1168    /// # Arguments
1169    ///
1170    /// * `tensor` - The tensor to take the hyperbolic cosine of.
1171    ///
1172    /// # Returns
1173    ///
1174    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1175    fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1176
1177    /// Returns a new tensor with hyperbolic sine values.
1178    ///
1179    /// # Arguments
1180    ///
1181    /// * `tensor` - The tensor to take the hyperbolic sine of.
1182    ///
1183    /// # Returns
1184    ///
1185    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1186    fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1187
1188    /// Returns a new tensor with hyperbolic tangent values.
1189    ///
1190    /// # Arguments
1191    ///
1192    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1193    ///
1194    /// # Returns
1195    ///
1196    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1197    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1198
1199    /// Returns a new tensor with inverse cosine values.
1200    ///
1201    /// # Arguments
1202    ///
1203    /// * `tensor` - The input tensor.
1204    ///
1205    /// # Returns
1206    ///
1207    /// A tensor with the same shape as `tensor` with inverse cosine values.
1208    fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1209
1210    /// Returns a new tensor with inverse hyperbolic cosine values.
1211    ///
1212    /// # Arguments
1213    ///
1214    /// * `tensor` - The input tensor.
1215    ///
1216    /// # Returns
1217    ///
1218    /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
1219    fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1220
1221    /// Returns a new tensor with inverse sine values.
1222    ///
1223    /// # Arguments
1224    ///
1225    /// * `tensor` - The input tensor.
1226    ///
1227    /// # Returns
1228    ///
1229    /// A tensor with the same shape as `tensor` with inverse sine values.
1230    fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1231
1232    /// Returns a new tensor with inverse hyperbolic sine values.
1233    ///
1234    /// # Arguments
1235    ///
1236    /// * `tensor` - The input tensor.
1237    ///
1238    /// # Returns
1239    ///
1240    /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1241    fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1242
1243    /// Returns a new tensor with the inverse tangent values.
1244    ///
1245    /// # Arguments
1246    ///
1247    /// * `tensor` - The input tensor.
1248    ///
1249    /// # Returns
1250    ///
1251    /// A tensor with the same shape as `tensor` with the inverse tangent values.
1252    fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1253
1254    /// Returns a new tensor with the inverse hyperbolic tangent values.
1255    ///
1256    /// # Arguments
1257    ///
1258    /// * `tensor` - The input tensor.
1259    ///
1260    /// # Returns
1261    ///
1262    /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.
1263    fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1264
1265    /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1266    ///
1267    /// # Arguments
1268    ///
1269    /// * `lhs` - The tensor with y coordinates.
1270    /// * `rhs` - The tensor with x coordinates.
1271    ///
1272    /// # Returns
1273    ///
1274    /// A tensor with the four-quadrant inverse tangent values.
1275    fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
1276
1277    /// Returns a new tensor with rounded values.
1278    ///
1279    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1280    /// strategy, with halfway cases rounded to the nearest even integer value.
1281    ///
1282    /// # Arguments
1283    ///
1284    /// * `tensor` - The tensor to be rounded.
1285    ///
1286    /// # Returns
1287    ///
1288    /// A tensor with the same shape as `tensor` with rounded values.
1289    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1290
1291    /// Returns a new tensor with floored values.
1292    ///
1293    /// # Arguments
1294    ///
1295    /// * `tensor` - The tensor to be floored.
1296    ///
1297    /// # Returns
1298    ///
1299    /// A tensor with the same shape as `tensor` with floored values.
1300    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1301
1302    /// Returns a new tensor with ceiled values.
1303    ///
1304    /// # Arguments
1305    ///
1306    /// * `tensor` - The tensor to be ceiled.
1307    ///
1308    /// # Returns
1309    ///
1310    /// A tensor with the same shape as `tensor` with ceiled values.
1311    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1312
1313    /// Returns a new tensor with truncated values.
1314    ///
1315    /// # Arguments
1316    ///
1317    /// * `tensor` - The tensor to be truncated.
1318    ///
1319    /// # Returns
1320    ///
1321    /// A tensor with the same shape as `tensor` with truncated values.
1322    fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1323
1324    /// Returns a new tensor with the error function values.
1325    ///
1326    /// # Arguments
1327    ///
1328    /// * `tensor` - The tensor to take the error function of.
1329    ///
1330    /// # Returns
1331    ///
1332    /// A tensor with the same shape as `tensor` with error function values.
1333    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1334
1335    /// Concatenates tensors along a dimension.
1336    ///
1337    /// # Arguments
1338    ///
1339    /// * `tensors` - The tensors to concatenate.
1340    /// * `dim` - The dimension along which to concatenate.
1341    ///
1342    /// # Returns
1343    ///
1344    /// A tensor with the concatenated tensors along `dim`.
1345    ///
1346    /// # Note
1347    ///
1348    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1349    /// high-level tensor API and will not be passed to this method. Backend implementations do
1350    /// not need to handle empty tensors.
1351    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1352        cat_with_slice_assign::<B, Float>(
1353            tensors.into_iter().map(TensorPrimitive::Float).collect(),
1354            dim,
1355        )
1356        .tensor()
1357    }
1358
1359    /// Gets the indices of the maximum elements of a tensor along an axis.
1360    ///
1361    /// # Arguments
1362    ///
1363    /// * `tensor` - The tensor to get the maximum elements of.
1364    /// * `dim` - The dimension along which to get the maximum elements.
1365    /// * `out_dtype` - The output tensor dtype.
1366    ///
1367    /// # Returns
1368    ///
1369    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1370    fn float_argmax(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1371
1372    /// Gets the indices of the k maximum elements of a tensor along an axis.
1373    /// if two elements are equals, it will be ordered by lowest indices
1374    ///
1375    /// # Arguments
1376    ///
1377    /// * `tensor` - The tensor to get the maximum elements of.
1378    /// * `dim` - The dimension along which to get the maximum elements.
1379    /// * `k` - number of maximum elements
1380    /// * `out_dtype` - The output tensor dtype.
1381    ///
1382    /// # Returns
1383    ///
1384    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1385    fn float_argtopk(
1386        tensor: FloatTensor<B>,
1387        dim: usize,
1388        k: usize,
1389        out_dtype: IntDType,
1390    ) -> IntTensor<B>;
1391
1392    /// Gets the values of the k maximum elements of a tensor along an axis.
1393    ///
1394    /// # Arguments
1395    ///
1396    /// * `tensor` - The tensor to get the maximum elements of.
1397    /// * `dim` - The dimension along which to get the maximum elements.
1398    /// * `k` - number of maximum elements
1399    /// * `out_dtype` - The output tensor dtype.
1400    ///
1401    /// # Returns
1402    ///
1403    /// A tensor with the values of the maximum elements of `tensor` along `dim`.
1404    fn float_topk(tensor: FloatTensor<B>, dim: usize, k: usize) -> FloatTensor<B>;
1405
1406    /// Gets the indices of the minimum elements of a tensor along an axis.
1407    ///
1408    /// # Arguments
1409    ///
1410    /// * `tensor` - The tensor to get the minimum elements of.
1411    /// * `dim` - The dimension along which to get the minimum elements.
1412    /// * `out_dtype` - The output tensor dtype.
1413    ///
1414    /// # Returns
1415    ///
1416    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1417    fn float_argmin(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1418
1419    /// Gets the maximum element of a tensor.
1420    ///
1421    /// # Arguments
1422    ///
1423    /// * `tensor` - The tensor to get the maximum elements of.
1424    ///
1425    /// # Returns
1426    ///
1427    /// A tensor with the maximum element of `tensor`.
1428    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1429        let shape = tensor.shape();
1430        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1431
1432        B::float_max_dim(tensor, 0)
1433    }
1434
1435    /// Gets the maximum elements of a tensor along an axis.
1436    ///
1437    /// # Arguments
1438    ///
1439    /// * `tensor` - The tensor to get the maximum elements of.
1440    /// * `dim` - The dimension along which to get the maximum elements.
1441    ///
1442    /// # Returns
1443    ///
1444    /// A tensor with the maximum elements of `tensor` along `dim`.
1445    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1446        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1447        let index = B::float_argmax(tensor.clone(), dim, dtype);
1448
1449        B::float_gather(dim, tensor, index)
1450    }
1451
1452    /// Gets the maximum elements of a tensor along an axis and their indices.
1453    ///
1454    /// # Arguments
1455    ///
1456    /// * `tensor` - The tensor to get the maximum elements of.
1457    /// * `dim` - The dimension along which to get the maximum elements.
1458    /// * `indices_dtype` - The indices tensor dtype.
1459    ///
1460    /// # Returns
1461    ///
1462    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1463    fn float_max_dim_with_indices(
1464        tensor: FloatTensor<B>,
1465        dim: usize,
1466        indices_dtype: IntDType,
1467    ) -> (FloatTensor<B>, IntTensor<B>) {
1468        let index = B::float_argmax(tensor.clone(), dim, indices_dtype);
1469        let values = B::float_gather(dim, tensor, index.clone());
1470
1471        (values, index)
1472    }
1473
1474    /// Gets the minimum element of a tensor.
1475    ///
1476    /// # Arguments
1477    ///
1478    /// * `tensor` - The tensor to get the minimum elements of.
1479    ///
1480    /// # Returns
1481    ///
1482    /// A tensor with the minimum element of `tensor`.
1483    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1484        let shape = tensor.shape();
1485        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1486
1487        B::float_min_dim(tensor, 0)
1488    }
1489
1490    /// Gets the minimum elements of a tensor along an axis.
1491    ///
1492    /// # Arguments
1493    ///
1494    /// * `tensor` - The tensor to get the minimum elements of.
1495    /// * `dim` - The dimension along which to get the minimum elements.
1496    ///
1497    /// # Returns
1498    ///
1499    /// A tensor with the minimum elements of `tensor` along `dim`.
1500    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1501        let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1502        let index = B::float_argmin(tensor.clone(), dim, dtype);
1503
1504        B::float_gather(dim, tensor, index)
1505    }
1506
1507    /// Gets the minimum elements of a tensor along an axis and their indices.
1508    ///
1509    /// # Arguments
1510    ///
1511    /// * `tensor` - The tensor to get the minimum elements of.
1512    /// * `dim` - The dimension along which to get the minimum elements.
1513    /// * `indices_dtype` - The indices tensor dtype.
1514    ///
1515    /// # Returns
1516    ///
1517    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1518    fn float_min_dim_with_indices(
1519        tensor: FloatTensor<B>,
1520        dim: usize,
1521        indices_dtype: IntDType,
1522    ) -> (FloatTensor<B>, IntTensor<B>) {
1523        let index = B::float_argmin(tensor.clone(), dim, indices_dtype);
1524        let values = B::float_gather(dim, tensor, index.clone());
1525
1526        (values, index)
1527    }
1528
1529    /// Gets the maximum absolute element of a tensor.
1530    ///
1531    /// # Arguments
1532    ///
1533    /// * `tensor` - The tensor to get the maximum elements of.
1534    ///
1535    /// # Returns
1536    ///
1537    /// A tensor with the maximum element of `tensor`.
1538    fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1539        let shape = tensor.shape();
1540        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1541
1542        B::float_max_abs_dim(tensor, 0)
1543    }
1544
1545    /// Gets the maximum absolute elements of a tensor along an axis.
1546    ///
1547    /// # Arguments
1548    ///
1549    /// * `tensor` - The tensor to get the maximum elements of.
1550    /// * `dim` - The dimension along which to get the maximum elements.
1551    ///
1552    /// # Returns
1553    ///
1554    /// A tensor with the maximum elements of `tensor` along `dim`.
1555    fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1556        B::float_max_dim(B::float_abs(tensor), dim)
1557    }
1558
1559    /// Tests if any element in the float `tensor` evaluates to True.
1560    ///
1561    /// # Arguments
1562    ///
1563    /// * `tensor` - The tensor to test.
1564    /// * `out_dtype` - The output tensor dtype.
1565    ///
1566    /// # Returns
1567    ///
1568    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1569    fn float_any(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1570        let float_dtype = tensor.dtype();
1571        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1572        let bool_tensor = B::bool_not(bool_tensor);
1573        let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1574        B::float_greater_elem(sum, 0f32.into(), out_dtype)
1575    }
1576
1577    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1578    ///
1579    /// # Arguments
1580    ///
1581    /// * `tensor` - The tensor to test.
1582    /// * `dim` - The axis along which to test.
1583    /// * `out_dtype` - The output tensor dtype.
1584    ///
1585    /// # Returns
1586    ///
1587    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1588    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1589    /// input evaluates to True, False otherwise.
1590    fn float_any_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1591        let float_dtype = tensor.dtype();
1592        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1593        let bool_tensor = B::bool_not(bool_tensor);
1594        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1595        B::float_greater_elem(sum, 0f32.into(), out_dtype)
1596    }
1597
1598    /// Tests if all elements in the float `tensor` evaluate to True.
1599    ///
1600    /// # Arguments
1601    ///
1602    /// * `tensor` - The tensor to test.
1603    /// * `out_dtype` - The output tensor dtype.
1604    ///
1605    /// # Returns
1606    ///
1607    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1608    /// evaluate to True, False otherwise.
1609    fn float_all(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1610        let float_dtype = tensor.dtype();
1611        let num_elems = tensor.shape().num_elements() as f32;
1612        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1613        let bool_tensor = B::bool_not(bool_tensor);
1614        let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1615        B::float_equal_elem(sum, num_elems.into(), out_dtype)
1616    }
1617
1618    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1619    ///
1620    /// # Arguments
1621    ///
1622    /// * `tensor` - The tensor to test.
1623    /// * `dim` - The axis along which to test.
1624    /// * `out_dtype` - The output tensor dtype.
1625    ///
1626    /// # Returns
1627    ///
1628    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1629    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1630    /// evaluates to True, False otherwise.
1631    fn float_all_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1632        let float_dtype = tensor.dtype();
1633        let num_elems = tensor.shape()[dim] as f32;
1634        let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1635        let bool_tensor = B::bool_not(bool_tensor);
1636        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1637        B::float_equal_elem(sum, num_elems.into(), out_dtype)
1638    }
1639
1640    /// Returns the signs of the float `tensor`.
1641    ///
1642    /// # Arguments
1643    ///
1644    /// * `tensor` - The tensor to extract the signs from.
1645    ///
1646    /// # Returns
1647    ///
1648    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1649    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1650        let device = B::float_device(&tensor);
1651        let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
1652        let zeros = B::float_zeros(tensor.shape(), &device, tensor.dtype().into());
1653        let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
1654        let greater_than_zero = B::float_greater_elem(tensor, 0f32.into(), bool_dtype);
1655
1656        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());
1657        result = B::float_mask_fill(result, greater_than_zero, 1f32.into());
1658        result
1659    }
1660
1661    /// Broadcasts the float `tensor` to the given `shape`.
1662    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1663
1664    /// Sort the elements of the input `tensor` by value in along a given dimension.
1665    ///
1666    /// This sort is unstable (i.e., may reorder equal elements).
1667    ///
1668    /// # Arguments
1669    ///
1670    /// * `tensor` - The input tensor.
1671    /// * `dim` - The axis along which to sort.
1672    /// * `descending` - The sorting order.
1673    ///
1674    /// # Returns
1675    ///
1676    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1677    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1678        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1679    }
1680
1681    /// Sort the elements of the input `tensor` by value in along a given dimension.
1682    ///
1683    /// This sort is unstable (i.e., may reorder equal elements).
1684    ///
1685    /// # Arguments
1686    ///
1687    /// * `tensor` - The input tensor.
1688    /// * `dim` - The axis along which to sort.
1689    /// * `descending` - The sorting order.
1690    /// * `indices_dtype` - The indices tensor dtype.
1691    ///
1692    /// # Returns
1693    ///
1694    /// A tensor with the same shape as the input tensor and corresponding indices, where
1695    /// the elements are sorted by value and the indices map back to the original input tensor.
1696    fn float_sort_with_indices(
1697        tensor: FloatTensor<B>,
1698        dim: usize,
1699        descending: bool,
1700        indices_dtype: IntDType,
1701    ) -> (FloatTensor<B>, IntTensor<B>) {
1702        let (values, indices) = sort_with_indices::<B, Float>(
1703            TensorPrimitive::Float(tensor),
1704            dim,
1705            descending,
1706            indices_dtype,
1707        );
1708        (values.tensor(), indices)
1709    }
1710
1711    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1712    ///
1713    /// This sort is unstable (i.e., may reorder equal elements).
1714    ///
1715    /// # Arguments
1716    ///
1717    /// * `tensor` - The input tensor.
1718    /// * `dim` - The axis along which to sort.
1719    /// * `descending` - The sorting order.
1720    /// * `out_dtype` - The output tensor dtype.
1721    ///
1722    /// # Returns
1723    ///
1724    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1725    fn float_argsort(
1726        tensor: FloatTensor<B>,
1727        dim: usize,
1728        descending: bool,
1729        out_dtype: IntDType,
1730    ) -> IntTensor<B> {
1731        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending, out_dtype)
1732    }
1733
1734    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1735    /// using the given locations in [-1, 1].
1736    ///
1737    /// # Arguments
1738    ///
1739    /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
1740    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1741    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1742    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1743    ///
1744    /// # Returns
1745    ///
1746    /// A tensor with shape (N, C, H_out, W_out)
1747    fn float_grid_sample_2d(
1748        tensor: FloatTensor<B>,
1749        grid: FloatTensor<B>,
1750        options: GridSampleOptions,
1751    ) -> FloatTensor<B> {
1752        // TODO: default impl should get int default dtype
1753        float_grid_sample_2d_ref::<B>(tensor, grid, options)
1754    }
1755
1756    /// Unfold windows along a dimension.
1757    ///
1758    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1759    /// where windows are advanced by `step` at each index.
1760    ///
1761    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1762    ///
1763    /// # Arguments
1764    ///
1765    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1766    /// * `dim` - the selected dim.
1767    /// * `size` - the size of each unfolded window.
1768    /// * `step` - the step between each window.
1769    ///
1770    /// # Returns
1771    ///
1772    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1773    fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1774    -> FloatTensor<B>;
1775
1776    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1777    ///
1778    /// # Returns
1779    ///
1780    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1781    fn float_is_nan(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1782        // Check if the input tensor is NaN by comparing it to itself
1783        // NaN is the only value that is not equal to itself
1784        B::float_not_equal(tensor.clone(), tensor, out_dtype)
1785    }
1786
1787    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1788    ///
1789    /// # Returns
1790    ///
1791    /// A boolean tensor where `true` indicates that the value is infinite
1792    fn float_is_inf(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1793        B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into(), out_dtype)
1794    }
1795}