burn_tensor/tensor/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_bilinear;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor};
5use crate::ops::InterpolateMode;
6use crate::{Distribution, ElementConversion, Float, TensorData, backend::Backend, tensor::Shape};
7use crate::{FloatDType, TensorMetadata, TensorPrimitive};
8use crate::{argsort, sort, sort_with_indices};
9use alloc::vec::Vec;
10
11/// Operations on float tensors.
12pub trait FloatTensorOps<B: Backend> {
13 /// Creates a new tensor from the data structure.
14 ///
15 /// # Arguments
16 ///
17 /// * `data` - The data structure.
18 /// * `device` - The device to create the tensor on.
19 ///
20 /// # Returns
21 ///
22 /// The tensor with the given data.
23 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
24
25 /// Creates a new tensor with random values.
26 ///
27 /// # Arguments
28 ///
29 /// * `shape` - The shape of the tensor.
30 /// * `distribution` - The distribution to sample from.
31 /// * `device` - The device to create the tensor on.
32 ///
33 /// # Returns
34 ///
35 /// The tensor with the given shape and random values.
36 fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
37 -> FloatTensor<B>;
38
39 /// Creates a new tensor with zeros.
40 ///
41 /// # Arguments
42 ///
43 /// * `shape` - The shape of the tensor.
44 /// * `device` - The device to create the tensor on.
45 /// * `dtype` - The target data type.
46 ///
47 /// # Returns
48 ///
49 /// The tensor with the given shape and zeros.
50 fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
51 Self::float_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)
52 }
53
54 /// Creates a new tensor with ones.
55 ///
56 /// # Arguments
57 ///
58 /// * `shape` - The shape of the tensor.
59 /// * `device` - The device to create the tensor on.
60 /// * `dtype` - The target data type.
61 ///
62 /// # Returns
63 ///
64 /// The tensor with the given shape and ones.
65 fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
66 Self::float_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)
67 }
68
69 /// Creates a tensor filled with given value.
70 ///
71 /// # Arguments
72 ///
73 /// * `shape` - The shape of the tensor.
74 /// * `fill_value` - The value with which to fill the tensor.
75 /// * `device` - The device to create the tensor on.
76 /// * `dtype` - The target data type.
77 ///
78 /// # Returns
79 ///
80 /// The tensor filled with given value
81 fn float_full(
82 shape: Shape,
83 fill_value: FloatElem<B>,
84 device: &Device<B>,
85 dtype: FloatDType,
86 ) -> FloatTensor<B> {
87 Self::float_from_data(
88 TensorData::full_dtype(shape, fill_value, dtype.into()),
89 device,
90 )
91 }
92
93 /// Converts the tensor to a data structure.
94 ///
95 /// # Arguments
96 ///
97 /// * `tensor` - The tensor.
98 ///
99 /// # Returns
100 ///
101 /// The data structure with the tensor's data.
102 fn float_into_data(tensor: FloatTensor<B>) -> impl Future<Output = TensorData> + Send;
103
104 /// Gets the device of the tensor.
105 ///
106 /// # Arguments
107 ///
108 /// * `tensor` - The tensor.
109 ///
110 /// # Returns
111 ///
112 /// The device of the tensor.
113 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
114
115 /// Moves the tensor to the given device.
116 ///
117 /// # Arguments
118 ///
119 /// * `tensor` - The tensor.
120 /// * `device` - The device to move the tensor to.
121 ///
122 /// # Returns
123 ///
124 /// The tensor on the given device.
125 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
126
127 /// Converts float tensor to int tensor.
128 ///
129 /// # Arguments
130 ///
131 /// * `tensor` - The tensor.
132 ///
133 /// # Returns
134 ///
135 /// The int tensor with the same data as the float tensor.
136 fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
137
138 /// Creates an empty tensor with the given shape.
139 ///
140 /// # Arguments
141 ///
142 /// * `shape` - The shape of the tensor.
143 /// * `device` - The device to create the tensor on.
144 /// * `dtype` - The target data type.
145 ///
146 /// # Returns
147 ///
148 /// The empty tensor with the given shape.
149 fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
150
151 /// Repeat the tensor along the given dimension.
152 ///
153 /// # Arguments
154 ///
155 /// * `tensor` - The tensor.
156 /// * `dim` - The dimension to repeat.
157 /// * `times` - The number of times to repeat the dimension.
158 ///
159 /// # Returns
160 ///
161 /// The tensor with the given dimension repeated.
162 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
163 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
164 }
165
166 /// Adds two tensors together.
167 ///
168 /// # Arguments
169 ///
170 /// * `lhs` - The left-hand side tensor.
171 /// * `rhs` - The right-hand side tensor.
172 ///
173 /// # Returns
174 ///
175 /// The result of adding the two tensors together.
176 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
177
178 /// Adds a scalar to a tensor.
179 ///
180 /// # Arguments
181 ///
182 /// * `lhs` - The left-hand side tensor.
183 /// * `rhs` - The right-hand side scalar.
184 ///
185 /// # Returns
186 ///
187 /// The result of adding the scalar to the tensor.
188 fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
189
190 /// Clamps a tensor under a minimum value.
191 ///
192 /// # Arguments
193 ///
194 /// * `tensor` - The tensor to clamp.
195 /// * `min` - The minimum value.
196 ///
197 /// # Returns
198 ///
199 /// The clamped tensor.
200 fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
201 // Default implementation
202 let mask = Self::float_lower_elem(tensor.clone(), min);
203 B::float_mask_fill(tensor, mask, min)
204 }
205
206 /// Clamps a tensor over a maximum value.
207 ///
208 /// # Arguments
209 ///
210 /// * `tensor` - The tensor to clamp.
211 /// * `max` - The maximum value.
212 ///
213 /// # Returns
214 ///
215 /// The clamped tensor.
216 fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
217 // Default implementation
218 let mask = Self::float_greater_elem(tensor.clone(), max);
219 B::float_mask_fill(tensor, mask, max)
220 }
221
222 /// Clamps a tensor between a minimum and maximum value.
223 ///
224 /// # Arguments
225 ///
226 /// * `tensor` - The tensor to clamp.
227 /// * `min` - The minimum value.
228 /// * `max` - The maximum value.
229 ///
230 /// # Returns
231 ///
232 /// The clamped tensor.
233 fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
234 // Default implementation
235 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
236 }
237
238 /// Subtracts two tensors.
239 ///
240 /// # Arguments
241 ///
242 /// * `lhs` - The left-hand side tensor.
243 /// * `rhs` - The right-hand side tensor.
244 ///
245 /// # Returns
246 ///
247 /// The result of subtracting the two tensors.
248 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
249
250 /// Subtracts a scalar from a tensor.
251 ///
252 /// # Arguments
253 ///
254 /// * `lhs` - The left-hand side tensor.
255 /// * `rhs` - The right-hand side scalar.
256 ///
257 /// # Returns
258 ///
259 /// The result of subtracting the scalar from the tensor.
260 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
261
262 /// Multiplies two tensors together element-wise.
263 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
264
265 /// Multiplies a tensor by a scalar.
266 ///
267 /// # Arguments
268 ///
269 /// * `lhs` - The left-hand side tensor.
270 /// * `rhs` - The right-hand side scalar.
271 ///
272 /// # Returns
273 ///
274 /// The result of multiplying the tensor by the scalar.
275 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
276
277 /// Divides two tensors element-wise.
278 ///
279 /// # Arguments
280 ///
281 /// * `lhs` - The left-hand side tensor.
282 /// * `rhs` - The right-hand side tensor.
283 ///
284 /// # Returns
285 ///
286 /// The result of dividing the two tensors.
287 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
288
289 /// Divides a tensor by a scalar.
290 ///
291 /// # Arguments
292 ///
293 /// * `lhs` - The left-hand side tensor.
294 /// * `rhs` - The right-hand side scalar.
295 ///
296 /// # Returns
297 ///
298 /// The result of dividing the tensor by the scalar.
299 fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
300
301 /// Computes the remainder of division between two tensors element-wise.
302 ///
303 /// # Arguments
304 ///
305 /// * `lhs` - The left-hand side tensor.
306 /// * `rhs` - The right-hand side tensor.
307 ///
308 /// # Returns
309 ///
310 /// The element-wise remainder when dividing `lhs` by `rhs`.
311 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
312
313 /// Computes the modulus of a tensor given a scalar.
314 ///
315 /// # Arguments
316 /// * `lhs` - The left-hand side tensor.
317 /// * `rhs` - The right-hand side scalar.
318 ///
319 /// # Returns
320 ///
321 /// The result of applying the modulus of the scalar to the tensor.
322 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
323
324 /// Multiplies two tensors together using matrix multiplication.
325 ///
326 /// # Arguments
327 ///
328 /// * `lhs` - The left-hand side tensor.
329 /// * `rhs` - The right-hand side tensor.
330 ///
331 /// # Returns
332 ///
333 /// The result of multiplying the two tensors together using matrix multiplication.
334 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
335
336 /// Computes the cross product of two tensors along a given dimension.
337 ///
338 /// # Arguments
339 ///
340 /// * `lhs` - The left-hand side tensor.
341 /// * `rhs` - The right-hand side tensor.
342 /// * `dim` - The dimension to compute the cross product along.
343 ///
344 /// # Returns
345 ///
346 /// The cross product of the two tensors.
347 fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
348
349 /// Negates a tensor element-wise.
350 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
351 Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
352 }
353
354 /// Calculates the reciprocals element-wise
355 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
356
357 /// Transposes a tensor.
358 ///
359 /// # Arguments
360 ///
361 /// * `tensor` - The tensor to transpose.
362 ///
363 /// # Returns
364 ///
365 /// The transposed tensor.
366 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
367 let ndims = tensor.shape().num_dims();
368 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
369 }
370
371 /// Swaps two dimensions of a tensor.
372 ///
373 /// # Arguments
374 ///
375 /// * `tensor` - The tensor to swap the dimensions of.
376 /// * `dim1` - The first dimension to swap.
377 /// * `dim2` - The second dimension to swap.
378 ///
379 /// # Returns
380 ///
381 /// The tensor with the dimensions swapped.
382 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
383
384 /// Permutes the dimensions of a tensor.
385 ///
386 /// # Arguments
387 ///
388 /// * `tensor` - The tensor to permute the dimensions of.
389 /// * `axes` - The new order of the dimensions.
390 /// # Returns
391 ///
392 /// The tensor with the dimensions permuted.
393 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
394
395 /// Reverse the order of elements in a tensor along the given axes.
396 ///
397 /// # Arguments
398 ///
399 /// * `tensor` - The tensor to reverse.
400 /// * `axes` - The axes to reverse.
401 ///
402 /// The tensor with the elements reversed.
403 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
404
405 /// Reshapes a tensor.
406 ///
407 /// # Arguments
408 ///
409 /// * `tensor` - The tensor to reshape.
410 /// * `shape` - The new shape of the tensor.
411 ///
412 /// # Returns
413 ///
414 /// The tensor with the new shape.
415 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
416
417 /// Gather elements from a tensor.
418 ///
419 /// # Arguments
420 ///
421 /// * `dim` - The dimension to gather from.
422 /// * `tensor` - The tensor to gather from.
423 /// * `indices` - The indices to gather.
424 ///
425 /// # Returns
426 ///
427 /// The gathered elements.
428 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
429
430 /// Scatter elements into a tensor.
431 ///
432 /// # Arguments
433 ///
434 /// * `dim` - The dimension to scatter into.
435 /// * `tensor` - The tensor to scatter into.
436 /// * `indices` - The indices to scatter into.
437 /// * `value` - The value to scatter.
438 ///
439 /// # Returns
440 ///
441 /// The tensor with the scattered elements.
442 fn float_scatter(
443 dim: usize,
444 tensor: FloatTensor<B>,
445 indices: IntTensor<B>,
446 value: FloatTensor<B>,
447 ) -> FloatTensor<B>;
448
449 /// Select tensor elements along the given dimension corresponding for the given indices.
450 ///
451 /// # Arguments
452 ///
453 /// * `tensor` - The tensor to select from.
454 /// * `dim` - The dimension to select from.
455 /// * `indices` - The indices to select.
456 ///
457 /// # Returns
458 ///
459 /// The selected elements.
460 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
461
462 /// Assign the selected elements along the given dimension corresponding for the given indices
463 /// to the given value.
464 ///
465 /// # Arguments
466 ///
467 /// * `tensor` - The tensor to select from.
468 /// * `dim` - The dimension to select from.
469 /// * `indices` - The indices to select.
470 /// * `value` - The value to assign.
471 ///
472 /// # Returns
473 ///
474 /// The tensor with the selected elements assigned to the given value.
475 fn float_select_assign(
476 tensor: FloatTensor<B>,
477 dim: usize,
478 indices: IntTensor<B>,
479 value: FloatTensor<B>,
480 ) -> FloatTensor<B>;
481
482 /// Select tensor elements corresponding to the given slices.
483 ///
484 /// # Arguments
485 ///
486 /// * `tensor` - The tensor to select from.
487 /// * `slices` - The slices specifying ranges and steps for each dimension.
488 ///
489 /// # Returns
490 ///
491 /// The selected elements in a new tensor.
492 fn float_slice(tensor: FloatTensor<B>, slices: &[crate::Slice]) -> FloatTensor<B>;
493
494 /// Assign the selected elements corresponding to the given slices to the given value.
495 ///
496 /// # Arguments
497 ///
498 /// * `tensor` - The tensor to select from.
499 /// * `ranges` - The ranges to select.
500 /// * `value` - The value to assign.
501 ///
502 /// # Returns
503 ///
504 /// The tensor with the selected elements assigned to the given value.
505 fn float_slice_assign(
506 tensor: FloatTensor<B>,
507 slices: &[crate::Slice],
508 value: FloatTensor<B>,
509 ) -> FloatTensor<B>;
510
511 /// Update the given tensor with the value tensor where the mask is true.
512 ///
513 /// # Arguments
514 ///
515 /// * `tensor` - The tensor to select from.
516 /// * `mask` - The boolean mask to select with.
517 /// * `value` - The value to assign to the selected elements from the value tensor.
518 ///
519 /// # Returns
520 ///
521 /// The tensor with the selected elements assigned to the given value.
522 fn float_mask_where(
523 tensor: FloatTensor<B>,
524 mask: BoolTensor<B>,
525 value: FloatTensor<B>,
526 ) -> FloatTensor<B>;
527
528 /// Update the given tensor with the value where the mask is true.
529 ///
530 /// # Arguments
531 ///
532 /// * `tensor` - The tensor to select from.
533 /// * `mask` - The boolean mask to select with.
534 /// * `value` - The value to assign to the selected elements.
535 ///
536 /// # Returns
537 ///
538 /// The tensor with the selected elements assigned to the given value.
539 fn float_mask_fill(
540 tensor: FloatTensor<B>,
541 mask: BoolTensor<B>,
542 value: FloatElem<B>,
543 ) -> FloatTensor<B>;
544
545 /// Equal comparison of two tensors.
546 ///
547 /// # Arguments
548 ///
549 /// * `lhs` - The left-hand side tensor.
550 /// * `rhs` - The right-hand side tensor.
551 ///
552 /// # Returns
553 ///
554 /// A boolean tensor with the result of the comparison.
555 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
556
557 /// Element-wise non-equality comparison.
558 ///
559 /// # Arguments
560 ///
561 /// * `lhs` - The left-hand side tensor.
562 /// * `rhs` - The right-hand side tensor.
563 ///
564 /// # Returns
565 ///
566 /// A boolean tensor with the result of the comparison.
567 fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
568 let equal_tensor = B::float_equal(lhs, rhs);
569 B::bool_not(equal_tensor)
570 }
571
572 /// Equal comparison of a tensor and a scalar.
573 ///
574 /// # Arguments
575 ///
576 /// * `lhs` - The left-hand side tensor.
577 /// * `rhs` - The right-hand side scalar.
578 ///
579 /// # Returns
580 ///
581 /// A boolean tensor with the result of the comparison.
582 fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
583
584 /// Element-wise non-equality comparison with a scalar.
585 ///
586 /// # Arguments
587 ///
588 /// * `lhs` - The left-hand side tensor.
589 /// * `rhs` - The right-hand side scalar.
590 ///
591 /// # Returns
592 ///
593 /// A boolean tensor with the result of the comparison.
594 fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
595 let equal_tensor = B::float_equal_elem(lhs, rhs);
596 B::bool_not(equal_tensor)
597 }
598
599 /// Greater than comparison of two tensors.
600 ///
601 /// # Arguments
602 ///
603 /// * `lhs` - The left-hand side tensor.
604 /// * `rhs` - The right-hand side tensor.
605 ///
606 /// # Returns
607 ///
608 /// A boolean tensor with the result of the comparison.
609 fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
610
611 /// Greater than comparison of a tensor and a scalar.
612 ///
613 /// # Arguments
614 ///
615 /// * `lhs` - The left-hand side tensor.
616 /// * `rhs` - The right-hand side scalar.
617 ///
618 /// # Returns
619 ///
620 /// A boolean tensor with the result of the comparison.
621 fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
622
623 /// Greater than or equal comparison of two tensors.
624 ///
625 /// # Arguments
626 ///
627 /// * `lhs` - The left-hand side tensor.
628 /// * `rhs` - The right-hand side tensor.
629 ///
630 /// # Returns
631 ///
632 /// A boolean tensor with the result of the comparison.
633 fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
634
635 /// Greater than or equal comparison of a tensor and a scalar.
636 ///
637 /// # Arguments
638 ///
639 /// * `lhs` - The left-hand side tensor.
640 /// * `rhs` - The right-hand side scalar.
641 ///
642 /// # Returns
643 ///
644 /// A boolean tensor with the result of the comparison.
645 fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
646
647 /// Less than comparison of two tensors.
648 ///
649 /// # Arguments
650 ///
651 /// * `lhs` - The left-hand side tensor.
652 /// * `rhs` - The right-hand side tensor.
653 ///
654 /// # Returns
655 ///
656 /// A boolean tensor with the result of the comparison.
657 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
658
659 /// Less than comparison of a tensor and a scalar.
660 ///
661 /// # Arguments
662 ///
663 /// * `lhs` - The left-hand side tensor.
664 /// * `rhs` - The right-hand side scalar.
665 ///
666 /// # Returns
667 ///
668 /// A boolean tensor with the result of the comparison.
669 fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
670
671 /// Less than or equal comparison of two tensors.
672 ///
673 /// # Arguments
674 ///
675 /// * `lhs` - The left-hand side tensor.
676 /// * `rhs` - The right-hand side tensor.
677 ///
678 /// # Returns
679 ///
680 /// A boolean tensor with the result of the comparison.
681 fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
682
683 /// Less than or equal comparison of a tensor and a scalar.
684 ///
685 /// # Arguments
686 ///
687 /// * `lhs` - The left-hand side tensor.
688 /// * `rhs` - The right-hand side scalar.
689 ///
690 /// # Returns
691 ///
692 /// A boolean tensor with the result of the comparison.
693 fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
694
695 /// Detaches a tensor from the computation graph.
696 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
697 // Should only be overridden by autodiff backends.
698 tensor
699 }
700
701 /// Sets the `require_grad` flag of a tensor.
702 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
703 // Should only be overridden by autodiff backends.
704 tensor
705 }
706
707 /// Returns the `require_grad` flag of a tensor.
708 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
709 // Should only be overridden by autodiff backends.
710 false
711 }
712
713 /// Sum of all elements in a tensor.
714 ///
715 /// # Arguments
716 ///
717 /// * `tensor` - The tensor to sum.
718 ///
719 /// # Returns
720 ///
721 /// A scalar tensor with the sum of all elements in `tensor`.
722 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
723
724 /// Sum of all elements in a tensor along a dimension.
725 ///
726 /// # Arguments
727 ///
728 /// * `tensor` - The tensor to sum.
729 /// * `dim` - The dimension along which to sum.
730 ///
731 /// # Returns
732 ///
733 /// A tensor with the sum of all elements in `tensor` along `dim`.
734 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
735
736 /// Product of all elements in a tensor.
737 ///
738 /// # Arguments
739 ///
740 /// * `tensor` - The tensor to product.
741 ///
742 /// # Returns
743 ///
744 /// A scalar tensor with the product of all elements in `tensor`.
745 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
746 // Product of all elements in a tensor
747 B::float_exp(B::float_sum(B::float_log(tensor)))
748 }
749
750 /// Product of all elements in a tensor along a dimension.
751 ///
752 /// # Arguments
753 ///
754 /// * `tensor` - The tensor to product.
755 ///
756 /// # Returns
757 ///
758 /// A tensor with the product of all elements in `tensor` along `dim`.
759 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
760 // Product of all elements in a tensor along a dimension
761 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
762 }
763
764 /// Mean of all elements in a tensor.
765 ///
766 /// # Arguments
767 ///
768 /// * `tensor` - The tensor to mean.
769 ///
770 /// # Returns
771 ///
772 /// A scalar tensor with the mean of all elements in `tensor`.
773 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
774 let num_elems = tensor.shape().num_elements();
775 B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
776 }
777
778 /// Mean of all elements in a tensor along a dimension.
779 ///
780 /// # Arguments
781 ///
782 /// * `tensor` - The tensor to mean.
783 /// * `dim` - The dimension along which to mean.
784 ///
785 /// # Returns
786 ///
787 /// A tensor with the mean of all elements in `tensor` along `dim`.
788 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
789
790 /// Computes the cumulative sum of elements along a dimension.
791 ///
792 /// # Arguments
793 ///
794 /// * `tensor` - The tensor to compute the cumulative sum of.
795 /// * `dim` - The dimension along which to compute the cumulative sum.
796 ///
797 /// # Returns
798 ///
799 /// A tensor with the same shape where each element is the cumulative sum
800 /// of all elements up to and including that position along the dimension.
801 fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
802
803 /// Computes the cumulative product of elements along a dimension.
804 ///
805 /// # Arguments
806 ///
807 /// * `tensor` - The tensor to compute the cumulative product of.
808 /// * `dim` - The dimension along which to compute the cumulative product.
809 ///
810 /// # Returns
811 ///
812 /// A tensor with the same shape where each element is the cumulative product
813 /// of all elements up to and including that position along the dimension.
814 fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
815
816 /// Computes the cumulative minimum of elements along a dimension.
817 ///
818 /// # Arguments
819 ///
820 /// * `tensor` - The tensor to compute the cumulative minimum of.
821 /// * `dim` - The dimension along which to compute the cumulative minimum.
822 ///
823 /// # Returns
824 ///
825 /// A tensor with the same shape where each element is the minimum
826 /// of all elements up to and including that position along the dimension.
827 fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
828
829 /// Computes the cumulative maximum of elements along a dimension.
830 ///
831 /// # Arguments
832 ///
833 /// * `tensor` - The tensor to compute the cumulative maximum of.
834 /// * `dim` - The dimension along which to compute the cumulative maximum.
835 ///
836 /// # Returns
837 ///
838 /// A tensor with the same shape where each element is the maximum
839 /// of all elements up to and including that position along the dimension.
840 fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
841
842 /// Converts a tensor to another floating point data type.
843 ///
844 /// # Arguments
845 ///
846 /// * `tensor` - The tensor to convert.
847 /// * `dtype` - The target data type.
848 ///
849 /// # Returns
850 ///
851 /// A tensor with the same values as `tensor` but in the target floating point data type.
852 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
853
854 /// Returns a new tensor with exponential values.
855 ///
856 /// # Arguments
857 ///
858 /// * `tensor` - The tensor to exponentiate.
859 ///
860 /// # Returns
861 ///
862 /// A tensor with the same shape as `tensor` with exponential values.
863 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
864
865 /// Returns a new tensor with natural logarithm values.
866 ///
867 /// # Arguments
868 ///
869 /// * `tensor` - The tensor to take the logarithm of.
870 ///
871 /// # Returns
872 ///
873 /// A tensor with the same shape as `tensor` with natural logarithm values.
874 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
875
876 /// Returns a new tensor with logarithm values of (1 + Xi).
877 ///
878 /// # Arguments
879 ///
880 /// * `tensor` - The tensor to take the logarithm of.
881 ///
882 /// # Returns
883 ///
884 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
885 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
886
887 /// Element-wise power with a FloatTensor.
888 ///
889 /// # Arguments
890 ///
891 /// * `lhs` - The left-hand side tensor.
892 /// * `rhs` - The right-hand side tensor.
893 ///
894 /// # Returns
895 ///
896 /// The elements of `lhs` raised to the power of the elements of `rhs`.
897 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
898
899 /// Element-wise power with an IntTensor.
900 ///
901 /// # Arguments
902 ///
903 /// * `lhs` - The left-hand side tensor.
904 /// * `rhs` - The right-hand side floatTensor.
905 ///
906 /// # Returns
907 ///
908 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
909 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
910 Self::float_powf(lhs, B::int_into_float(rhs))
911 }
912
913 /// Raises a tensor to the power of an int scalar.
914 ///
915 /// # Backend Implementors Note
916 ///
917 /// A number of common exponent cases can be implemented with operations
918 /// which are much cheaper than generic exponentiation.
919 ///
920 /// This (`Backend` impl overridable) operation handles generic optimizations
921 /// for several common integer exponent cases; and then dispatches to
922 /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
923 /// operation to handle the generic case.
924 ///
925 /// # Arguments
926 ///
927 /// * `lhs` - The left-hand side tensor.
928 /// * `rhs` - The right-hand side scalar.
929 ///
930 /// # Returns
931 ///
932 /// The elements of `lhs` raised to the value of `rhs`.
933 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
934 let exp = rhs.elem::<i32>();
935 match exp {
936 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
937 1 => lhs,
938 2 => B::float_mul(lhs.clone(), lhs),
939 -1 => Self::float_recip(lhs),
940 -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
941 _ => Self::float_powi_scalar_impl(lhs, rhs),
942 }
943 }
944
945 /// Raises a tensor to the power of an int scalar.
946 ///
947 /// # Backend Implementors Note
948 ///
949 /// This is the generic implementation of integer exponentiation
950 /// called by [`Self::float_powi_scalar`] in the fallback case.
951 ///
952 /// As a general rule, this should not be called directly.
953 ///
954 /// # Arguments
955 ///
956 /// * `lhs` - The left-hand side tensor.
957 /// * `rhs` - The right-hand side scalar.
958 ///
959 /// # Returns
960 ///
961 /// The elements of `lhs` raised to the value of `rhs`.
962 fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
963 // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
964 Self::float_powf_scalar_impl(lhs, rhs.elem::<f32>())
965 }
966
967 /// Returns a new tensor with values raised to the power of float `value`.
968 ///
969 /// # Backend Implementors Note
970 ///
971 /// This (`Backend` impl overridable) operation dispatches integer exponentiation
972 /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
973 /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
974 /// operation to handle the generic case.
975 ///
976 /// # Arguments
977 ///
978 /// * `tensor` - The tensor to exponentiate.
979 /// * `value` - The exponent.
980 ///
981 /// # Returns
982 ///
983 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
984 fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B> {
985 if num_traits::Float::floor(value) == value {
986 // When the exponent is an integer, use the integer exponentiation implementation.
987 let exp = B::IntElem::from_elem(value as i32);
988 Self::float_powi_scalar(tensor, exp)
989 } else {
990 Self::float_powf_scalar_impl(tensor, value)
991 }
992 }
993
994 /// Returns a new tensor with values raised to the power of float `value`.
995 ///
996 /// # Backend Implementors Note
997 ///
998 /// This is the generic implementation of integer exponentiation
999 /// called by [`Self::float_powf_scalar`] in the fallback case.
1000 ///
1001 /// This is the minimal required support a `Backend` must implement
1002 /// for exponentiation.
1003 ///
1004 /// As a general rule, this should not be called directly.
1005 ///
1006 /// # Arguments
1007 ///
1008 /// * `tensor` - The tensor to exponentiate.
1009 /// * `value` - The exponent.
1010 ///
1011 /// # Returns
1012 ///
1013 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1014 fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
1015
1016 /// Returns a new tensor with square root values.
1017 ///
1018 /// # Arguments
1019 ///
1020 /// * `tensor` - The tensor to take the square root of.
1021 ///
1022 /// # Returns
1023 ///
1024 /// A tensor with the same shape as `tensor` with square root values.
1025 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1026
1027 /// Returns a new tensor with absolute values.
1028 ///
1029 /// # Arguments
1030 ///
1031 /// * `tensor` - The tensor to take absolute value of.
1032 ///
1033 /// # Returns
1034 ///
1035 /// A tensor with the same shape as `tensor` with absolute values.
1036 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1037
1038 /// Returns a new tensor with cosine values.
1039 ///
1040 /// # Arguments
1041 ///
1042 /// * `tensor` - The tensor to take the cosine of.
1043 ///
1044 /// # Returns
1045 ///
1046 /// A tensor with the same shape as `tensor` with cosine values.
1047 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1048
1049 /// Returns a new tensor with sine values.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `tensor` - The tensor to take the sine of.
1054 ///
1055 /// # Returns
1056 ///
1057 /// A tensor with the same shape as `tensor` with sine values.
1058 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1059
1060 /// Returns a new tensor with tangent values.
1061 ///
1062 /// # Arguments
1063 ///
1064 /// * `tensor` - The tensor to take the tangent of.
1065 ///
1066 /// # Returns
1067 ///
1068 /// A tensor with the same shape as `tensor` with tangent values.
1069 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> {
1070 let sin = B::float_sin(tensor.clone());
1071 let cos = B::float_cos(tensor);
1072 B::float_div(sin, cos)
1073 }
1074
1075 /// Returns a new tensor with hyperbolic cosine values.
1076 ///
1077 /// # Arguments
1078 ///
1079 /// * `tensor` - The tensor to take the hyperbolic cosine of.
1080 ///
1081 /// # Returns
1082 ///
1083 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1084 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1085 // cosh = ( e^x + e^(-x) ) / 2
1086 let e_x = B::float_exp(tensor.clone());
1087 let e_neg_x = B::float_exp(B::float_neg(tensor));
1088 let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x)
1089 B::float_div_scalar(num, 2.0.elem())
1090 }
1091
1092 /// Returns a new tensor with hyperbolic sine values.
1093 ///
1094 /// # Arguments
1095 ///
1096 /// * `tensor` - The tensor to take the hyperbolic sine of.
1097 ///
1098 /// # Returns
1099 ///
1100 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1101 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1102 // sinh = ( e^x - e^(-x) ) / 2
1103 let e_x = B::float_exp(tensor.clone());
1104 let e_neg_x = B::float_exp(B::float_neg(tensor));
1105 let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x)
1106 B::float_div_scalar(num, 2.0.elem())
1107 }
1108
1109 /// Returns a new tensor with hyperbolic tangent values.
1110 ///
1111 /// # Arguments
1112 ///
1113 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1114 ///
1115 /// # Returns
1116 ///
1117 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1118 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> {
1119 let sinh = B::float_sinh(tensor.clone());
1120 let cosh = B::float_cosh(tensor);
1121 B::float_div(sinh, cosh)
1122 }
1123
1124 /// Returns a new tensor with rounded values.
1125 ///
1126 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1127 /// strategy, with halfway cases rounded to the nearest even integer value.
1128 ///
1129 /// # Arguments
1130 ///
1131 /// * `tensor` - The tensor to be rounded.
1132 ///
1133 /// # Returns
1134 ///
1135 /// A tensor with the same shape as `tensor` with rounded values.
1136 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1137
1138 /// Returns a new tensor with floored values.
1139 ///
1140 /// # Arguments
1141 ///
1142 /// * `tensor` - The tensor to be floored.
1143 ///
1144 /// # Returns
1145 ///
1146 /// A tensor with the same shape as `tensor` with floored values.
1147 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1148
1149 /// Returns a new tensor with ceiled values.
1150 ///
1151 /// # Arguments
1152 ///
1153 /// * `tensor` - The tensor to be ceiled.
1154 ///
1155 /// # Returns
1156 ///
1157 /// A tensor with the same shape as `tensor` with ceiled values.
1158 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1159
1160 /// Returns a new tensor with truncated values.
1161 ///
1162 /// # Arguments
1163 ///
1164 /// * `tensor` - The tensor to be truncated.
1165 ///
1166 /// # Returns
1167 ///
1168 /// A tensor with the same shape as `tensor` with truncated values.
1169 fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1170
1171 /// Returns a new tensor with the error function values.
1172 ///
1173 /// # Arguments
1174 ///
1175 /// * `tensor` - The tensor to take the error function of.
1176 ///
1177 /// # Returns
1178 ///
1179 /// A tensor with the same shape as `tensor` with error function values.
1180 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1181
1182 /// Concatenates tensors along a dimension.
1183 ///
1184 /// # Arguments
1185 ///
1186 /// * `tensors` - The tensors to concatenate.
1187 /// * `dim` - The dimension along which to concatenate.
1188 ///
1189 /// # Returns
1190 ///
1191 /// A tensor with the concatenated tensors along `dim`.
1192 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1193 cat_with_slice_assign::<B, Float>(
1194 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1195 dim,
1196 )
1197 .tensor()
1198 }
1199
1200 /// Gets the indices of the maximum elements of a tensor along an axis.
1201 ///
1202 /// # Arguments
1203 ///
1204 /// * `tensor` - The tensor to get the maximum elements of.
1205 /// * `dim` - The dimension along which to get the maximum elements.
1206 ///
1207 /// # Returns
1208 ///
1209 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1210 fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1211
1212 /// Gets the indices of the minimum elements of a tensor along an axis.
1213 ///
1214 /// # Arguments
1215 ///
1216 /// * `tensor` - The tensor to get the minimum elements of.
1217 /// * `dim` - The dimension along which to get the minimum elements.
1218 ///
1219 /// # Returns
1220 ///
1221 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1222 fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1223
1224 /// Gets the maximum element of a tensor.
1225 ///
1226 /// # Arguments
1227 ///
1228 /// * `tensor` - The tensor to get the maximum elements of.
1229 ///
1230 /// # Returns
1231 ///
1232 /// A tensor with the maximum element of `tensor`.
1233 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1234 let shape = tensor.shape();
1235 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1236
1237 B::float_max_dim(tensor, 0)
1238 }
1239
1240 /// Gets the maximum elements of a tensor along an axis.
1241 ///
1242 /// # Arguments
1243 ///
1244 /// * `tensor` - The tensor to get the maximum elements of.
1245 /// * `dim` - The dimension along which to get the maximum elements.
1246 ///
1247 /// # Returns
1248 ///
1249 /// A tensor with the maximum elements of `tensor` along `dim`.
1250 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1251 let index = B::float_argmax(tensor.clone(), dim);
1252
1253 B::float_gather(dim, tensor, index)
1254 }
1255
1256 /// Gets the maximum elements of a tensor along an axis and their indices.
1257 ///
1258 /// # Arguments
1259 ///
1260 /// * `tensor` - The tensor to get the maximum elements of.
1261 /// * `dim` - The dimension along which to get the maximum elements.
1262 ///
1263 /// # Returns
1264 ///
1265 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1266 fn float_max_dim_with_indices(
1267 tensor: FloatTensor<B>,
1268 dim: usize,
1269 ) -> (FloatTensor<B>, IntTensor<B>) {
1270 let index = B::float_argmax(tensor.clone(), dim);
1271 let values = B::float_gather(dim, tensor, index.clone());
1272
1273 (values, index)
1274 }
1275
1276 /// Gets the minimum element of a tensor.
1277 ///
1278 /// # Arguments
1279 ///
1280 /// * `tensor` - The tensor to get the minimum elements of.
1281 ///
1282 /// # Returns
1283 ///
1284 /// A tensor with the minimum element of `tensor`.
1285 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1286 let shape = tensor.shape();
1287 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1288
1289 B::float_min_dim(tensor, 0)
1290 }
1291
1292 /// Gets the minimum elements of a tensor along an axis.
1293 ///
1294 /// # Arguments
1295 ///
1296 /// * `tensor` - The tensor to get the minimum elements of.
1297 /// * `dim` - The dimension along which to get the minimum elements.
1298 ///
1299 /// # Returns
1300 ///
1301 /// A tensor with the minimum elements of `tensor` along `dim`.
1302 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1303 let index = B::float_argmin(tensor.clone(), dim);
1304
1305 B::float_gather(dim, tensor, index)
1306 }
1307
1308 /// Gets the minimum elements of a tensor along an axis and their indices.
1309 ///
1310 /// # Arguments
1311 ///
1312 /// * `tensor` - The tensor to get the minimum elements of.
1313 /// * `dim` - The dimension along which to get the minimum elements.
1314 ///
1315 /// # Returns
1316 ///
1317 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1318 fn float_min_dim_with_indices(
1319 tensor: FloatTensor<B>,
1320 dim: usize,
1321 ) -> (FloatTensor<B>, IntTensor<B>) {
1322 let index = B::float_argmin(tensor.clone(), dim);
1323 let values = B::float_gather(dim, tensor, index.clone());
1324
1325 (values, index)
1326 }
1327
1328 /// Gets the maximum absolute element of a tensor.
1329 ///
1330 /// # Arguments
1331 ///
1332 /// * `tensor` - The tensor to get the maximum elements of.
1333 ///
1334 /// # Returns
1335 ///
1336 /// A tensor with the maximum element of `tensor`.
1337 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1338 let shape = tensor.shape();
1339 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1340
1341 B::float_max_abs_dim(tensor, 0)
1342 }
1343
1344 /// Gets the maximum absolute elements of a tensor along an axis.
1345 ///
1346 /// # Arguments
1347 ///
1348 /// * `tensor` - The tensor to get the maximum elements of.
1349 /// * `dim` - The dimension along which to get the maximum elements.
1350 ///
1351 /// # Returns
1352 ///
1353 /// A tensor with the maximum elements of `tensor` along `dim`.
1354 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1355 B::float_max_dim(B::float_abs(tensor), dim)
1356 }
1357
1358 /// Tests if any element in the float `tensor` evaluates to True.
1359 ///
1360 /// # Arguments
1361 ///
1362 /// * `tensor` - The tensor to test.
1363 ///
1364 /// # Returns
1365 ///
1366 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1367 fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1368 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1369 let bool_tensor = B::bool_not(bool_tensor);
1370 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1371 B::float_greater_elem(sum, 0.0f32.elem())
1372 }
1373
1374 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1375 ///
1376 /// # Arguments
1377 ///
1378 /// * `tensor` - The tensor to test.
1379 /// * `dim` - The axis along which to test.
1380 ///
1381 /// # Returns
1382 ///
1383 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1384 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1385 /// input evaluates to True, False otherwise.
1386 fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1387 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1388 let bool_tensor = B::bool_not(bool_tensor);
1389 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1390 B::float_greater_elem(sum, 0.0f32.elem())
1391 }
1392
1393 /// Tests if all elements in the float `tensor` evaluate to True.
1394 ///
1395 /// # Arguments
1396 ///
1397 /// * `tensor` - The tensor to test.
1398 ///
1399 /// # Returns
1400 ///
1401 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1402 /// evaluate to True, False otherwise.
1403 fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1404 let num_elems = tensor.shape().num_elements();
1405 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1406 let bool_tensor = B::bool_not(bool_tensor);
1407 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1408 B::float_equal_elem(sum, (num_elems as f32).elem())
1409 }
1410
1411 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1412 ///
1413 /// # Arguments
1414 ///
1415 /// * `tensor` - The tensor to test.
1416 /// * `dim` - The axis along which to test.
1417 ///
1418 /// # Returns
1419 ///
1420 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1421 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1422 /// evaluates to True, False otherwise.
1423 fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1424 let num_elems = tensor.shape().dims[dim];
1425 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1426 let bool_tensor = B::bool_not(bool_tensor);
1427 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1428 B::float_equal_elem(sum, (num_elems as f32).elem())
1429 }
1430
1431 /// Returns the signs of the float `tensor`.
1432 ///
1433 /// # Arguments
1434 ///
1435 /// * `tensor` - The tensor to extract the signs from.
1436 ///
1437 /// # Returns
1438 ///
1439 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1440 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1441 let zeros = B::float_zeros(
1442 tensor.shape(),
1443 &B::float_device(&tensor),
1444 tensor.dtype().into(),
1445 );
1446 let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1447 let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1448
1449 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1450 result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1451 result
1452 }
1453
1454 /// Broadcasts the float `tensor` to the given `shape`.
1455 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1456
1457 /// Sort the elements of the input `tensor` by value in along a given dimension.
1458 ///
1459 /// This sort is unstable (i.e., may reorder equal elements).
1460 ///
1461 /// # Arguments
1462 ///
1463 /// * `tensor` - The input tensor.
1464 /// * `dim` - The axis along which to sort.
1465 /// * `descending` - The sorting order.
1466 ///
1467 /// # Returns
1468 ///
1469 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1470 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1471 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1472 }
1473
1474 /// Sort the elements of the input `tensor` by value in along a given dimension.
1475 ///
1476 /// This sort is unstable (i.e., may reorder equal elements).
1477 ///
1478 /// # Arguments
1479 ///
1480 /// * `tensor` - The input tensor.
1481 /// * `dim` - The axis along which to sort.
1482 /// * `descending` - The sorting order.
1483 ///
1484 /// # Returns
1485 ///
1486 /// A tensor with the same shape as the input tensor and corresponding indices, where
1487 /// the elements are sorted by value and the indices map back to the original input tensor.
1488 fn float_sort_with_indices(
1489 tensor: FloatTensor<B>,
1490 dim: usize,
1491 descending: bool,
1492 ) -> (FloatTensor<B>, IntTensor<B>) {
1493 let (values, indices) =
1494 sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1495 (values.tensor(), indices)
1496 }
1497
1498 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1499 ///
1500 /// This sort is unstable (i.e., may reorder equal elements).
1501 ///
1502 /// # Arguments
1503 ///
1504 /// * `tensor` - The input tensor.
1505 /// * `dim` - The axis along which to sort.
1506 /// * `descending` - The sorting order.
1507 ///
1508 /// # Returns
1509 ///
1510 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1511 fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1512 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1513 }
1514
1515 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1516 /// using the given locations in [-1, 1].
1517 ///
1518 /// Interpolation is bilinear.
1519 /// Padding is border: out of bounds locations will be clamped to the nearest border
1520 ///
1521 /// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
1522 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1523 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1524 /// * `method` - How to interpolate between samples
1525 ///
1526 /// # Returns
1527 ///
1528 /// A tensor with shape (N, C, H_out, W_out)
1529 fn float_grid_sample_2d(
1530 tensor: FloatTensor<B>,
1531 grid: FloatTensor<B>,
1532 method: InterpolateMode,
1533 ) -> FloatTensor<B> {
1534 match method {
1535 InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(tensor, grid),
1536 _ => todo!("Default implementation for grid_sample_2d with {method:?} unimplemented"),
1537 }
1538 }
1539
1540 /// Unfold windows along a dimension.
1541 ///
1542 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1543 /// where windows are advanced by `step` at each index.
1544 ///
1545 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1546 ///
1547 /// # Arguments
1548 ///
1549 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1550 /// * `dim` - the selected dim.
1551 /// * `size` - the size of each unfolded window.
1552 /// * `step` - the step between each window.
1553 ///
1554 /// # Returns
1555 ///
1556 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1557 fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1558 -> FloatTensor<B>;
1559
1560 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1561 ///
1562 /// # Returns
1563 ///
1564 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1565 fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> {
1566 // Check if the input tensor is NaN by comparing it to itself
1567 // NaN is the only value that is not equal to itself
1568 B::float_not_equal(tensor.clone(), tensor)
1569 }
1570
1571 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1572 ///
1573 /// # Returns
1574 ///
1575 /// A boolean tensor where `true` indicates that the value is infinite
1576 fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> {
1577 B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.elem())
1578 }
1579}