burn_backend/backend/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};
7use crate::{Backend, Distribution, TensorData};
8use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{FloatDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14 /// Creates a new tensor from the data structure.
15 ///
16 /// # Arguments
17 ///
18 /// * `data` - The data structure.
19 /// * `device` - The device to create the tensor on.
20 ///
21 /// # Returns
22 ///
23 /// The tensor with the given data.
24 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26 /// Creates a new tensor with random values.
27 ///
28 /// # Arguments
29 ///
30 /// * `shape` - The shape of the tensor.
31 /// * `distribution` - The distribution to sample from.
32 /// * `device` - The device to create the tensor on.
33 ///
34 /// # Returns
35 ///
36 /// The tensor with the given shape and random values.
37 fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
38 -> FloatTensor<B>;
39
40 /// Creates a new tensor with zeros.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device to create the tensor on.
46 /// * `dtype` - The target data type.
47 ///
48 /// # Returns
49 ///
50 /// The tensor with the given shape and zeros.
51 fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
52 Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)
53 }
54
55 /// Creates a new tensor with ones.
56 ///
57 /// # Arguments
58 ///
59 /// * `shape` - The shape of the tensor.
60 /// * `device` - The device to create the tensor on.
61 /// * `dtype` - The target data type.
62 ///
63 /// # Returns
64 ///
65 /// The tensor with the given shape and ones.
66 fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
67 Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)
68 }
69
70 /// Creates a tensor filled with given value.
71 ///
72 /// # Arguments
73 ///
74 /// * `shape` - The shape of the tensor.
75 /// * `fill_value` - The value with which to fill the tensor.
76 /// * `device` - The device to create the tensor on.
77 /// * `dtype` - The target data type.
78 ///
79 /// # Returns
80 ///
81 /// The tensor filled with given value
82 fn float_full(
83 shape: Shape,
84 fill_value: Scalar,
85 device: &Device<B>,
86 dtype: FloatDType,
87 ) -> FloatTensor<B> {
88 Self::float_from_data(
89 TensorData::full_dtype(shape, fill_value, dtype.into()),
90 device,
91 )
92 }
93
94 /// Converts the tensor to a data structure.
95 ///
96 /// # Arguments
97 ///
98 /// * `tensor` - The tensor.
99 ///
100 /// # Returns
101 ///
102 /// The data structure with the tensor's data.
103 fn float_into_data(
104 tensor: FloatTensor<B>,
105 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
106
107 /// Gets the device of the tensor.
108 ///
109 /// # Arguments
110 ///
111 /// * `tensor` - The tensor.
112 ///
113 /// # Returns
114 ///
115 /// The device of the tensor.
116 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
117
118 /// Moves the tensor to the given device.
119 ///
120 /// # Arguments
121 ///
122 /// * `tensor` - The tensor.
123 /// * `device` - The device to move the tensor to.
124 ///
125 /// # Returns
126 ///
127 /// The tensor on the given device.
128 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
129
130 /// Converts float tensor to int tensor.
131 ///
132 /// # Arguments
133 ///
134 /// * `tensor` - The tensor.
135 ///
136 /// # Returns
137 ///
138 /// The int tensor with the same data as the float tensor.
139 fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
140
141 /// Creates an empty tensor with the given shape.
142 ///
143 /// # Arguments
144 ///
145 /// * `shape` - The shape of the tensor.
146 /// * `device` - The device to create the tensor on.
147 /// * `dtype` - The target data type.
148 ///
149 /// # Returns
150 ///
151 /// The empty tensor with the given shape.
152 fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
153
154 /// Repeat the tensor along the given dimension.
155 ///
156 /// # Arguments
157 ///
158 /// * `tensor` - The tensor.
159 /// * `dim` - The dimension to repeat.
160 /// * `times` - The number of times to repeat the dimension.
161 ///
162 /// # Returns
163 ///
164 /// The tensor with the given dimension repeated.
165 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
166 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
167 }
168
169 /// Adds two tensors together.
170 ///
171 /// # Arguments
172 ///
173 /// * `lhs` - The left-hand side tensor.
174 /// * `rhs` - The right-hand side tensor.
175 ///
176 /// # Returns
177 ///
178 /// The result of adding the two tensors together.
179 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
180
181 /// Adds a scalar to a tensor.
182 ///
183 /// # Arguments
184 ///
185 /// * `lhs` - The left-hand side tensor.
186 /// * `rhs` - The right-hand side scalar.
187 ///
188 /// # Returns
189 ///
190 /// The result of adding the scalar to the tensor.
191 fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
192
193 /// Clamps a tensor under a minimum value.
194 ///
195 /// # Arguments
196 ///
197 /// * `tensor` - The tensor to clamp.
198 /// * `min` - The minimum value.
199 ///
200 /// # Returns
201 ///
202 /// The clamped tensor.
203 fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {
204 // Default implementation
205 let mask = Self::float_lower_elem(tensor.clone(), min);
206 B::float_mask_fill(tensor, mask, min)
207 }
208
209 /// Clamps a tensor over a maximum value.
210 ///
211 /// # Arguments
212 ///
213 /// * `tensor` - The tensor to clamp.
214 /// * `max` - The maximum value.
215 ///
216 /// # Returns
217 ///
218 /// The clamped tensor.
219 fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {
220 // Default implementation
221 let mask = Self::float_greater_elem(tensor.clone(), max);
222 B::float_mask_fill(tensor, mask, max)
223 }
224
225 /// Clamps a tensor between a minimum and maximum value.
226 ///
227 /// # Arguments
228 ///
229 /// * `tensor` - The tensor to clamp.
230 /// * `min` - The minimum value.
231 /// * `max` - The maximum value.
232 ///
233 /// # Returns
234 ///
235 /// The clamped tensor.
236 fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {
237 // Default implementation
238 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
239 }
240
241 /// Subtracts two tensors.
242 ///
243 /// # Arguments
244 ///
245 /// * `lhs` - The left-hand side tensor.
246 /// * `rhs` - The right-hand side tensor.
247 ///
248 /// # Returns
249 ///
250 /// The result of subtracting the two tensors.
251 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
252
253 /// Subtracts a scalar from a tensor.
254 ///
255 /// # Arguments
256 ///
257 /// * `lhs` - The left-hand side tensor.
258 /// * `rhs` - The right-hand side scalar.
259 ///
260 /// # Returns
261 ///
262 /// The result of subtracting the scalar from the tensor.
263 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
264
265 /// Multiplies two tensors together element-wise.
266 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
267
268 /// Multiplies a tensor by a scalar.
269 ///
270 /// # Arguments
271 ///
272 /// * `lhs` - The left-hand side tensor.
273 /// * `rhs` - The right-hand side scalar.
274 ///
275 /// # Returns
276 ///
277 /// The result of multiplying the tensor by the scalar.
278 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
279
280 /// Divides two tensors element-wise.
281 ///
282 /// # Arguments
283 ///
284 /// * `lhs` - The left-hand side tensor.
285 /// * `rhs` - The right-hand side tensor.
286 ///
287 /// # Returns
288 ///
289 /// The result of dividing the two tensors.
290 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
291
292 /// Divides a tensor by a scalar.
293 ///
294 /// # Arguments
295 ///
296 /// * `lhs` - The left-hand side tensor.
297 /// * `rhs` - The right-hand side scalar.
298 ///
299 /// # Returns
300 ///
301 /// The result of dividing the tensor by the scalar.
302 fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
303
304 /// Computes the remainder of division between two tensors element-wise.
305 ///
306 /// # Arguments
307 ///
308 /// * `lhs` - The left-hand side tensor.
309 /// * `rhs` - The right-hand side tensor.
310 ///
311 /// # Returns
312 ///
313 /// The element-wise remainder when dividing `lhs` by `rhs`.
314 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
315
316 /// Computes the modulus of a tensor given a scalar.
317 ///
318 /// # Arguments
319 /// * `lhs` - The left-hand side tensor.
320 /// * `rhs` - The right-hand side scalar.
321 ///
322 /// # Returns
323 ///
324 /// The result of applying the modulus of the scalar to the tensor.
325 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
326
327 /// Multiplies two tensors together using matrix multiplication.
328 ///
329 /// # Arguments
330 ///
331 /// * `lhs` - The left-hand side tensor.
332 /// * `rhs` - The right-hand side tensor.
333 ///
334 /// # Returns
335 ///
336 /// The result of multiplying the two tensors together using matrix multiplication.
337 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
338
339 /// Computes the cross product of two tensors along a given dimension.
340 ///
341 /// # Arguments
342 ///
343 /// * `lhs` - The left-hand side tensor.
344 /// * `rhs` - The right-hand side tensor.
345 /// * `dim` - The dimension to compute the cross product along.
346 ///
347 /// # Returns
348 ///
349 /// The cross product of the two tensors.
350 fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
351
352 /// Negates a tensor element-wise.
353 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
354 Self::float_mul_scalar(tensor, (-1f32).into())
355 }
356
357 /// Calculates the reciprocals element-wise
358 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
359
360 /// Transposes a tensor.
361 ///
362 /// # Arguments
363 ///
364 /// * `tensor` - The tensor to transpose.
365 ///
366 /// # Returns
367 ///
368 /// The transposed tensor.
369 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
370 let ndims = tensor.shape().num_dims();
371 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
372 }
373
374 /// Swaps two dimensions of a tensor.
375 ///
376 /// # Arguments
377 ///
378 /// * `tensor` - The tensor to swap the dimensions of.
379 /// * `dim1` - The first dimension to swap.
380 /// * `dim2` - The second dimension to swap.
381 ///
382 /// # Returns
383 ///
384 /// The tensor with the dimensions swapped.
385 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
386
387 /// Permutes the dimensions of a tensor.
388 ///
389 /// # Arguments
390 ///
391 /// * `tensor` - The tensor to permute the dimensions of.
392 /// * `axes` - The new order of the dimensions.
393 /// # Returns
394 ///
395 /// The tensor with the dimensions permuted.
396 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
397
398 /// Reverse the order of elements in a tensor along the given axes.
399 ///
400 /// # Arguments
401 ///
402 /// * `tensor` - The tensor to reverse.
403 /// * `axes` - The axes to reverse.
404 ///
405 /// The tensor with the elements reversed.
406 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
407
408 /// Reshapes a tensor.
409 ///
410 /// # Arguments
411 ///
412 /// * `tensor` - The tensor to reshape.
413 /// * `shape` - The new shape of the tensor.
414 ///
415 /// # Returns
416 ///
417 /// The tensor with the new shape.
418 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
419
420 /// Gather elements from a tensor.
421 ///
422 /// # Arguments
423 ///
424 /// * `dim` - The dimension to gather from.
425 /// * `tensor` - The tensor to gather from.
426 /// * `indices` - The indices to gather.
427 ///
428 /// # Returns
429 ///
430 /// The gathered elements.
431 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
432
433 /// Scatter elements into a tensor using sum reduction.
434 ///
435 /// # Arguments
436 ///
437 /// * `dim` - The dimension to scatter into.
438 /// * `tensor` - The tensor to scatter into.
439 /// * `indices` - The indices to scatter into.
440 /// * `value` - The value to scatter.
441 ///
442 /// # Returns
443 ///
444 /// The tensor with the scattered elements.
445 fn float_scatter_add(
446 dim: usize,
447 tensor: FloatTensor<B>,
448 indices: IntTensor<B>,
449 value: FloatTensor<B>,
450 ) -> FloatTensor<B>;
451
452 /// Select tensor elements along the given dimension corresponding for the given indices.
453 ///
454 /// # Arguments
455 ///
456 /// * `tensor` - The tensor to select from.
457 /// * `dim` - The dimension to select from.
458 /// * `indices` - The indices to select.
459 ///
460 /// # Returns
461 ///
462 /// The selected elements.
463 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
464
465 /// Assign the selected elements along the given dimension corresponding for the given indices
466 /// to the given value using sum reduction.
467 ///
468 /// # Arguments
469 ///
470 /// * `tensor` - The tensor to select from.
471 /// * `dim` - The dimension to select from.
472 /// * `indices` - The indices to select.
473 /// * `value` - The value to assign.
474 ///
475 /// # Returns
476 ///
477 /// The tensor with the selected elements assigned to the given value.
478 fn float_select_add(
479 tensor: FloatTensor<B>,
480 dim: usize,
481 indices: IntTensor<B>,
482 value: FloatTensor<B>,
483 ) -> FloatTensor<B>;
484
485 /// Select tensor elements corresponding to the given slices.
486 ///
487 /// # Arguments
488 ///
489 /// * `tensor` - The tensor to select from.
490 /// * `slices` - The slices specifying ranges and steps for each dimension.
491 ///
492 /// # Returns
493 ///
494 /// The selected elements in a new tensor.
495 ///
496 /// # Note
497 ///
498 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
499 /// be passed to this method. Backend implementations do not need to handle empty slices.
500 fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
501
502 /// Assign the selected elements corresponding to the given slices to the given value.
503 ///
504 /// # Arguments
505 ///
506 /// * `tensor` - The tensor to select from.
507 /// * `ranges` - The ranges to select.
508 /// * `value` - The value to assign.
509 ///
510 /// # Returns
511 ///
512 /// The tensor with the selected elements assigned to the given value.
513 ///
514 /// # Note
515 ///
516 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
517 /// high-level tensor API and will not be passed to this method. Backend implementations do
518 /// not need to handle empty slice assignments.
519 fn float_slice_assign(
520 tensor: FloatTensor<B>,
521 slices: &[Slice],
522 value: FloatTensor<B>,
523 ) -> FloatTensor<B>;
524
525 /// Update the given tensor with the value tensor where the mask is true.
526 ///
527 /// # Arguments
528 ///
529 /// * `tensor` - The tensor to select from.
530 /// * `mask` - The boolean mask to select with.
531 /// * `value` - The value to assign to the selected elements from the value tensor.
532 ///
533 /// # Returns
534 ///
535 /// The tensor with the selected elements assigned to the given value.
536 fn float_mask_where(
537 tensor: FloatTensor<B>,
538 mask: BoolTensor<B>,
539 value: FloatTensor<B>,
540 ) -> FloatTensor<B>;
541
542 /// Update the given tensor with the value where the mask is true.
543 ///
544 /// # Arguments
545 ///
546 /// * `tensor` - The tensor to select from.
547 /// * `mask` - The boolean mask to select with.
548 /// * `value` - The value to assign to the selected elements.
549 ///
550 /// # Returns
551 ///
552 /// The tensor with the selected elements assigned to the given value.
553 fn float_mask_fill(
554 tensor: FloatTensor<B>,
555 mask: BoolTensor<B>,
556 value: Scalar,
557 ) -> FloatTensor<B>;
558
559 /// Equal comparison of two tensors.
560 ///
561 /// # Arguments
562 ///
563 /// * `lhs` - The left-hand side tensor.
564 /// * `rhs` - The right-hand side tensor.
565 ///
566 /// # Returns
567 ///
568 /// A boolean tensor with the result of the comparison.
569 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
570
571 /// Element-wise non-equality comparison.
572 ///
573 /// # Arguments
574 ///
575 /// * `lhs` - The left-hand side tensor.
576 /// * `rhs` - The right-hand side tensor.
577 ///
578 /// # Returns
579 ///
580 /// A boolean tensor with the result of the comparison.
581 fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
582 let equal_tensor = B::float_equal(lhs, rhs);
583 B::bool_not(equal_tensor)
584 }
585
586 /// Equal comparison of a tensor and a scalar.
587 ///
588 /// # Arguments
589 ///
590 /// * `lhs` - The left-hand side tensor.
591 /// * `rhs` - The right-hand side scalar.
592 ///
593 /// # Returns
594 ///
595 /// A boolean tensor with the result of the comparison.
596 fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;
597
598 /// Element-wise non-equality comparison with a scalar.
599 ///
600 /// # Arguments
601 ///
602 /// * `lhs` - The left-hand side tensor.
603 /// * `rhs` - The right-hand side scalar.
604 ///
605 /// # Returns
606 ///
607 /// A boolean tensor with the result of the comparison.
608 fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B> {
609 let equal_tensor = B::float_equal_elem(lhs, rhs);
610 B::bool_not(equal_tensor)
611 }
612
613 /// Greater than comparison of two tensors.
614 ///
615 /// # Arguments
616 ///
617 /// * `lhs` - The left-hand side tensor.
618 /// * `rhs` - The right-hand side tensor.
619 ///
620 /// # Returns
621 ///
622 /// A boolean tensor with the result of the comparison.
623 fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
624
625 /// Greater than comparison of a tensor and a scalar.
626 ///
627 /// # Arguments
628 ///
629 /// * `lhs` - The left-hand side tensor.
630 /// * `rhs` - The right-hand side scalar.
631 ///
632 /// # Returns
633 ///
634 /// A boolean tensor with the result of the comparison.
635 fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;
636
637 /// Greater than or equal comparison of two tensors.
638 ///
639 /// # Arguments
640 ///
641 /// * `lhs` - The left-hand side tensor.
642 /// * `rhs` - The right-hand side tensor.
643 ///
644 /// # Returns
645 ///
646 /// A boolean tensor with the result of the comparison.
647 fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
648
649 /// Greater than or equal comparison of a tensor and a scalar.
650 ///
651 /// # Arguments
652 ///
653 /// * `lhs` - The left-hand side tensor.
654 /// * `rhs` - The right-hand side scalar.
655 ///
656 /// # Returns
657 ///
658 /// A boolean tensor with the result of the comparison.
659 fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;
660
661 /// Less than comparison of two tensors.
662 ///
663 /// # Arguments
664 ///
665 /// * `lhs` - The left-hand side tensor.
666 /// * `rhs` - The right-hand side tensor.
667 ///
668 /// # Returns
669 ///
670 /// A boolean tensor with the result of the comparison.
671 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
672
673 /// Less than comparison of a tensor and a scalar.
674 ///
675 /// # Arguments
676 ///
677 /// * `lhs` - The left-hand side tensor.
678 /// * `rhs` - The right-hand side scalar.
679 ///
680 /// # Returns
681 ///
682 /// A boolean tensor with the result of the comparison.
683 fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;
684
685 /// Less than or equal comparison of two tensors.
686 ///
687 /// # Arguments
688 ///
689 /// * `lhs` - The left-hand side tensor.
690 /// * `rhs` - The right-hand side tensor.
691 ///
692 /// # Returns
693 ///
694 /// A boolean tensor with the result of the comparison.
695 fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
696
697 /// Less than or equal comparison of a tensor and a scalar.
698 ///
699 /// # Arguments
700 ///
701 /// * `lhs` - The left-hand side tensor.
702 /// * `rhs` - The right-hand side scalar.
703 ///
704 /// # Returns
705 ///
706 /// A boolean tensor with the result of the comparison.
707 fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;
708
709 /// Detaches a tensor from the computation graph.
710 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
711 // Should only be overridden by autodiff backends.
712 tensor
713 }
714
715 /// Sets the `require_grad` flag of a tensor.
716 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
717 // Should only be overridden by autodiff backends.
718 tensor
719 }
720
721 /// Returns the `require_grad` flag of a tensor.
722 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
723 // Should only be overridden by autodiff backends.
724 false
725 }
726
727 /// Sum of all elements in a tensor.
728 ///
729 /// # Arguments
730 ///
731 /// * `tensor` - The tensor to sum.
732 ///
733 /// # Returns
734 ///
735 /// A scalar tensor with the sum of all elements in `tensor`.
736 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
737
738 /// Sum of all elements in a tensor along a dimension.
739 ///
740 /// # Arguments
741 ///
742 /// * `tensor` - The tensor to sum.
743 /// * `dim` - The dimension along which to sum.
744 ///
745 /// # Returns
746 ///
747 /// A tensor with the sum of all elements in `tensor` along `dim`.
748 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
749
750 /// Product of all elements in a tensor.
751 ///
752 /// # Arguments
753 ///
754 /// * `tensor` - The tensor to product.
755 ///
756 /// # Returns
757 ///
758 /// A scalar tensor with the product of all elements in `tensor`.
759 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
760 // Product of all elements in a tensor
761 B::float_exp(B::float_sum(B::float_log(tensor)))
762 }
763
764 /// Product of all elements in a tensor along a dimension.
765 ///
766 /// # Arguments
767 ///
768 /// * `tensor` - The tensor to product.
769 ///
770 /// # Returns
771 ///
772 /// A tensor with the product of all elements in `tensor` along `dim`.
773 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
774 // Product of all elements in a tensor along a dimension
775 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
776 }
777
778 /// Mean of all elements in a tensor.
779 ///
780 /// # Arguments
781 ///
782 /// * `tensor` - The tensor to mean.
783 ///
784 /// # Returns
785 ///
786 /// A scalar tensor with the mean of all elements in `tensor`.
787 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
788 let num_elems = tensor.shape().num_elements() as f32;
789 B::float_div_scalar(B::float_sum(tensor), num_elems.into())
790 }
791
792 /// Mean of all elements in a tensor along a dimension.
793 ///
794 /// # Arguments
795 ///
796 /// * `tensor` - The tensor to mean.
797 /// * `dim` - The dimension along which to mean.
798 ///
799 /// # Returns
800 ///
801 /// A tensor with the mean of all elements in `tensor` along `dim`.
802 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
803
804 /// Computes the cumulative sum of elements along a dimension.
805 ///
806 /// # Arguments
807 ///
808 /// * `tensor` - The tensor to compute the cumulative sum of.
809 /// * `dim` - The dimension along which to compute the cumulative sum.
810 ///
811 /// # Returns
812 ///
813 /// A tensor with the same shape where each element is the cumulative sum
814 /// of all elements up to and including that position along the dimension.
815 fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
816
817 /// Computes the cumulative product of elements along a dimension.
818 ///
819 /// # Arguments
820 ///
821 /// * `tensor` - The tensor to compute the cumulative product of.
822 /// * `dim` - The dimension along which to compute the cumulative product.
823 ///
824 /// # Returns
825 ///
826 /// A tensor with the same shape where each element is the cumulative product
827 /// of all elements up to and including that position along the dimension.
828 fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
829
830 /// Computes the cumulative minimum of elements along a dimension.
831 ///
832 /// # Arguments
833 ///
834 /// * `tensor` - The tensor to compute the cumulative minimum of.
835 /// * `dim` - The dimension along which to compute the cumulative minimum.
836 ///
837 /// # Returns
838 ///
839 /// A tensor with the same shape where each element is the minimum
840 /// of all elements up to and including that position along the dimension.
841 fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
842
843 /// Computes the cumulative maximum of elements along a dimension.
844 ///
845 /// # Arguments
846 ///
847 /// * `tensor` - The tensor to compute the cumulative maximum of.
848 /// * `dim` - The dimension along which to compute the cumulative maximum.
849 ///
850 /// # Returns
851 ///
852 /// A tensor with the same shape where each element is the maximum
853 /// of all elements up to and including that position along the dimension.
854 fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
855
856 /// Converts a tensor to another floating point data type.
857 ///
858 /// # Arguments
859 ///
860 /// * `tensor` - The tensor to convert.
861 /// * `dtype` - The target data type.
862 ///
863 /// # Returns
864 ///
865 /// A tensor with the same values as `tensor` but in the target floating point data type.
866 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
867
868 /// Returns a new tensor with exponential values.
869 ///
870 /// # Arguments
871 ///
872 /// * `tensor` - The tensor to exponentiate.
873 ///
874 /// # Returns
875 ///
876 /// A tensor with the same shape as `tensor` with exponential values.
877 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
878
879 /// Returns a new tensor with natural logarithm values.
880 ///
881 /// # Arguments
882 ///
883 /// * `tensor` - The tensor to take the logarithm of.
884 ///
885 /// # Returns
886 ///
887 /// A tensor with the same shape as `tensor` with natural logarithm values.
888 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
889
890 /// Returns a new tensor with logarithm values of (1 + Xi).
891 ///
892 /// # Arguments
893 ///
894 /// * `tensor` - The tensor to take the logarithm of.
895 ///
896 /// # Returns
897 ///
898 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
899 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
900
901 /// Element-wise power with a FloatTensor.
902 ///
903 /// # Arguments
904 ///
905 /// * `lhs` - The left-hand side tensor.
906 /// * `rhs` - The right-hand side tensor.
907 ///
908 /// # Returns
909 ///
910 /// The elements of `lhs` raised to the power of the elements of `rhs`.
911 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
912
913 /// Element-wise power with an IntTensor.
914 ///
915 /// # Arguments
916 ///
917 /// * `lhs` - The left-hand side tensor.
918 /// * `rhs` - The right-hand side floatTensor.
919 ///
920 /// # Returns
921 ///
922 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
923 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
924 Self::float_powf(lhs, B::int_into_float(rhs))
925 }
926
927 /// Raises a tensor to the power of an int scalar.
928 ///
929 /// # Backend Implementors Note
930 ///
931 /// A number of common exponent cases can be implemented with operations
932 /// which are much cheaper than generic exponentiation.
933 ///
934 /// This (`Backend` impl overridable) operation handles generic optimizations
935 /// for several common integer exponent cases; and then dispatches to
936 /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
937 /// operation to handle the generic case.
938 ///
939 /// # Arguments
940 ///
941 /// * `lhs` - The left-hand side tensor.
942 /// * `rhs` - The right-hand side scalar.
943 ///
944 /// # Returns
945 ///
946 /// The elements of `lhs` raised to the value of `rhs`.
947 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
948 match rhs.elem::<i64>() {
949 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
950 1 => lhs,
951 2 => B::float_mul(lhs.clone(), lhs),
952 -1 => Self::float_recip(lhs),
953 -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
954 _ => Self::float_powi_scalar_impl(lhs, rhs),
955 }
956 }
957
958 /// Raises a tensor to the power of an int scalar.
959 ///
960 /// # Backend Implementors Note
961 ///
962 /// This is the generic implementation of integer exponentiation
963 /// called by [`Self::float_powi_scalar`] in the fallback case.
964 ///
965 /// As a general rule, this should not be called directly.
966 ///
967 /// # Arguments
968 ///
969 /// * `lhs` - The left-hand side tensor.
970 /// * `rhs` - The right-hand side scalar.
971 ///
972 /// # Returns
973 ///
974 /// The elements of `lhs` raised to the value of `rhs`.
975 fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
976 // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
977 Self::float_powf_scalar_impl(lhs, rhs)
978 }
979
980 /// Returns a new tensor with values raised to the power of float `value`.
981 ///
982 /// # Backend Implementors Note
983 ///
984 /// This (`Backend` impl overridable) operation dispatches integer exponentiation
985 /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
986 /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
987 /// operation to handle the generic case.
988 ///
989 /// # Arguments
990 ///
991 /// * `tensor` - The tensor to exponentiate.
992 /// * `value` - The exponent.
993 ///
994 /// # Returns
995 ///
996 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
997 fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {
998 if let Some(exp) = value.try_as_integer() {
999 Self::float_powi_scalar(tensor, exp)
1000 } else {
1001 Self::float_powf_scalar_impl(tensor, value)
1002 }
1003 }
1004
1005 /// Returns a new tensor with values raised to the power of float `value`.
1006 ///
1007 /// # Backend Implementors Note
1008 ///
1009 /// This is the generic implementation of integer exponentiation
1010 /// called by [`Self::float_powf_scalar`] in the fallback case.
1011 ///
1012 /// This is the minimal required support a `Backend` must implement
1013 /// for exponentiation.
1014 ///
1015 /// As a general rule, this should not be called directly.
1016 ///
1017 /// # Arguments
1018 ///
1019 /// * `tensor` - The tensor to exponentiate.
1020 /// * `value` - The exponent.
1021 ///
1022 /// # Returns
1023 ///
1024 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1025 fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;
1026
1027 /// Returns a new tensor with square root values.
1028 ///
1029 /// # Arguments
1030 ///
1031 /// * `tensor` - The tensor to take the square root of.
1032 ///
1033 /// # Returns
1034 ///
1035 /// A tensor with the same shape as `tensor` with square root values.
1036 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1037
1038 /// Returns a new tensor with absolute values.
1039 ///
1040 /// # Arguments
1041 ///
1042 /// * `tensor` - The tensor to take absolute value of.
1043 ///
1044 /// # Returns
1045 ///
1046 /// A tensor with the same shape as `tensor` with absolute values.
1047 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1048
1049 /// Returns a new tensor with cosine values.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `tensor` - The tensor to take the cosine of.
1054 ///
1055 /// # Returns
1056 ///
1057 /// A tensor with the same shape as `tensor` with cosine values.
1058 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1059
1060 /// Returns a new tensor with sine values.
1061 ///
1062 /// # Arguments
1063 ///
1064 /// * `tensor` - The tensor to take the sine of.
1065 ///
1066 /// # Returns
1067 ///
1068 /// A tensor with the same shape as `tensor` with sine values.
1069 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1070
1071 /// Returns a new tensor with tangent values.
1072 ///
1073 /// # Arguments
1074 ///
1075 /// * `tensor` - The tensor to take the tangent of.
1076 ///
1077 /// # Returns
1078 ///
1079 /// A tensor with the same shape as `tensor` with tangent values.
1080 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1081
1082 /// Returns a new tensor with hyperbolic cosine values.
1083 ///
1084 /// # Arguments
1085 ///
1086 /// * `tensor` - The tensor to take the hyperbolic cosine of.
1087 ///
1088 /// # Returns
1089 ///
1090 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1091 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1092
1093 /// Returns a new tensor with hyperbolic sine values.
1094 ///
1095 /// # Arguments
1096 ///
1097 /// * `tensor` - The tensor to take the hyperbolic sine of.
1098 ///
1099 /// # Returns
1100 ///
1101 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1102 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1103
1104 /// Returns a new tensor with hyperbolic tangent values.
1105 ///
1106 /// # Arguments
1107 ///
1108 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1109 ///
1110 /// # Returns
1111 ///
1112 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1113 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1114
1115 /// Returns a new tensor with inverse cosine values.
1116 ///
1117 /// # Arguments
1118 ///
1119 /// * `tensor` - The input tensor.
1120 ///
1121 /// # Returns
1122 ///
1123 /// A tensor with the same shape as `tensor` with inverse cosine values.
1124 fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1125
1126 /// Returns a new tensor with inverse hyperbolic cosine values.
1127 ///
1128 /// # Arguments
1129 ///
1130 /// * `tensor` - The input tensor.
1131 ///
1132 /// # Returns
1133 ///
1134 /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
1135 fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1136
1137 /// Returns a new tensor with inverse sine values.
1138 ///
1139 /// # Arguments
1140 ///
1141 /// * `tensor` - The input tensor.
1142 ///
1143 /// # Returns
1144 ///
1145 /// A tensor with the same shape as `tensor` with inverse sine values.
1146 fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1147
1148 /// Returns a new tensor with inverse hyperbolic sine values.
1149 ///
1150 /// # Arguments
1151 ///
1152 /// * `tensor` - The input tensor.
1153 ///
1154 /// # Returns
1155 ///
1156 /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1157 fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1158
1159 /// Returns a new tensor with the inverse tangent values.
1160 ///
1161 /// # Arguments
1162 ///
1163 /// * `tensor` - The input tensor.
1164 ///
1165 /// # Returns
1166 ///
1167 /// A tensor with the same shape as `tensor` with the inverse tangent values.
1168 fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1169
1170 /// Returns a new tensor with the inverse hyperbolic tangent values.
1171 ///
1172 /// # Arguments
1173 ///
1174 /// * `tensor` - The input tensor.
1175 ///
1176 /// # Returns
1177 ///
1178 /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.
1179 fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1180
1181 /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1182 ///
1183 /// # Arguments
1184 ///
1185 /// * `lhs` - The tensor with y coordinates.
1186 /// * `rhs` - The tensor with x coordinates.
1187 ///
1188 /// # Returns
1189 ///
1190 /// A tensor with the four-quadrant inverse tangent values.
1191 fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
1192
1193 /// Returns a new tensor with rounded values.
1194 ///
1195 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1196 /// strategy, with halfway cases rounded to the nearest even integer value.
1197 ///
1198 /// # Arguments
1199 ///
1200 /// * `tensor` - The tensor to be rounded.
1201 ///
1202 /// # Returns
1203 ///
1204 /// A tensor with the same shape as `tensor` with rounded values.
1205 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1206
1207 /// Returns a new tensor with floored values.
1208 ///
1209 /// # Arguments
1210 ///
1211 /// * `tensor` - The tensor to be floored.
1212 ///
1213 /// # Returns
1214 ///
1215 /// A tensor with the same shape as `tensor` with floored values.
1216 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1217
1218 /// Returns a new tensor with ceiled values.
1219 ///
1220 /// # Arguments
1221 ///
1222 /// * `tensor` - The tensor to be ceiled.
1223 ///
1224 /// # Returns
1225 ///
1226 /// A tensor with the same shape as `tensor` with ceiled values.
1227 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1228
1229 /// Returns a new tensor with truncated values.
1230 ///
1231 /// # Arguments
1232 ///
1233 /// * `tensor` - The tensor to be truncated.
1234 ///
1235 /// # Returns
1236 ///
1237 /// A tensor with the same shape as `tensor` with truncated values.
1238 fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1239
1240 /// Returns a new tensor with the error function values.
1241 ///
1242 /// # Arguments
1243 ///
1244 /// * `tensor` - The tensor to take the error function of.
1245 ///
1246 /// # Returns
1247 ///
1248 /// A tensor with the same shape as `tensor` with error function values.
1249 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1250
1251 /// Concatenates tensors along a dimension.
1252 ///
1253 /// # Arguments
1254 ///
1255 /// * `tensors` - The tensors to concatenate.
1256 /// * `dim` - The dimension along which to concatenate.
1257 ///
1258 /// # Returns
1259 ///
1260 /// A tensor with the concatenated tensors along `dim`.
1261 ///
1262 /// # Note
1263 ///
1264 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1265 /// high-level tensor API and will not be passed to this method. Backend implementations do
1266 /// not need to handle empty tensors.
1267 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1268 cat_with_slice_assign::<B, Float>(
1269 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1270 dim,
1271 )
1272 .tensor()
1273 }
1274
1275 /// Gets the indices of the maximum elements of a tensor along an axis.
1276 ///
1277 /// # Arguments
1278 ///
1279 /// * `tensor` - The tensor to get the maximum elements of.
1280 /// * `dim` - The dimension along which to get the maximum elements.
1281 ///
1282 /// # Returns
1283 ///
1284 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1285 fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1286
1287 /// Gets the indices of the minimum elements of a tensor along an axis.
1288 ///
1289 /// # Arguments
1290 ///
1291 /// * `tensor` - The tensor to get the minimum elements of.
1292 /// * `dim` - The dimension along which to get the minimum elements.
1293 ///
1294 /// # Returns
1295 ///
1296 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1297 fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1298
1299 /// Gets the maximum element of a tensor.
1300 ///
1301 /// # Arguments
1302 ///
1303 /// * `tensor` - The tensor to get the maximum elements of.
1304 ///
1305 /// # Returns
1306 ///
1307 /// A tensor with the maximum element of `tensor`.
1308 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1309 let shape = tensor.shape();
1310 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1311
1312 B::float_max_dim(tensor, 0)
1313 }
1314
1315 /// Gets the maximum elements of a tensor along an axis.
1316 ///
1317 /// # Arguments
1318 ///
1319 /// * `tensor` - The tensor to get the maximum elements of.
1320 /// * `dim` - The dimension along which to get the maximum elements.
1321 ///
1322 /// # Returns
1323 ///
1324 /// A tensor with the maximum elements of `tensor` along `dim`.
1325 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1326 let index = B::float_argmax(tensor.clone(), dim);
1327
1328 B::float_gather(dim, tensor, index)
1329 }
1330
1331 /// Gets the maximum elements of a tensor along an axis and their indices.
1332 ///
1333 /// # Arguments
1334 ///
1335 /// * `tensor` - The tensor to get the maximum elements of.
1336 /// * `dim` - The dimension along which to get the maximum elements.
1337 ///
1338 /// # Returns
1339 ///
1340 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1341 fn float_max_dim_with_indices(
1342 tensor: FloatTensor<B>,
1343 dim: usize,
1344 ) -> (FloatTensor<B>, IntTensor<B>) {
1345 let index = B::float_argmax(tensor.clone(), dim);
1346 let values = B::float_gather(dim, tensor, index.clone());
1347
1348 (values, index)
1349 }
1350
1351 /// Gets the minimum element of a tensor.
1352 ///
1353 /// # Arguments
1354 ///
1355 /// * `tensor` - The tensor to get the minimum elements of.
1356 ///
1357 /// # Returns
1358 ///
1359 /// A tensor with the minimum element of `tensor`.
1360 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1361 let shape = tensor.shape();
1362 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1363
1364 B::float_min_dim(tensor, 0)
1365 }
1366
1367 /// Gets the minimum elements of a tensor along an axis.
1368 ///
1369 /// # Arguments
1370 ///
1371 /// * `tensor` - The tensor to get the minimum elements of.
1372 /// * `dim` - The dimension along which to get the minimum elements.
1373 ///
1374 /// # Returns
1375 ///
1376 /// A tensor with the minimum elements of `tensor` along `dim`.
1377 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1378 let index = B::float_argmin(tensor.clone(), dim);
1379
1380 B::float_gather(dim, tensor, index)
1381 }
1382
1383 /// Gets the minimum elements of a tensor along an axis and their indices.
1384 ///
1385 /// # Arguments
1386 ///
1387 /// * `tensor` - The tensor to get the minimum elements of.
1388 /// * `dim` - The dimension along which to get the minimum elements.
1389 ///
1390 /// # Returns
1391 ///
1392 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1393 fn float_min_dim_with_indices(
1394 tensor: FloatTensor<B>,
1395 dim: usize,
1396 ) -> (FloatTensor<B>, IntTensor<B>) {
1397 let index = B::float_argmin(tensor.clone(), dim);
1398 let values = B::float_gather(dim, tensor, index.clone());
1399
1400 (values, index)
1401 }
1402
1403 /// Gets the maximum absolute element of a tensor.
1404 ///
1405 /// # Arguments
1406 ///
1407 /// * `tensor` - The tensor to get the maximum elements of.
1408 ///
1409 /// # Returns
1410 ///
1411 /// A tensor with the maximum element of `tensor`.
1412 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1413 let shape = tensor.shape();
1414 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1415
1416 B::float_max_abs_dim(tensor, 0)
1417 }
1418
1419 /// Gets the maximum absolute elements of a tensor along an axis.
1420 ///
1421 /// # Arguments
1422 ///
1423 /// * `tensor` - The tensor to get the maximum elements of.
1424 /// * `dim` - The dimension along which to get the maximum elements.
1425 ///
1426 /// # Returns
1427 ///
1428 /// A tensor with the maximum elements of `tensor` along `dim`.
1429 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1430 B::float_max_dim(B::float_abs(tensor), dim)
1431 }
1432
1433 /// Tests if any element in the float `tensor` evaluates to True.
1434 ///
1435 /// # Arguments
1436 ///
1437 /// * `tensor` - The tensor to test.
1438 ///
1439 /// # Returns
1440 ///
1441 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1442 fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1443 let bool_tensor = B::float_equal_elem(tensor, 0f32.into());
1444 let bool_tensor = B::bool_not(bool_tensor);
1445 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1446 B::float_greater_elem(sum, 0f32.into())
1447 }
1448
1449 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1450 ///
1451 /// # Arguments
1452 ///
1453 /// * `tensor` - The tensor to test.
1454 /// * `dim` - The axis along which to test.
1455 ///
1456 /// # Returns
1457 ///
1458 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1459 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1460 /// input evaluates to True, False otherwise.
1461 fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1462 let bool_tensor = B::float_equal_elem(tensor, 0f32.into());
1463 let bool_tensor = B::bool_not(bool_tensor);
1464 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1465 B::float_greater_elem(sum, 0f32.into())
1466 }
1467
1468 /// Tests if all elements in the float `tensor` evaluate to True.
1469 ///
1470 /// # Arguments
1471 ///
1472 /// * `tensor` - The tensor to test.
1473 ///
1474 /// # Returns
1475 ///
1476 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1477 /// evaluate to True, False otherwise.
1478 fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1479 let num_elems = tensor.shape().num_elements() as f32;
1480 let bool_tensor = B::float_equal_elem(tensor, 0f32.into());
1481 let bool_tensor = B::bool_not(bool_tensor);
1482 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1483 B::float_equal_elem(sum, num_elems.into())
1484 }
1485
1486 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1487 ///
1488 /// # Arguments
1489 ///
1490 /// * `tensor` - The tensor to test.
1491 /// * `dim` - The axis along which to test.
1492 ///
1493 /// # Returns
1494 ///
1495 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1496 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1497 /// evaluates to True, False otherwise.
1498 fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1499 let num_elems = tensor.shape().dims[dim] as f32;
1500 let bool_tensor = B::float_equal_elem(tensor, 0f32.into());
1501 let bool_tensor = B::bool_not(bool_tensor);
1502 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1503 B::float_equal_elem(sum, num_elems.into())
1504 }
1505
1506 /// Returns the signs of the float `tensor`.
1507 ///
1508 /// # Arguments
1509 ///
1510 /// * `tensor` - The tensor to extract the signs from.
1511 ///
1512 /// # Returns
1513 ///
1514 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1515 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1516 let zeros = B::float_zeros(
1517 tensor.shape(),
1518 &B::float_device(&tensor),
1519 tensor.dtype().into(),
1520 );
1521 let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into());
1522 let greater_than_zero = B::float_greater_elem(tensor, 0f32.into());
1523
1524 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());
1525 result = B::float_mask_fill(result, greater_than_zero, 1f32.into());
1526 result
1527 }
1528
1529 /// Broadcasts the float `tensor` to the given `shape`.
1530 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1531
1532 /// Sort the elements of the input `tensor` by value in along a given dimension.
1533 ///
1534 /// This sort is unstable (i.e., may reorder equal elements).
1535 ///
1536 /// # Arguments
1537 ///
1538 /// * `tensor` - The input tensor.
1539 /// * `dim` - The axis along which to sort.
1540 /// * `descending` - The sorting order.
1541 ///
1542 /// # Returns
1543 ///
1544 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1545 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1546 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1547 }
1548
1549 /// Sort the elements of the input `tensor` by value in along a given dimension.
1550 ///
1551 /// This sort is unstable (i.e., may reorder equal elements).
1552 ///
1553 /// # Arguments
1554 ///
1555 /// * `tensor` - The input tensor.
1556 /// * `dim` - The axis along which to sort.
1557 /// * `descending` - The sorting order.
1558 ///
1559 /// # Returns
1560 ///
1561 /// A tensor with the same shape as the input tensor and corresponding indices, where
1562 /// the elements are sorted by value and the indices map back to the original input tensor.
1563 fn float_sort_with_indices(
1564 tensor: FloatTensor<B>,
1565 dim: usize,
1566 descending: bool,
1567 ) -> (FloatTensor<B>, IntTensor<B>) {
1568 let (values, indices) =
1569 sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1570 (values.tensor(), indices)
1571 }
1572
1573 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1574 ///
1575 /// This sort is unstable (i.e., may reorder equal elements).
1576 ///
1577 /// # Arguments
1578 ///
1579 /// * `tensor` - The input tensor.
1580 /// * `dim` - The axis along which to sort.
1581 /// * `descending` - The sorting order.
1582 ///
1583 /// # Returns
1584 ///
1585 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1586 fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1587 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1588 }
1589
1590 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1591 /// using the given locations in [-1, 1].
1592 ///
1593 /// # Arguments
1594 ///
1595 /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
1596 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1597 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1598 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1599 ///
1600 /// # Returns
1601 ///
1602 /// A tensor with shape (N, C, H_out, W_out)
1603 fn float_grid_sample_2d(
1604 tensor: FloatTensor<B>,
1605 grid: FloatTensor<B>,
1606 options: GridSampleOptions,
1607 ) -> FloatTensor<B> {
1608 float_grid_sample_2d_ref::<B>(tensor, grid, options)
1609 }
1610
1611 /// Unfold windows along a dimension.
1612 ///
1613 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1614 /// where windows are advanced by `step` at each index.
1615 ///
1616 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1617 ///
1618 /// # Arguments
1619 ///
1620 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1621 /// * `dim` - the selected dim.
1622 /// * `size` - the size of each unfolded window.
1623 /// * `step` - the step between each window.
1624 ///
1625 /// # Returns
1626 ///
1627 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1628 fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1629 -> FloatTensor<B>;
1630
1631 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1632 ///
1633 /// # Returns
1634 ///
1635 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1636 fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> {
1637 // Check if the input tensor is NaN by comparing it to itself
1638 // NaN is the only value that is not equal to itself
1639 B::float_not_equal(tensor.clone(), tensor)
1640 }
1641
1642 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1643 ///
1644 /// # Returns
1645 ///
1646 /// A boolean tensor where `true` indicates that the value is infinite
1647 fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> {
1648 B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into())
1649 }
1650}