burn_backend/backend/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};
7use crate::{Backend, Distribution, TensorData, get_device_settings};
8use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14 /// Creates a new tensor from the data structure.
15 ///
16 /// # Arguments
17 ///
18 /// * `data` - The data structure.
19 /// * `device` - The device to create the tensor on.
20 ///
21 /// # Returns
22 ///
23 /// The tensor with the given data.
24 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26 /// Creates a new tensor with random values.
27 ///
28 /// # Arguments
29 ///
30 /// * `shape` - The shape of the tensor.
31 /// * `distribution` - The distribution to sample from.
32 /// * `device` - The device to create the tensor on.
33 /// * `dtype` - The target data type.
34 ///
35 /// # Returns
36 ///
37 /// The tensor with the given shape and random values.
38 fn float_random(
39 shape: Shape,
40 distribution: Distribution,
41 device: &Device<B>,
42 dtype: FloatDType,
43 ) -> FloatTensor<B>;
44
45 /// Creates a new tensor with zeros.
46 ///
47 /// # Arguments
48 ///
49 /// * `shape` - The shape of the tensor.
50 /// * `device` - The device to create the tensor on.
51 /// * `dtype` - The target data type.
52 ///
53 /// # Returns
54 ///
55 /// The tensor with the given shape and zeros.
56 fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
57 Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)
58 }
59
60 /// Creates a new tensor with ones.
61 ///
62 /// # Arguments
63 ///
64 /// * `shape` - The shape of the tensor.
65 /// * `device` - The device to create the tensor on.
66 /// * `dtype` - The target data type.
67 ///
68 /// # Returns
69 ///
70 /// The tensor with the given shape and ones.
71 fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
72 Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)
73 }
74
75 /// Creates a tensor filled with given value.
76 ///
77 /// # Arguments
78 ///
79 /// * `shape` - The shape of the tensor.
80 /// * `fill_value` - The value with which to fill the tensor.
81 /// * `device` - The device to create the tensor on.
82 /// * `dtype` - The target data type.
83 ///
84 /// # Returns
85 ///
86 /// The tensor filled with given value
87 fn float_full(
88 shape: Shape,
89 fill_value: Scalar,
90 device: &Device<B>,
91 dtype: FloatDType,
92 ) -> FloatTensor<B> {
93 Self::float_from_data(
94 TensorData::full_dtype(shape, fill_value, dtype.into()),
95 device,
96 )
97 }
98
99 /// Converts the tensor to a data structure.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The data structure with the tensor's data.
108 fn float_into_data(
109 tensor: FloatTensor<B>,
110 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
111
112 /// Gets the device of the tensor.
113 ///
114 /// # Arguments
115 ///
116 /// * `tensor` - The tensor.
117 ///
118 /// # Returns
119 ///
120 /// The device of the tensor.
121 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
122
123 /// Moves the tensor to the given device.
124 ///
125 /// # Arguments
126 ///
127 /// * `tensor` - The tensor.
128 /// * `device` - The device to move the tensor to.
129 ///
130 /// # Returns
131 ///
132 /// The tensor on the given device.
133 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
134
135 /// Converts float tensor to int tensor.
136 ///
137 /// # Arguments
138 ///
139 /// * `tensor` - The tensor.
140 /// * `out_dtype` - The output tensor dtype.
141 ///
142 /// # Returns
143 ///
144 /// The int tensor with the same data as the float tensor.
145 fn float_into_int(tensor: FloatTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
146
147 /// Creates an empty tensor with the given shape.
148 ///
149 /// # Arguments
150 ///
151 /// * `shape` - The shape of the tensor.
152 /// * `device` - The device to create the tensor on.
153 /// * `dtype` - The target data type.
154 ///
155 /// # Returns
156 ///
157 /// The empty tensor with the given shape.
158 fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
159
160 /// Repeat the tensor along the given dimension.
161 ///
162 /// # Arguments
163 ///
164 /// * `tensor` - The tensor.
165 /// * `dim` - The dimension to repeat.
166 /// * `times` - The number of times to repeat the dimension.
167 ///
168 /// # Returns
169 ///
170 /// The tensor with the given dimension repeated.
171 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
172 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
173 }
174
175 /// Adds two tensors together.
176 ///
177 /// # Arguments
178 ///
179 /// * `lhs` - The left-hand side tensor.
180 /// * `rhs` - The right-hand side tensor.
181 ///
182 /// # Returns
183 ///
184 /// The result of adding the two tensors together.
185 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
186
187 /// Adds a scalar to a tensor.
188 ///
189 /// # Arguments
190 ///
191 /// * `lhs` - The left-hand side tensor.
192 /// * `rhs` - The right-hand side scalar.
193 ///
194 /// # Returns
195 ///
196 /// The result of adding the scalar to the tensor.
197 fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
198
199 /// Clamps a tensor under a minimum value.
200 ///
201 /// # Arguments
202 ///
203 /// * `tensor` - The tensor to clamp.
204 /// * `min` - The minimum value.
205 ///
206 /// # Returns
207 ///
208 /// The clamped tensor.
209 fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {
210 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
211 let mask = Self::float_lower_elem(tensor.clone(), min, dtype);
212 B::float_mask_fill(tensor, mask, min)
213 }
214
215 /// Clamps a tensor over a maximum value.
216 ///
217 /// # Arguments
218 ///
219 /// * `tensor` - The tensor to clamp.
220 /// * `max` - The maximum value.
221 ///
222 /// # Returns
223 ///
224 /// The clamped tensor.
225 fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {
226 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
227 let mask = Self::float_greater_elem(tensor.clone(), max, dtype);
228 B::float_mask_fill(tensor, mask, max)
229 }
230
231 /// Clamps a tensor between a minimum and maximum value.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor to clamp.
236 /// * `min` - The minimum value.
237 /// * `max` - The maximum value.
238 ///
239 /// # Returns
240 ///
241 /// The clamped tensor.
242 fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {
243 // Default implementation
244 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
245 }
246
247 /// Subtracts two tensors.
248 ///
249 /// # Arguments
250 ///
251 /// * `lhs` - The left-hand side tensor.
252 /// * `rhs` - The right-hand side tensor.
253 ///
254 /// # Returns
255 ///
256 /// The result of subtracting the two tensors.
257 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259 /// Subtracts a scalar from a tensor.
260 ///
261 /// # Arguments
262 ///
263 /// * `lhs` - The left-hand side tensor.
264 /// * `rhs` - The right-hand side scalar.
265 ///
266 /// # Returns
267 ///
268 /// The result of subtracting the scalar from the tensor.
269 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
270
271 /// Multiplies two tensors together element-wise.
272 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
273
274 /// Multiplies a tensor by a scalar.
275 ///
276 /// # Arguments
277 ///
278 /// * `lhs` - The left-hand side tensor.
279 /// * `rhs` - The right-hand side scalar.
280 ///
281 /// # Returns
282 ///
283 /// The result of multiplying the tensor by the scalar.
284 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
285
286 /// Divides two tensors element-wise.
287 ///
288 /// # Arguments
289 ///
290 /// * `lhs` - The left-hand side tensor.
291 /// * `rhs` - The right-hand side tensor.
292 ///
293 /// # Returns
294 ///
295 /// The result of dividing the two tensors.
296 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
297
298 /// Divides a tensor by a scalar.
299 ///
300 /// # Arguments
301 ///
302 /// * `lhs` - The left-hand side tensor.
303 /// * `rhs` - The right-hand side scalar.
304 ///
305 /// # Returns
306 ///
307 /// The result of dividing the tensor by the scalar.
308 fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
309
310 /// Computes the remainder of division between two tensors element-wise.
311 ///
312 /// # Arguments
313 ///
314 /// * `lhs` - The left-hand side tensor.
315 /// * `rhs` - The right-hand side tensor.
316 ///
317 /// # Returns
318 ///
319 /// The element-wise remainder when dividing `lhs` by `rhs`.
320 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
321
322 /// Computes the modulus of a tensor given a scalar.
323 ///
324 /// # Arguments
325 /// * `lhs` - The left-hand side tensor.
326 /// * `rhs` - The right-hand side scalar.
327 ///
328 /// # Returns
329 ///
330 /// The result of applying the modulus of the scalar to the tensor.
331 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
332
333 /// Multiplies two tensors together using matrix multiplication.
334 ///
335 /// # Arguments
336 ///
337 /// * `lhs` - The left-hand side tensor.
338 /// * `rhs` - The right-hand side tensor.
339 ///
340 /// # Returns
341 ///
342 /// The result of multiplying the two tensors together using matrix multiplication.
343 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
344
345 /// Computes the cross product of two tensors along a given dimension.
346 ///
347 /// # Arguments
348 ///
349 /// * `lhs` - The left-hand side tensor.
350 /// * `rhs` - The right-hand side tensor.
351 /// * `dim` - The dimension to compute the cross product along.
352 ///
353 /// # Returns
354 ///
355 /// The cross product of the two tensors.
356 fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
357
358 /// Negates a tensor element-wise.
359 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
360 Self::float_mul_scalar(tensor, (-1f32).into())
361 }
362
363 /// Calculates the reciprocals element-wise
364 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
365
366 /// Transposes a tensor.
367 ///
368 /// # Arguments
369 ///
370 /// * `tensor` - The tensor to transpose.
371 ///
372 /// # Returns
373 ///
374 /// The transposed tensor.
375 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
376 let ndims = tensor.shape().num_dims();
377 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
378 }
379
380 /// Swaps two dimensions of a tensor.
381 ///
382 /// # Arguments
383 ///
384 /// * `tensor` - The tensor to swap the dimensions of.
385 /// * `dim1` - The first dimension to swap.
386 /// * `dim2` - The second dimension to swap.
387 ///
388 /// # Returns
389 ///
390 /// The tensor with the dimensions swapped.
391 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
392
393 /// Permutes the dimensions of a tensor.
394 ///
395 /// # Arguments
396 ///
397 /// * `tensor` - The tensor to permute the dimensions of.
398 /// * `axes` - The new order of the dimensions.
399 /// # Returns
400 ///
401 /// The tensor with the dimensions permuted.
402 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
403
404 /// Reverse the order of elements in a tensor along the given axes.
405 ///
406 /// # Arguments
407 ///
408 /// * `tensor` - The tensor to reverse.
409 /// * `axes` - The axes to reverse.
410 ///
411 /// The tensor with the elements reversed.
412 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
413
414 /// Reshapes a tensor.
415 ///
416 /// # Arguments
417 ///
418 /// * `tensor` - The tensor to reshape.
419 /// * `shape` - The new shape of the tensor.
420 ///
421 /// # Returns
422 ///
423 /// The tensor with the new shape.
424 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
425
426 /// Gather elements from a tensor.
427 ///
428 /// # Arguments
429 ///
430 /// * `dim` - The dimension to gather from.
431 /// * `tensor` - The tensor to gather from.
432 /// * `indices` - The indices to gather.
433 ///
434 /// # Returns
435 ///
436 /// The gathered elements.
437 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
438
439 /// Scatter elements into a tensor using sum reduction.
440 ///
441 /// # Arguments
442 ///
443 /// * `dim` - The dimension to scatter into.
444 /// * `tensor` - The tensor to scatter into.
445 /// * `indices` - The indices to scatter into.
446 /// * `value` - The value to scatter.
447 ///
448 /// # Returns
449 ///
450 /// The tensor with the scattered elements.
451 fn float_scatter_add(
452 dim: usize,
453 tensor: FloatTensor<B>,
454 indices: IntTensor<B>,
455 value: FloatTensor<B>,
456 ) -> FloatTensor<B>;
457
458 /// Select tensor elements along the given dimension corresponding for the given indices.
459 ///
460 /// # Arguments
461 ///
462 /// * `tensor` - The tensor to select from.
463 /// * `dim` - The dimension to select from.
464 /// * `indices` - The indices to select.
465 ///
466 /// # Returns
467 ///
468 /// The selected elements.
469 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
470
471 /// Assign the selected elements along the given dimension corresponding for the given indices
472 /// to the given value using sum reduction.
473 ///
474 /// # Arguments
475 ///
476 /// * `tensor` - The tensor to select from.
477 /// * `dim` - The dimension to select from.
478 /// * `indices` - The indices to select.
479 /// * `value` - The value to assign.
480 ///
481 /// # Returns
482 ///
483 /// The tensor with the selected elements assigned to the given value.
484 fn float_select_add(
485 tensor: FloatTensor<B>,
486 dim: usize,
487 indices: IntTensor<B>,
488 value: FloatTensor<B>,
489 ) -> FloatTensor<B>;
490
491 /// Select tensor elements corresponding to the given slices.
492 ///
493 /// # Arguments
494 ///
495 /// * `tensor` - The tensor to select from.
496 /// * `slices` - The slices specifying ranges and steps for each dimension.
497 ///
498 /// # Returns
499 ///
500 /// The selected elements in a new tensor.
501 ///
502 /// # Note
503 ///
504 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
505 /// be passed to this method. Backend implementations do not need to handle empty slices.
506 fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
507
508 /// Assign the selected elements corresponding to the given slices to the given value.
509 ///
510 /// # Arguments
511 ///
512 /// * `tensor` - The tensor to select from.
513 /// * `ranges` - The ranges to select.
514 /// * `value` - The value to assign.
515 ///
516 /// # Returns
517 ///
518 /// The tensor with the selected elements assigned to the given value.
519 ///
520 /// # Note
521 ///
522 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
523 /// high-level tensor API and will not be passed to this method. Backend implementations do
524 /// not need to handle empty slice assignments.
525 fn float_slice_assign(
526 tensor: FloatTensor<B>,
527 slices: &[Slice],
528 value: FloatTensor<B>,
529 ) -> FloatTensor<B>;
530
531 /// Update the given tensor with the value tensor where the mask is true.
532 ///
533 /// # Arguments
534 ///
535 /// * `tensor` - The tensor to select from.
536 /// * `mask` - The boolean mask to select with.
537 /// * `value` - The value to assign to the selected elements from the value tensor.
538 ///
539 /// # Returns
540 ///
541 /// The tensor with the selected elements assigned to the given value.
542 fn float_mask_where(
543 tensor: FloatTensor<B>,
544 mask: BoolTensor<B>,
545 value: FloatTensor<B>,
546 ) -> FloatTensor<B>;
547
548 /// Update the given tensor with the value where the mask is true.
549 ///
550 /// # Arguments
551 ///
552 /// * `tensor` - The tensor to select from.
553 /// * `mask` - The boolean mask to select with.
554 /// * `value` - The value to assign to the selected elements.
555 ///
556 /// # Returns
557 ///
558 /// The tensor with the selected elements assigned to the given value.
559 fn float_mask_fill(
560 tensor: FloatTensor<B>,
561 mask: BoolTensor<B>,
562 value: Scalar,
563 ) -> FloatTensor<B>;
564
565 /// Equal comparison of two tensors.
566 ///
567 /// # Arguments
568 ///
569 /// * `lhs` - The left-hand side tensor.
570 /// * `rhs` - The right-hand side tensor.
571 /// * `out_dtype` - The output tensor dtype.
572 ///
573 /// # Returns
574 ///
575 /// A boolean tensor with the result of the comparison.
576 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
577 -> BoolTensor<B>;
578
579 /// Element-wise non-equality comparison.
580 ///
581 /// # Arguments
582 ///
583 /// * `lhs` - The left-hand side tensor.
584 /// * `rhs` - The right-hand side tensor.
585 /// * `out_dtype` - The output tensor dtype.
586 ///
587 /// # Returns
588 ///
589 /// A boolean tensor with the result of the comparison.
590 fn float_not_equal(
591 lhs: FloatTensor<B>,
592 rhs: FloatTensor<B>,
593 out_dtype: BoolDType,
594 ) -> BoolTensor<B> {
595 let equal_tensor = B::float_equal(lhs, rhs, out_dtype);
596 B::bool_not(equal_tensor)
597 }
598
599 /// Equal comparison of a tensor and a scalar.
600 ///
601 /// # Arguments
602 ///
603 /// * `lhs` - The left-hand side tensor.
604 /// * `rhs` - The right-hand side scalar.
605 /// * `out_dtype` - The output tensor dtype.
606 ///
607 /// # Returns
608 ///
609 /// A boolean tensor with the result of the comparison.
610 fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
611
612 /// Element-wise non-equality comparison with a scalar.
613 ///
614 /// # Arguments
615 ///
616 /// * `lhs` - The left-hand side tensor.
617 /// * `rhs` - The right-hand side scalar.
618 /// * `out_dtype` - The output tensor dtype.
619 ///
620 /// # Returns
621 ///
622 /// A boolean tensor with the result of the comparison.
623 fn float_not_equal_elem(
624 lhs: FloatTensor<B>,
625 rhs: Scalar,
626 out_dtype: BoolDType,
627 ) -> BoolTensor<B> {
628 let equal_tensor = B::float_equal_elem(lhs, rhs, out_dtype);
629 B::bool_not(equal_tensor)
630 }
631
632 /// Greater than comparison of two tensors.
633 ///
634 /// # Arguments
635 ///
636 /// * `lhs` - The left-hand side tensor.
637 /// * `rhs` - The right-hand side tensor.
638 /// * `out_dtype` - The output tensor dtype.
639 ///
640 /// # Returns
641 ///
642 /// A boolean tensor with the result of the comparison.
643 fn float_greater(
644 lhs: FloatTensor<B>,
645 rhs: FloatTensor<B>,
646 out_dtype: BoolDType,
647 ) -> BoolTensor<B>;
648
649 /// Greater than comparison of a tensor and a scalar.
650 ///
651 /// # Arguments
652 ///
653 /// * `lhs` - The left-hand side tensor.
654 /// * `rhs` - The right-hand side scalar.
655 /// * `out_dtype` - The output tensor dtype.
656 ///
657 /// # Returns
658 ///
659 /// A boolean tensor with the result of the comparison.
660 fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
661
662 /// Greater than or equal comparison of two tensors.
663 ///
664 /// # Arguments
665 ///
666 /// * `lhs` - The left-hand side tensor.
667 /// * `rhs` - The right-hand side tensor.
668 /// * `out_dtype` - The output tensor dtype.
669 ///
670 /// # Returns
671 ///
672 /// A boolean tensor with the result of the comparison.
673 fn float_greater_equal(
674 lhs: FloatTensor<B>,
675 rhs: FloatTensor<B>,
676 out_dtype: BoolDType,
677 ) -> BoolTensor<B>;
678
679 /// Greater than or equal comparison of a tensor and a scalar.
680 ///
681 /// # Arguments
682 ///
683 /// * `lhs` - The left-hand side tensor.
684 /// * `rhs` - The right-hand side scalar.
685 /// * `out_dtype` - The output tensor dtype.
686 ///
687 /// # Returns
688 ///
689 /// A boolean tensor with the result of the comparison.
690 fn float_greater_equal_elem(
691 lhs: FloatTensor<B>,
692 rhs: Scalar,
693 out_dtype: BoolDType,
694 ) -> BoolTensor<B>;
695
696 /// Less than comparison of two tensors.
697 ///
698 /// # Arguments
699 ///
700 /// * `lhs` - The left-hand side tensor.
701 /// * `rhs` - The right-hand side tensor.
702 /// * `out_dtype` - The output tensor dtype.
703 ///
704 /// # Returns
705 ///
706 /// A boolean tensor with the result of the comparison.
707 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
708 -> BoolTensor<B>;
709
710 /// Less than comparison of a tensor and a scalar.
711 ///
712 /// # Arguments
713 ///
714 /// * `lhs` - The left-hand side tensor.
715 /// * `rhs` - The right-hand side scalar.
716 /// * `out_dtype` - The output tensor dtype.
717 ///
718 /// # Returns
719 ///
720 /// A boolean tensor with the result of the comparison.
721 fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
722
723 /// Less than or equal comparison of two tensors.
724 ///
725 /// # Arguments
726 ///
727 /// * `lhs` - The left-hand side tensor.
728 /// * `rhs` - The right-hand side tensor.
729 /// * `out_dtype` - The output tensor dtype.
730 ///
731 /// # Returns
732 ///
733 /// A boolean tensor with the result of the comparison.
734 fn float_lower_equal(
735 lhs: FloatTensor<B>,
736 rhs: FloatTensor<B>,
737 out_dtype: BoolDType,
738 ) -> BoolTensor<B>;
739
740 /// Less than or equal comparison of a tensor and a scalar.
741 ///
742 /// # Arguments
743 ///
744 /// * `lhs` - The left-hand side tensor.
745 /// * `rhs` - The right-hand side scalar.
746 /// * `out_dtype` - The output tensor dtype.
747 ///
748 /// # Returns
749 ///
750 /// A boolean tensor with the result of the comparison.
751 fn float_lower_equal_elem(
752 lhs: FloatTensor<B>,
753 rhs: Scalar,
754 out_dtype: BoolDType,
755 ) -> BoolTensor<B>;
756
757 /// Detaches a tensor from the computation graph.
758 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
759 // Should only be overridden by autodiff backends.
760 tensor
761 }
762
763 /// Sets the `require_grad` flag of a tensor.
764 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
765 // Should only be overridden by autodiff backends.
766 tensor
767 }
768
769 /// Returns the `require_grad` flag of a tensor.
770 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
771 // Should only be overridden by autodiff backends.
772 false
773 }
774
775 /// Sum of all elements in a tensor.
776 ///
777 /// # Arguments
778 ///
779 /// * `tensor` - The tensor to sum.
780 ///
781 /// # Returns
782 ///
783 /// A scalar tensor with the sum of all elements in `tensor`.
784 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
785
786 /// Sum of all elements in a tensor along a dimension.
787 ///
788 /// # Arguments
789 ///
790 /// * `tensor` - The tensor to sum.
791 /// * `dim` - The dimension along which to sum.
792 ///
793 /// # Returns
794 ///
795 /// A tensor with the sum of all elements in `tensor` along `dim`.
796 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
797
798 /// Product of all elements in a tensor.
799 ///
800 /// # Arguments
801 ///
802 /// * `tensor` - The tensor to product.
803 ///
804 /// # Returns
805 ///
806 /// A scalar tensor with the product of all elements in `tensor`.
807 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
808 // Product of all elements in a tensor
809 B::float_exp(B::float_sum(B::float_log(tensor)))
810 }
811
812 /// Product of all elements in a tensor along a dimension.
813 ///
814 /// # Arguments
815 ///
816 /// * `tensor` - The tensor to product.
817 ///
818 /// # Returns
819 ///
820 /// A tensor with the product of all elements in `tensor` along `dim`.
821 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
822 // Product of all elements in a tensor along a dimension
823 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
824 }
825
826 /// Mean of all elements in a tensor.
827 ///
828 /// # Arguments
829 ///
830 /// * `tensor` - The tensor to mean.
831 ///
832 /// # Returns
833 ///
834 /// A scalar tensor with the mean of all elements in `tensor`.
835 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
836 let num_elems = tensor.shape().num_elements() as f32;
837 B::float_div_scalar(B::float_sum(tensor), num_elems.into())
838 }
839
840 /// Mean of all elements in a tensor along a dimension.
841 ///
842 /// # Arguments
843 ///
844 /// * `tensor` - The tensor to mean.
845 /// * `dim` - The dimension along which to mean.
846 ///
847 /// # Returns
848 ///
849 /// A tensor with the mean of all elements in `tensor` along `dim`.
850 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
851
852 /// Computes the cumulative sum of elements along a dimension.
853 ///
854 /// # Arguments
855 ///
856 /// * `tensor` - The tensor to compute the cumulative sum of.
857 /// * `dim` - The dimension along which to compute the cumulative sum.
858 ///
859 /// # Returns
860 ///
861 /// A tensor with the same shape where each element is the cumulative sum
862 /// of all elements up to and including that position along the dimension.
863 fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
864
865 /// Computes the cumulative product of elements along a dimension.
866 ///
867 /// # Arguments
868 ///
869 /// * `tensor` - The tensor to compute the cumulative product of.
870 /// * `dim` - The dimension along which to compute the cumulative product.
871 ///
872 /// # Returns
873 ///
874 /// A tensor with the same shape where each element is the cumulative product
875 /// of all elements up to and including that position along the dimension.
876 fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
877
878 /// Computes the cumulative minimum of elements along a dimension.
879 ///
880 /// # Arguments
881 ///
882 /// * `tensor` - The tensor to compute the cumulative minimum of.
883 /// * `dim` - The dimension along which to compute the cumulative minimum.
884 ///
885 /// # Returns
886 ///
887 /// A tensor with the same shape where each element is the minimum
888 /// of all elements up to and including that position along the dimension.
889 fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
890
891 /// Computes the cumulative maximum of elements along a dimension.
892 ///
893 /// # Arguments
894 ///
895 /// * `tensor` - The tensor to compute the cumulative maximum of.
896 /// * `dim` - The dimension along which to compute the cumulative maximum.
897 ///
898 /// # Returns
899 ///
900 /// A tensor with the same shape where each element is the maximum
901 /// of all elements up to and including that position along the dimension.
902 fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
903
904 /// Converts a tensor to another floating point data type.
905 ///
906 /// # Arguments
907 ///
908 /// * `tensor` - The tensor to convert.
909 /// * `dtype` - The target data type.
910 ///
911 /// # Returns
912 ///
913 /// A tensor with the same values as `tensor` but in the target floating point data type.
914 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
915
916 /// Returns a new tensor with exponential values.
917 ///
918 /// # Arguments
919 ///
920 /// * `tensor` - The tensor to exponentiate.
921 ///
922 /// # Returns
923 ///
924 /// A tensor with the same shape as `tensor` with exponential values.
925 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
926
927 /// Returns a new tensor with natural logarithm values.
928 ///
929 /// # Arguments
930 ///
931 /// * `tensor` - The tensor to take the logarithm of.
932 ///
933 /// # Returns
934 ///
935 /// A tensor with the same shape as `tensor` with natural logarithm values.
936 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
937
938 /// Returns a new tensor with logarithm values of (1 + Xi).
939 ///
940 /// # Arguments
941 ///
942 /// * `tensor` - The tensor to take the logarithm of.
943 ///
944 /// # Returns
945 ///
946 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
947 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
948
949 /// Element-wise power with a FloatTensor.
950 ///
951 /// # Arguments
952 ///
953 /// * `lhs` - The left-hand side tensor.
954 /// * `rhs` - The right-hand side tensor.
955 ///
956 /// # Returns
957 ///
958 /// The elements of `lhs` raised to the power of the elements of `rhs`.
959 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
960
961 /// Element-wise power with an IntTensor.
962 ///
963 /// # Arguments
964 ///
965 /// * `lhs` - The left-hand side tensor.
966 /// * `rhs` - The right-hand side floatTensor.
967 ///
968 /// # Returns
969 ///
970 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
971 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
972 let dtype = lhs.dtype();
973 Self::float_powf(lhs, B::int_into_float(rhs, dtype.into()))
974 }
975
976 /// Raises a tensor to the power of an int scalar.
977 ///
978 /// # Backend Implementors Note
979 ///
980 /// A number of common exponent cases can be implemented with operations
981 /// which are much cheaper than generic exponentiation.
982 ///
983 /// This (`Backend` impl overridable) operation handles generic optimizations
984 /// for several common integer exponent cases; and then dispatches to
985 /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
986 /// operation to handle the generic case.
987 ///
988 /// # Arguments
989 ///
990 /// * `lhs` - The left-hand side tensor.
991 /// * `rhs` - The right-hand side scalar.
992 ///
993 /// # Returns
994 ///
995 /// The elements of `lhs` raised to the value of `rhs`.
996 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
997 match rhs.elem::<i64>() {
998 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
999 1 => lhs,
1000 2 => B::float_mul(lhs.clone(), lhs),
1001 -1 => Self::float_recip(lhs),
1002 -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
1003 _ => Self::float_powi_scalar_impl(lhs, rhs),
1004 }
1005 }
1006
1007 /// Raises a tensor to the power of an int scalar.
1008 ///
1009 /// # Backend Implementors Note
1010 ///
1011 /// This is the generic implementation of integer exponentiation
1012 /// called by [`Self::float_powi_scalar`] in the fallback case.
1013 ///
1014 /// As a general rule, this should not be called directly.
1015 ///
1016 /// # Arguments
1017 ///
1018 /// * `lhs` - The left-hand side tensor.
1019 /// * `rhs` - The right-hand side scalar.
1020 ///
1021 /// # Returns
1022 ///
1023 /// The elements of `lhs` raised to the value of `rhs`.
1024 fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1025 // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
1026 Self::float_powf_scalar_impl(lhs, rhs)
1027 }
1028
1029 /// Returns a new tensor with values raised to the power of float `value`.
1030 ///
1031 /// # Backend Implementors Note
1032 ///
1033 /// This (`Backend` impl overridable) operation dispatches integer exponentiation
1034 /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
1035 /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
1036 /// operation to handle the generic case.
1037 ///
1038 /// # Arguments
1039 ///
1040 /// * `tensor` - The tensor to exponentiate.
1041 /// * `value` - The exponent.
1042 ///
1043 /// # Returns
1044 ///
1045 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1046 fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {
1047 if let Some(exp) = value.try_as_integer() {
1048 Self::float_powi_scalar(tensor, exp)
1049 } else {
1050 Self::float_powf_scalar_impl(tensor, value)
1051 }
1052 }
1053
1054 /// Returns a new tensor with values raised to the power of float `value`.
1055 ///
1056 /// # Backend Implementors Note
1057 ///
1058 /// This is the generic implementation of integer exponentiation
1059 /// called by [`Self::float_powf_scalar`] in the fallback case.
1060 ///
1061 /// This is the minimal required support a `Backend` must implement
1062 /// for exponentiation.
1063 ///
1064 /// As a general rule, this should not be called directly.
1065 ///
1066 /// # Arguments
1067 ///
1068 /// * `tensor` - The tensor to exponentiate.
1069 /// * `value` - The exponent.
1070 ///
1071 /// # Returns
1072 ///
1073 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1074 fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;
1075
1076 /// Returns a new tensor with square root values.
1077 ///
1078 /// # Arguments
1079 ///
1080 /// * `tensor` - The tensor to take the square root of.
1081 ///
1082 /// # Returns
1083 ///
1084 /// A tensor with the same shape as `tensor` with square root values.
1085 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1086
1087 /// Returns a new tensor with absolute values.
1088 ///
1089 /// # Arguments
1090 ///
1091 /// * `tensor` - The tensor to take absolute value of.
1092 ///
1093 /// # Returns
1094 ///
1095 /// A tensor with the same shape as `tensor` with absolute values.
1096 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1097
1098 /// Returns a new tensor with cosine values.
1099 ///
1100 /// # Arguments
1101 ///
1102 /// * `tensor` - The tensor to take the cosine of.
1103 ///
1104 /// # Returns
1105 ///
1106 /// A tensor with the same shape as `tensor` with cosine values.
1107 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1108
1109 /// Returns a new tensor with sine values.
1110 ///
1111 /// # Arguments
1112 ///
1113 /// * `tensor` - The tensor to take the sine of.
1114 ///
1115 /// # Returns
1116 ///
1117 /// A tensor with the same shape as `tensor` with sine values.
1118 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1119
1120 /// Returns a new tensor with tangent values.
1121 ///
1122 /// # Arguments
1123 ///
1124 /// * `tensor` - The tensor to take the tangent of.
1125 ///
1126 /// # Returns
1127 ///
1128 /// A tensor with the same shape as `tensor` with tangent values.
1129 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1130
1131 /// Returns a new tensor with hyperbolic cosine values.
1132 ///
1133 /// # Arguments
1134 ///
1135 /// * `tensor` - The tensor to take the hyperbolic cosine of.
1136 ///
1137 /// # Returns
1138 ///
1139 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1140 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1141
1142 /// Returns a new tensor with hyperbolic sine values.
1143 ///
1144 /// # Arguments
1145 ///
1146 /// * `tensor` - The tensor to take the hyperbolic sine of.
1147 ///
1148 /// # Returns
1149 ///
1150 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1151 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1152
1153 /// Returns a new tensor with hyperbolic tangent values.
1154 ///
1155 /// # Arguments
1156 ///
1157 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1158 ///
1159 /// # Returns
1160 ///
1161 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1162 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1163
1164 /// Returns a new tensor with inverse cosine values.
1165 ///
1166 /// # Arguments
1167 ///
1168 /// * `tensor` - The input tensor.
1169 ///
1170 /// # Returns
1171 ///
1172 /// A tensor with the same shape as `tensor` with inverse cosine values.
1173 fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1174
1175 /// Returns a new tensor with inverse hyperbolic cosine values.
1176 ///
1177 /// # Arguments
1178 ///
1179 /// * `tensor` - The input tensor.
1180 ///
1181 /// # Returns
1182 ///
1183 /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
1184 fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1185
1186 /// Returns a new tensor with inverse sine values.
1187 ///
1188 /// # Arguments
1189 ///
1190 /// * `tensor` - The input tensor.
1191 ///
1192 /// # Returns
1193 ///
1194 /// A tensor with the same shape as `tensor` with inverse sine values.
1195 fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1196
1197 /// Returns a new tensor with inverse hyperbolic sine values.
1198 ///
1199 /// # Arguments
1200 ///
1201 /// * `tensor` - The input tensor.
1202 ///
1203 /// # Returns
1204 ///
1205 /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1206 fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1207
1208 /// Returns a new tensor with the inverse tangent values.
1209 ///
1210 /// # Arguments
1211 ///
1212 /// * `tensor` - The input tensor.
1213 ///
1214 /// # Returns
1215 ///
1216 /// A tensor with the same shape as `tensor` with the inverse tangent values.
1217 fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1218
1219 /// Returns a new tensor with the inverse hyperbolic tangent values.
1220 ///
1221 /// # Arguments
1222 ///
1223 /// * `tensor` - The input tensor.
1224 ///
1225 /// # Returns
1226 ///
1227 /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.
1228 fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1229
1230 /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1231 ///
1232 /// # Arguments
1233 ///
1234 /// * `lhs` - The tensor with y coordinates.
1235 /// * `rhs` - The tensor with x coordinates.
1236 ///
1237 /// # Returns
1238 ///
1239 /// A tensor with the four-quadrant inverse tangent values.
1240 fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
1241
1242 /// Returns a new tensor with rounded values.
1243 ///
1244 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1245 /// strategy, with halfway cases rounded to the nearest even integer value.
1246 ///
1247 /// # Arguments
1248 ///
1249 /// * `tensor` - The tensor to be rounded.
1250 ///
1251 /// # Returns
1252 ///
1253 /// A tensor with the same shape as `tensor` with rounded values.
1254 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1255
1256 /// Returns a new tensor with floored values.
1257 ///
1258 /// # Arguments
1259 ///
1260 /// * `tensor` - The tensor to be floored.
1261 ///
1262 /// # Returns
1263 ///
1264 /// A tensor with the same shape as `tensor` with floored values.
1265 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1266
1267 /// Returns a new tensor with ceiled values.
1268 ///
1269 /// # Arguments
1270 ///
1271 /// * `tensor` - The tensor to be ceiled.
1272 ///
1273 /// # Returns
1274 ///
1275 /// A tensor with the same shape as `tensor` with ceiled values.
1276 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1277
1278 /// Returns a new tensor with truncated values.
1279 ///
1280 /// # Arguments
1281 ///
1282 /// * `tensor` - The tensor to be truncated.
1283 ///
1284 /// # Returns
1285 ///
1286 /// A tensor with the same shape as `tensor` with truncated values.
1287 fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1288
1289 /// Returns a new tensor with the error function values.
1290 ///
1291 /// # Arguments
1292 ///
1293 /// * `tensor` - The tensor to take the error function of.
1294 ///
1295 /// # Returns
1296 ///
1297 /// A tensor with the same shape as `tensor` with error function values.
1298 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1299
1300 /// Concatenates tensors along a dimension.
1301 ///
1302 /// # Arguments
1303 ///
1304 /// * `tensors` - The tensors to concatenate.
1305 /// * `dim` - The dimension along which to concatenate.
1306 ///
1307 /// # Returns
1308 ///
1309 /// A tensor with the concatenated tensors along `dim`.
1310 ///
1311 /// # Note
1312 ///
1313 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1314 /// high-level tensor API and will not be passed to this method. Backend implementations do
1315 /// not need to handle empty tensors.
1316 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1317 cat_with_slice_assign::<B, Float>(
1318 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1319 dim,
1320 )
1321 .tensor()
1322 }
1323
1324 /// Gets the indices of the maximum elements of a tensor along an axis.
1325 ///
1326 /// # Arguments
1327 ///
1328 /// * `tensor` - The tensor to get the maximum elements of.
1329 /// * `dim` - The dimension along which to get the maximum elements.
1330 /// * `out_dtype` - The output tensor dtype.
1331 ///
1332 /// # Returns
1333 ///
1334 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1335 fn float_argmax(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1336
1337 /// Gets the indices of the minimum elements of a tensor along an axis.
1338 ///
1339 /// # Arguments
1340 ///
1341 /// * `tensor` - The tensor to get the minimum elements of.
1342 /// * `dim` - The dimension along which to get the minimum elements.
1343 /// * `out_dtype` - The output tensor dtype.
1344 ///
1345 /// # Returns
1346 ///
1347 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1348 fn float_argmin(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1349
1350 /// Gets the maximum element of a tensor.
1351 ///
1352 /// # Arguments
1353 ///
1354 /// * `tensor` - The tensor to get the maximum elements of.
1355 ///
1356 /// # Returns
1357 ///
1358 /// A tensor with the maximum element of `tensor`.
1359 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1360 let shape = tensor.shape();
1361 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1362
1363 B::float_max_dim(tensor, 0)
1364 }
1365
1366 /// Gets the maximum elements of a tensor along an axis.
1367 ///
1368 /// # Arguments
1369 ///
1370 /// * `tensor` - The tensor to get the maximum elements of.
1371 /// * `dim` - The dimension along which to get the maximum elements.
1372 ///
1373 /// # Returns
1374 ///
1375 /// A tensor with the maximum elements of `tensor` along `dim`.
1376 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1377 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1378 let index = B::float_argmax(tensor.clone(), dim, dtype);
1379
1380 B::float_gather(dim, tensor, index)
1381 }
1382
1383 /// Gets the maximum elements of a tensor along an axis and their indices.
1384 ///
1385 /// # Arguments
1386 ///
1387 /// * `tensor` - The tensor to get the maximum elements of.
1388 /// * `dim` - The dimension along which to get the maximum elements.
1389 /// * `indices_dtype` - The indices tensor dtype.
1390 ///
1391 /// # Returns
1392 ///
1393 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1394 fn float_max_dim_with_indices(
1395 tensor: FloatTensor<B>,
1396 dim: usize,
1397 indices_dtype: IntDType,
1398 ) -> (FloatTensor<B>, IntTensor<B>) {
1399 let index = B::float_argmax(tensor.clone(), dim, indices_dtype);
1400 let values = B::float_gather(dim, tensor, index.clone());
1401
1402 (values, index)
1403 }
1404
1405 /// Gets the minimum element of a tensor.
1406 ///
1407 /// # Arguments
1408 ///
1409 /// * `tensor` - The tensor to get the minimum elements of.
1410 ///
1411 /// # Returns
1412 ///
1413 /// A tensor with the minimum element of `tensor`.
1414 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1415 let shape = tensor.shape();
1416 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1417
1418 B::float_min_dim(tensor, 0)
1419 }
1420
1421 /// Gets the minimum elements of a tensor along an axis.
1422 ///
1423 /// # Arguments
1424 ///
1425 /// * `tensor` - The tensor to get the minimum elements of.
1426 /// * `dim` - The dimension along which to get the minimum elements.
1427 ///
1428 /// # Returns
1429 ///
1430 /// A tensor with the minimum elements of `tensor` along `dim`.
1431 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1432 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1433 let index = B::float_argmin(tensor.clone(), dim, dtype);
1434
1435 B::float_gather(dim, tensor, index)
1436 }
1437
1438 /// Gets the minimum elements of a tensor along an axis and their indices.
1439 ///
1440 /// # Arguments
1441 ///
1442 /// * `tensor` - The tensor to get the minimum elements of.
1443 /// * `dim` - The dimension along which to get the minimum elements.
1444 /// * `indices_dtype` - The indices tensor dtype.
1445 ///
1446 /// # Returns
1447 ///
1448 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1449 fn float_min_dim_with_indices(
1450 tensor: FloatTensor<B>,
1451 dim: usize,
1452 indices_dtype: IntDType,
1453 ) -> (FloatTensor<B>, IntTensor<B>) {
1454 let index = B::float_argmin(tensor.clone(), dim, indices_dtype);
1455 let values = B::float_gather(dim, tensor, index.clone());
1456
1457 (values, index)
1458 }
1459
1460 /// Gets the maximum absolute element of a tensor.
1461 ///
1462 /// # Arguments
1463 ///
1464 /// * `tensor` - The tensor to get the maximum elements of.
1465 ///
1466 /// # Returns
1467 ///
1468 /// A tensor with the maximum element of `tensor`.
1469 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1470 let shape = tensor.shape();
1471 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1472
1473 B::float_max_abs_dim(tensor, 0)
1474 }
1475
1476 /// Gets the maximum absolute elements of a tensor along an axis.
1477 ///
1478 /// # Arguments
1479 ///
1480 /// * `tensor` - The tensor to get the maximum elements of.
1481 /// * `dim` - The dimension along which to get the maximum elements.
1482 ///
1483 /// # Returns
1484 ///
1485 /// A tensor with the maximum elements of `tensor` along `dim`.
1486 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1487 B::float_max_dim(B::float_abs(tensor), dim)
1488 }
1489
1490 /// Tests if any element in the float `tensor` evaluates to True.
1491 ///
1492 /// # Arguments
1493 ///
1494 /// * `tensor` - The tensor to test.
1495 /// * `out_dtype` - The output tensor dtype.
1496 ///
1497 /// # Returns
1498 ///
1499 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1500 fn float_any(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1501 let float_dtype = tensor.dtype();
1502 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1503 let bool_tensor = B::bool_not(bool_tensor);
1504 let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1505 B::float_greater_elem(sum, 0f32.into(), out_dtype)
1506 }
1507
1508 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1509 ///
1510 /// # Arguments
1511 ///
1512 /// * `tensor` - The tensor to test.
1513 /// * `dim` - The axis along which to test.
1514 /// * `out_dtype` - The output tensor dtype.
1515 ///
1516 /// # Returns
1517 ///
1518 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1519 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1520 /// input evaluates to True, False otherwise.
1521 fn float_any_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1522 let float_dtype = tensor.dtype();
1523 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1524 let bool_tensor = B::bool_not(bool_tensor);
1525 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1526 B::float_greater_elem(sum, 0f32.into(), out_dtype)
1527 }
1528
1529 /// Tests if all elements in the float `tensor` evaluate to True.
1530 ///
1531 /// # Arguments
1532 ///
1533 /// * `tensor` - The tensor to test.
1534 /// * `out_dtype` - The output tensor dtype.
1535 ///
1536 /// # Returns
1537 ///
1538 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1539 /// evaluate to True, False otherwise.
1540 fn float_all(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1541 let float_dtype = tensor.dtype();
1542 let num_elems = tensor.shape().num_elements() as f32;
1543 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1544 let bool_tensor = B::bool_not(bool_tensor);
1545 let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1546 B::float_equal_elem(sum, num_elems.into(), out_dtype)
1547 }
1548
1549 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1550 ///
1551 /// # Arguments
1552 ///
1553 /// * `tensor` - The tensor to test.
1554 /// * `dim` - The axis along which to test.
1555 /// * `out_dtype` - The output tensor dtype.
1556 ///
1557 /// # Returns
1558 ///
1559 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1560 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1561 /// evaluates to True, False otherwise.
1562 fn float_all_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1563 let float_dtype = tensor.dtype();
1564 let num_elems = tensor.shape()[dim] as f32;
1565 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1566 let bool_tensor = B::bool_not(bool_tensor);
1567 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1568 B::float_equal_elem(sum, num_elems.into(), out_dtype)
1569 }
1570
1571 /// Returns the signs of the float `tensor`.
1572 ///
1573 /// # Arguments
1574 ///
1575 /// * `tensor` - The tensor to extract the signs from.
1576 ///
1577 /// # Returns
1578 ///
1579 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1580 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1581 let device = B::float_device(&tensor);
1582 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
1583 let zeros = B::float_zeros(tensor.shape(), &device, tensor.dtype().into());
1584 let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
1585 let greater_than_zero = B::float_greater_elem(tensor, 0f32.into(), bool_dtype);
1586
1587 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());
1588 result = B::float_mask_fill(result, greater_than_zero, 1f32.into());
1589 result
1590 }
1591
1592 /// Broadcasts the float `tensor` to the given `shape`.
1593 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1594
1595 /// Sort the elements of the input `tensor` by value in along a given dimension.
1596 ///
1597 /// This sort is unstable (i.e., may reorder equal elements).
1598 ///
1599 /// # Arguments
1600 ///
1601 /// * `tensor` - The input tensor.
1602 /// * `dim` - The axis along which to sort.
1603 /// * `descending` - The sorting order.
1604 ///
1605 /// # Returns
1606 ///
1607 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1608 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1609 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1610 }
1611
1612 /// Sort the elements of the input `tensor` by value in along a given dimension.
1613 ///
1614 /// This sort is unstable (i.e., may reorder equal elements).
1615 ///
1616 /// # Arguments
1617 ///
1618 /// * `tensor` - The input tensor.
1619 /// * `dim` - The axis along which to sort.
1620 /// * `descending` - The sorting order.
1621 /// * `indices_dtype` - The indices tensor dtype.
1622 ///
1623 /// # Returns
1624 ///
1625 /// A tensor with the same shape as the input tensor and corresponding indices, where
1626 /// the elements are sorted by value and the indices map back to the original input tensor.
1627 fn float_sort_with_indices(
1628 tensor: FloatTensor<B>,
1629 dim: usize,
1630 descending: bool,
1631 indices_dtype: IntDType,
1632 ) -> (FloatTensor<B>, IntTensor<B>) {
1633 let (values, indices) = sort_with_indices::<B, Float>(
1634 TensorPrimitive::Float(tensor),
1635 dim,
1636 descending,
1637 indices_dtype,
1638 );
1639 (values.tensor(), indices)
1640 }
1641
1642 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1643 ///
1644 /// This sort is unstable (i.e., may reorder equal elements).
1645 ///
1646 /// # Arguments
1647 ///
1648 /// * `tensor` - The input tensor.
1649 /// * `dim` - The axis along which to sort.
1650 /// * `descending` - The sorting order.
1651 /// * `out_dtype` - The output tensor dtype.
1652 ///
1653 /// # Returns
1654 ///
1655 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1656 fn float_argsort(
1657 tensor: FloatTensor<B>,
1658 dim: usize,
1659 descending: bool,
1660 out_dtype: IntDType,
1661 ) -> IntTensor<B> {
1662 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending, out_dtype)
1663 }
1664
1665 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1666 /// using the given locations in [-1, 1].
1667 ///
1668 /// # Arguments
1669 ///
1670 /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
1671 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1672 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1673 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1674 ///
1675 /// # Returns
1676 ///
1677 /// A tensor with shape (N, C, H_out, W_out)
1678 fn float_grid_sample_2d(
1679 tensor: FloatTensor<B>,
1680 grid: FloatTensor<B>,
1681 options: GridSampleOptions,
1682 ) -> FloatTensor<B> {
1683 // TODO: default impl should get int default dtype
1684 float_grid_sample_2d_ref::<B>(tensor, grid, options)
1685 }
1686
1687 /// Unfold windows along a dimension.
1688 ///
1689 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1690 /// where windows are advanced by `step` at each index.
1691 ///
1692 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1693 ///
1694 /// # Arguments
1695 ///
1696 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1697 /// * `dim` - the selected dim.
1698 /// * `size` - the size of each unfolded window.
1699 /// * `step` - the step between each window.
1700 ///
1701 /// # Returns
1702 ///
1703 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1704 fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1705 -> FloatTensor<B>;
1706
1707 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1708 ///
1709 /// # Returns
1710 ///
1711 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1712 fn float_is_nan(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1713 // Check if the input tensor is NaN by comparing it to itself
1714 // NaN is the only value that is not equal to itself
1715 B::float_not_equal(tensor.clone(), tensor, out_dtype)
1716 }
1717
1718 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1719 ///
1720 /// # Returns
1721 ///
1722 /// A boolean tensor where `true` indicates that the value is infinite
1723 fn float_is_inf(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1724 B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into(), out_dtype)
1725 }
1726}