burn_backend/backend/ops/
tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatElem, FloatTensor, IntElem, IntTensor};
7use crate::{Backend, Distribution, TensorData, element::ElementConversion};
8use crate::{ExecutionError, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{FloatDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14    /// Creates a new tensor from the data structure.
15    ///
16    /// # Arguments
17    ///
18    /// * `data` - The data structure.
19    /// * `device` - The device to create the tensor on.
20    ///
21    /// # Returns
22    ///
23    /// The tensor with the given data.
24    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26    /// Creates a new tensor with random values.
27    ///
28    /// # Arguments
29    ///
30    /// * `shape` - The shape of the tensor.
31    /// * `distribution` - The distribution to sample from.
32    /// * `device` - The device to create the tensor on.
33    ///
34    /// # Returns
35    ///
36    /// The tensor with the given shape and random values.
37    fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
38    -> FloatTensor<B>;
39
40    /// Creates a new tensor with zeros.
41    ///
42    /// # Arguments
43    ///
44    /// * `shape` - The shape of the tensor.
45    /// * `device` - The device to create the tensor on.
46    /// * `dtype` - The target data type.
47    ///
48    /// # Returns
49    ///
50    /// The tensor with the given shape and zeros.
51    fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
52        Self::float_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
53    }
54
55    /// Creates a new tensor with ones.
56    ///
57    /// # Arguments
58    ///
59    /// * `shape` - The shape of the tensor.
60    /// * `device` - The device to create the tensor on.
61    /// * `dtype` - The target data type.
62    ///
63    /// # Returns
64    ///
65    /// The tensor with the given shape and ones.
66    fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
67        Self::float_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
68    }
69
70    /// Creates a tensor filled with given value.
71    ///
72    /// # Arguments
73    ///
74    /// * `shape` - The shape of the tensor.
75    /// * `fill_value` - The value with which to fill the tensor.
76    /// * `device` - The device to create the tensor on.
77    /// * `dtype` - The target data type.
78    ///
79    /// # Returns
80    ///
81    /// The tensor filled with given value
82    fn float_full(
83        shape: Shape,
84        fill_value: FloatElem<B>,
85        device: &Device<B>,
86        dtype: FloatDType,
87    ) -> FloatTensor<B> {
88        Self::float_from_data(
89            TensorData::full_dtype(shape, fill_value, dtype.into()),
90            device,
91        )
92    }
93
94    /// Converts the tensor to a data structure.
95    ///
96    /// # Arguments
97    ///
98    /// * `tensor` - The tensor.
99    ///
100    /// # Returns
101    ///
102    /// The data structure with the tensor's data.
103    fn float_into_data(
104        tensor: FloatTensor<B>,
105    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
106
107    /// Gets the device of the tensor.
108    ///
109    /// # Arguments
110    ///
111    /// * `tensor` - The tensor.
112    ///
113    /// # Returns
114    ///
115    /// The device of the tensor.
116    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
117
118    /// Moves the tensor to the given device.
119    ///
120    /// # Arguments
121    ///
122    /// * `tensor` - The tensor.
123    /// * `device` - The device to move the tensor to.
124    ///
125    /// # Returns
126    ///
127    /// The tensor on the given device.
128    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
129
130    /// Converts float tensor to int tensor.
131    ///
132    /// # Arguments
133    ///
134    /// * `tensor` - The tensor.
135    ///
136    /// # Returns
137    ///
138    /// The int tensor with the same data as the float tensor.
139    fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
140
141    /// Creates an empty tensor with the given shape.
142    ///
143    /// # Arguments
144    ///
145    /// * `shape` - The shape of the tensor.
146    /// * `device` - The device to create the tensor on.
147    /// * `dtype` - The target data type.
148    ///
149    /// # Returns
150    ///
151    /// The empty tensor with the given shape.
152    fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
153
154    /// Repeat the tensor along the given dimension.
155    ///
156    /// # Arguments
157    ///
158    /// * `tensor` - The tensor.
159    /// * `dim` - The dimension to repeat.
160    /// * `times` - The number of times to repeat the dimension.
161    ///
162    /// # Returns
163    ///
164    /// The tensor with the given dimension repeated.
165    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
166        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
167    }
168
169    /// Adds two tensors together.
170    ///
171    /// # Arguments
172    ///
173    /// * `lhs` - The left-hand side tensor.
174    /// * `rhs` - The right-hand side tensor.
175    ///
176    /// # Returns
177    ///
178    /// The result of adding the two tensors together.
179    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
180
181    /// Adds a scalar to a tensor.
182    ///
183    /// # Arguments
184    ///
185    /// * `lhs` - The left-hand side tensor.
186    /// * `rhs` - The right-hand side scalar.
187    ///
188    /// # Returns
189    ///
190    /// The result of adding the scalar to the tensor.
191    fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
192
193    /// Clamps a tensor under a minimum value.
194    ///
195    /// # Arguments
196    ///
197    /// * `tensor` - The tensor to clamp.
198    /// * `min` - The minimum value.
199    ///
200    /// # Returns
201    ///
202    /// The clamped tensor.
203    fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
204        // Default implementation
205        let mask = Self::float_lower_elem(tensor.clone(), min);
206        B::float_mask_fill(tensor, mask, min)
207    }
208
209    /// Clamps a tensor over a maximum value.
210    ///
211    /// # Arguments
212    ///
213    /// * `tensor` - The tensor to clamp.
214    /// * `max` - The maximum value.
215    ///
216    /// # Returns
217    ///
218    /// The clamped tensor.
219    fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
220        // Default implementation
221        let mask = Self::float_greater_elem(tensor.clone(), max);
222        B::float_mask_fill(tensor, mask, max)
223    }
224
225    /// Clamps a tensor between a minimum and maximum value.
226    ///
227    /// # Arguments
228    ///
229    /// * `tensor` - The tensor to clamp.
230    /// * `min` - The minimum value.
231    /// * `max` - The maximum value.
232    ///
233    /// # Returns
234    ///
235    /// The clamped tensor.
236    fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
237        // Default implementation
238        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
239    }
240
241    /// Subtracts two tensors.
242    ///
243    /// # Arguments
244    ///
245    /// * `lhs` - The left-hand side tensor.
246    /// * `rhs` - The right-hand side tensor.
247    ///
248    /// # Returns
249    ///
250    /// The result of subtracting the two tensors.
251    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
252
253    /// Subtracts a scalar from a tensor.
254    ///
255    /// # Arguments
256    ///
257    /// * `lhs` - The left-hand side tensor.
258    /// * `rhs` - The right-hand side scalar.
259    ///
260    /// # Returns
261    ///
262    /// The result of subtracting the scalar from the tensor.
263    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
264
265    /// Multiplies two tensors together element-wise.
266    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
267
268    /// Multiplies a tensor by a scalar.
269    ///
270    /// # Arguments
271    ///
272    /// * `lhs` - The left-hand side tensor.
273    /// * `rhs` - The right-hand side scalar.
274    ///
275    /// # Returns
276    ///
277    /// The result of multiplying the tensor by the scalar.
278    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
279
280    /// Divides two tensors element-wise.
281    ///
282    /// # Arguments
283    ///
284    /// * `lhs` - The left-hand side tensor.
285    /// * `rhs` - The right-hand side tensor.
286    ///
287    /// # Returns
288    ///
289    /// The result of dividing the two tensors.
290    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
291
292    /// Divides a tensor by a scalar.
293    ///
294    /// # Arguments
295    ///
296    /// * `lhs` - The left-hand side tensor.
297    /// * `rhs` - The right-hand side scalar.
298    ///
299    /// # Returns
300    ///
301    /// The result of dividing the tensor by the scalar.
302    fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
303
304    /// Computes the remainder of division between two tensors element-wise.
305    ///
306    /// # Arguments
307    ///
308    /// * `lhs` - The left-hand side tensor.
309    /// * `rhs` - The right-hand side tensor.
310    ///
311    /// # Returns
312    ///
313    /// The element-wise remainder when dividing `lhs` by `rhs`.
314    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
315
316    /// Computes the modulus of a tensor given a scalar.
317    ///
318    /// # Arguments
319    /// * `lhs` - The left-hand side tensor.
320    /// * `rhs` - The right-hand side scalar.
321    ///
322    /// # Returns
323    ///
324    /// The result of applying the modulus of the scalar to the tensor.
325    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
326
327    /// Multiplies two tensors together using matrix multiplication.
328    ///
329    /// # Arguments
330    ///
331    /// * `lhs` - The left-hand side tensor.
332    /// * `rhs` - The right-hand side tensor.
333    ///
334    /// # Returns
335    ///
336    /// The result of multiplying the two tensors together using matrix multiplication.
337    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
338
339    /// Computes the cross product of two tensors along a given dimension.
340    ///
341    /// # Arguments
342    ///
343    /// * `lhs` - The left-hand side tensor.
344    /// * `rhs` - The right-hand side tensor.
345    /// * `dim` - The dimension to compute the cross product along.
346    ///
347    /// # Returns
348    ///
349    /// The cross product of the two tensors.
350    fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
351
352    /// Negates a tensor element-wise.
353    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
354        Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
355    }
356
357    /// Calculates the reciprocals element-wise
358    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
359
360    /// Transposes a tensor.
361    ///
362    /// # Arguments
363    ///
364    /// * `tensor` - The tensor to transpose.
365    ///
366    /// # Returns
367    ///
368    /// The transposed tensor.
369    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
370        let ndims = tensor.shape().num_dims();
371        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
372    }
373
374    /// Swaps two dimensions of a tensor.
375    ///
376    /// # Arguments
377    ///
378    /// * `tensor` - The tensor to swap the dimensions of.
379    /// * `dim1` - The first dimension to swap.
380    /// * `dim2` - The second dimension to swap.
381    ///
382    /// # Returns
383    ///
384    /// The tensor with the dimensions swapped.
385    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
386
387    /// Permutes the dimensions of a tensor.
388    ///
389    /// # Arguments
390    ///
391    /// * `tensor` - The tensor to permute the dimensions of.
392    /// * `axes` - The new order of the dimensions.
393    /// # Returns
394    ///
395    /// The tensor with the dimensions permuted.
396    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
397
398    /// Reverse the order of elements in a tensor along the given axes.
399    ///
400    /// # Arguments
401    ///
402    /// * `tensor` - The tensor to reverse.
403    /// * `axes` - The axes to reverse.
404    ///
405    /// The tensor with the elements reversed.
406    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
407
408    /// Reshapes a tensor.
409    ///
410    /// # Arguments
411    ///
412    /// * `tensor` - The tensor to reshape.
413    /// * `shape` - The new shape of the tensor.
414    ///
415    /// # Returns
416    ///
417    /// The tensor with the new shape.
418    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
419
420    /// Gather elements from a tensor.
421    ///
422    /// # Arguments
423    ///
424    /// * `dim` - The dimension to gather from.
425    /// * `tensor` - The tensor to gather from.
426    /// * `indices` - The indices to gather.
427    ///
428    /// # Returns
429    ///
430    /// The gathered elements.
431    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
432
433    /// Scatter elements into a tensor using sum reduction.
434    ///
435    /// # Arguments
436    ///
437    /// * `dim` - The dimension to scatter into.
438    /// * `tensor` - The tensor to scatter into.
439    /// * `indices` - The indices to scatter into.
440    /// * `value` - The value to scatter.
441    ///
442    /// # Returns
443    ///
444    /// The tensor with the scattered elements.
445    fn float_scatter_add(
446        dim: usize,
447        tensor: FloatTensor<B>,
448        indices: IntTensor<B>,
449        value: FloatTensor<B>,
450    ) -> FloatTensor<B>;
451
452    /// Select tensor elements along the given dimension corresponding for the given indices.
453    ///
454    /// # Arguments
455    ///
456    /// * `tensor` - The tensor to select from.
457    /// * `dim` - The dimension to select from.
458    /// * `indices` - The indices to select.
459    ///
460    /// # Returns
461    ///
462    /// The selected elements.
463    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
464
465    /// Assign the selected elements along the given dimension corresponding for the given indices
466    /// to the given value using sum reduction.
467    ///
468    /// # Arguments
469    ///
470    /// * `tensor` - The tensor to select from.
471    /// * `dim` - The dimension to select from.
472    /// * `indices` - The indices to select.
473    /// * `value` - The value to assign.
474    ///
475    /// # Returns
476    ///
477    /// The tensor with the selected elements assigned to the given value.
478    fn float_select_add(
479        tensor: FloatTensor<B>,
480        dim: usize,
481        indices: IntTensor<B>,
482        value: FloatTensor<B>,
483    ) -> FloatTensor<B>;
484
485    /// Select tensor elements corresponding to the given slices.
486    ///
487    /// # Arguments
488    ///
489    /// * `tensor` - The tensor to select from.
490    /// * `slices` - The slices specifying ranges and steps for each dimension.
491    ///
492    /// # Returns
493    ///
494    /// The selected elements in a new tensor.
495    ///
496    /// # Note
497    ///
498    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
499    /// be passed to this method. Backend implementations do not need to handle empty slices.
500    fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
501
502    /// Assign the selected elements corresponding to the given slices to the given value.
503    ///
504    /// # Arguments
505    ///
506    /// * `tensor` - The tensor to select from.
507    /// * `ranges` - The ranges to select.
508    /// * `value` - The value to assign.
509    ///
510    /// # Returns
511    ///
512    /// The tensor with the selected elements assigned to the given value.
513    ///
514    /// # Note
515    ///
516    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
517    /// high-level tensor API and will not be passed to this method. Backend implementations do
518    /// not need to handle empty slice assignments.
519    fn float_slice_assign(
520        tensor: FloatTensor<B>,
521        slices: &[Slice],
522        value: FloatTensor<B>,
523    ) -> FloatTensor<B>;
524
525    /// Update the given tensor with the value tensor where the mask is true.
526    ///
527    /// # Arguments
528    ///
529    /// * `tensor` - The tensor to select from.
530    /// * `mask` - The boolean mask to select with.
531    /// * `value` - The value to assign to the selected elements from the value tensor.
532    ///
533    /// # Returns
534    ///
535    /// The tensor with the selected elements assigned to the given value.
536    fn float_mask_where(
537        tensor: FloatTensor<B>,
538        mask: BoolTensor<B>,
539        value: FloatTensor<B>,
540    ) -> FloatTensor<B>;
541
542    /// Update the given tensor with the value where the mask is true.
543    ///
544    /// # Arguments
545    ///
546    /// * `tensor` - The tensor to select from.
547    /// * `mask` - The boolean mask to select with.
548    /// * `value` - The value to assign to the selected elements.
549    ///
550    /// # Returns
551    ///
552    /// The tensor with the selected elements assigned to the given value.
553    fn float_mask_fill(
554        tensor: FloatTensor<B>,
555        mask: BoolTensor<B>,
556        value: FloatElem<B>,
557    ) -> FloatTensor<B>;
558
559    /// Equal comparison of two tensors.
560    ///
561    /// # Arguments
562    ///
563    /// * `lhs` - The left-hand side tensor.
564    /// * `rhs` - The right-hand side tensor.
565    ///
566    /// # Returns
567    ///
568    /// A boolean tensor with the result of the comparison.
569    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
570
571    /// Element-wise non-equality comparison.
572    ///
573    /// # Arguments
574    ///
575    /// * `lhs` - The left-hand side tensor.
576    /// * `rhs` - The right-hand side tensor.
577    ///
578    /// # Returns
579    ///
580    /// A boolean tensor with the result of the comparison.
581    fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
582        let equal_tensor = B::float_equal(lhs, rhs);
583        B::bool_not(equal_tensor)
584    }
585
586    /// Equal comparison of a tensor and a scalar.
587    ///
588    /// # Arguments
589    ///
590    /// * `lhs` - The left-hand side tensor.
591    /// * `rhs` - The right-hand side scalar.
592    ///
593    /// # Returns
594    ///
595    /// A boolean tensor with the result of the comparison.
596    fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
597
598    /// Element-wise non-equality comparison with a scalar.
599    ///
600    /// # Arguments
601    ///
602    /// * `lhs` - The left-hand side tensor.
603    /// * `rhs` - The right-hand side scalar.
604    ///
605    /// # Returns
606    ///
607    /// A boolean tensor with the result of the comparison.
608    fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
609        let equal_tensor = B::float_equal_elem(lhs, rhs);
610        B::bool_not(equal_tensor)
611    }
612
613    /// Greater than comparison of two tensors.
614    ///
615    /// # Arguments
616    ///
617    /// * `lhs` - The left-hand side tensor.
618    /// * `rhs` - The right-hand side tensor.
619    ///
620    /// # Returns
621    ///
622    /// A boolean tensor with the result of the comparison.
623    fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
624
625    /// Greater than comparison of a tensor and a scalar.
626    ///
627    /// # Arguments
628    ///
629    /// * `lhs` - The left-hand side tensor.
630    /// * `rhs` - The right-hand side scalar.
631    ///
632    /// # Returns
633    ///
634    /// A boolean tensor with the result of the comparison.
635    fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
636
637    /// Greater than or equal comparison of two tensors.
638    ///
639    /// # Arguments
640    ///
641    /// * `lhs` - The left-hand side tensor.
642    /// * `rhs` - The right-hand side tensor.
643    ///
644    /// # Returns
645    ///
646    /// A boolean tensor with the result of the comparison.
647    fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
648
649    /// Greater than or equal comparison of a tensor and a scalar.
650    ///
651    /// # Arguments
652    ///
653    /// * `lhs` - The left-hand side tensor.
654    /// * `rhs` - The right-hand side scalar.
655    ///
656    /// # Returns
657    ///
658    /// A boolean tensor with the result of the comparison.
659    fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
660
661    /// Less than comparison of two tensors.
662    ///
663    /// # Arguments
664    ///
665    /// * `lhs` - The left-hand side tensor.
666    /// * `rhs` - The right-hand side tensor.
667    ///
668    /// # Returns
669    ///
670    /// A boolean tensor with the result of the comparison.
671    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
672
673    /// Less than comparison of a tensor and a scalar.
674    ///
675    /// # Arguments
676    ///
677    /// * `lhs` - The left-hand side tensor.
678    /// * `rhs` - The right-hand side scalar.
679    ///
680    /// # Returns
681    ///
682    /// A boolean tensor with the result of the comparison.
683    fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
684
685    /// Less than or equal comparison of two tensors.
686    ///
687    /// # Arguments
688    ///
689    /// * `lhs` - The left-hand side tensor.
690    /// * `rhs` - The right-hand side tensor.
691    ///
692    /// # Returns
693    ///
694    /// A boolean tensor with the result of the comparison.
695    fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
696
697    /// Less than or equal comparison of a tensor and a scalar.
698    ///
699    /// # Arguments
700    ///
701    /// * `lhs` - The left-hand side tensor.
702    /// * `rhs` - The right-hand side scalar.
703    ///
704    /// # Returns
705    ///
706    /// A boolean tensor with the result of the comparison.
707    fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
708
709    /// Detaches a tensor from the computation graph.
710    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
711        // Should only be overridden by autodiff backends.
712        tensor
713    }
714
715    /// Sets the `require_grad` flag of a tensor.
716    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
717        // Should only be overridden by autodiff backends.
718        tensor
719    }
720
721    /// Returns the `require_grad` flag of a tensor.
722    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
723        // Should only be overridden by autodiff backends.
724        false
725    }
726
727    /// Sum of all elements in a tensor.
728    ///
729    /// # Arguments
730    ///
731    /// * `tensor` - The tensor to sum.
732    ///
733    /// # Returns
734    ///
735    /// A scalar tensor with the sum of all elements in `tensor`.
736    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
737
738    /// Sum of all elements in a tensor along a dimension.
739    ///
740    /// # Arguments
741    ///
742    /// * `tensor` - The tensor to sum.
743    /// * `dim` - The dimension along which to sum.
744    ///
745    /// # Returns
746    ///
747    /// A tensor with the sum of all elements in `tensor` along `dim`.
748    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
749
750    /// Product of all elements in a tensor.
751    ///
752    /// # Arguments
753    ///
754    /// * `tensor` - The tensor to product.
755    ///
756    /// # Returns
757    ///
758    /// A scalar tensor with the product of all elements in `tensor`.
759    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
760        // Product of all elements in a tensor
761        B::float_exp(B::float_sum(B::float_log(tensor)))
762    }
763
764    /// Product of all elements in a tensor along a dimension.
765    ///
766    /// # Arguments
767    ///
768    /// * `tensor` - The tensor to product.
769    ///
770    /// # Returns
771    ///
772    /// A tensor with the product of all elements in `tensor` along `dim`.
773    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
774        // Product of all elements in a tensor along a dimension
775        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
776    }
777
778    /// Mean of all elements in a tensor.
779    ///
780    /// # Arguments
781    ///
782    /// * `tensor` - The tensor to mean.
783    ///
784    /// # Returns
785    ///
786    /// A scalar tensor with the mean of all elements in `tensor`.
787    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
788        let num_elems = tensor.shape().num_elements();
789        B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
790    }
791
792    /// Mean of all elements in a tensor along a dimension.
793    ///
794    /// # Arguments
795    ///
796    /// * `tensor` - The tensor to mean.
797    /// * `dim` - The dimension along which to mean.
798    ///
799    /// # Returns
800    ///
801    /// A tensor with the mean of all elements in `tensor` along `dim`.
802    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
803
804    /// Computes the cumulative sum of elements along a dimension.
805    ///
806    /// # Arguments
807    ///
808    /// * `tensor` - The tensor to compute the cumulative sum of.
809    /// * `dim` - The dimension along which to compute the cumulative sum.
810    ///
811    /// # Returns
812    ///
813    /// A tensor with the same shape where each element is the cumulative sum
814    /// of all elements up to and including that position along the dimension.
815    fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
816
817    /// Computes the cumulative product of elements along a dimension.
818    ///
819    /// # Arguments
820    ///
821    /// * `tensor` - The tensor to compute the cumulative product of.
822    /// * `dim` - The dimension along which to compute the cumulative product.
823    ///
824    /// # Returns
825    ///
826    /// A tensor with the same shape where each element is the cumulative product
827    /// of all elements up to and including that position along the dimension.
828    fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
829
830    /// Computes the cumulative minimum of elements along a dimension.
831    ///
832    /// # Arguments
833    ///
834    /// * `tensor` - The tensor to compute the cumulative minimum of.
835    /// * `dim` - The dimension along which to compute the cumulative minimum.
836    ///
837    /// # Returns
838    ///
839    /// A tensor with the same shape where each element is the minimum
840    /// of all elements up to and including that position along the dimension.
841    fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
842
843    /// Computes the cumulative maximum of elements along a dimension.
844    ///
845    /// # Arguments
846    ///
847    /// * `tensor` - The tensor to compute the cumulative maximum of.
848    /// * `dim` - The dimension along which to compute the cumulative maximum.
849    ///
850    /// # Returns
851    ///
852    /// A tensor with the same shape where each element is the maximum
853    /// of all elements up to and including that position along the dimension.
854    fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
855
856    /// Converts a tensor to another floating point data type.
857    ///
858    /// # Arguments
859    ///
860    /// * `tensor` - The tensor to convert.
861    /// * `dtype` - The target data type.
862    ///
863    /// # Returns
864    ///
865    /// A tensor with the same values as `tensor` but in the target floating point data type.
866    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
867
868    /// Returns a new tensor with exponential values.
869    ///
870    /// # Arguments
871    ///
872    /// * `tensor` - The tensor to exponentiate.
873    ///
874    /// # Returns
875    ///
876    /// A tensor with the same shape as `tensor` with exponential values.
877    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
878
879    /// Returns a new tensor with natural logarithm values.
880    ///
881    /// # Arguments
882    ///
883    /// * `tensor` - The tensor to take the logarithm of.
884    ///
885    /// # Returns
886    ///
887    /// A tensor with the same shape as `tensor` with natural logarithm values.
888    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
889
890    /// Returns a new tensor with logarithm values of (1 + Xi).
891    ///
892    /// # Arguments
893    ///
894    /// * `tensor` - The tensor to take the logarithm of.
895    ///
896    /// # Returns
897    ///
898    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
899    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
900
901    /// Element-wise power with a FloatTensor.
902    ///
903    /// # Arguments
904    ///
905    /// * `lhs` - The left-hand side tensor.
906    /// * `rhs` - The right-hand side tensor.
907    ///
908    /// # Returns
909    ///
910    /// The elements of `lhs` raised to the power of the elements of `rhs`.
911    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
912
913    /// Element-wise power with an IntTensor.
914    ///
915    /// # Arguments
916    ///
917    /// * `lhs` - The left-hand side tensor.
918    /// * `rhs` - The right-hand side floatTensor.
919    ///
920    /// # Returns
921    ///
922    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
923    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
924        Self::float_powf(lhs, B::int_into_float(rhs))
925    }
926
927    /// Raises a tensor to the power of an int scalar.
928    ///
929    /// # Backend Implementors Note
930    ///
931    /// A number of common exponent cases can be implemented with operations
932    /// which are much cheaper than generic exponentiation.
933    ///
934    /// This (`Backend` impl overridable) operation handles generic optimizations
935    /// for several common integer exponent cases; and then dispatches to
936    /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
937    /// operation to handle the generic case.
938    ///
939    /// # Arguments
940    ///
941    /// * `lhs` - The left-hand side tensor.
942    /// * `rhs` - The right-hand side scalar.
943    ///
944    /// # Returns
945    ///
946    /// The elements of `lhs` raised to the value of `rhs`.
947    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
948        let exp = rhs.elem::<i32>();
949        match exp {
950            0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
951            1 => lhs,
952            2 => B::float_mul(lhs.clone(), lhs),
953            -1 => Self::float_recip(lhs),
954            -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
955            _ => Self::float_powi_scalar_impl(lhs, rhs),
956        }
957    }
958
959    /// Raises a tensor to the power of an int scalar.
960    ///
961    /// # Backend Implementors Note
962    ///
963    /// This is the generic implementation of integer exponentiation
964    /// called by [`Self::float_powi_scalar`] in the fallback case.
965    ///
966    /// As a general rule, this should not be called directly.
967    ///
968    /// # Arguments
969    ///
970    /// * `lhs` - The left-hand side tensor.
971    /// * `rhs` - The right-hand side scalar.
972    ///
973    /// # Returns
974    ///
975    /// The elements of `lhs` raised to the value of `rhs`.
976    fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
977        // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
978        Self::float_powf_scalar_impl(lhs, rhs.elem::<f32>())
979    }
980
981    /// Returns a new tensor with values raised to the power of float `value`.
982    ///
983    /// # Backend Implementors Note
984    ///
985    /// This (`Backend` impl overridable) operation dispatches integer exponentiation
986    /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
987    /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
988    /// operation to handle the generic case.
989    ///
990    /// # Arguments
991    ///
992    /// * `tensor` - The tensor to exponentiate.
993    /// * `value` - The exponent.
994    ///
995    /// # Returns
996    ///
997    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
998    fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B> {
999        if num_traits::Float::floor(value) == value {
1000            // When the exponent is an integer, use the integer exponentiation implementation.
1001            let exp = B::IntElem::from_elem(value as i32);
1002            Self::float_powi_scalar(tensor, exp)
1003        } else {
1004            Self::float_powf_scalar_impl(tensor, value)
1005        }
1006    }
1007
1008    /// Returns a new tensor with values raised to the power of float `value`.
1009    ///
1010    /// # Backend Implementors Note
1011    ///
1012    /// This is the generic implementation of integer exponentiation
1013    /// called by [`Self::float_powf_scalar`] in the fallback case.
1014    ///
1015    /// This is the minimal required support a `Backend` must implement
1016    /// for exponentiation.
1017    ///
1018    /// As a general rule, this should not be called directly.
1019    ///
1020    /// # Arguments
1021    ///
1022    /// * `tensor` - The tensor to exponentiate.
1023    /// * `value` - The exponent.
1024    ///
1025    /// # Returns
1026    ///
1027    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1028    fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
1029
1030    /// Returns a new tensor with square root values.
1031    ///
1032    /// # Arguments
1033    ///
1034    /// * `tensor` - The tensor to take the square root of.
1035    ///
1036    /// # Returns
1037    ///
1038    /// A tensor with the same shape as `tensor` with square root values.
1039    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1040
1041    /// Returns a new tensor with absolute values.
1042    ///
1043    /// # Arguments
1044    ///
1045    /// * `tensor` - The tensor to take absolute value of.
1046    ///
1047    /// # Returns
1048    ///
1049    /// A tensor with the same shape as `tensor` with absolute values.
1050    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1051
1052    /// Returns a new tensor with cosine values.
1053    ///
1054    /// # Arguments
1055    ///
1056    /// * `tensor` - The tensor to take the cosine of.
1057    ///
1058    /// # Returns
1059    ///
1060    /// A tensor with the same shape as `tensor` with cosine values.
1061    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1062
1063    /// Returns a new tensor with sine values.
1064    ///
1065    /// # Arguments
1066    ///
1067    /// * `tensor` - The tensor to take the sine of.
1068    ///
1069    /// # Returns
1070    ///
1071    /// A tensor with the same shape as `tensor` with sine values.
1072    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1073
1074    /// Returns a new tensor with tangent values.
1075    ///
1076    /// # Arguments
1077    ///
1078    /// * `tensor` - The tensor to take the tangent of.
1079    ///
1080    /// # Returns
1081    ///
1082    /// A tensor with the same shape as `tensor` with tangent values.
1083    fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> {
1084        let sin = B::float_sin(tensor.clone());
1085        let cos = B::float_cos(tensor);
1086        B::float_div(sin, cos)
1087    }
1088
1089    /// Returns a new tensor with hyperbolic cosine values.
1090    ///
1091    /// # Arguments
1092    ///
1093    /// * `tensor` - The tensor to take the hyperbolic cosine of.
1094    ///
1095    /// # Returns
1096    ///
1097    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1098    fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1099        // cosh = ( e^x + e^(-x) ) / 2
1100        let e_x = B::float_exp(tensor.clone());
1101        let e_neg_x = B::float_exp(B::float_neg(tensor));
1102        let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x)
1103        B::float_div_scalar(num, 2.0.elem())
1104    }
1105
1106    /// Returns a new tensor with hyperbolic sine values.
1107    ///
1108    /// # Arguments
1109    ///
1110    /// * `tensor` - The tensor to take the hyperbolic sine of.
1111    ///
1112    /// # Returns
1113    ///
1114    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1115    fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1116        // sinh = ( e^x - e^(-x) ) / 2
1117        let e_x = B::float_exp(tensor.clone());
1118        let e_neg_x = B::float_exp(B::float_neg(tensor));
1119        let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x)
1120        B::float_div_scalar(num, 2.0.elem())
1121    }
1122
1123    /// Returns a new tensor with hyperbolic tangent values.
1124    ///
1125    /// # Arguments
1126    ///
1127    /// * `tensor` - The tensor to take the hyperbolic tangent of.
1128    ///
1129    /// # Returns
1130    ///
1131    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1132    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1133        let sinh = B::float_sinh(tensor.clone());
1134        let cosh = B::float_cosh(tensor);
1135        B::float_div(sinh, cosh)
1136    }
1137
1138    /// Returns a new tensor with rounded values.
1139    ///
1140    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1141    /// strategy, with halfway cases rounded to the nearest even integer value.
1142    ///
1143    /// # Arguments
1144    ///
1145    /// * `tensor` - The tensor to be rounded.
1146    ///
1147    /// # Returns
1148    ///
1149    /// A tensor with the same shape as `tensor` with rounded values.
1150    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1151
1152    /// Returns a new tensor with floored values.
1153    ///
1154    /// # Arguments
1155    ///
1156    /// * `tensor` - The tensor to be floored.
1157    ///
1158    /// # Returns
1159    ///
1160    /// A tensor with the same shape as `tensor` with floored values.
1161    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1162
1163    /// Returns a new tensor with ceiled values.
1164    ///
1165    /// # Arguments
1166    ///
1167    /// * `tensor` - The tensor to be ceiled.
1168    ///
1169    /// # Returns
1170    ///
1171    /// A tensor with the same shape as `tensor` with ceiled values.
1172    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1173
1174    /// Returns a new tensor with truncated values.
1175    ///
1176    /// # Arguments
1177    ///
1178    /// * `tensor` - The tensor to be truncated.
1179    ///
1180    /// # Returns
1181    ///
1182    /// A tensor with the same shape as `tensor` with truncated values.
1183    fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1184
1185    /// Returns a new tensor with the error function values.
1186    ///
1187    /// # Arguments
1188    ///
1189    /// * `tensor` - The tensor to take the error function of.
1190    ///
1191    /// # Returns
1192    ///
1193    /// A tensor with the same shape as `tensor` with error function values.
1194    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1195
1196    /// Concatenates tensors along a dimension.
1197    ///
1198    /// # Arguments
1199    ///
1200    /// * `tensors` - The tensors to concatenate.
1201    /// * `dim` - The dimension along which to concatenate.
1202    ///
1203    /// # Returns
1204    ///
1205    /// A tensor with the concatenated tensors along `dim`.
1206    ///
1207    /// # Note
1208    ///
1209    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1210    /// high-level tensor API and will not be passed to this method. Backend implementations do
1211    /// not need to handle empty tensors.
1212    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1213        cat_with_slice_assign::<B, Float>(
1214            tensors.into_iter().map(TensorPrimitive::Float).collect(),
1215            dim,
1216        )
1217        .tensor()
1218    }
1219
1220    /// Gets the indices of the maximum elements of a tensor along an axis.
1221    ///
1222    /// # Arguments
1223    ///
1224    /// * `tensor` - The tensor to get the maximum elements of.
1225    /// * `dim` - The dimension along which to get the maximum elements.
1226    ///
1227    /// # Returns
1228    ///
1229    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1230    fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1231
1232    /// Gets the indices of the minimum elements of a tensor along an axis.
1233    ///
1234    /// # Arguments
1235    ///
1236    /// * `tensor` - The tensor to get the minimum elements of.
1237    /// * `dim` - The dimension along which to get the minimum elements.
1238    ///
1239    /// # Returns
1240    ///
1241    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1242    fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1243
1244    /// Gets the maximum element of a tensor.
1245    ///
1246    /// # Arguments
1247    ///
1248    /// * `tensor` - The tensor to get the maximum elements of.
1249    ///
1250    /// # Returns
1251    ///
1252    /// A tensor with the maximum element of `tensor`.
1253    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1254        let shape = tensor.shape();
1255        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1256
1257        B::float_max_dim(tensor, 0)
1258    }
1259
1260    /// Gets the maximum elements of a tensor along an axis.
1261    ///
1262    /// # Arguments
1263    ///
1264    /// * `tensor` - The tensor to get the maximum elements of.
1265    /// * `dim` - The dimension along which to get the maximum elements.
1266    ///
1267    /// # Returns
1268    ///
1269    /// A tensor with the maximum elements of `tensor` along `dim`.
1270    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1271        let index = B::float_argmax(tensor.clone(), dim);
1272
1273        B::float_gather(dim, tensor, index)
1274    }
1275
1276    /// Gets the maximum elements of a tensor along an axis and their indices.
1277    ///
1278    /// # Arguments
1279    ///
1280    /// * `tensor` - The tensor to get the maximum elements of.
1281    /// * `dim` - The dimension along which to get the maximum elements.
1282    ///
1283    /// # Returns
1284    ///
1285    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1286    fn float_max_dim_with_indices(
1287        tensor: FloatTensor<B>,
1288        dim: usize,
1289    ) -> (FloatTensor<B>, IntTensor<B>) {
1290        let index = B::float_argmax(tensor.clone(), dim);
1291        let values = B::float_gather(dim, tensor, index.clone());
1292
1293        (values, index)
1294    }
1295
1296    /// Gets the minimum element of a tensor.
1297    ///
1298    /// # Arguments
1299    ///
1300    /// * `tensor` - The tensor to get the minimum elements of.
1301    ///
1302    /// # Returns
1303    ///
1304    /// A tensor with the minimum element of `tensor`.
1305    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1306        let shape = tensor.shape();
1307        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1308
1309        B::float_min_dim(tensor, 0)
1310    }
1311
1312    /// Gets the minimum elements of a tensor along an axis.
1313    ///
1314    /// # Arguments
1315    ///
1316    /// * `tensor` - The tensor to get the minimum elements of.
1317    /// * `dim` - The dimension along which to get the minimum elements.
1318    ///
1319    /// # Returns
1320    ///
1321    /// A tensor with the minimum elements of `tensor` along `dim`.
1322    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1323        let index = B::float_argmin(tensor.clone(), dim);
1324
1325        B::float_gather(dim, tensor, index)
1326    }
1327
1328    /// Gets the minimum elements of a tensor along an axis and their indices.
1329    ///
1330    /// # Arguments
1331    ///
1332    /// * `tensor` - The tensor to get the minimum elements of.
1333    /// * `dim` - The dimension along which to get the minimum elements.
1334    ///
1335    /// # Returns
1336    ///
1337    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1338    fn float_min_dim_with_indices(
1339        tensor: FloatTensor<B>,
1340        dim: usize,
1341    ) -> (FloatTensor<B>, IntTensor<B>) {
1342        let index = B::float_argmin(tensor.clone(), dim);
1343        let values = B::float_gather(dim, tensor, index.clone());
1344
1345        (values, index)
1346    }
1347
1348    /// Gets the maximum absolute element of a tensor.
1349    ///
1350    /// # Arguments
1351    ///
1352    /// * `tensor` - The tensor to get the maximum elements of.
1353    ///
1354    /// # Returns
1355    ///
1356    /// A tensor with the maximum element of `tensor`.
1357    fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1358        let shape = tensor.shape();
1359        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1360
1361        B::float_max_abs_dim(tensor, 0)
1362    }
1363
1364    /// Gets the maximum absolute elements of a tensor along an axis.
1365    ///
1366    /// # Arguments
1367    ///
1368    /// * `tensor` - The tensor to get the maximum elements of.
1369    /// * `dim` - The dimension along which to get the maximum elements.
1370    ///
1371    /// # Returns
1372    ///
1373    /// A tensor with the maximum elements of `tensor` along `dim`.
1374    fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1375        B::float_max_dim(B::float_abs(tensor), dim)
1376    }
1377
1378    /// Tests if any element in the float `tensor` evaluates to True.
1379    ///
1380    /// # Arguments
1381    ///
1382    /// * `tensor` - The tensor to test.
1383    ///
1384    /// # Returns
1385    ///
1386    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1387    fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1388        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1389        let bool_tensor = B::bool_not(bool_tensor);
1390        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1391        B::float_greater_elem(sum, 0.0f32.elem())
1392    }
1393
1394    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1395    ///
1396    /// # Arguments
1397    ///
1398    /// * `tensor` - The tensor to test.
1399    /// * `dim` - The axis along which to test.
1400    ///
1401    /// # Returns
1402    ///
1403    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1404    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1405    /// input evaluates to True, False otherwise.
1406    fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1407        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1408        let bool_tensor = B::bool_not(bool_tensor);
1409        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1410        B::float_greater_elem(sum, 0.0f32.elem())
1411    }
1412
1413    /// Tests if all elements in the float `tensor` evaluate to True.
1414    ///
1415    /// # Arguments
1416    ///
1417    /// * `tensor` - The tensor to test.
1418    ///
1419    /// # Returns
1420    ///
1421    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1422    /// evaluate to True, False otherwise.
1423    fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1424        let num_elems = tensor.shape().num_elements();
1425        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1426        let bool_tensor = B::bool_not(bool_tensor);
1427        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1428        B::float_equal_elem(sum, (num_elems as f32).elem())
1429    }
1430
1431    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1432    ///
1433    /// # Arguments
1434    ///
1435    /// * `tensor` - The tensor to test.
1436    /// * `dim` - The axis along which to test.
1437    ///
1438    /// # Returns
1439    ///
1440    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1441    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1442    /// evaluates to True, False otherwise.
1443    fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1444        let num_elems = tensor.shape().dims[dim];
1445        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1446        let bool_tensor = B::bool_not(bool_tensor);
1447        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1448        B::float_equal_elem(sum, (num_elems as f32).elem())
1449    }
1450
1451    /// Returns the signs of the float `tensor`.
1452    ///
1453    /// # Arguments
1454    ///
1455    /// * `tensor` - The tensor to extract the signs from.
1456    ///
1457    /// # Returns
1458    ///
1459    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1460    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1461        let zeros = B::float_zeros(
1462            tensor.shape(),
1463            &B::float_device(&tensor),
1464            tensor.dtype().into(),
1465        );
1466        let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1467        let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1468
1469        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1470        result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1471        result
1472    }
1473
1474    /// Broadcasts the float `tensor` to the given `shape`.
1475    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1476
1477    /// Sort the elements of the input `tensor` by value in along a given dimension.
1478    ///
1479    /// This sort is unstable (i.e., may reorder equal elements).
1480    ///
1481    /// # Arguments
1482    ///
1483    /// * `tensor` - The input tensor.
1484    /// * `dim` - The axis along which to sort.
1485    /// * `descending` - The sorting order.
1486    ///
1487    /// # Returns
1488    ///
1489    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1490    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1491        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1492    }
1493
1494    /// Sort the elements of the input `tensor` by value in along a given dimension.
1495    ///
1496    /// This sort is unstable (i.e., may reorder equal elements).
1497    ///
1498    /// # Arguments
1499    ///
1500    /// * `tensor` - The input tensor.
1501    /// * `dim` - The axis along which to sort.
1502    /// * `descending` - The sorting order.
1503    ///
1504    /// # Returns
1505    ///
1506    /// A tensor with the same shape as the input tensor and corresponding indices, where
1507    /// the elements are sorted by value and the indices map back to the original input tensor.
1508    fn float_sort_with_indices(
1509        tensor: FloatTensor<B>,
1510        dim: usize,
1511        descending: bool,
1512    ) -> (FloatTensor<B>, IntTensor<B>) {
1513        let (values, indices) =
1514            sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1515        (values.tensor(), indices)
1516    }
1517
1518    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1519    ///
1520    /// This sort is unstable (i.e., may reorder equal elements).
1521    ///
1522    /// # Arguments
1523    ///
1524    /// * `tensor` - The input tensor.
1525    /// * `dim` - The axis along which to sort.
1526    /// * `descending` - The sorting order.
1527    ///
1528    /// # Returns
1529    ///
1530    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1531    fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1532        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1533    }
1534
1535    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1536    /// using the given locations in [-1, 1].
1537    ///
1538    /// # Arguments
1539    ///
1540    /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
1541    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1542    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1543    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1544    ///
1545    /// # Returns
1546    ///
1547    /// A tensor with shape (N, C, H_out, W_out)
1548    fn float_grid_sample_2d(
1549        tensor: FloatTensor<B>,
1550        grid: FloatTensor<B>,
1551        options: GridSampleOptions,
1552    ) -> FloatTensor<B> {
1553        float_grid_sample_2d_ref::<B>(tensor, grid, options)
1554    }
1555
1556    /// Unfold windows along a dimension.
1557    ///
1558    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1559    /// where windows are advanced by `step` at each index.
1560    ///
1561    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1562    ///
1563    /// # Arguments
1564    ///
1565    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1566    /// * `dim` - the selected dim.
1567    /// * `size` - the size of each unfolded window.
1568    /// * `step` - the step between each window.
1569    ///
1570    /// # Returns
1571    ///
1572    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1573    fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1574    -> FloatTensor<B>;
1575
1576    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1577    ///
1578    /// # Returns
1579    ///
1580    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1581    fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> {
1582        // Check if the input tensor is NaN by comparing it to itself
1583        // NaN is the only value that is not equal to itself
1584        B::float_not_equal(tensor.clone(), tensor)
1585    }
1586
1587    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1588    ///
1589    /// # Returns
1590    ///
1591    /// A boolean tensor where `true` indicates that the value is infinite
1592    fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> {
1593        B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.elem())
1594    }
1595}