burn_backend/backend/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatElem, FloatTensor, IntElem, IntTensor};
7use crate::{Backend, Distribution, TensorData, element::ElementConversion};
8use crate::{ExecutionError, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{FloatDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14 /// Creates a new tensor from the data structure.
15 ///
16 /// # Arguments
17 ///
18 /// * `data` - The data structure.
19 /// * `device` - The device to create the tensor on.
20 ///
21 /// # Returns
22 ///
23 /// The tensor with the given data.
24 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26 /// Creates a new tensor with random values.
27 ///
28 /// # Arguments
29 ///
30 /// * `shape` - The shape of the tensor.
31 /// * `distribution` - The distribution to sample from.
32 /// * `device` - The device to create the tensor on.
33 ///
34 /// # Returns
35 ///
36 /// The tensor with the given shape and random values.
37 fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
38 -> FloatTensor<B>;
39
40 /// Creates a new tensor with zeros.
41 ///
42 /// # Arguments
43 ///
44 /// * `shape` - The shape of the tensor.
45 /// * `device` - The device to create the tensor on.
46 /// * `dtype` - The target data type.
47 ///
48 /// # Returns
49 ///
50 /// The tensor with the given shape and zeros.
51 fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
52 Self::float_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
53 }
54
55 /// Creates a new tensor with ones.
56 ///
57 /// # Arguments
58 ///
59 /// * `shape` - The shape of the tensor.
60 /// * `device` - The device to create the tensor on.
61 /// * `dtype` - The target data type.
62 ///
63 /// # Returns
64 ///
65 /// The tensor with the given shape and ones.
66 fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
67 Self::float_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
68 }
69
70 /// Creates a tensor filled with given value.
71 ///
72 /// # Arguments
73 ///
74 /// * `shape` - The shape of the tensor.
75 /// * `fill_value` - The value with which to fill the tensor.
76 /// * `device` - The device to create the tensor on.
77 /// * `dtype` - The target data type.
78 ///
79 /// # Returns
80 ///
81 /// The tensor filled with given value
82 fn float_full(
83 shape: Shape,
84 fill_value: FloatElem<B>,
85 device: &Device<B>,
86 dtype: FloatDType,
87 ) -> FloatTensor<B> {
88 Self::float_from_data(
89 TensorData::full_dtype(shape, fill_value, dtype.into()),
90 device,
91 )
92 }
93
94 /// Converts the tensor to a data structure.
95 ///
96 /// # Arguments
97 ///
98 /// * `tensor` - The tensor.
99 ///
100 /// # Returns
101 ///
102 /// The data structure with the tensor's data.
103 fn float_into_data(
104 tensor: FloatTensor<B>,
105 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
106
107 /// Gets the device of the tensor.
108 ///
109 /// # Arguments
110 ///
111 /// * `tensor` - The tensor.
112 ///
113 /// # Returns
114 ///
115 /// The device of the tensor.
116 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
117
118 /// Moves the tensor to the given device.
119 ///
120 /// # Arguments
121 ///
122 /// * `tensor` - The tensor.
123 /// * `device` - The device to move the tensor to.
124 ///
125 /// # Returns
126 ///
127 /// The tensor on the given device.
128 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
129
130 /// Converts float tensor to int tensor.
131 ///
132 /// # Arguments
133 ///
134 /// * `tensor` - The tensor.
135 ///
136 /// # Returns
137 ///
138 /// The int tensor with the same data as the float tensor.
139 fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
140
141 /// Creates an empty tensor with the given shape.
142 ///
143 /// # Arguments
144 ///
145 /// * `shape` - The shape of the tensor.
146 /// * `device` - The device to create the tensor on.
147 /// * `dtype` - The target data type.
148 ///
149 /// # Returns
150 ///
151 /// The empty tensor with the given shape.
152 fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
153
154 /// Repeat the tensor along the given dimension.
155 ///
156 /// # Arguments
157 ///
158 /// * `tensor` - The tensor.
159 /// * `dim` - The dimension to repeat.
160 /// * `times` - The number of times to repeat the dimension.
161 ///
162 /// # Returns
163 ///
164 /// The tensor with the given dimension repeated.
165 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
166 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
167 }
168
169 /// Adds two tensors together.
170 ///
171 /// # Arguments
172 ///
173 /// * `lhs` - The left-hand side tensor.
174 /// * `rhs` - The right-hand side tensor.
175 ///
176 /// # Returns
177 ///
178 /// The result of adding the two tensors together.
179 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
180
181 /// Adds a scalar to a tensor.
182 ///
183 /// # Arguments
184 ///
185 /// * `lhs` - The left-hand side tensor.
186 /// * `rhs` - The right-hand side scalar.
187 ///
188 /// # Returns
189 ///
190 /// The result of adding the scalar to the tensor.
191 fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
192
193 /// Clamps a tensor under a minimum value.
194 ///
195 /// # Arguments
196 ///
197 /// * `tensor` - The tensor to clamp.
198 /// * `min` - The minimum value.
199 ///
200 /// # Returns
201 ///
202 /// The clamped tensor.
203 fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
204 // Default implementation
205 let mask = Self::float_lower_elem(tensor.clone(), min);
206 B::float_mask_fill(tensor, mask, min)
207 }
208
209 /// Clamps a tensor over a maximum value.
210 ///
211 /// # Arguments
212 ///
213 /// * `tensor` - The tensor to clamp.
214 /// * `max` - The maximum value.
215 ///
216 /// # Returns
217 ///
218 /// The clamped tensor.
219 fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
220 // Default implementation
221 let mask = Self::float_greater_elem(tensor.clone(), max);
222 B::float_mask_fill(tensor, mask, max)
223 }
224
225 /// Clamps a tensor between a minimum and maximum value.
226 ///
227 /// # Arguments
228 ///
229 /// * `tensor` - The tensor to clamp.
230 /// * `min` - The minimum value.
231 /// * `max` - The maximum value.
232 ///
233 /// # Returns
234 ///
235 /// The clamped tensor.
236 fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
237 // Default implementation
238 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
239 }
240
241 /// Subtracts two tensors.
242 ///
243 /// # Arguments
244 ///
245 /// * `lhs` - The left-hand side tensor.
246 /// * `rhs` - The right-hand side tensor.
247 ///
248 /// # Returns
249 ///
250 /// The result of subtracting the two tensors.
251 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
252
253 /// Subtracts a scalar from a tensor.
254 ///
255 /// # Arguments
256 ///
257 /// * `lhs` - The left-hand side tensor.
258 /// * `rhs` - The right-hand side scalar.
259 ///
260 /// # Returns
261 ///
262 /// The result of subtracting the scalar from the tensor.
263 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
264
265 /// Multiplies two tensors together element-wise.
266 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
267
268 /// Multiplies a tensor by a scalar.
269 ///
270 /// # Arguments
271 ///
272 /// * `lhs` - The left-hand side tensor.
273 /// * `rhs` - The right-hand side scalar.
274 ///
275 /// # Returns
276 ///
277 /// The result of multiplying the tensor by the scalar.
278 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
279
280 /// Divides two tensors element-wise.
281 ///
282 /// # Arguments
283 ///
284 /// * `lhs` - The left-hand side tensor.
285 /// * `rhs` - The right-hand side tensor.
286 ///
287 /// # Returns
288 ///
289 /// The result of dividing the two tensors.
290 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
291
292 /// Divides a tensor by a scalar.
293 ///
294 /// # Arguments
295 ///
296 /// * `lhs` - The left-hand side tensor.
297 /// * `rhs` - The right-hand side scalar.
298 ///
299 /// # Returns
300 ///
301 /// The result of dividing the tensor by the scalar.
302 fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
303
304 /// Computes the remainder of division between two tensors element-wise.
305 ///
306 /// # Arguments
307 ///
308 /// * `lhs` - The left-hand side tensor.
309 /// * `rhs` - The right-hand side tensor.
310 ///
311 /// # Returns
312 ///
313 /// The element-wise remainder when dividing `lhs` by `rhs`.
314 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
315
316 /// Computes the modulus of a tensor given a scalar.
317 ///
318 /// # Arguments
319 /// * `lhs` - The left-hand side tensor.
320 /// * `rhs` - The right-hand side scalar.
321 ///
322 /// # Returns
323 ///
324 /// The result of applying the modulus of the scalar to the tensor.
325 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
326
327 /// Multiplies two tensors together using matrix multiplication.
328 ///
329 /// # Arguments
330 ///
331 /// * `lhs` - The left-hand side tensor.
332 /// * `rhs` - The right-hand side tensor.
333 ///
334 /// # Returns
335 ///
336 /// The result of multiplying the two tensors together using matrix multiplication.
337 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
338
339 /// Computes the cross product of two tensors along a given dimension.
340 ///
341 /// # Arguments
342 ///
343 /// * `lhs` - The left-hand side tensor.
344 /// * `rhs` - The right-hand side tensor.
345 /// * `dim` - The dimension to compute the cross product along.
346 ///
347 /// # Returns
348 ///
349 /// The cross product of the two tensors.
350 fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
351
352 /// Negates a tensor element-wise.
353 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
354 Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
355 }
356
357 /// Calculates the reciprocals element-wise
358 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
359
360 /// Transposes a tensor.
361 ///
362 /// # Arguments
363 ///
364 /// * `tensor` - The tensor to transpose.
365 ///
366 /// # Returns
367 ///
368 /// The transposed tensor.
369 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
370 let ndims = tensor.shape().num_dims();
371 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
372 }
373
374 /// Swaps two dimensions of a tensor.
375 ///
376 /// # Arguments
377 ///
378 /// * `tensor` - The tensor to swap the dimensions of.
379 /// * `dim1` - The first dimension to swap.
380 /// * `dim2` - The second dimension to swap.
381 ///
382 /// # Returns
383 ///
384 /// The tensor with the dimensions swapped.
385 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
386
387 /// Permutes the dimensions of a tensor.
388 ///
389 /// # Arguments
390 ///
391 /// * `tensor` - The tensor to permute the dimensions of.
392 /// * `axes` - The new order of the dimensions.
393 /// # Returns
394 ///
395 /// The tensor with the dimensions permuted.
396 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
397
398 /// Reverse the order of elements in a tensor along the given axes.
399 ///
400 /// # Arguments
401 ///
402 /// * `tensor` - The tensor to reverse.
403 /// * `axes` - The axes to reverse.
404 ///
405 /// The tensor with the elements reversed.
406 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
407
408 /// Reshapes a tensor.
409 ///
410 /// # Arguments
411 ///
412 /// * `tensor` - The tensor to reshape.
413 /// * `shape` - The new shape of the tensor.
414 ///
415 /// # Returns
416 ///
417 /// The tensor with the new shape.
418 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
419
420 /// Gather elements from a tensor.
421 ///
422 /// # Arguments
423 ///
424 /// * `dim` - The dimension to gather from.
425 /// * `tensor` - The tensor to gather from.
426 /// * `indices` - The indices to gather.
427 ///
428 /// # Returns
429 ///
430 /// The gathered elements.
431 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
432
433 /// Scatter elements into a tensor using sum reduction.
434 ///
435 /// # Arguments
436 ///
437 /// * `dim` - The dimension to scatter into.
438 /// * `tensor` - The tensor to scatter into.
439 /// * `indices` - The indices to scatter into.
440 /// * `value` - The value to scatter.
441 ///
442 /// # Returns
443 ///
444 /// The tensor with the scattered elements.
445 fn float_scatter_add(
446 dim: usize,
447 tensor: FloatTensor<B>,
448 indices: IntTensor<B>,
449 value: FloatTensor<B>,
450 ) -> FloatTensor<B>;
451
452 /// Select tensor elements along the given dimension corresponding for the given indices.
453 ///
454 /// # Arguments
455 ///
456 /// * `tensor` - The tensor to select from.
457 /// * `dim` - The dimension to select from.
458 /// * `indices` - The indices to select.
459 ///
460 /// # Returns
461 ///
462 /// The selected elements.
463 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
464
465 /// Assign the selected elements along the given dimension corresponding for the given indices
466 /// to the given value using sum reduction.
467 ///
468 /// # Arguments
469 ///
470 /// * `tensor` - The tensor to select from.
471 /// * `dim` - The dimension to select from.
472 /// * `indices` - The indices to select.
473 /// * `value` - The value to assign.
474 ///
475 /// # Returns
476 ///
477 /// The tensor with the selected elements assigned to the given value.
478 fn float_select_add(
479 tensor: FloatTensor<B>,
480 dim: usize,
481 indices: IntTensor<B>,
482 value: FloatTensor<B>,
483 ) -> FloatTensor<B>;
484
485 /// Select tensor elements corresponding to the given slices.
486 ///
487 /// # Arguments
488 ///
489 /// * `tensor` - The tensor to select from.
490 /// * `slices` - The slices specifying ranges and steps for each dimension.
491 ///
492 /// # Returns
493 ///
494 /// The selected elements in a new tensor.
495 ///
496 /// # Note
497 ///
498 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
499 /// be passed to this method. Backend implementations do not need to handle empty slices.
500 fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
501
502 /// Assign the selected elements corresponding to the given slices to the given value.
503 ///
504 /// # Arguments
505 ///
506 /// * `tensor` - The tensor to select from.
507 /// * `ranges` - The ranges to select.
508 /// * `value` - The value to assign.
509 ///
510 /// # Returns
511 ///
512 /// The tensor with the selected elements assigned to the given value.
513 ///
514 /// # Note
515 ///
516 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
517 /// high-level tensor API and will not be passed to this method. Backend implementations do
518 /// not need to handle empty slice assignments.
519 fn float_slice_assign(
520 tensor: FloatTensor<B>,
521 slices: &[Slice],
522 value: FloatTensor<B>,
523 ) -> FloatTensor<B>;
524
525 /// Update the given tensor with the value tensor where the mask is true.
526 ///
527 /// # Arguments
528 ///
529 /// * `tensor` - The tensor to select from.
530 /// * `mask` - The boolean mask to select with.
531 /// * `value` - The value to assign to the selected elements from the value tensor.
532 ///
533 /// # Returns
534 ///
535 /// The tensor with the selected elements assigned to the given value.
536 fn float_mask_where(
537 tensor: FloatTensor<B>,
538 mask: BoolTensor<B>,
539 value: FloatTensor<B>,
540 ) -> FloatTensor<B>;
541
542 /// Update the given tensor with the value where the mask is true.
543 ///
544 /// # Arguments
545 ///
546 /// * `tensor` - The tensor to select from.
547 /// * `mask` - The boolean mask to select with.
548 /// * `value` - The value to assign to the selected elements.
549 ///
550 /// # Returns
551 ///
552 /// The tensor with the selected elements assigned to the given value.
553 fn float_mask_fill(
554 tensor: FloatTensor<B>,
555 mask: BoolTensor<B>,
556 value: FloatElem<B>,
557 ) -> FloatTensor<B>;
558
559 /// Equal comparison of two tensors.
560 ///
561 /// # Arguments
562 ///
563 /// * `lhs` - The left-hand side tensor.
564 /// * `rhs` - The right-hand side tensor.
565 ///
566 /// # Returns
567 ///
568 /// A boolean tensor with the result of the comparison.
569 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
570
571 /// Element-wise non-equality comparison.
572 ///
573 /// # Arguments
574 ///
575 /// * `lhs` - The left-hand side tensor.
576 /// * `rhs` - The right-hand side tensor.
577 ///
578 /// # Returns
579 ///
580 /// A boolean tensor with the result of the comparison.
581 fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
582 let equal_tensor = B::float_equal(lhs, rhs);
583 B::bool_not(equal_tensor)
584 }
585
586 /// Equal comparison of a tensor and a scalar.
587 ///
588 /// # Arguments
589 ///
590 /// * `lhs` - The left-hand side tensor.
591 /// * `rhs` - The right-hand side scalar.
592 ///
593 /// # Returns
594 ///
595 /// A boolean tensor with the result of the comparison.
596 fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
597
598 /// Element-wise non-equality comparison with a scalar.
599 ///
600 /// # Arguments
601 ///
602 /// * `lhs` - The left-hand side tensor.
603 /// * `rhs` - The right-hand side scalar.
604 ///
605 /// # Returns
606 ///
607 /// A boolean tensor with the result of the comparison.
608 fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
609 let equal_tensor = B::float_equal_elem(lhs, rhs);
610 B::bool_not(equal_tensor)
611 }
612
613 /// Greater than comparison of two tensors.
614 ///
615 /// # Arguments
616 ///
617 /// * `lhs` - The left-hand side tensor.
618 /// * `rhs` - The right-hand side tensor.
619 ///
620 /// # Returns
621 ///
622 /// A boolean tensor with the result of the comparison.
623 fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
624
625 /// Greater than comparison of a tensor and a scalar.
626 ///
627 /// # Arguments
628 ///
629 /// * `lhs` - The left-hand side tensor.
630 /// * `rhs` - The right-hand side scalar.
631 ///
632 /// # Returns
633 ///
634 /// A boolean tensor with the result of the comparison.
635 fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
636
637 /// Greater than or equal comparison of two tensors.
638 ///
639 /// # Arguments
640 ///
641 /// * `lhs` - The left-hand side tensor.
642 /// * `rhs` - The right-hand side tensor.
643 ///
644 /// # Returns
645 ///
646 /// A boolean tensor with the result of the comparison.
647 fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
648
649 /// Greater than or equal comparison of a tensor and a scalar.
650 ///
651 /// # Arguments
652 ///
653 /// * `lhs` - The left-hand side tensor.
654 /// * `rhs` - The right-hand side scalar.
655 ///
656 /// # Returns
657 ///
658 /// A boolean tensor with the result of the comparison.
659 fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
660
661 /// Less than comparison of two tensors.
662 ///
663 /// # Arguments
664 ///
665 /// * `lhs` - The left-hand side tensor.
666 /// * `rhs` - The right-hand side tensor.
667 ///
668 /// # Returns
669 ///
670 /// A boolean tensor with the result of the comparison.
671 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
672
673 /// Less than comparison of a tensor and a scalar.
674 ///
675 /// # Arguments
676 ///
677 /// * `lhs` - The left-hand side tensor.
678 /// * `rhs` - The right-hand side scalar.
679 ///
680 /// # Returns
681 ///
682 /// A boolean tensor with the result of the comparison.
683 fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
684
685 /// Less than or equal comparison of two tensors.
686 ///
687 /// # Arguments
688 ///
689 /// * `lhs` - The left-hand side tensor.
690 /// * `rhs` - The right-hand side tensor.
691 ///
692 /// # Returns
693 ///
694 /// A boolean tensor with the result of the comparison.
695 fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
696
697 /// Less than or equal comparison of a tensor and a scalar.
698 ///
699 /// # Arguments
700 ///
701 /// * `lhs` - The left-hand side tensor.
702 /// * `rhs` - The right-hand side scalar.
703 ///
704 /// # Returns
705 ///
706 /// A boolean tensor with the result of the comparison.
707 fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
708
709 /// Detaches a tensor from the computation graph.
710 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
711 // Should only be overridden by autodiff backends.
712 tensor
713 }
714
715 /// Sets the `require_grad` flag of a tensor.
716 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
717 // Should only be overridden by autodiff backends.
718 tensor
719 }
720
721 /// Returns the `require_grad` flag of a tensor.
722 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
723 // Should only be overridden by autodiff backends.
724 false
725 }
726
727 /// Sum of all elements in a tensor.
728 ///
729 /// # Arguments
730 ///
731 /// * `tensor` - The tensor to sum.
732 ///
733 /// # Returns
734 ///
735 /// A scalar tensor with the sum of all elements in `tensor`.
736 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
737
738 /// Sum of all elements in a tensor along a dimension.
739 ///
740 /// # Arguments
741 ///
742 /// * `tensor` - The tensor to sum.
743 /// * `dim` - The dimension along which to sum.
744 ///
745 /// # Returns
746 ///
747 /// A tensor with the sum of all elements in `tensor` along `dim`.
748 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
749
750 /// Product of all elements in a tensor.
751 ///
752 /// # Arguments
753 ///
754 /// * `tensor` - The tensor to product.
755 ///
756 /// # Returns
757 ///
758 /// A scalar tensor with the product of all elements in `tensor`.
759 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
760 // Product of all elements in a tensor
761 B::float_exp(B::float_sum(B::float_log(tensor)))
762 }
763
764 /// Product of all elements in a tensor along a dimension.
765 ///
766 /// # Arguments
767 ///
768 /// * `tensor` - The tensor to product.
769 ///
770 /// # Returns
771 ///
772 /// A tensor with the product of all elements in `tensor` along `dim`.
773 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
774 // Product of all elements in a tensor along a dimension
775 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
776 }
777
778 /// Mean of all elements in a tensor.
779 ///
780 /// # Arguments
781 ///
782 /// * `tensor` - The tensor to mean.
783 ///
784 /// # Returns
785 ///
786 /// A scalar tensor with the mean of all elements in `tensor`.
787 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
788 let num_elems = tensor.shape().num_elements();
789 B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
790 }
791
792 /// Mean of all elements in a tensor along a dimension.
793 ///
794 /// # Arguments
795 ///
796 /// * `tensor` - The tensor to mean.
797 /// * `dim` - The dimension along which to mean.
798 ///
799 /// # Returns
800 ///
801 /// A tensor with the mean of all elements in `tensor` along `dim`.
802 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
803
804 /// Computes the cumulative sum of elements along a dimension.
805 ///
806 /// # Arguments
807 ///
808 /// * `tensor` - The tensor to compute the cumulative sum of.
809 /// * `dim` - The dimension along which to compute the cumulative sum.
810 ///
811 /// # Returns
812 ///
813 /// A tensor with the same shape where each element is the cumulative sum
814 /// of all elements up to and including that position along the dimension.
815 fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
816
817 /// Computes the cumulative product of elements along a dimension.
818 ///
819 /// # Arguments
820 ///
821 /// * `tensor` - The tensor to compute the cumulative product of.
822 /// * `dim` - The dimension along which to compute the cumulative product.
823 ///
824 /// # Returns
825 ///
826 /// A tensor with the same shape where each element is the cumulative product
827 /// of all elements up to and including that position along the dimension.
828 fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
829
830 /// Computes the cumulative minimum of elements along a dimension.
831 ///
832 /// # Arguments
833 ///
834 /// * `tensor` - The tensor to compute the cumulative minimum of.
835 /// * `dim` - The dimension along which to compute the cumulative minimum.
836 ///
837 /// # Returns
838 ///
839 /// A tensor with the same shape where each element is the minimum
840 /// of all elements up to and including that position along the dimension.
841 fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
842
843 /// Computes the cumulative maximum of elements along a dimension.
844 ///
845 /// # Arguments
846 ///
847 /// * `tensor` - The tensor to compute the cumulative maximum of.
848 /// * `dim` - The dimension along which to compute the cumulative maximum.
849 ///
850 /// # Returns
851 ///
852 /// A tensor with the same shape where each element is the maximum
853 /// of all elements up to and including that position along the dimension.
854 fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
855
856 /// Converts a tensor to another floating point data type.
857 ///
858 /// # Arguments
859 ///
860 /// * `tensor` - The tensor to convert.
861 /// * `dtype` - The target data type.
862 ///
863 /// # Returns
864 ///
865 /// A tensor with the same values as `tensor` but in the target floating point data type.
866 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
867
868 /// Returns a new tensor with exponential values.
869 ///
870 /// # Arguments
871 ///
872 /// * `tensor` - The tensor to exponentiate.
873 ///
874 /// # Returns
875 ///
876 /// A tensor with the same shape as `tensor` with exponential values.
877 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
878
879 /// Returns a new tensor with natural logarithm values.
880 ///
881 /// # Arguments
882 ///
883 /// * `tensor` - The tensor to take the logarithm of.
884 ///
885 /// # Returns
886 ///
887 /// A tensor with the same shape as `tensor` with natural logarithm values.
888 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
889
890 /// Returns a new tensor with logarithm values of (1 + Xi).
891 ///
892 /// # Arguments
893 ///
894 /// * `tensor` - The tensor to take the logarithm of.
895 ///
896 /// # Returns
897 ///
898 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
899 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
900
901 /// Element-wise power with a FloatTensor.
902 ///
903 /// # Arguments
904 ///
905 /// * `lhs` - The left-hand side tensor.
906 /// * `rhs` - The right-hand side tensor.
907 ///
908 /// # Returns
909 ///
910 /// The elements of `lhs` raised to the power of the elements of `rhs`.
911 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
912
913 /// Element-wise power with an IntTensor.
914 ///
915 /// # Arguments
916 ///
917 /// * `lhs` - The left-hand side tensor.
918 /// * `rhs` - The right-hand side floatTensor.
919 ///
920 /// # Returns
921 ///
922 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
923 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
924 Self::float_powf(lhs, B::int_into_float(rhs))
925 }
926
927 /// Raises a tensor to the power of an int scalar.
928 ///
929 /// # Backend Implementors Note
930 ///
931 /// A number of common exponent cases can be implemented with operations
932 /// which are much cheaper than generic exponentiation.
933 ///
934 /// This (`Backend` impl overridable) operation handles generic optimizations
935 /// for several common integer exponent cases; and then dispatches to
936 /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
937 /// operation to handle the generic case.
938 ///
939 /// # Arguments
940 ///
941 /// * `lhs` - The left-hand side tensor.
942 /// * `rhs` - The right-hand side scalar.
943 ///
944 /// # Returns
945 ///
946 /// The elements of `lhs` raised to the value of `rhs`.
947 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
948 let exp = rhs.elem::<i32>();
949 match exp {
950 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
951 1 => lhs,
952 2 => B::float_mul(lhs.clone(), lhs),
953 -1 => Self::float_recip(lhs),
954 -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
955 _ => Self::float_powi_scalar_impl(lhs, rhs),
956 }
957 }
958
959 /// Raises a tensor to the power of an int scalar.
960 ///
961 /// # Backend Implementors Note
962 ///
963 /// This is the generic implementation of integer exponentiation
964 /// called by [`Self::float_powi_scalar`] in the fallback case.
965 ///
966 /// As a general rule, this should not be called directly.
967 ///
968 /// # Arguments
969 ///
970 /// * `lhs` - The left-hand side tensor.
971 /// * `rhs` - The right-hand side scalar.
972 ///
973 /// # Returns
974 ///
975 /// The elements of `lhs` raised to the value of `rhs`.
976 fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
977 // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
978 Self::float_powf_scalar_impl(lhs, rhs.elem::<f32>())
979 }
980
981 /// Returns a new tensor with values raised to the power of float `value`.
982 ///
983 /// # Backend Implementors Note
984 ///
985 /// This (`Backend` impl overridable) operation dispatches integer exponentiation
986 /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
987 /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
988 /// operation to handle the generic case.
989 ///
990 /// # Arguments
991 ///
992 /// * `tensor` - The tensor to exponentiate.
993 /// * `value` - The exponent.
994 ///
995 /// # Returns
996 ///
997 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
998 fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B> {
999 if num_traits::Float::floor(value) == value {
1000 // When the exponent is an integer, use the integer exponentiation implementation.
1001 let exp = B::IntElem::from_elem(value as i32);
1002 Self::float_powi_scalar(tensor, exp)
1003 } else {
1004 Self::float_powf_scalar_impl(tensor, value)
1005 }
1006 }
1007
1008 /// Returns a new tensor with values raised to the power of float `value`.
1009 ///
1010 /// # Backend Implementors Note
1011 ///
1012 /// This is the generic implementation of integer exponentiation
1013 /// called by [`Self::float_powf_scalar`] in the fallback case.
1014 ///
1015 /// This is the minimal required support a `Backend` must implement
1016 /// for exponentiation.
1017 ///
1018 /// As a general rule, this should not be called directly.
1019 ///
1020 /// # Arguments
1021 ///
1022 /// * `tensor` - The tensor to exponentiate.
1023 /// * `value` - The exponent.
1024 ///
1025 /// # Returns
1026 ///
1027 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1028 fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
1029
1030 /// Returns a new tensor with square root values.
1031 ///
1032 /// # Arguments
1033 ///
1034 /// * `tensor` - The tensor to take the square root of.
1035 ///
1036 /// # Returns
1037 ///
1038 /// A tensor with the same shape as `tensor` with square root values.
1039 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1040
1041 /// Returns a new tensor with absolute values.
1042 ///
1043 /// # Arguments
1044 ///
1045 /// * `tensor` - The tensor to take absolute value of.
1046 ///
1047 /// # Returns
1048 ///
1049 /// A tensor with the same shape as `tensor` with absolute values.
1050 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1051
1052 /// Returns a new tensor with cosine values.
1053 ///
1054 /// # Arguments
1055 ///
1056 /// * `tensor` - The tensor to take the cosine of.
1057 ///
1058 /// # Returns
1059 ///
1060 /// A tensor with the same shape as `tensor` with cosine values.
1061 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1062
1063 /// Returns a new tensor with sine values.
1064 ///
1065 /// # Arguments
1066 ///
1067 /// * `tensor` - The tensor to take the sine of.
1068 ///
1069 /// # Returns
1070 ///
1071 /// A tensor with the same shape as `tensor` with sine values.
1072 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1073
1074 /// Returns a new tensor with tangent values.
1075 ///
1076 /// # Arguments
1077 ///
1078 /// * `tensor` - The tensor to take the tangent of.
1079 ///
1080 /// # Returns
1081 ///
1082 /// A tensor with the same shape as `tensor` with tangent values.
1083 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> {
1084 let sin = B::float_sin(tensor.clone());
1085 let cos = B::float_cos(tensor);
1086 B::float_div(sin, cos)
1087 }
1088
1089 /// Returns a new tensor with hyperbolic cosine values.
1090 ///
1091 /// # Arguments
1092 ///
1093 /// * `tensor` - The tensor to take the hyperbolic cosine of.
1094 ///
1095 /// # Returns
1096 ///
1097 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1098 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1099 // cosh = ( e^x + e^(-x) ) / 2
1100 let e_x = B::float_exp(tensor.clone());
1101 let e_neg_x = B::float_exp(B::float_neg(tensor));
1102 let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x)
1103 B::float_div_scalar(num, 2.0.elem())
1104 }
1105
1106 /// Returns a new tensor with hyperbolic sine values.
1107 ///
1108 /// # Arguments
1109 ///
1110 /// * `tensor` - The tensor to take the hyperbolic sine of.
1111 ///
1112 /// # Returns
1113 ///
1114 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1115 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1116 // sinh = ( e^x - e^(-x) ) / 2
1117 let e_x = B::float_exp(tensor.clone());
1118 let e_neg_x = B::float_exp(B::float_neg(tensor));
1119 let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x)
1120 B::float_div_scalar(num, 2.0.elem())
1121 }
1122
1123 /// Returns a new tensor with hyperbolic tangent values.
1124 ///
1125 /// # Arguments
1126 ///
1127 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1128 ///
1129 /// # Returns
1130 ///
1131 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1132 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1133 let sinh = B::float_sinh(tensor.clone());
1134 let cosh = B::float_cosh(tensor);
1135 B::float_div(sinh, cosh)
1136 }
1137
1138 /// Returns a new tensor with rounded values.
1139 ///
1140 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1141 /// strategy, with halfway cases rounded to the nearest even integer value.
1142 ///
1143 /// # Arguments
1144 ///
1145 /// * `tensor` - The tensor to be rounded.
1146 ///
1147 /// # Returns
1148 ///
1149 /// A tensor with the same shape as `tensor` with rounded values.
1150 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1151
1152 /// Returns a new tensor with floored values.
1153 ///
1154 /// # Arguments
1155 ///
1156 /// * `tensor` - The tensor to be floored.
1157 ///
1158 /// # Returns
1159 ///
1160 /// A tensor with the same shape as `tensor` with floored values.
1161 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1162
1163 /// Returns a new tensor with ceiled values.
1164 ///
1165 /// # Arguments
1166 ///
1167 /// * `tensor` - The tensor to be ceiled.
1168 ///
1169 /// # Returns
1170 ///
1171 /// A tensor with the same shape as `tensor` with ceiled values.
1172 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1173
1174 /// Returns a new tensor with truncated values.
1175 ///
1176 /// # Arguments
1177 ///
1178 /// * `tensor` - The tensor to be truncated.
1179 ///
1180 /// # Returns
1181 ///
1182 /// A tensor with the same shape as `tensor` with truncated values.
1183 fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1184
1185 /// Returns a new tensor with the error function values.
1186 ///
1187 /// # Arguments
1188 ///
1189 /// * `tensor` - The tensor to take the error function of.
1190 ///
1191 /// # Returns
1192 ///
1193 /// A tensor with the same shape as `tensor` with error function values.
1194 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1195
1196 /// Concatenates tensors along a dimension.
1197 ///
1198 /// # Arguments
1199 ///
1200 /// * `tensors` - The tensors to concatenate.
1201 /// * `dim` - The dimension along which to concatenate.
1202 ///
1203 /// # Returns
1204 ///
1205 /// A tensor with the concatenated tensors along `dim`.
1206 ///
1207 /// # Note
1208 ///
1209 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1210 /// high-level tensor API and will not be passed to this method. Backend implementations do
1211 /// not need to handle empty tensors.
1212 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1213 cat_with_slice_assign::<B, Float>(
1214 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1215 dim,
1216 )
1217 .tensor()
1218 }
1219
1220 /// Gets the indices of the maximum elements of a tensor along an axis.
1221 ///
1222 /// # Arguments
1223 ///
1224 /// * `tensor` - The tensor to get the maximum elements of.
1225 /// * `dim` - The dimension along which to get the maximum elements.
1226 ///
1227 /// # Returns
1228 ///
1229 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1230 fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1231
1232 /// Gets the indices of the minimum elements of a tensor along an axis.
1233 ///
1234 /// # Arguments
1235 ///
1236 /// * `tensor` - The tensor to get the minimum elements of.
1237 /// * `dim` - The dimension along which to get the minimum elements.
1238 ///
1239 /// # Returns
1240 ///
1241 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1242 fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1243
1244 /// Gets the maximum element of a tensor.
1245 ///
1246 /// # Arguments
1247 ///
1248 /// * `tensor` - The tensor to get the maximum elements of.
1249 ///
1250 /// # Returns
1251 ///
1252 /// A tensor with the maximum element of `tensor`.
1253 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1254 let shape = tensor.shape();
1255 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1256
1257 B::float_max_dim(tensor, 0)
1258 }
1259
1260 /// Gets the maximum elements of a tensor along an axis.
1261 ///
1262 /// # Arguments
1263 ///
1264 /// * `tensor` - The tensor to get the maximum elements of.
1265 /// * `dim` - The dimension along which to get the maximum elements.
1266 ///
1267 /// # Returns
1268 ///
1269 /// A tensor with the maximum elements of `tensor` along `dim`.
1270 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1271 let index = B::float_argmax(tensor.clone(), dim);
1272
1273 B::float_gather(dim, tensor, index)
1274 }
1275
1276 /// Gets the maximum elements of a tensor along an axis and their indices.
1277 ///
1278 /// # Arguments
1279 ///
1280 /// * `tensor` - The tensor to get the maximum elements of.
1281 /// * `dim` - The dimension along which to get the maximum elements.
1282 ///
1283 /// # Returns
1284 ///
1285 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1286 fn float_max_dim_with_indices(
1287 tensor: FloatTensor<B>,
1288 dim: usize,
1289 ) -> (FloatTensor<B>, IntTensor<B>) {
1290 let index = B::float_argmax(tensor.clone(), dim);
1291 let values = B::float_gather(dim, tensor, index.clone());
1292
1293 (values, index)
1294 }
1295
1296 /// Gets the minimum element of a tensor.
1297 ///
1298 /// # Arguments
1299 ///
1300 /// * `tensor` - The tensor to get the minimum elements of.
1301 ///
1302 /// # Returns
1303 ///
1304 /// A tensor with the minimum element of `tensor`.
1305 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1306 let shape = tensor.shape();
1307 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1308
1309 B::float_min_dim(tensor, 0)
1310 }
1311
1312 /// Gets the minimum elements of a tensor along an axis.
1313 ///
1314 /// # Arguments
1315 ///
1316 /// * `tensor` - The tensor to get the minimum elements of.
1317 /// * `dim` - The dimension along which to get the minimum elements.
1318 ///
1319 /// # Returns
1320 ///
1321 /// A tensor with the minimum elements of `tensor` along `dim`.
1322 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1323 let index = B::float_argmin(tensor.clone(), dim);
1324
1325 B::float_gather(dim, tensor, index)
1326 }
1327
1328 /// Gets the minimum elements of a tensor along an axis and their indices.
1329 ///
1330 /// # Arguments
1331 ///
1332 /// * `tensor` - The tensor to get the minimum elements of.
1333 /// * `dim` - The dimension along which to get the minimum elements.
1334 ///
1335 /// # Returns
1336 ///
1337 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1338 fn float_min_dim_with_indices(
1339 tensor: FloatTensor<B>,
1340 dim: usize,
1341 ) -> (FloatTensor<B>, IntTensor<B>) {
1342 let index = B::float_argmin(tensor.clone(), dim);
1343 let values = B::float_gather(dim, tensor, index.clone());
1344
1345 (values, index)
1346 }
1347
1348 /// Gets the maximum absolute element of a tensor.
1349 ///
1350 /// # Arguments
1351 ///
1352 /// * `tensor` - The tensor to get the maximum elements of.
1353 ///
1354 /// # Returns
1355 ///
1356 /// A tensor with the maximum element of `tensor`.
1357 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1358 let shape = tensor.shape();
1359 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1360
1361 B::float_max_abs_dim(tensor, 0)
1362 }
1363
1364 /// Gets the maximum absolute elements of a tensor along an axis.
1365 ///
1366 /// # Arguments
1367 ///
1368 /// * `tensor` - The tensor to get the maximum elements of.
1369 /// * `dim` - The dimension along which to get the maximum elements.
1370 ///
1371 /// # Returns
1372 ///
1373 /// A tensor with the maximum elements of `tensor` along `dim`.
1374 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1375 B::float_max_dim(B::float_abs(tensor), dim)
1376 }
1377
1378 /// Tests if any element in the float `tensor` evaluates to True.
1379 ///
1380 /// # Arguments
1381 ///
1382 /// * `tensor` - The tensor to test.
1383 ///
1384 /// # Returns
1385 ///
1386 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1387 fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1388 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1389 let bool_tensor = B::bool_not(bool_tensor);
1390 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1391 B::float_greater_elem(sum, 0.0f32.elem())
1392 }
1393
1394 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1395 ///
1396 /// # Arguments
1397 ///
1398 /// * `tensor` - The tensor to test.
1399 /// * `dim` - The axis along which to test.
1400 ///
1401 /// # Returns
1402 ///
1403 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1404 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1405 /// input evaluates to True, False otherwise.
1406 fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1407 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1408 let bool_tensor = B::bool_not(bool_tensor);
1409 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1410 B::float_greater_elem(sum, 0.0f32.elem())
1411 }
1412
1413 /// Tests if all elements in the float `tensor` evaluate to True.
1414 ///
1415 /// # Arguments
1416 ///
1417 /// * `tensor` - The tensor to test.
1418 ///
1419 /// # Returns
1420 ///
1421 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1422 /// evaluate to True, False otherwise.
1423 fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1424 let num_elems = tensor.shape().num_elements();
1425 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1426 let bool_tensor = B::bool_not(bool_tensor);
1427 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1428 B::float_equal_elem(sum, (num_elems as f32).elem())
1429 }
1430
1431 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1432 ///
1433 /// # Arguments
1434 ///
1435 /// * `tensor` - The tensor to test.
1436 /// * `dim` - The axis along which to test.
1437 ///
1438 /// # Returns
1439 ///
1440 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1441 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1442 /// evaluates to True, False otherwise.
1443 fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1444 let num_elems = tensor.shape().dims[dim];
1445 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1446 let bool_tensor = B::bool_not(bool_tensor);
1447 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1448 B::float_equal_elem(sum, (num_elems as f32).elem())
1449 }
1450
1451 /// Returns the signs of the float `tensor`.
1452 ///
1453 /// # Arguments
1454 ///
1455 /// * `tensor` - The tensor to extract the signs from.
1456 ///
1457 /// # Returns
1458 ///
1459 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1460 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1461 let zeros = B::float_zeros(
1462 tensor.shape(),
1463 &B::float_device(&tensor),
1464 tensor.dtype().into(),
1465 );
1466 let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1467 let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1468
1469 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1470 result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1471 result
1472 }
1473
1474 /// Broadcasts the float `tensor` to the given `shape`.
1475 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1476
1477 /// Sort the elements of the input `tensor` by value in along a given dimension.
1478 ///
1479 /// This sort is unstable (i.e., may reorder equal elements).
1480 ///
1481 /// # Arguments
1482 ///
1483 /// * `tensor` - The input tensor.
1484 /// * `dim` - The axis along which to sort.
1485 /// * `descending` - The sorting order.
1486 ///
1487 /// # Returns
1488 ///
1489 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1490 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1491 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1492 }
1493
1494 /// Sort the elements of the input `tensor` by value in along a given dimension.
1495 ///
1496 /// This sort is unstable (i.e., may reorder equal elements).
1497 ///
1498 /// # Arguments
1499 ///
1500 /// * `tensor` - The input tensor.
1501 /// * `dim` - The axis along which to sort.
1502 /// * `descending` - The sorting order.
1503 ///
1504 /// # Returns
1505 ///
1506 /// A tensor with the same shape as the input tensor and corresponding indices, where
1507 /// the elements are sorted by value and the indices map back to the original input tensor.
1508 fn float_sort_with_indices(
1509 tensor: FloatTensor<B>,
1510 dim: usize,
1511 descending: bool,
1512 ) -> (FloatTensor<B>, IntTensor<B>) {
1513 let (values, indices) =
1514 sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1515 (values.tensor(), indices)
1516 }
1517
1518 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1519 ///
1520 /// This sort is unstable (i.e., may reorder equal elements).
1521 ///
1522 /// # Arguments
1523 ///
1524 /// * `tensor` - The input tensor.
1525 /// * `dim` - The axis along which to sort.
1526 /// * `descending` - The sorting order.
1527 ///
1528 /// # Returns
1529 ///
1530 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1531 fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1532 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1533 }
1534
1535 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1536 /// using the given locations in [-1, 1].
1537 ///
1538 /// # Arguments
1539 ///
1540 /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
1541 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1542 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1543 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1544 ///
1545 /// # Returns
1546 ///
1547 /// A tensor with shape (N, C, H_out, W_out)
1548 fn float_grid_sample_2d(
1549 tensor: FloatTensor<B>,
1550 grid: FloatTensor<B>,
1551 options: GridSampleOptions,
1552 ) -> FloatTensor<B> {
1553 float_grid_sample_2d_ref::<B>(tensor, grid, options)
1554 }
1555
1556 /// Unfold windows along a dimension.
1557 ///
1558 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1559 /// where windows are advanced by `step` at each index.
1560 ///
1561 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1562 ///
1563 /// # Arguments
1564 ///
1565 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1566 /// * `dim` - the selected dim.
1567 /// * `size` - the size of each unfolded window.
1568 /// * `step` - the step between each window.
1569 ///
1570 /// # Returns
1571 ///
1572 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1573 fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1574 -> FloatTensor<B>;
1575
1576 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1577 ///
1578 /// # Returns
1579 ///
1580 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1581 fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> {
1582 // Check if the input tensor is NaN by comparing it to itself
1583 // NaN is the only value that is not equal to itself
1584 B::float_not_equal(tensor.clone(), tensor)
1585 }
1586
1587 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1588 ///
1589 /// # Returns
1590 ///
1591 /// A boolean tensor where `true` indicates that the value is infinite
1592 fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> {
1593 B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.elem())
1594 }
1595}