burn_tensor/tensor/ops/
tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor};
4use crate::tensor::cast::ToElement;
5use crate::{Distribution, ElementConversion, Float, TensorData, backend::Backend, tensor::Shape};
6use crate::{
7    FloatDType, TensorMetadata, TensorPrimitive, tensor::api::chunk, tensor::api::narrow,
8    tensor::api::split, tensor::api::split_with_sizes,
9};
10use alloc::vec::Vec;
11use core::future::Future;
12use core::ops::Range;
13
14use crate::{argsort, sort, sort_with_indices};
15
16/// Operations on float tensors.
17pub trait FloatTensorOps<B: Backend> {
18    /// Creates a new tensor from the data structure.
19    ///
20    /// # Arguments
21    ///
22    /// * `data` - The data structure.
23    /// * `device` - The device to create the tensor on.
24    ///
25    /// # Returns
26    ///
27    /// The tensor with the given data.
28    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
29
30    /// Creates a new tensor with random values.
31    ///
32    /// # Arguments
33    ///
34    /// * `shape` - The shape of the tensor.
35    /// * `distribution` - The distribution to sample from.
36    /// * `device` - The device to create the tensor on.
37    ///
38    /// # Returns
39    ///
40    /// The tensor with the given shape and random values.
41    fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
42    -> FloatTensor<B>;
43
44    /// Creates a new tensor with zeros.
45    ///
46    /// # Arguments
47    ///
48    /// * `shape` - The shape of the tensor.
49    /// * `device` - The device to create the tensor on.
50    ///
51    /// # Returns
52    ///
53    /// The tensor with the given shape and zeros.
54    fn float_zeros(shape: Shape, device: &Device<B>) -> FloatTensor<B> {
55        Self::float_from_data(TensorData::zeros::<FloatElem<B>, _>(shape), device)
56    }
57
58    /// Creates a new tensor with ones.
59    ///
60    /// # Arguments
61    ///
62    /// * `shape` - The shape of the tensor.
63    /// * `device` - The device to create the tensor on.
64    ///
65    /// # Returns
66    ///
67    /// The tensor with the given shape and ones.
68    fn float_ones(shape: Shape, device: &Device<B>) -> FloatTensor<B> {
69        Self::float_from_data(TensorData::ones::<FloatElem<B>, _>(shape), device)
70    }
71
72    /// Creates a tensor filled with given value.
73    ///
74    /// # Arguments
75    ///
76    /// * `shape` - The shape of the tensor.
77    /// * `fill_value` - The value with which to fill the tensor.
78    /// * `device` - The device to create the tensor on.
79    ///
80    /// # Returns
81    ///
82    /// The tensor filled with given value
83    fn float_full(shape: Shape, fill_value: FloatElem<B>, device: &Device<B>) -> FloatTensor<B> {
84        Self::float_add_scalar(Self::float_zeros(shape, device), fill_value)
85    }
86
87    /// Converts the tensor to a data structure.
88    ///
89    /// # Arguments
90    ///
91    /// * `tensor` - The tensor.
92    ///
93    /// # Returns
94    ///
95    /// The data structure with the tensor's data.
96    fn float_into_data(tensor: FloatTensor<B>)
97    -> impl Future<Output = TensorData> + 'static + Send;
98
99    /// Gets the device of the tensor.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The device of the tensor.
108    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
109
110    /// Moves the tensor to the given device.
111    ///
112    /// # Arguments
113    ///
114    /// * `tensor` - The tensor.
115    /// * `device` - The device to move the tensor to.
116    ///
117    /// # Returns
118    ///
119    /// The tensor on the given device.
120    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
121
122    /// Converts float tensor to int tensor.
123    ///
124    /// # Arguments
125    ///
126    /// * `tensor` - The tensor.
127    ///
128    /// # Returns
129    ///
130    /// The int tensor with the same data as the float tensor.
131    fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
132
133    /// Creates an empty tensor with the given shape.
134    ///
135    /// # Arguments
136    ///
137    /// * `shape` - The shape of the tensor.
138    /// * `device` - The device to create the tensor on.
139    ///
140    /// # Returns
141    ///
142    /// The empty tensor with the given shape.
143    fn float_empty(shape: Shape, device: &Device<B>) -> FloatTensor<B>;
144
145    /// Repeat the tensor along the given dimension.
146    ///
147    /// # Arguments
148    ///
149    /// * `tensor` - The tensor.
150    /// * `dim` - The dimension to repeat.
151    /// * `times` - The number of times to repeat the dimension.
152    ///
153    /// # Returns
154    ///
155    /// The tensor with the given dimension repeated.
156    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
157        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
158    }
159
160    /// Adds two tensors together.
161    ///
162    /// # Arguments
163    ///
164    /// * `lhs` - The left hand side tensor.
165    /// * `rhs` - The right hand side tensor.
166    ///
167    /// # Returns
168    ///
169    /// The result of adding the two tensors together.
170    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
171
172    /// Adds a scalar to a tensor.
173    ///
174    /// # Arguments
175    ///
176    /// * `lhs` - The left hand side tensor.
177    /// * `rhs` - The right hand side scalar.
178    ///
179    /// # Returns
180    ///
181    /// The result of adding the scalar to the tensor.
182    fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
183
184    /// Clamps a tensor under a minimum value.
185    ///
186    /// # Arguments
187    ///
188    /// * `tensor` - The tensor to clamp.
189    /// * `min` - The minimum value.
190    ///
191    /// # Returns
192    ///
193    /// The clamped tensor.
194    fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
195        // Default implementation
196        let mask = Self::float_lower_elem(tensor.clone(), min);
197        B::float_mask_fill(tensor, mask, min)
198    }
199
200    /// Clamps a tensor over a maximum value.
201    ///
202    /// # Arguments
203    ///
204    /// * `tensor` - The tensor to clamp.
205    /// * `max` - The maximum value.
206    ///
207    /// # Returns
208    ///
209    /// The clamped tensor.
210    fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
211        // Default implementation
212        let mask = Self::float_greater_elem(tensor.clone(), max);
213        B::float_mask_fill(tensor, mask, max)
214    }
215
216    /// Clamps a tensor between a minimum and maximum value.
217    ///
218    /// # Arguments
219    ///
220    /// * `tensor` - The tensor to clamp.
221    /// * `min` - The minimum value.
222    /// * `max` - The maximum value.
223    ///
224    /// # Returns
225    ///
226    /// The clamped tensor.
227    fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
228        // Default implementation
229        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
230    }
231
232    /// Subtracts two tensors.
233    ///
234    /// # Arguments
235    ///
236    /// * `lhs` - The left hand side tensor.
237    /// * `rhs` - The right hand side tensor.
238    ///
239    /// # Returns
240    ///
241    /// The result of subtracting the two tensors.
242    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
243
244    /// Subtracts a scalar from a tensor.
245    ///
246    /// # Arguments
247    ///
248    /// * `lhs` - The left hand side tensor.
249    /// * `rhs` - The right hand side scalar.
250    ///
251    /// # Returns
252    ///
253    /// The result of subtracting the scalar from the tensor.
254    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
255
256    /// Multiplies two tensors together element-wise.
257    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259    /// Multiplies a tensor by a scalar.
260    ///
261    /// # Arguments
262    ///
263    /// * `lhs` - The left hand side tensor.
264    /// * `rhs` - The right hand side scalar.
265    ///
266    /// # Returns
267    ///
268    /// The result of multiplying the tensor by the scalar.
269    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
270
271    /// Divides two tensors element-wise.
272    ///
273    /// # Arguments
274    ///
275    /// * `lhs` - The left hand side tensor.
276    /// * `rhs` - The right hand side tensor.
277    ///
278    /// # Returns
279    ///
280    /// The result of dividing the two tensors.
281    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
282
283    /// Divides a tensor by a scalar.
284    ///
285    /// # Arguments
286    ///
287    /// * `lhs` - The left hand side tensor.
288    /// * `rhs` - The right hand side scalar.
289    ///
290    /// # Returns
291    ///
292    /// The result of dividing the tensor by the scalar.
293    fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
294
295    /// Computes the remainder of division between two tensors element-wise.
296    ///
297    /// # Arguments
298    ///
299    /// * `lhs` - The left hand side tensor.
300    /// * `rhs` - The right hand side tensor.
301    ///
302    /// # Returns
303    ///
304    /// The element-wise remainder when dividing `lhs` by `rhs`.
305    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
306
307    /// Computes the modulus of a tensor given a scalar.
308    ///
309    /// # Arguments
310    /// * `lhs` - The left hand side tensor.
311    /// * `rhs` - The right hand side scalar.
312    ///
313    /// # Returns
314    ///
315    /// The result of applying the modulus of the scalar to the tensor.
316    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
317
318    /// Multiplies two tensors together using matrix multiplication.
319    ///
320    /// # Arguments
321    ///
322    /// * `lhs` - The left hand side tensor.
323    /// * `rhs` - The right hand side tensor.
324    ///
325    /// # Returns
326    ///
327    /// The result of multiplying the two tensors together using matrix multiplication.
328    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
329
330    /// Negates a tensor element-wise.
331    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
332        Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
333    }
334
335    /// Calculates the reciprocals element-wise
336    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
337
338    /// Transposes a tensor.
339    ///
340    /// # Arguments
341    ///
342    /// * `tensor` - The tensor to transpose.
343    ///
344    /// # Returns
345    ///
346    /// The transposed tensor.
347    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
348        let ndims = tensor.shape().num_dims();
349        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
350    }
351
352    /// Swaps two dimensions of a tensor.
353    ///
354    /// # Arguments
355    ///
356    /// * `tensor` - The tensor to swap the dimensions of.
357    /// * `dim1` - The first dimension to swap.
358    /// * `dim2` - The second dimension to swap.
359    ///
360    /// # Returns
361    ///
362    /// The tensor with the dimensions swapped.
363    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
364
365    /// Permutes the dimensions of a tensor.
366    ///
367    /// # Arguments
368    ///
369    /// * `tensor` - The tensor to permute the dimensions of.
370    /// * `axes` - The new order of the dimensions.
371    /// # Returns
372    ///
373    /// The tensor with the dimensions permuted.
374    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
375
376    /// Reverse the order of elements in a tensor along the given axes.
377    ///
378    /// # Arguments
379    ///
380    /// * `tensor` - The tensor to reverse.
381    /// * `axes` - The axes to reverse.
382    ///
383    /// The tensor with the elements reversed.
384    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
385
386    /// Reshapes a tensor.
387    ///
388    /// # Arguments
389    ///
390    /// * `tensor` - The tensor to reshape.
391    /// * `shape` - The new shape of the tensor.
392    ///
393    /// # Returns
394    ///
395    /// The tensor with the new shape.
396    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
397
398    /// Gather elements from a tensor.
399    ///
400    /// # Arguments
401    ///
402    /// * `dim` - The dimension to gather from.
403    /// * `tensor` - The tensor to gather from.
404    /// * `indices` - The indices to gather.
405    ///
406    /// # Returns
407    ///
408    /// The gathered elements.
409    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
410
411    /// Scatter elements into a tensor.
412    ///
413    /// # Arguments
414    ///
415    /// * `dim` - The dimension to scatter into.
416    /// * `tensor` - The tensor to scatter into.
417    /// * `indices` - The indices to scatter into.
418    /// * `value` - The value to scatter.
419    ///
420    /// # Returns
421    ///
422    /// The tensor with the scattered elements.
423    fn float_scatter(
424        dim: usize,
425        tensor: FloatTensor<B>,
426        indices: IntTensor<B>,
427        value: FloatTensor<B>,
428    ) -> FloatTensor<B>;
429
430    /// Select tensor elements along the given dimension corresponding for the given indices.
431    ///
432    /// # Arguments
433    ///
434    /// * `tensor` - The tensor to select from.
435    /// * `dim` - The dimension to select from.
436    /// * `indices` - The indices to select.
437    ///
438    /// # Returns
439    ///
440    /// The selected elements.
441    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
442
443    /// Assign the selected elements along the given dimension corresponding for the given indices
444    /// to the given value.
445    ///
446    /// # Arguments
447    ///
448    /// * `tensor` - The tensor to select from.
449    /// * `dim` - The dimension to select from.
450    /// * `indices` - The indices to select.
451    /// * `value` - The value to assign.
452    ///
453    /// # Returns
454    ///
455    /// The tensor with the selected elements assigned to the given value.
456    fn float_select_assign(
457        tensor: FloatTensor<B>,
458        dim: usize,
459        indices: IntTensor<B>,
460        value: FloatTensor<B>,
461    ) -> FloatTensor<B>;
462
463    /// Select tensor elements corresponding for the given ranges.
464    ///
465    /// # Arguments
466    ///
467    /// * `tensor` - The tensor to select from.
468    /// * `ranges` - The ranges to select.
469    ///
470    /// # Returns
471    ///
472    /// The selected elements in a new tensor.
473    fn float_slice(tensor: FloatTensor<B>, ranges: &[Range<usize>]) -> FloatTensor<B>;
474
475    /// Assign the selected elements corresponding for the given ranges to the given value.
476    ///
477    /// # Arguments
478    ///
479    /// * `tensor` - The tensor to select from.
480    /// * `ranges` - The ranges to select.
481    /// * `value` - The value to assign.
482    ///
483    /// # Returns
484    ///
485    /// The tensor with the selected elements assigned to the given value.
486    fn float_slice_assign(
487        tensor: FloatTensor<B>,
488        ranges: &[Range<usize>],
489        value: FloatTensor<B>,
490    ) -> FloatTensor<B>;
491
492    /// Update the given tensor with the value tensor where the mask is true.
493    ///
494    /// # Arguments
495    ///
496    /// * `tensor` - The tensor to select from.
497    /// * `mask` - The boolean mask to select with.
498    /// * `value` - The value to assign to the selected elements from the value tensor.
499    ///
500    /// # Returns
501    ///
502    /// The tensor with the selected elements assigned to the given value.
503    fn float_mask_where(
504        tensor: FloatTensor<B>,
505        mask: BoolTensor<B>,
506        value: FloatTensor<B>,
507    ) -> FloatTensor<B>;
508
509    /// Update the given tensor with the value where the mask is true.
510    ///
511    /// # Arguments
512    ///
513    /// * `tensor` - The tensor to select from.
514    /// * `mask` - The boolean mask to select with.
515    /// * `value` - The value to assign to the selected elements.
516    ///
517    /// # Returns
518    ///
519    /// The tensor with the selected elements assigned to the given value.
520    fn float_mask_fill(
521        tensor: FloatTensor<B>,
522        mask: BoolTensor<B>,
523        value: FloatElem<B>,
524    ) -> FloatTensor<B>;
525
526    /// Equal comparison of two tensors.
527    ///
528    /// # Arguments
529    ///
530    /// * `lhs` - The left hand side tensor.
531    /// * `rhs` - The right hand side tensor.
532    ///
533    /// # Returns
534    ///
535    /// A boolean tensor with the result of the comparison.
536    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
537
538    /// Element-wise non-equality comparison.
539    ///
540    /// # Arguments
541    ///
542    /// * `lhs` - The left hand side tensor.
543    /// * `rhs` - The right hand side tensor.
544    ///
545    /// # Returns
546    ///
547    /// A boolean tensor with the result of the comparison.
548    fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
549        let equal_tensor = B::float_equal(lhs, rhs);
550        B::bool_not(equal_tensor)
551    }
552
553    /// Equal comparison of a tensor and a scalar.
554    ///
555    /// # Arguments
556    ///
557    /// * `lhs` - The left hand side tensor.
558    /// * `rhs` - The right hand side scalar.
559    ///
560    /// # Returns
561    ///
562    /// A boolean tensor with the result of the comparison.
563    fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
564
565    /// Element-wise non-equality comparison with a scalar.
566    ///
567    /// # Arguments
568    ///
569    /// * `lhs` - The left hand side tensor.
570    /// * `rhs` - The right hand side scalar.
571    ///
572    /// # Returns
573    ///
574    /// A boolean tensor with the result of the comparison.
575    fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
576        let equal_tensor = B::float_equal_elem(lhs, rhs);
577        B::bool_not(equal_tensor)
578    }
579
580    /// Greater than comparison of two tensors.
581    ///
582    /// # Arguments
583    ///
584    /// * `lhs` - The left hand side tensor.
585    /// * `rhs` - The right hand side tensor.
586    ///
587    /// # Returns
588    ///
589    /// A boolean tensor with the result of the comparison.
590    fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
591
592    /// Greater than comparison of a tensor and a scalar.
593    ///
594    /// # Arguments
595    ///
596    /// * `lhs` - The left hand side tensor.
597    /// * `rhs` - The right hand side scalar.
598    ///
599    /// # Returns
600    ///
601    /// A boolean tensor with the result of the comparison.
602    fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
603
604    /// Greater than or equal comparison of two tensors.
605    ///
606    /// # Arguments
607    ///
608    /// * `lhs` - The left hand side tensor.
609    /// * `rhs` - The right hand side tensor.
610    ///
611    /// # Returns
612    ///
613    /// A boolean tensor with the result of the comparison.
614    fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
615
616    /// Greater than or equal comparison of a tensor and a scalar.
617    ///
618    /// # Arguments
619    ///
620    /// * `lhs` - The left hand side tensor.
621    /// * `rhs` - The right hand side scalar.
622    ///
623    /// # Returns
624    ///
625    /// A boolean tensor with the result of the comparison.
626    fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
627
628    /// Less than comparison of two tensors.
629    ///
630    /// # Arguments
631    ///
632    /// * `lhs` - The left hand side tensor.
633    /// * `rhs` - The right hand side tensor.
634    ///
635    /// # Returns
636    ///
637    /// A boolean tensor with the result of the comparison.
638    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
639
640    /// Less than comparison of a tensor and a scalar.
641    ///
642    /// # Arguments
643    ///
644    /// * `lhs` - The left hand side tensor.
645    /// * `rhs` - The right hand side scalar.
646    ///
647    /// # Returns
648    ///
649    /// A boolean tensor with the result of the comparison.
650    fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
651
652    /// Less than or equal comparison of two tensors.
653    ///
654    /// # Arguments
655    ///
656    /// * `lhs` - The left hand side tensor.
657    /// * `rhs` - The right hand side tensor.
658    ///
659    /// # Returns
660    ///
661    /// A boolean tensor with the result of the comparison.
662    fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
663
664    /// Less than or equal comparison of a tensor and a scalar.
665    ///
666    /// # Arguments
667    ///
668    /// * `lhs` - The left hand side tensor.
669    /// * `rhs` - The right hand side scalar.
670    ///
671    /// # Returns
672    ///
673    /// A boolean tensor with the result of the comparison.
674    fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
675
676    /// Detaches a tensor from the computation graph.
677    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
678        // Should only be overridden by autodiff backends.
679        tensor
680    }
681
682    /// Sets the `require_grad` flag of a tensor.
683    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
684        // Should only be overridden by autodiff backends.
685        tensor
686    }
687
688    /// Returns the `require_grad` flag of a tensor.
689    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
690        // Should only be overridden by autodiff backends.
691        false
692    }
693
694    /// Sum of all elements in a tensor.
695    ///
696    /// # Arguments
697    ///
698    /// * `tensor` - The tensor to sum.
699    ///
700    /// # Returns
701    ///
702    /// A scalar tensor with the sum of all elements in `tensor`.
703    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
704
705    /// Sum of all elements in a tensor along a dimension.
706    ///
707    /// # Arguments
708    ///
709    /// * `tensor` - The tensor to sum.
710    /// * `dim` - The dimension along which to sum.
711    ///
712    /// # Returns
713    ///
714    /// A tensor with the sum of all elements in `tensor` along `dim`.
715    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
716
717    /// Product of all elements in a tensor.
718    ///
719    /// # Arguments
720    ///
721    /// * `tensor` - The tensor to product.
722    ///
723    /// # Returns
724    ///
725    /// A scalar tensor with the product of all elements in `tensor`.
726    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
727        // Product of all elements in a tensor
728        B::float_exp(B::float_sum(B::float_log(tensor)))
729    }
730
731    /// Product of all elements in a tensor along a dimension.
732    ///
733    /// # Arguments
734    ///
735    /// * `tensor` - The tensor to product.
736    ///
737    /// # Returns
738    ///
739    /// A tensor with the product of all elements in `tensor` along `dim`.
740    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
741        // Product of all elements in a tensor along a dimension
742        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
743    }
744
745    /// Mean of all elements in a tensor.
746    ///
747    /// # Arguments
748    ///
749    /// * `tensor` - The tensor to mean.
750    ///
751    /// # Returns
752    ///
753    /// A scalar tensor with the mean of all elements in `tensor`.
754    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
755        let num_elems = tensor.shape().num_elements();
756        B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
757    }
758
759    /// Mean of all elements in a tensor along a dimension.
760    ///
761    /// # Arguments
762    ///
763    /// * `tensor` - The tensor to mean.
764    /// * `dim` - The dimension along which to mean.
765    ///
766    /// # Returns
767    ///
768    /// A tensor with the mean of all elements in `tensor` along `dim`.
769    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
770
771    /// Converts a tensor to another floating point data type.
772    ///
773    /// # Arguments
774    ///
775    /// * `tensor` - The tensor to convert.
776    /// * `dtype` - The target data type.
777    ///
778    /// # Returns
779    ///
780    /// A tensor with the same values as `tensor` but in the target floating point data type.
781    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
782
783    /// Returns a new tensor with exponential values.
784    ///
785    /// # Arguments
786    ///
787    /// * `tensor` - The tensor to exponentiate.
788    ///
789    /// # Returns
790    ///
791    /// A tensor with the same shape as `tensor` with exponential values.
792    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
793
794    /// Returns a new tensor with natural logarithm values.
795    ///
796    /// # Arguments
797    ///
798    /// * `tensor` - The tensor to take the logarithm of.
799    ///
800    /// # Returns
801    ///
802    /// A tensor with the same shape as `tensor` with natural logarithm values.
803    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
804
805    /// Returns a new tensor with logarithm values of (1 + Xi).
806    ///
807    /// # Arguments
808    ///
809    /// * `tensor` - The tensor to take the logarithm of.
810    ///
811    /// # Returns
812    ///
813    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
814    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
815
816    /// Element-wise power with a FloatTensor.
817    ///
818    /// # Arguments
819    ///
820    /// * `lhs` - The left hand side tensor.
821    /// * `rhs` - The right hand side tensor.
822    ///
823    /// # Returns
824    ///
825    /// The elements of `lhs` raised to the power of the elements of `rhs`.
826    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
827
828    /// Element-wise power with an IntTensor.
829    ///
830    /// # Arguments
831    ///
832    /// * `lhs` - The left hand side tensor.
833    /// * `rhs` - The right hand side floatTensor.
834    ///
835    /// # Returns
836    ///
837    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
838    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
839        Self::float_powf(lhs, B::int_into_float(rhs))
840    }
841
842    /// raises a tensor to the power of an int scalar.
843    ///
844    /// # Arguments
845    ///
846    /// * `lhs` - The left hand side tensor.
847    /// * `rhs` - The right hand side scalar.
848    ///
849    /// # Returns
850    ///
851    /// The elements of `lhs` raised to the value of `rhs`.
852    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
853        Self::float_powf_scalar(lhs, rhs.to_f32())
854    }
855
856    /// Returns a new tensor with values raised to the power of float `value`.
857    ///
858    /// # Arguments
859    ///
860    /// * `tensor` - The tensor to exponentiate.
861    /// * `value` - The exponent.
862    ///
863    /// # Returns
864    ///
865    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
866    fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
867
868    /// Returns a new tensor with square root values.
869    ///
870    /// # Arguments
871    ///
872    /// * `tensor` - The tensor to take the square root of.
873    ///
874    /// # Returns
875    ///
876    /// A tensor with the same shape as `tensor` with square root values.
877    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
878
879    /// Returns a new tensor with absolute values.
880    ///
881    /// # Arguments
882    ///
883    /// * `tensor` - The tensor to take absolute value of.
884    ///
885    /// # Returns
886    ///
887    /// A tensor with the same shape as `tensor` with absolute values.
888    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
889
890    /// Returns a new tensor with cosine values.
891    ///
892    /// # Arguments
893    ///
894    /// * `tensor` - The tensor to take the cosine of.
895    ///
896    /// # Returns
897    ///
898    /// A tensor with the same shape as `tensor` with cosine values.
899    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
900
901    /// Returns a new tensor with sine values.
902    ///
903    /// # Arguments
904    ///
905    /// * `tensor` - The tensor to take the sine of.
906    ///
907    /// # Returns
908    ///
909    /// A tensor with the same shape as `tensor` with sine values.
910    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
911
912    /// Returns a new tensor with tangent values.
913    ///
914    /// # Arguments
915    ///
916    /// * `tensor` - The tensor to take the tangent of.
917    ///
918    /// # Returns
919    ///
920    /// A tensor with the same shape as `tensor` with tangent values.
921    fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> {
922        let sin = B::float_sin(tensor.clone());
923        let cos = B::float_cos(tensor);
924        B::float_div(sin, cos)
925    }
926
927    /// Returns a new tensor with hyperbolic cosine values.
928    ///
929    /// # Arguments
930    ///
931    /// * `tensor` - The tensor to take the hyperbolic cosine of.
932    ///
933    /// # Returns
934    ///
935    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
936    fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> {
937        // cosh = ( e^x + e^(-x) ) / 2
938        let e_x = B::float_exp(tensor.clone());
939        let e_neg_x = B::float_exp(B::float_neg(tensor));
940        let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x)
941        B::float_div_scalar(num, 2.0.elem())
942    }
943
944    /// Returns a new tensor with hyperbolic sine values.
945    ///
946    /// # Arguments
947    ///
948    /// * `tensor` - The tensor to take the hyperbolic sine of.
949    ///
950    /// # Returns
951    ///
952    /// A tensor with the same shape as `tensor` with hyperbolic sine values.
953    fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> {
954        // sinh = ( e^x - e^(-x) ) / 2
955        let e_x = B::float_exp(tensor.clone());
956        let e_neg_x = B::float_exp(B::float_neg(tensor));
957        let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x)
958        B::float_div_scalar(num, 2.0.elem())
959    }
960
961    /// Returns a new tensor with hyperbolic tangent values.
962    ///
963    /// # Arguments
964    ///
965    /// * `tensor` - The tensor to take the hyperbolic tangent of.
966    ///
967    /// # Returns
968    ///
969    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
970    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> {
971        let sinh = B::float_sinh(tensor.clone());
972        let cosh = B::float_cosh(tensor);
973        B::float_div(sinh, cosh)
974    }
975
976    /// Returns a new tensor with rounded values.
977    ///
978    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
979    /// strategy, with halfway cases rounded to the nearest even integer value.
980    ///
981    /// # Arguments
982    ///
983    /// * `tensor` - The tensor to be rounded.
984    ///
985    /// # Returns
986    ///
987    /// A tensor with the same shape as `tensor` with rounded values.
988    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
989
990    /// Returns a new tensor with floored values.
991    ///
992    /// # Arguments
993    ///
994    /// * `tensor` - The tensor to be floored.
995    ///
996    /// # Returns
997    ///
998    /// A tensor with the same shape as `tensor` with floored values.
999    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1000
1001    /// Returns a new tensor with ceiled values.
1002    ///
1003    /// # Arguments
1004    ///
1005    /// * `tensor` - The tensor to be ceiled.
1006    ///
1007    /// # Returns
1008    ///
1009    /// A tensor with the same shape as `tensor` with ceiled values.
1010    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1011
1012    /// Returns a new tensor with the error function values.
1013    ///
1014    /// # Arguments
1015    ///
1016    /// * `tensor` - The tensor to take the error function of.
1017    ///
1018    /// # Returns
1019    ///
1020    /// A tensor with the same shape as `tensor` with error function values.
1021    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1022
1023    /// Concatenates tensors along a dimension.
1024    ///
1025    /// # Arguments
1026    ///
1027    /// * `tensors` - The tensors to concatenate.
1028    /// * `dim` - The dimension along which to concatenate.
1029    ///
1030    /// # Returns
1031    ///
1032    /// A tensor with the concatenated tensors along `dim`.
1033    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1034        cat_with_slice_assign::<B, Float>(
1035            tensors.into_iter().map(TensorPrimitive::Float).collect(),
1036            dim,
1037        )
1038        .tensor()
1039    }
1040
1041    /// Gets the indices of the maximum elements of a tensor along an axis.
1042    ///
1043    /// # Arguments
1044    ///
1045    /// * `tensor` - The tensor to get the maximum elements of.
1046    /// * `dim` - The dimension along which to get the maximum elements.
1047    ///
1048    /// # Returns
1049    ///
1050    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1051    fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1052
1053    /// Gets the indices of the minimum elements of a tensor along an axis.
1054    ///
1055    /// # Arguments
1056    ///
1057    /// * `tensor` - The tensor to get the minimum elements of.
1058    /// * `dim` - The dimension along which to get the minimum elements.
1059    ///
1060    /// # Returns
1061    ///
1062    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1063    fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1064
1065    /// Gets the maximum element of a tensor.
1066    ///
1067    /// # Arguments
1068    ///
1069    /// * `tensor` - The tensor to get the maximum elements of.
1070    ///
1071    /// # Returns
1072    ///
1073    /// A tensor with the maximum element of `tensor`.
1074    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1075        let shape = tensor.shape();
1076        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1077
1078        B::float_max_dim(tensor, 0)
1079    }
1080
1081    /// Gets the maximum elements of a tensor along an axis.
1082    ///
1083    /// # Arguments
1084    ///
1085    /// * `tensor` - The tensor to get the maximum elements of.
1086    /// * `dim` - The dimension along which to get the maximum elements.
1087    ///
1088    /// # Returns
1089    ///
1090    /// A tensor with the maximum elements of `tensor` along `dim`.
1091    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1092        let index = B::float_argmax(tensor.clone(), dim);
1093
1094        B::float_gather(dim, tensor, index)
1095    }
1096
1097    /// Gets the maximum elements of a tensor along an axis and their indices.
1098    ///
1099    /// # Arguments
1100    ///
1101    /// * `tensor` - The tensor to get the maximum elements of.
1102    /// * `dim` - The dimension along which to get the maximum elements.
1103    ///
1104    /// # Returns
1105    ///
1106    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1107    fn float_max_dim_with_indices(
1108        tensor: FloatTensor<B>,
1109        dim: usize,
1110    ) -> (FloatTensor<B>, IntTensor<B>) {
1111        let index = B::float_argmax(tensor.clone(), dim);
1112        let values = B::float_gather(dim, tensor, index.clone());
1113
1114        (values, index)
1115    }
1116
1117    /// Gets the minimum element of a tensor.
1118    ///
1119    /// # Arguments
1120    ///
1121    /// * `tensor` - The tensor to get the minimum elements of.
1122    ///
1123    /// # Returns
1124    ///
1125    /// A tensor with the minimum element of `tensor`.
1126    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1127        let shape = tensor.shape();
1128        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1129
1130        B::float_min_dim(tensor, 0)
1131    }
1132
1133    /// Gets the minimum elements of a tensor along an axis.
1134    ///
1135    /// # Arguments
1136    ///
1137    /// * `tensor` - The tensor to get the minimum elements of.
1138    /// * `dim` - The dimension along which to get the minimum elements.
1139    ///
1140    /// # Returns
1141    ///
1142    /// A tensor with the minimum elements of `tensor` along `dim`.
1143    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1144        let index = B::float_argmin(tensor.clone(), dim);
1145
1146        B::float_gather(dim, tensor, index)
1147    }
1148
1149    /// Gets the minimum elements of a tensor along an axis and their indices.
1150    ///
1151    /// # Arguments
1152    ///
1153    /// * `tensor` - The tensor to get the minimum elements of.
1154    /// * `dim` - The dimension along which to get the minimum elements.
1155    ///
1156    /// # Returns
1157    ///
1158    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1159    fn float_min_dim_with_indices(
1160        tensor: FloatTensor<B>,
1161        dim: usize,
1162    ) -> (FloatTensor<B>, IntTensor<B>) {
1163        let index = B::float_argmin(tensor.clone(), dim);
1164        let values = B::float_gather(dim, tensor, index.clone());
1165
1166        (values, index)
1167    }
1168
1169    /// Gets the maximum absolute element of a tensor.
1170    ///
1171    /// # Arguments
1172    ///
1173    /// * `tensor` - The tensor to get the maximum elements of.
1174    ///
1175    /// # Returns
1176    ///
1177    /// A tensor with the maximum element of `tensor`.
1178    fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1179        let shape = tensor.shape();
1180        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1181
1182        B::float_max_abs_dim(tensor, 0)
1183    }
1184
1185    /// Gets the maximum absolute elements of a tensor along an axis.
1186    ///
1187    /// # Arguments
1188    ///
1189    /// * `tensor` - The tensor to get the maximum elements of.
1190    /// * `dim` - The dimension along which to get the maximum elements.
1191    ///
1192    /// # Returns
1193    ///
1194    /// A tensor with the maximum elements of `tensor` along `dim`.
1195    fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1196        B::float_max_dim(B::float_abs(tensor), dim)
1197    }
1198
1199    /// Returns a new tensor with the given dimension narrowed to the given range.
1200    ///
1201    /// # Arguments
1202    ///
1203    /// * `dim` - The dimension along which the tensor will be narrowed.
1204    /// * `start` - The starting point of the given range.
1205    /// * `length` - The ending point of the given range.
1206    /// # Panics
1207    ///
1208    /// - If the dimension is greater than the number of dimensions of the tensor.
1209    /// - If the given range exceeds the number of elements on the given dimension.
1210    ///
1211    /// # Returns
1212    ///
1213    /// A new tensor with the given dimension narrowed to the given range.
1214    fn float_narrow(
1215        tensor: FloatTensor<B>,
1216        dim: usize,
1217        start: usize,
1218        length: usize,
1219    ) -> FloatTensor<B> {
1220        narrow::<B, Float>(TensorPrimitive::Float(tensor), dim, start, length).tensor()
1221    }
1222
1223    /// Split the tensor along the given dimension into chunks.
1224    ///
1225    /// # Arguments
1226    ///
1227    /// * `tensor` - The tensor.
1228    /// * `chunks` - The number of chunks to be produced
1229    /// * `times` - The dimension along which the tensor will be split.
1230    ///
1231    /// # Returns
1232    ///
1233    /// A vector of tensors
1234    fn float_chunk(tensor: FloatTensor<B>, chunks: usize, dim: usize) -> Vec<FloatTensor<B>> {
1235        chunk::<B, Float>(TensorPrimitive::Float(tensor), chunks, dim)
1236            .into_iter()
1237            .map(|t| t.tensor())
1238            .collect()
1239    }
1240
1241    /// Split the tensor along the given dimension into chunks of `split_size`.
1242    ///
1243    /// # Arguments
1244    ///
1245    /// * `tensor` - The tensor.
1246    /// * `split_size` - The size of a single chunk.
1247    /// * `times` - The dimension along which the tensor will be split.
1248    ///
1249    /// # Returns
1250    ///
1251    /// A vector of tensors.
1252    fn float_split(tensor: FloatTensor<B>, split_size: usize, dim: usize) -> Vec<FloatTensor<B>> {
1253        split::<B, Float>(TensorPrimitive::Float(tensor), split_size, dim)
1254            .into_iter()
1255            .map(|t| t.tensor())
1256            .collect()
1257    }
1258
1259    /// Split the tensor along the given dimension into chunks with sizes in
1260    /// `dim` according to `split_sizes`.
1261    ///
1262    /// # Arguments
1263    ///
1264    /// * `tensor` - The tensor.
1265    /// * `split_sizes` - Vector of sizes for each chunk.
1266    /// * `times` - The dimension along which the tensor will be split.
1267    ///
1268    /// # Returns
1269    ///
1270    /// A vector of tensors.
1271    fn float_split_with_sizes(
1272        tensor: FloatTensor<B>,
1273        split_sizes: Vec<usize>,
1274        dim: usize,
1275    ) -> Vec<FloatTensor<B>> {
1276        split_with_sizes::<B, Float>(TensorPrimitive::Float(tensor), split_sizes, dim)
1277            .into_iter()
1278            .map(|t| t.tensor())
1279            .collect()
1280    }
1281
1282    /// Tests if any element in the float `tensor` evaluates to True.
1283    ///
1284    /// # Arguments
1285    ///
1286    /// * `tensor` - The tensor to test.
1287    ///
1288    /// # Returns
1289    ///
1290    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1291    fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1292        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1293        let bool_tensor = B::bool_not(bool_tensor);
1294        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1295        B::float_greater_elem(sum, 0.0f32.elem())
1296    }
1297
1298    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1299    ///
1300    /// # Arguments
1301    ///
1302    /// * `tensor` - The tensor to test.
1303    /// * `dim` - The axis along which to test.
1304    ///
1305    /// # Returns
1306    ///
1307    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1308    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1309    /// input evaluates to True, False otherwise.
1310    fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1311        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1312        let bool_tensor = B::bool_not(bool_tensor);
1313        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1314        B::float_greater_elem(sum, 0.0f32.elem())
1315    }
1316
1317    /// Tests if all elements in the float `tensor` evaluate to True.
1318    ///
1319    /// # Arguments
1320    ///
1321    /// * `tensor` - The tensor to test.
1322    ///
1323    /// # Returns
1324    ///
1325    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1326    /// evaluate to True, False otherwise.
1327    fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1328        let num_elems = tensor.shape().num_elements();
1329        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1330        let bool_tensor = B::bool_not(bool_tensor);
1331        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1332        B::float_equal_elem(sum, (num_elems as f32).elem())
1333    }
1334
1335    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1336    ///
1337    /// # Arguments
1338    ///
1339    /// * `tensor` - The tensor to test.
1340    /// * `dim` - The axis along which to test.
1341    ///
1342    /// # Returns
1343    ///
1344    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1345    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1346    /// evaluates to True, False otherwise.
1347    fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1348        let num_elems = tensor.shape().dims[dim];
1349        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1350        let bool_tensor = B::bool_not(bool_tensor);
1351        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1352        B::float_equal_elem(sum, (num_elems as f32).elem())
1353    }
1354
1355    /// Returns the signs of the float `tensor`.
1356    ///
1357    /// # Arguments
1358    ///
1359    /// * `tensor` - The tensor to extract the signs from.
1360    ///
1361    /// # Returns
1362    ///
1363    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1364    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1365        let zeros = B::float_zeros(tensor.shape(), &B::float_device(&tensor));
1366        let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1367        let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1368
1369        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1370        result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1371        result
1372    }
1373
1374    /// Broadcasts the float `tensor` to the given `shape`.
1375    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1376
1377    /// Sort the elements of the input `tensor` by value in along a given dimension.
1378    ///
1379    /// This sort is unstable (i.e., may reorder equal elements).
1380    ///
1381    /// # Arguments
1382    ///
1383    /// * `tensor` - The input tensor.
1384    /// * `dim` - The axis along which to sort.
1385    /// * `descending` - The sorting order.
1386    ///
1387    /// # Returns
1388    ///
1389    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1390    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1391        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1392    }
1393
1394    /// Sort the elements of the input `tensor` by value in along a given dimension.
1395    ///
1396    /// This sort is unstable (i.e., may reorder equal elements).
1397    ///
1398    /// # Arguments
1399    ///
1400    /// * `tensor` - The input tensor.
1401    /// * `dim` - The axis along which to sort.
1402    /// * `descending` - The sorting order.
1403    ///
1404    /// # Returns
1405    ///
1406    /// A tensor with the same shape as the input tensor and corresponding indices, where
1407    /// the elements are sorted by value and the indices map back to the original input tensor.
1408    fn float_sort_with_indices(
1409        tensor: FloatTensor<B>,
1410        dim: usize,
1411        descending: bool,
1412    ) -> (FloatTensor<B>, IntTensor<B>) {
1413        let (values, indices) =
1414            sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1415        (values.tensor(), indices)
1416    }
1417
1418    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1419    ///
1420    /// This sort is unstable (i.e., may reorder equal elements).
1421    ///
1422    /// # Arguments
1423    ///
1424    /// * `tensor` - The input tensor.
1425    /// * `dim` - The axis along which to sort.
1426    /// * `descending` - The sorting order.
1427    ///
1428    /// # Returns
1429    ///
1430    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1431    fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1432        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1433    }
1434}