burn_tensor/tensor/ops/
tensor.rs

1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor};
4use crate::tensor::cast::ToElement;
5use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData};
6use crate::{
7    tensor::api::chunk, tensor::api::narrow, tensor::api::split, tensor::api::split_with_sizes,
8    FloatDType, TensorMetadata, TensorPrimitive,
9};
10use alloc::vec::Vec;
11use core::future::Future;
12use core::ops::Range;
13
14use crate::{argsort, sort, sort_with_indices};
15
16/// Operations on float tensors.
17pub trait FloatTensorOps<B: Backend> {
18    /// Creates a new tensor from the data structure.
19    ///
20    /// # Arguments
21    ///
22    /// * `data` - The data structure.
23    /// * `device` - The device to create the tensor on.
24    ///
25    /// # Returns
26    ///
27    /// The tensor with the given data.
28    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
29
30    /// Creates a new tensor with random values.
31    ///
32    /// # Arguments
33    ///
34    /// * `shape` - The shape of the tensor.
35    /// * `distribution` - The distribution to sample from.
36    /// * `device` - The device to create the tensor on.
37    ///
38    /// # Returns
39    ///
40    /// The tensor with the given shape and random values.
41    fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
42        -> FloatTensor<B>;
43
44    /// Creates a new tensor with zeros.
45    ///
46    /// # Arguments
47    ///
48    /// * `shape` - The shape of the tensor.
49    /// * `device` - The device to create the tensor on.
50    ///
51    /// # Returns
52    ///
53    /// The tensor with the given shape and zeros.
54    fn float_zeros(shape: Shape, device: &Device<B>) -> FloatTensor<B> {
55        Self::float_from_data(TensorData::zeros::<FloatElem<B>, _>(shape), device)
56    }
57
58    /// Creates a new tensor with ones.
59    ///
60    /// # Arguments
61    ///
62    /// * `shape` - The shape of the tensor.
63    /// * `device` - The device to create the tensor on.
64    ///
65    /// # Returns
66    ///
67    /// The tensor with the given shape and ones.
68    fn float_ones(shape: Shape, device: &Device<B>) -> FloatTensor<B> {
69        Self::float_from_data(TensorData::ones::<FloatElem<B>, _>(shape), device)
70    }
71
72    /// Creates a tensor filled with given value.
73    ///
74    /// # Arguments
75    ///
76    /// * `shape` - The shape of the tensor.
77    /// * `fill_value` - The value with which to fill the tensor.
78    /// * `device` - The device to create the tensor on.
79    ///
80    /// # Returns
81    ///
82    /// The tensor filled with given value
83    fn float_full(shape: Shape, fill_value: FloatElem<B>, device: &Device<B>) -> FloatTensor<B> {
84        Self::float_add_scalar(Self::float_zeros(shape, device), fill_value)
85    }
86
87    /// Converts the tensor to a data structure.
88    ///
89    /// # Arguments
90    ///
91    /// * `tensor` - The tensor.
92    ///
93    /// # Returns
94    ///
95    /// The data structure with the tensor's data.
96    fn float_into_data(tensor: FloatTensor<B>)
97        -> impl Future<Output = TensorData> + 'static + Send;
98
99    /// Gets the device of the tensor.
100    ///
101    /// # Arguments
102    ///
103    /// * `tensor` - The tensor.
104    ///
105    /// # Returns
106    ///
107    /// The device of the tensor.
108    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
109
110    /// Moves the tensor to the given device.
111    ///
112    /// # Arguments
113    ///
114    /// * `tensor` - The tensor.
115    /// * `device` - The device to move the tensor to.
116    ///
117    /// # Returns
118    ///
119    /// The tensor on the given device.
120    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
121
122    /// Converts float tensor to int tensor.
123    ///
124    /// # Arguments
125    ///
126    /// * `tensor` - The tensor.
127    ///
128    /// # Returns
129    ///
130    /// The int tensor with the same data as the float tensor.
131    fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
132
133    /// Creates an empty tensor with the given shape.
134    ///
135    /// # Arguments
136    ///
137    /// * `shape` - The shape of the tensor.
138    /// * `device` - The device to create the tensor on.
139    ///
140    /// # Returns
141    ///
142    /// The empty tensor with the given shape.
143    fn float_empty(shape: Shape, device: &Device<B>) -> FloatTensor<B>;
144
145    /// Repeat the tensor along the given dimension.
146    ///
147    /// # Arguments
148    ///
149    /// * `tensor` - The tensor.
150    /// * `dim` - The dimension to repeat.
151    /// * `times` - The number of times to repeat the dimension.
152    ///
153    /// # Returns
154    ///
155    /// The tensor with the given dimension repeated.
156    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
157        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
158    }
159
160    /// Adds two tensors together.
161    ///
162    /// # Arguments
163    ///
164    /// * `lhs` - The left hand side tensor.
165    /// * `rhs` - The right hand side tensor.
166    ///
167    /// # Returns
168    ///
169    /// The result of adding the two tensors together.
170    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
171
172    /// Adds a scalar to a tensor.
173    ///
174    /// # Arguments
175    ///
176    /// * `lhs` - The left hand side tensor.
177    /// * `rhs` - The right hand side scalar.
178    ///
179    /// # Returns
180    ///
181    /// The result of adding the scalar to the tensor.
182    fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
183
184    /// Clamps a tensor under a minimum value.
185    ///
186    /// # Arguments
187    ///
188    /// * `tensor` - The tensor to clamp.
189    /// * `min` - The minimum value.
190    ///
191    /// # Returns
192    ///
193    /// The clamped tensor.
194    fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
195        // Default implementation
196        let mask = Self::float_lower_elem(tensor.clone(), min);
197        B::float_mask_fill(tensor, mask, min)
198    }
199
200    /// Clamps a tensor over a maximum value.
201    ///
202    /// # Arguments
203    ///
204    /// * `tensor` - The tensor to clamp.
205    /// * `max` - The maximum value.
206    ///
207    /// # Returns
208    ///
209    /// The clamped tensor.
210    fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
211        // Default implementation
212        let mask = Self::float_greater_elem(tensor.clone(), max);
213        B::float_mask_fill(tensor, mask, max)
214    }
215
216    /// Clamps a tensor between a minimum and maximum value.
217    ///
218    /// # Arguments
219    ///
220    /// * `tensor` - The tensor to clamp.
221    /// * `min` - The minimum value.
222    /// * `max` - The maximum value.
223    ///
224    /// # Returns
225    ///
226    /// The clamped tensor.
227    fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
228        // Default implementation
229        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
230    }
231
232    /// Subtracts two tensors.
233    ///
234    /// # Arguments
235    ///
236    /// * `lhs` - The left hand side tensor.
237    /// * `rhs` - The right hand side tensor.
238    ///
239    /// # Returns
240    ///
241    /// The result of subtracting the two tensors.
242    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
243
244    /// Subtracts a scalar from a tensor.
245    ///
246    /// # Arguments
247    ///
248    /// * `lhs` - The left hand side tensor.
249    /// * `rhs` - The right hand side scalar.
250    ///
251    /// # Returns
252    ///
253    /// The result of subtracting the scalar from the tensor.
254    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
255
256    /// Multiplies two tensors together element-wise.
257    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259    /// Multiplies a tensor by a scalar.
260    ///
261    /// # Arguments
262    ///
263    /// * `lhs` - The left hand side tensor.
264    /// * `rhs` - The right hand side scalar.
265    ///
266    /// # Returns
267    ///
268    /// The result of multiplying the tensor by the scalar.
269    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
270
271    /// Divides two tensors element-wise.
272    ///
273    /// # Arguments
274    ///
275    /// * `lhs` - The left hand side tensor.
276    /// * `rhs` - The right hand side tensor.
277    ///
278    /// # Returns
279    ///
280    /// The result of dividing the two tensors.
281    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
282
283    /// Divides a tensor by a scalar.
284    ///
285    /// # Arguments
286    ///
287    /// * `lhs` - The left hand side tensor.
288    /// * `rhs` - The right hand side scalar.
289    ///
290    /// # Returns
291    ///
292    /// The result of dividing the tensor by the scalar.
293    fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
294
295    /// Computes the remainder of division between two tensors element-wise.
296    ///
297    /// # Arguments
298    ///
299    /// * `lhs` - The left hand side tensor.
300    /// * `rhs` - The right hand side tensor.
301    ///
302    /// # Returns
303    ///
304    /// The element-wise remainder when dividing `lhs` by `rhs`.
305    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
306
307    /// Computes the modulus of a tensor given a scalar.
308    ///
309    /// # Arguments
310    /// * `lhs` - The left hand side tensor.
311    /// * `rhs` - The right hand side scalar.
312    ///
313    /// # Returns
314    ///
315    /// The result of applying the modulus of the scalar to the tensor.
316    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
317
318    /// Multiplies two tensors together using matrix multiplication.
319    ///
320    /// # Arguments
321    ///
322    /// * `lhs` - The left hand side tensor.
323    /// * `rhs` - The right hand side tensor.
324    ///
325    /// # Returns
326    ///
327    /// The result of multiplying the two tensors together using matrix multiplication.
328    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
329
330    /// Negates a tensor element-wise.
331    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
332        Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
333    }
334
335    /// Calculates the reciprocals element-wise
336    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
337
338    /// Transposes a tensor.
339    ///
340    /// # Arguments
341    ///
342    /// * `tensor` - The tensor to transpose.
343    ///
344    /// # Returns
345    ///
346    /// The transposed tensor.
347    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
348        let ndims = tensor.shape().num_dims();
349        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
350    }
351
352    /// Swaps two dimensions of a tensor.
353    ///
354    /// # Arguments
355    ///
356    /// * `tensor` - The tensor to swap the dimensions of.
357    /// * `dim1` - The first dimension to swap.
358    /// * `dim2` - The second dimension to swap.
359    ///
360    /// # Returns
361    ///
362    /// The tensor with the dimensions swapped.
363    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
364
365    /// Permutes the dimensions of a tensor.
366    ///
367    /// # Arguments
368    ///
369    /// * `tensor` - The tensor to permute the dimensions of.
370    /// * `axes` - The new order of the dimensions.
371    /// # Returns
372    ///
373    /// The tensor with the dimensions permuted.
374    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
375
376    /// Reverse the order of elements in a tensor along the given axes.
377    ///
378    /// # Arguments
379    ///
380    /// * `tensor` - The tensor to reverse.
381    /// * `axes` - The axes to reverse.
382    ///
383    /// The tensor with the elements reversed.
384    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
385
386    /// Reshapes a tensor.
387    ///
388    /// # Arguments
389    ///
390    /// * `tensor` - The tensor to reshape.
391    /// * `shape` - The new shape of the tensor.
392    ///
393    /// # Returns
394    ///
395    /// The tensor with the new shape.
396    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
397
398    /// Gather elements from a tensor.
399    ///
400    /// # Arguments
401    ///
402    /// * `dim` - The dimension to gather from.
403    /// * `tensor` - The tensor to gather from.
404    /// * `indices` - The indices to gather.
405    ///
406    /// # Returns
407    ///
408    /// The gathered elements.
409    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
410
411    /// Scatter elements into a tensor.
412    ///
413    /// # Arguments
414    ///
415    /// * `dim` - The dimension to scatter into.
416    /// * `tensor` - The tensor to scatter into.
417    /// * `indices` - The indices to scatter into.
418    /// * `value` - The value to scatter.
419    ///
420    /// # Returns
421    ///
422    /// The tensor with the scattered elements.
423    fn float_scatter(
424        dim: usize,
425        tensor: FloatTensor<B>,
426        indices: IntTensor<B>,
427        value: FloatTensor<B>,
428    ) -> FloatTensor<B>;
429
430    /// Select tensor elements along the given dimension corresponding for the given indices.
431    ///
432    /// # Arguments
433    ///
434    /// * `tensor` - The tensor to select from.
435    /// * `dim` - The dimension to select from.
436    /// * `indices` - The indices to select.
437    ///
438    /// # Returns
439    ///
440    /// The selected elements.
441    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
442
443    /// Assign the selected elements along the given dimension corresponding for the given indices
444    /// to the given value.
445    ///
446    /// # Arguments
447    ///
448    /// * `tensor` - The tensor to select from.
449    /// * `dim` - The dimension to select from.
450    /// * `indices` - The indices to select.
451    /// * `value` - The value to assign.
452    ///
453    /// # Returns
454    ///
455    /// The tensor with the selected elements assigned to the given value.
456    fn float_select_assign(
457        tensor: FloatTensor<B>,
458        dim: usize,
459        indices: IntTensor<B>,
460        value: FloatTensor<B>,
461    ) -> FloatTensor<B>;
462
463    /// Select tensor elements corresponding for the given ranges.
464    ///
465    /// # Arguments
466    ///
467    /// * `tensor` - The tensor to select from.
468    /// * `ranges` - The ranges to select.
469    ///
470    /// # Returns
471    ///
472    /// The selected elements in a new tensor.
473    fn float_slice(tensor: FloatTensor<B>, ranges: &[Range<usize>]) -> FloatTensor<B>;
474
475    /// Assign the selected elements corresponding for the given ranges to the given value.
476    ///
477    /// # Arguments
478    ///
479    /// * `tensor` - The tensor to select from.
480    /// * `ranges` - The ranges to select.
481    /// * `value` - The value to assign.
482    ///
483    /// # Returns
484    ///
485    /// The tensor with the selected elements assigned to the given value.
486    fn float_slice_assign(
487        tensor: FloatTensor<B>,
488        ranges: &[Range<usize>],
489        value: FloatTensor<B>,
490    ) -> FloatTensor<B>;
491
492    /// Update the given tensor with the value tensor where the mask is true.
493    ///
494    /// # Arguments
495    ///
496    /// * `tensor` - The tensor to select from.
497    /// * `mask` - The boolean mask to select with.
498    /// * `value` - The value to assign to the selected elements from the value tensor.
499    ///
500    /// # Returns
501    ///
502    /// The tensor with the selected elements assigned to the given value.
503    fn float_mask_where(
504        tensor: FloatTensor<B>,
505        mask: BoolTensor<B>,
506        value: FloatTensor<B>,
507    ) -> FloatTensor<B>;
508
509    /// Update the given tensor with the value where the mask is true.
510    ///
511    /// # Arguments
512    ///
513    /// * `tensor` - The tensor to select from.
514    /// * `mask` - The boolean mask to select with.
515    /// * `value` - The value to assign to the selected elements.
516    ///
517    /// # Returns
518    ///
519    /// The tensor with the selected elements assigned to the given value.
520    fn float_mask_fill(
521        tensor: FloatTensor<B>,
522        mask: BoolTensor<B>,
523        value: FloatElem<B>,
524    ) -> FloatTensor<B>;
525
526    /// Equal comparison of two tensors.
527    ///
528    /// # Arguments
529    ///
530    /// * `lhs` - The left hand side tensor.
531    /// * `rhs` - The right hand side tensor.
532    ///
533    /// # Returns
534    ///
535    /// A boolean tensor with the result of the comparison.
536    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
537
538    /// Element-wise non-equality comparison.
539    ///
540    /// # Arguments
541    ///
542    /// * `lhs` - The left hand side tensor.
543    /// * `rhs` - The right hand side tensor.
544    ///
545    /// # Returns
546    ///
547    /// A boolean tensor with the result of the comparison.
548    fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
549        let equal_tensor = B::float_equal(lhs, rhs);
550        B::bool_not(equal_tensor)
551    }
552
553    /// Equal comparison of a tensor and a scalar.
554    ///
555    /// # Arguments
556    ///
557    /// * `lhs` - The left hand side tensor.
558    /// * `rhs` - The right hand side scalar.
559    ///
560    /// # Returns
561    ///
562    /// A boolean tensor with the result of the comparison.
563    fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
564
565    /// Element-wise non-equality comparison with a scalar.
566    ///
567    /// # Arguments
568    ///
569    /// * `lhs` - The left hand side tensor.
570    /// * `rhs` - The right hand side scalar.
571    ///
572    /// # Returns
573    ///
574    /// A boolean tensor with the result of the comparison.
575    fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
576        let equal_tensor = B::float_equal_elem(lhs, rhs);
577        B::bool_not(equal_tensor)
578    }
579
580    /// Greater than comparison of two tensors.
581    ///
582    /// # Arguments
583    ///
584    /// * `lhs` - The left hand side tensor.
585    /// * `rhs` - The right hand side tensor.
586    ///
587    /// # Returns
588    ///
589    /// A boolean tensor with the result of the comparison.
590    fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
591
592    /// Greater than comparison of a tensor and a scalar.
593    ///
594    /// # Arguments
595    ///
596    /// * `lhs` - The left hand side tensor.
597    /// * `rhs` - The right hand side scalar.
598    ///
599    /// # Returns
600    ///
601    /// A boolean tensor with the result of the comparison.
602    fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
603
604    /// Greater than or equal comparison of two tensors.
605    ///
606    /// # Arguments
607    ///
608    /// * `lhs` - The left hand side tensor.
609    /// * `rhs` - The right hand side tensor.
610    ///
611    /// # Returns
612    ///
613    /// A boolean tensor with the result of the comparison.
614    fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
615
616    /// Greater than or equal comparison of a tensor and a scalar.
617    ///
618    /// # Arguments
619    ///
620    /// * `lhs` - The left hand side tensor.
621    /// * `rhs` - The right hand side scalar.
622    ///
623    /// # Returns
624    ///
625    /// A boolean tensor with the result of the comparison.
626    fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
627
628    /// Less than comparison of two tensors.
629    ///
630    /// # Arguments
631    ///
632    /// * `lhs` - The left hand side tensor.
633    /// * `rhs` - The right hand side tensor.
634    ///
635    /// # Returns
636    ///
637    /// A boolean tensor with the result of the comparison.
638    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
639
640    /// Less than comparison of a tensor and a scalar.
641    ///
642    /// # Arguments
643    ///
644    /// * `lhs` - The left hand side tensor.
645    /// * `rhs` - The right hand side scalar.
646    ///
647    /// # Returns
648    ///
649    /// A boolean tensor with the result of the comparison.
650    fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
651
652    /// Less than or equal comparison of two tensors.
653    ///
654    /// # Arguments
655    ///
656    /// * `lhs` - The left hand side tensor.
657    /// * `rhs` - The right hand side tensor.
658    ///
659    /// # Returns
660    ///
661    /// A boolean tensor with the result of the comparison.
662    fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
663
664    /// Less than or equal comparison of a tensor and a scalar.
665    ///
666    /// # Arguments
667    ///
668    /// * `lhs` - The left hand side tensor.
669    /// * `rhs` - The right hand side scalar.
670    ///
671    /// # Returns
672    ///
673    /// A boolean tensor with the result of the comparison.
674    fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
675
676    /// Detaches a tensor from the computation graph.
677    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
678        // Should only be overridden by autodiff backends.
679        tensor
680    }
681
682    /// Sets the `require_grad` flag of a tensor.
683    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
684        // Should only be overridden by autodiff backends.
685        tensor
686    }
687
688    /// Returns the `require_grad` flag of a tensor.
689    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
690        // Should only be overridden by autodiff backends.
691        false
692    }
693
694    /// Sum of all elements in a tensor.
695    ///
696    /// # Arguments
697    ///
698    /// * `tensor` - The tensor to sum.
699    ///
700    /// # Returns
701    ///
702    /// A scalar tensor with the sum of all elements in `tensor`.
703    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
704
705    /// Sum of all elements in a tensor along a dimension.
706    ///
707    /// # Arguments
708    ///
709    /// * `tensor` - The tensor to sum.
710    /// * `dim` - The dimension along which to sum.
711    ///
712    /// # Returns
713    ///
714    /// A tensor with the sum of all elements in `tensor` along `dim`.
715    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
716
717    /// Product of all elements in a tensor.
718    ///
719    /// # Arguments
720    ///
721    /// * `tensor` - The tensor to product.
722    ///
723    /// # Returns
724    ///
725    /// A scalar tensor with the product of all elements in `tensor`.
726    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
727        // Product of all elements in a tensor
728        B::float_exp(B::float_sum(B::float_log(tensor)))
729    }
730
731    /// Product of all elements in a tensor along a dimension.
732    ///
733    /// # Arguments
734    ///
735    /// * `tensor` - The tensor to product.
736    ///
737    /// # Returns
738    ///
739    /// A tensor with the product of all elements in `tensor` along `dim`.
740    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
741        // Product of all elements in a tensor along a dimension
742        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
743    }
744
745    /// Mean of all elements in a tensor.
746    ///
747    /// # Arguments
748    ///
749    /// * `tensor` - The tensor to mean.
750    ///
751    /// # Returns
752    ///
753    /// A scalar tensor with the mean of all elements in `tensor`.
754    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
755        let num_elems = tensor.shape().num_elements();
756        B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
757    }
758
759    /// Mean of all elements in a tensor along a dimension.
760    ///
761    /// # Arguments
762    ///
763    /// * `tensor` - The tensor to mean.
764    /// * `dim` - The dimension along which to mean.
765    ///
766    /// # Returns
767    ///
768    /// A tensor with the mean of all elements in `tensor` along `dim`.
769    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
770
771    /// Converts a tensor to another floating point data type.
772    ///
773    /// # Arguments
774    ///
775    /// * `tensor` - The tensor to convert.
776    /// * `dtype` - The target data type.
777    ///
778    /// # Returns
779    ///
780    /// A tensor with the same values as `tensor` but in the target floating point data type.
781    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
782
783    /// Returns a new tensor with exponential values.
784    ///
785    /// # Arguments
786    ///
787    /// * `tensor` - The tensor to exponentiate.
788    ///
789    /// # Returns
790    ///
791    /// A tensor with the same shape as `tensor` with exponential values.
792    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
793
794    /// Returns a new tensor with natural logarithm values.
795    ///
796    /// # Arguments
797    ///
798    /// * `tensor` - The tensor to take the logarithm of.
799    ///
800    /// # Returns
801    ///
802    /// A tensor with the same shape as `tensor` with natural logarithm values.
803    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
804
805    /// Returns a new tensor with logarithm values of (1 + Xi).
806    ///
807    /// # Arguments
808    ///
809    /// * `tensor` - The tensor to take the logarithm of.
810    ///
811    /// # Returns
812    ///
813    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
814    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
815
816    /// Element-wise power with a FloatTensor.
817    ///
818    /// # Arguments
819    ///
820    /// * `lhs` - The left hand side tensor.
821    /// * `rhs` - The right hand side tensor.
822    ///
823    /// # Returns
824    ///
825    /// The elements of `lhs` raised to the power of the elements of `rhs`.
826    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
827
828    /// Element-wise power with an IntTensor.
829    ///
830    /// # Arguments
831    ///
832    /// * `lhs` - The left hand side tensor.
833    /// * `rhs` - The right hand side floatTensor.
834    ///
835    /// # Returns
836    ///
837    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
838    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
839        Self::float_powf(lhs, B::int_into_float(rhs))
840    }
841
842    /// raises a tensor to the power of an int scalar.
843    ///
844    /// # Arguments
845    ///
846    /// * `lhs` - The left hand side tensor.
847    /// * `rhs` - The right hand side scalar.
848    ///
849    /// # Returns
850    ///
851    /// The elements of `lhs` raised to the value of `rhs`.
852    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
853        Self::float_powf_scalar(lhs, rhs.to_f32())
854    }
855
856    /// Returns a new tensor with values raised to the power of float `value`.
857    ///
858    /// # Arguments
859    ///
860    /// * `tensor` - The tensor to exponentiate.
861    /// * `value` - The exponent.
862    ///
863    /// # Returns
864    ///
865    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
866    fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
867
868    /// Returns a new tensor with square root values.
869    ///
870    /// # Arguments
871    ///
872    /// * `tensor` - The tensor to take the square root of.
873    ///
874    /// # Returns
875    ///
876    /// A tensor with the same shape as `tensor` with square root values.
877    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
878
879    /// Returns a new tensor with absolute values.
880    ///
881    /// # Arguments
882    ///
883    /// * `tensor` - The tensor to take absolute value of.
884    ///
885    /// # Returns
886    ///
887    /// A tensor with the same shape as `tensor` with absolute values.
888    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
889
890    /// Returns a new tensor with cosine values.
891    ///
892    /// # Arguments
893    ///
894    /// * `tensor` - The tensor to take the cosine of.
895    ///
896    /// # Returns
897    ///
898    /// A tensor with the same shape as `tensor` with cosine values.
899    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
900
901    /// Returns a new tensor with sine values.
902    ///
903    /// # Arguments
904    ///
905    /// * `tensor` - The tensor to take the sine of.
906    ///
907    /// # Returns
908    ///
909    /// A tensor with the same shape as `tensor` with sine values.
910    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
911
912    /// Returns a new tensor with tangent values.
913    ///
914    /// # Arguments
915    ///
916    /// * `tensor` - The tensor to take the tangent of.
917    ///
918    /// # Returns
919    ///
920    /// A tensor with the same shape as `tensor` with tangent values.
921    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
922
923    /// Returns a new tensor with rounded values.
924    ///
925    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
926    /// strategy, with halfway cases rounded to the nearest even integer value.
927    ///
928    /// # Arguments
929    ///
930    /// * `tensor` - The tensor to be rounded.
931    ///
932    /// # Returns
933    ///
934    /// A tensor with the same shape as `tensor` with rounded values.
935    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
936
937    /// Returns a new tensor with floored values.
938    ///
939    /// # Arguments
940    ///
941    /// * `tensor` - The tensor to be floored.
942    ///
943    /// # Returns
944    ///
945    /// A tensor with the same shape as `tensor` with floored values.
946    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
947
948    /// Returns a new tensor with ceiled values.
949    ///
950    /// # Arguments
951    ///
952    /// * `tensor` - The tensor to be ceiled.
953    ///
954    /// # Returns
955    ///
956    /// A tensor with the same shape as `tensor` with ceiled values.
957    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
958
959    /// Returns a new tensor with the error function values.
960    ///
961    /// # Arguments
962    ///
963    /// * `tensor` - The tensor to take the error function of.
964    ///
965    /// # Returns
966    ///
967    /// A tensor with the same shape as `tensor` with error function values.
968    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
969
970    /// Concatenates tensors along a dimension.
971    ///
972    /// # Arguments
973    ///
974    /// * `tensors` - The tensors to concatenate.
975    /// * `dim` - The dimension along which to concatenate.
976    ///
977    /// # Returns
978    ///
979    /// A tensor with the concatenated tensors along `dim`.
980    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
981        cat_with_slice_assign::<B, Float>(
982            tensors.into_iter().map(TensorPrimitive::Float).collect(),
983            dim,
984        )
985        .tensor()
986    }
987
988    /// Gets the indices of the maximum elements of a tensor along an axis.
989    ///
990    /// # Arguments
991    ///
992    /// * `tensor` - The tensor to get the maximum elements of.
993    /// * `dim` - The dimension along which to get the maximum elements.
994    ///
995    /// # Returns
996    ///
997    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
998    fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
999
1000    /// Gets the indices of the minimum elements of a tensor along an axis.
1001    ///
1002    /// # Arguments
1003    ///
1004    /// * `tensor` - The tensor to get the minimum elements of.
1005    /// * `dim` - The dimension along which to get the minimum elements.
1006    ///
1007    /// # Returns
1008    ///
1009    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1010    fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1011
1012    /// Gets the maximum element of a tensor.
1013    ///
1014    /// # Arguments
1015    ///
1016    /// * `tensor` - The tensor to get the maximum elements of.
1017    ///
1018    /// # Returns
1019    ///
1020    /// A tensor with the maximum element of `tensor`.
1021    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1022        let shape = tensor.shape();
1023        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1024
1025        B::float_max_dim(tensor, 0)
1026    }
1027
1028    /// Gets the maximum elements of a tensor along an axis.
1029    ///
1030    /// # Arguments
1031    ///
1032    /// * `tensor` - The tensor to get the maximum elements of.
1033    /// * `dim` - The dimension along which to get the maximum elements.
1034    ///
1035    /// # Returns
1036    ///
1037    /// A tensor with the maximum elements of `tensor` along `dim`.
1038    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1039        let index = B::float_argmax(tensor.clone(), dim);
1040
1041        B::float_gather(dim, tensor, index)
1042    }
1043
1044    /// Gets the maximum elements of a tensor along an axis and their indices.
1045    ///
1046    /// # Arguments
1047    ///
1048    /// * `tensor` - The tensor to get the maximum elements of.
1049    /// * `dim` - The dimension along which to get the maximum elements.
1050    ///
1051    /// # Returns
1052    ///
1053    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1054    fn float_max_dim_with_indices(
1055        tensor: FloatTensor<B>,
1056        dim: usize,
1057    ) -> (FloatTensor<B>, IntTensor<B>) {
1058        let index = B::float_argmax(tensor.clone(), dim);
1059        let values = B::float_gather(dim, tensor, index.clone());
1060
1061        (values, index)
1062    }
1063
1064    /// Gets the minimum element of a tensor.
1065    ///
1066    /// # Arguments
1067    ///
1068    /// * `tensor` - The tensor to get the minimum elements of.
1069    ///
1070    /// # Returns
1071    ///
1072    /// A tensor with the minimum element of `tensor`.
1073    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1074        let shape = tensor.shape();
1075        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1076
1077        B::float_min_dim(tensor, 0)
1078    }
1079
1080    /// Gets the minimum elements of a tensor along an axis.
1081    ///
1082    /// # Arguments
1083    ///
1084    /// * `tensor` - The tensor to get the minimum elements of.
1085    /// * `dim` - The dimension along which to get the minimum elements.
1086    ///
1087    /// # Returns
1088    ///
1089    /// A tensor with the minimum elements of `tensor` along `dim`.
1090    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1091        let index = B::float_argmin(tensor.clone(), dim);
1092
1093        B::float_gather(dim, tensor, index)
1094    }
1095
1096    /// Gets the minimum elements of a tensor along an axis and their indices.
1097    ///
1098    /// # Arguments
1099    ///
1100    /// * `tensor` - The tensor to get the minimum elements of.
1101    /// * `dim` - The dimension along which to get the minimum elements.
1102    ///
1103    /// # Returns
1104    ///
1105    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1106    fn float_min_dim_with_indices(
1107        tensor: FloatTensor<B>,
1108        dim: usize,
1109    ) -> (FloatTensor<B>, IntTensor<B>) {
1110        let index = B::float_argmin(tensor.clone(), dim);
1111        let values = B::float_gather(dim, tensor, index.clone());
1112
1113        (values, index)
1114    }
1115
1116    /// Returns a new tensor with the given dimension narrowed to the given range.
1117    ///
1118    /// # Arguments
1119    ///
1120    /// * `dim` - The dimension along which the tensor will be narrowed.
1121    /// * `start` - The starting point of the given range.
1122    /// * `length` - The ending point of the given range.
1123    /// # Panics
1124    ///
1125    /// - If the dimension is greater than the number of dimensions of the tensor.
1126    /// - If the given range exceeds the number of elements on the given dimension.
1127    ///
1128    /// # Returns
1129    ///
1130    /// A new tensor with the given dimension narrowed to the given range.
1131    fn float_narrow(
1132        tensor: FloatTensor<B>,
1133        dim: usize,
1134        start: usize,
1135        length: usize,
1136    ) -> FloatTensor<B> {
1137        narrow::<B, Float>(TensorPrimitive::Float(tensor), dim, start, length).tensor()
1138    }
1139
1140    /// Split the tensor along the given dimension into chunks.
1141    ///
1142    /// # Arguments
1143    ///
1144    /// * `tensor` - The tensor.
1145    /// * `chunks` - The number of chunks to be produced
1146    /// * `times` - The dimension along which the tensor will be split.
1147    ///
1148    /// # Returns
1149    ///
1150    /// A vector of tensors
1151    fn float_chunk(tensor: FloatTensor<B>, chunks: usize, dim: usize) -> Vec<FloatTensor<B>> {
1152        chunk::<B, Float>(TensorPrimitive::Float(tensor), chunks, dim)
1153            .into_iter()
1154            .map(|t| t.tensor())
1155            .collect()
1156    }
1157
1158    /// Split the tensor along the given dimension into chunks of `split_size`.
1159    ///
1160    /// # Arguments
1161    ///
1162    /// * `tensor` - The tensor.
1163    /// * `split_size` - The size of a single chunk.
1164    /// * `times` - The dimension along which the tensor will be split.
1165    ///
1166    /// # Returns
1167    ///
1168    /// A vector of tensors.
1169    fn float_split(tensor: FloatTensor<B>, split_size: usize, dim: usize) -> Vec<FloatTensor<B>> {
1170        split::<B, Float>(TensorPrimitive::Float(tensor), split_size, dim)
1171            .into_iter()
1172            .map(|t| t.tensor())
1173            .collect()
1174    }
1175
1176    /// Split the tensor along the given dimension into chunks with sizes in
1177    /// `dim` according to `split_sizes`.
1178    ///
1179    /// # Arguments
1180    ///
1181    /// * `tensor` - The tensor.
1182    /// * `split_sizes` - Vector of sizes for each chunk.
1183    /// * `times` - The dimension along which the tensor will be split.
1184    ///
1185    /// # Returns
1186    ///
1187    /// A vector of tensors.
1188    fn float_split_with_sizes(
1189        tensor: FloatTensor<B>,
1190        split_sizes: Vec<usize>,
1191        dim: usize,
1192    ) -> Vec<FloatTensor<B>> {
1193        split_with_sizes::<B, Float>(TensorPrimitive::Float(tensor), split_sizes, dim)
1194            .into_iter()
1195            .map(|t| t.tensor())
1196            .collect()
1197    }
1198
1199    /// Tests if any element in the float `tensor` evaluates to True.
1200    ///
1201    /// # Arguments
1202    ///
1203    /// * `tensor` - The tensor to test.
1204    ///
1205    /// # Returns
1206    ///
1207    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1208    fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1209        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1210        let bool_tensor = B::bool_not(bool_tensor);
1211        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1212        B::float_greater_elem(sum, 0.0f32.elem())
1213    }
1214
1215    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1216    ///
1217    /// # Arguments
1218    ///
1219    /// * `tensor` - The tensor to test.
1220    /// * `dim` - The axis along which to test.
1221    ///
1222    /// # Returns
1223    ///
1224    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1225    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1226    /// input evaluates to True, False otherwise.
1227    fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1228        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1229        let bool_tensor = B::bool_not(bool_tensor);
1230        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1231        B::float_greater_elem(sum, 0.0f32.elem())
1232    }
1233
1234    /// Tests if all elements in the float `tensor` evaluate to True.
1235    ///
1236    /// # Arguments
1237    ///
1238    /// * `tensor` - The tensor to test.
1239    ///
1240    /// # Returns
1241    ///
1242    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1243    /// evaluate to True, False otherwise.
1244    fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1245        let num_elems = tensor.shape().num_elements();
1246        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1247        let bool_tensor = B::bool_not(bool_tensor);
1248        let sum = B::float_sum(B::bool_into_float(bool_tensor));
1249        B::float_equal_elem(sum, (num_elems as f32).elem())
1250    }
1251
1252    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1253    ///
1254    /// # Arguments
1255    ///
1256    /// * `tensor` - The tensor to test.
1257    /// * `dim` - The axis along which to test.
1258    ///
1259    /// # Returns
1260    ///
1261    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1262    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1263    /// evaluates to True, False otherwise.
1264    fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1265        let num_elems = tensor.shape().dims[dim];
1266        let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1267        let bool_tensor = B::bool_not(bool_tensor);
1268        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1269        B::float_equal_elem(sum, (num_elems as f32).elem())
1270    }
1271
1272    /// Returns the signs of the float `tensor`.
1273    ///
1274    /// # Arguments
1275    ///
1276    /// * `tensor` - The tensor to extract the signs from.
1277    ///
1278    /// # Returns
1279    ///
1280    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1281    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1282        let zeros = B::float_zeros(tensor.shape(), &B::float_device(&tensor));
1283        let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1284        let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1285
1286        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1287        result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1288        result
1289    }
1290
1291    /// Broadcasts the float `tensor` to the given `shape`.
1292    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1293
1294    /// Sort the elements of the input `tensor` by value in along a given dimension.
1295    ///
1296    /// This sort is unstable (i.e., may reorder equal elements).
1297    ///
1298    /// # Arguments
1299    ///
1300    /// * `tensor` - The input tensor.
1301    /// * `dim` - The axis along which to sort.
1302    /// * `descending` - The sorting order.
1303    ///
1304    /// # Returns
1305    ///
1306    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1307    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1308        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1309    }
1310
1311    /// Sort the elements of the input `tensor` by value in along a given dimension.
1312    ///
1313    /// This sort is unstable (i.e., may reorder equal elements).
1314    ///
1315    /// # Arguments
1316    ///
1317    /// * `tensor` - The input tensor.
1318    /// * `dim` - The axis along which to sort.
1319    /// * `descending` - The sorting order.
1320    ///
1321    /// # Returns
1322    ///
1323    /// A tensor with the same shape as the input tensor and corresponding indices, where
1324    /// the elements are sorted by value and the indices map back to the original input tensor.
1325    fn float_sort_with_indices(
1326        tensor: FloatTensor<B>,
1327        dim: usize,
1328        descending: bool,
1329    ) -> (FloatTensor<B>, IntTensor<B>) {
1330        let (values, indices) =
1331            sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1332        (values.tensor(), indices)
1333    }
1334
1335    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1336    ///
1337    /// This sort is unstable (i.e., may reorder equal elements).
1338    ///
1339    /// # Arguments
1340    ///
1341    /// * `tensor` - The input tensor.
1342    /// * `dim` - The axis along which to sort.
1343    /// * `descending` - The sorting order.
1344    ///
1345    /// # Returns
1346    ///
1347    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1348    fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1349        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1350    }
1351}