burn_backend/backend/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};
7use crate::{Backend, Distribution, TensorData, get_device_settings};
8use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14 /// Creates a new tensor from the data structure.
15 ///
16 /// # Arguments
17 ///
18 /// * `data` - The data structure.
19 /// * `device` - The device to create the tensor on.
20 ///
21 /// # Returns
22 ///
23 /// The tensor with the given data.
24 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26 /// Creates a new tensor with random values.
27 ///
28 /// # Arguments
29 ///
30 /// * `shape` - The shape of the tensor.
31 /// * `distribution` - The distribution to sample from.
32 /// * `device` - The device to create the tensor on.
33 /// * `dtype` - The target data type.
34 ///
35 /// # Returns
36 ///
37 /// The tensor with the given shape and random values.
38 fn float_random(
39 shape: Shape,
40 distribution: Distribution,
41 device: &Device<B>,
42 dtype: FloatDType,
43 ) -> FloatTensor<B>;
44
45 /// Creates a new tensor with zeros.
46 ///
47 /// # Arguments
48 ///
49 /// * `shape` - The shape of the tensor.
50 /// * `device` - The device to create the tensor on.
51 /// * `dtype` - The target data type.
52 ///
53 /// # Returns
54 ///
55 /// The tensor with the given shape and zeros.
56 fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
57 Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)
58 }
59
60 /// Creates a new tensor with ones.
61 ///
62 /// # Arguments
63 ///
64 /// * `shape` - The shape of the tensor.
65 /// * `device` - The device to create the tensor on.
66 /// * `dtype` - The target data type.
67 ///
68 /// # Returns
69 ///
70 /// The tensor with the given shape and ones.
71 fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
72 Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)
73 }
74
75 /// Creates a tensor filled with given value.
76 ///
77 /// # Arguments
78 ///
79 /// * `shape` - The shape of the tensor.
80 /// * `fill_value` - The value with which to fill the tensor.
81 /// * `device` - The device to create the tensor on.
82 /// * `dtype` - The target data type.
83 ///
84 /// # Returns
85 ///
86 /// The tensor filled with given value
87 fn float_full(
88 shape: Shape,
89 fill_value: Scalar,
90 device: &Device<B>,
91 dtype: FloatDType,
92 ) -> FloatTensor<B> {
93 Self::float_from_data(
94 TensorData::full_dtype(shape, fill_value, dtype.into()),
95 device,
96 )
97 }
98
99 /// Converts the tensor to a data structure.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The data structure with the tensor's data.
108 fn float_into_data(
109 tensor: FloatTensor<B>,
110 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
111
112 /// Gets the device of the tensor.
113 ///
114 /// # Arguments
115 ///
116 /// * `tensor` - The tensor.
117 ///
118 /// # Returns
119 ///
120 /// The device of the tensor.
121 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
122
123 /// Moves the tensor to the given device.
124 ///
125 /// # Arguments
126 ///
127 /// * `tensor` - The tensor.
128 /// * `device` - The device to move the tensor to.
129 ///
130 /// # Returns
131 ///
132 /// The tensor on the given device.
133 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
134
135 /// Converts float tensor to int tensor.
136 ///
137 /// # Arguments
138 ///
139 /// * `tensor` - The tensor.
140 /// * `out_dtype` - The output tensor dtype.
141 ///
142 /// # Returns
143 ///
144 /// The int tensor with the same data as the float tensor.
145 fn float_into_int(tensor: FloatTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
146
147 /// Creates an empty tensor with the given shape.
148 ///
149 /// # Arguments
150 ///
151 /// * `shape` - The shape of the tensor.
152 /// * `device` - The device to create the tensor on.
153 /// * `dtype` - The target data type.
154 ///
155 /// # Returns
156 ///
157 /// The empty tensor with the given shape.
158 fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
159
160 /// Repeat the tensor along the given dimension.
161 ///
162 /// # Arguments
163 ///
164 /// * `tensor` - The tensor.
165 /// * `dim` - The dimension to repeat.
166 /// * `times` - The number of times to repeat the dimension.
167 ///
168 /// # Returns
169 ///
170 /// The tensor with the given dimension repeated.
171 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
172 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
173 }
174
175 /// Adds two tensors together.
176 ///
177 /// # Arguments
178 ///
179 /// * `lhs` - The left-hand side tensor.
180 /// * `rhs` - The right-hand side tensor.
181 ///
182 /// # Returns
183 ///
184 /// The result of adding the two tensors together.
185 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
186
187 /// Adds a scalar to a tensor.
188 ///
189 /// # Arguments
190 ///
191 /// * `lhs` - The left-hand side tensor.
192 /// * `rhs` - The right-hand side scalar.
193 ///
194 /// # Returns
195 ///
196 /// The result of adding the scalar to the tensor.
197 fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
198
199 /// Clamps a tensor under a minimum value.
200 ///
201 /// # Arguments
202 ///
203 /// * `tensor` - The tensor to clamp.
204 /// * `min` - The minimum value.
205 ///
206 /// # Returns
207 ///
208 /// The clamped tensor.
209 fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {
210 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
211 let mask = Self::float_lower_elem(tensor.clone(), min, dtype);
212 B::float_mask_fill(tensor, mask, min)
213 }
214
215 /// Clamps a tensor over a maximum value.
216 ///
217 /// # Arguments
218 ///
219 /// * `tensor` - The tensor to clamp.
220 /// * `max` - The maximum value.
221 ///
222 /// # Returns
223 ///
224 /// The clamped tensor.
225 fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {
226 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
227 let mask = Self::float_greater_elem(tensor.clone(), max, dtype);
228 B::float_mask_fill(tensor, mask, max)
229 }
230
231 /// Clamps a tensor between a minimum and maximum value.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor to clamp.
236 /// * `min` - The minimum value.
237 /// * `max` - The maximum value.
238 ///
239 /// # Returns
240 ///
241 /// The clamped tensor.
242 fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {
243 // Default implementation
244 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
245 }
246
247 /// Subtracts two tensors.
248 ///
249 /// # Arguments
250 ///
251 /// * `lhs` - The left-hand side tensor.
252 /// * `rhs` - The right-hand side tensor.
253 ///
254 /// # Returns
255 ///
256 /// The result of subtracting the two tensors.
257 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259 /// Subtracts a scalar from a tensor.
260 ///
261 /// # Arguments
262 ///
263 /// * `lhs` - The left-hand side tensor.
264 /// * `rhs` - The right-hand side scalar.
265 ///
266 /// # Returns
267 ///
268 /// The result of subtracting the scalar from the tensor.
269 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
270
271 /// Multiplies two tensors together element-wise.
272 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
273
274 /// Multiplies a tensor by a scalar.
275 ///
276 /// # Arguments
277 ///
278 /// * `lhs` - The left-hand side tensor.
279 /// * `rhs` - The right-hand side scalar.
280 ///
281 /// # Returns
282 ///
283 /// The result of multiplying the tensor by the scalar.
284 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
285
286 /// Divides two tensors element-wise.
287 ///
288 /// # Arguments
289 ///
290 /// * `lhs` - The left-hand side tensor.
291 /// * `rhs` - The right-hand side tensor.
292 ///
293 /// # Returns
294 ///
295 /// The result of dividing the two tensors.
296 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
297
298 /// Divides a tensor by a scalar.
299 ///
300 /// # Arguments
301 ///
302 /// * `lhs` - The left-hand side tensor.
303 /// * `rhs` - The right-hand side scalar.
304 ///
305 /// # Returns
306 ///
307 /// The result of dividing the tensor by the scalar.
308 fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
309
310 /// Computes the remainder of division between two tensors element-wise.
311 ///
312 /// # Arguments
313 ///
314 /// * `lhs` - The left-hand side tensor.
315 /// * `rhs` - The right-hand side tensor.
316 ///
317 /// # Returns
318 ///
319 /// The element-wise remainder when dividing `lhs` by `rhs`.
320 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
321
322 /// Computes the modulus of a tensor given a scalar.
323 ///
324 /// # Arguments
325 /// * `lhs` - The left-hand side tensor.
326 /// * `rhs` - The right-hand side scalar.
327 ///
328 /// # Returns
329 ///
330 /// The result of applying the modulus of the scalar to the tensor.
331 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
332
333 /// Multiplies two tensors together using matrix multiplication.
334 ///
335 /// # Arguments
336 ///
337 /// * `lhs` - The left-hand side tensor.
338 /// * `rhs` - The right-hand side tensor.
339 ///
340 /// # Returns
341 ///
342 /// The result of multiplying the two tensors together using matrix multiplication.
343 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
344
345 /// Computes the cross product of two tensors along a given dimension.
346 ///
347 /// # Arguments
348 ///
349 /// * `lhs` - The left-hand side tensor.
350 /// * `rhs` - The right-hand side tensor.
351 /// * `dim` - The dimension to compute the cross product along.
352 ///
353 /// # Returns
354 ///
355 /// The cross product of the two tensors.
356 fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
357
358 /// Negates a tensor element-wise.
359 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
360 Self::float_mul_scalar(tensor, (-1f32).into())
361 }
362
363 /// Calculates the reciprocals element-wise
364 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
365
366 /// Transposes a tensor.
367 ///
368 /// # Arguments
369 ///
370 /// * `tensor` - The tensor to transpose.
371 ///
372 /// # Returns
373 ///
374 /// The transposed tensor.
375 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
376 let ndims = tensor.shape().num_dims();
377 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
378 }
379
380 /// Swaps two dimensions of a tensor.
381 ///
382 /// # Arguments
383 ///
384 /// * `tensor` - The tensor to swap the dimensions of.
385 /// * `dim1` - The first dimension to swap.
386 /// * `dim2` - The second dimension to swap.
387 ///
388 /// # Returns
389 ///
390 /// The tensor with the dimensions swapped.
391 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
392
393 /// Permutes the dimensions of a tensor.
394 ///
395 /// # Arguments
396 ///
397 /// * `tensor` - The tensor to permute the dimensions of.
398 /// * `axes` - The new order of the dimensions.
399 /// # Returns
400 ///
401 /// The tensor with the dimensions permuted.
402 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
403
404 /// Reverse the order of elements in a tensor along the given axes.
405 ///
406 /// # Arguments
407 ///
408 /// * `tensor` - The tensor to reverse.
409 /// * `axes` - The axes to reverse.
410 ///
411 /// The tensor with the elements reversed.
412 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
413
414 /// Reshapes a tensor.
415 ///
416 /// # Arguments
417 ///
418 /// * `tensor` - The tensor to reshape.
419 /// * `shape` - The new shape of the tensor.
420 ///
421 /// # Returns
422 ///
423 /// The tensor with the new shape.
424 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
425
426 /// Gather elements from a tensor.
427 ///
428 /// # Arguments
429 ///
430 /// * `dim` - The dimension to gather from.
431 /// * `tensor` - The tensor to gather from.
432 /// * `indices` - The indices to gather.
433 ///
434 /// # Returns
435 ///
436 /// The gathered elements.
437 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
438
439 /// Scatter elements into a tensor using sum reduction.
440 ///
441 /// # Arguments
442 ///
443 /// * `dim` - The dimension to scatter into.
444 /// * `tensor` - The tensor to scatter into.
445 /// * `indices` - The indices to scatter into.
446 /// * `value` - The value to scatter.
447 ///
448 /// # Returns
449 ///
450 /// The tensor with the scattered elements.
451 fn float_scatter_add(
452 dim: usize,
453 tensor: FloatTensor<B>,
454 indices: IntTensor<B>,
455 value: FloatTensor<B>,
456 ) -> FloatTensor<B>;
457
458 /// Multi-dimensional scatter: update `data` at locations specified by `indices` with `values`.
459 ///
460 /// # Arguments
461 ///
462 /// * `data` - The tensor to scatter into.
463 /// * `indices` - An M-dimensional integer tensor whose last dimension indexes into `data`.
464 /// * `values` - The values to scatter.
465 /// * `reduction` - How to combine with existing values.
466 ///
467 /// # Returns
468 ///
469 /// The tensor with scattered values.
470 fn float_scatter_nd(
471 _data: FloatTensor<B>,
472 _indices: IntTensor<B>,
473 _values: FloatTensor<B>,
474 _reduction: crate::tensor::IndexingUpdateOp,
475 ) -> FloatTensor<B> {
476 unimplemented!("float_scatter_nd is not implemented for this backend")
477 }
478
479 /// Multi-dimensional gather: collect slices from `data` at locations specified by `indices`.
480 ///
481 /// # Arguments
482 ///
483 /// * `data` - The tensor to gather from.
484 /// * `indices` - An M-dimensional integer tensor whose last dimension indexes into `data`.
485 ///
486 /// # Returns
487 ///
488 /// The gathered tensor.
489 fn float_gather_nd(_data: FloatTensor<B>, _indices: IntTensor<B>) -> FloatTensor<B> {
490 unimplemented!("float_gather_nd is not implemented for this backend")
491 }
492
493 /// Select tensor elements along the given dimension corresponding for the given indices.
494 ///
495 /// # Arguments
496 ///
497 /// * `tensor` - The tensor to select from.
498 /// * `dim` - The dimension to select from.
499 /// * `indices` - The indices to select.
500 ///
501 /// # Returns
502 ///
503 /// The selected elements.
504 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
505
506 /// Assign the selected elements along the given dimension corresponding for the given indices
507 /// to the given value using sum reduction.
508 ///
509 /// # Arguments
510 ///
511 /// * `tensor` - The tensor to select from.
512 /// * `dim` - The dimension to select from.
513 /// * `indices` - The indices to select.
514 /// * `value` - The value to assign.
515 ///
516 /// # Returns
517 ///
518 /// The tensor with the selected elements assigned to the given value.
519 fn float_select_add(
520 tensor: FloatTensor<B>,
521 dim: usize,
522 indices: IntTensor<B>,
523 value: FloatTensor<B>,
524 ) -> FloatTensor<B>;
525
526 /// Select tensor elements corresponding to the given slices.
527 ///
528 /// # Arguments
529 ///
530 /// * `tensor` - The tensor to select from.
531 /// * `slices` - The slices specifying ranges and steps for each dimension.
532 ///
533 /// # Returns
534 ///
535 /// The selected elements in a new tensor.
536 ///
537 /// # Note
538 ///
539 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
540 /// be passed to this method. Backend implementations do not need to handle empty slices.
541 fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
542
543 /// Assign the selected elements corresponding to the given slices to the given value.
544 ///
545 /// # Arguments
546 ///
547 /// * `tensor` - The tensor to select from.
548 /// * `ranges` - The ranges to select.
549 /// * `value` - The value to assign.
550 ///
551 /// # Returns
552 ///
553 /// The tensor with the selected elements assigned to the given value.
554 ///
555 /// # Note
556 ///
557 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
558 /// high-level tensor API and will not be passed to this method. Backend implementations do
559 /// not need to handle empty slice assignments.
560 fn float_slice_assign(
561 tensor: FloatTensor<B>,
562 slices: &[Slice],
563 value: FloatTensor<B>,
564 ) -> FloatTensor<B>;
565
566 /// Update the given tensor with the value tensor where the mask is true.
567 ///
568 /// # Arguments
569 ///
570 /// * `tensor` - The tensor to select from.
571 /// * `mask` - The boolean mask to select with.
572 /// * `value` - The value to assign to the selected elements from the value tensor.
573 ///
574 /// # Returns
575 ///
576 /// The tensor with the selected elements assigned to the given value.
577 fn float_mask_where(
578 tensor: FloatTensor<B>,
579 mask: BoolTensor<B>,
580 value: FloatTensor<B>,
581 ) -> FloatTensor<B>;
582
583 /// Update the given tensor with the value where the mask is true.
584 ///
585 /// # Arguments
586 ///
587 /// * `tensor` - The tensor to select from.
588 /// * `mask` - The boolean mask to select with.
589 /// * `value` - The value to assign to the selected elements.
590 ///
591 /// # Returns
592 ///
593 /// The tensor with the selected elements assigned to the given value.
594 fn float_mask_fill(
595 tensor: FloatTensor<B>,
596 mask: BoolTensor<B>,
597 value: Scalar,
598 ) -> FloatTensor<B>;
599
600 /// Equal comparison of two tensors.
601 ///
602 /// # Arguments
603 ///
604 /// * `lhs` - The left-hand side tensor.
605 /// * `rhs` - The right-hand side tensor.
606 /// * `out_dtype` - The output tensor dtype.
607 ///
608 /// # Returns
609 ///
610 /// A boolean tensor with the result of the comparison.
611 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
612 -> BoolTensor<B>;
613
614 /// Element-wise non-equality comparison.
615 ///
616 /// # Arguments
617 ///
618 /// * `lhs` - The left-hand side tensor.
619 /// * `rhs` - The right-hand side tensor.
620 /// * `out_dtype` - The output tensor dtype.
621 ///
622 /// # Returns
623 ///
624 /// A boolean tensor with the result of the comparison.
625 fn float_not_equal(
626 lhs: FloatTensor<B>,
627 rhs: FloatTensor<B>,
628 out_dtype: BoolDType,
629 ) -> BoolTensor<B> {
630 let equal_tensor = B::float_equal(lhs, rhs, out_dtype);
631 B::bool_not(equal_tensor)
632 }
633
634 /// Equal comparison of a tensor and a scalar.
635 ///
636 /// # Arguments
637 ///
638 /// * `lhs` - The left-hand side tensor.
639 /// * `rhs` - The right-hand side scalar.
640 /// * `out_dtype` - The output tensor dtype.
641 ///
642 /// # Returns
643 ///
644 /// A boolean tensor with the result of the comparison.
645 fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
646
647 /// Element-wise non-equality comparison with a scalar.
648 ///
649 /// # Arguments
650 ///
651 /// * `lhs` - The left-hand side tensor.
652 /// * `rhs` - The right-hand side scalar.
653 /// * `out_dtype` - The output tensor dtype.
654 ///
655 /// # Returns
656 ///
657 /// A boolean tensor with the result of the comparison.
658 fn float_not_equal_elem(
659 lhs: FloatTensor<B>,
660 rhs: Scalar,
661 out_dtype: BoolDType,
662 ) -> BoolTensor<B> {
663 let equal_tensor = B::float_equal_elem(lhs, rhs, out_dtype);
664 B::bool_not(equal_tensor)
665 }
666
667 /// Greater than comparison of two tensors.
668 ///
669 /// # Arguments
670 ///
671 /// * `lhs` - The left-hand side tensor.
672 /// * `rhs` - The right-hand side tensor.
673 /// * `out_dtype` - The output tensor dtype.
674 ///
675 /// # Returns
676 ///
677 /// A boolean tensor with the result of the comparison.
678 fn float_greater(
679 lhs: FloatTensor<B>,
680 rhs: FloatTensor<B>,
681 out_dtype: BoolDType,
682 ) -> BoolTensor<B>;
683
684 /// Greater than comparison of a tensor and a scalar.
685 ///
686 /// # Arguments
687 ///
688 /// * `lhs` - The left-hand side tensor.
689 /// * `rhs` - The right-hand side scalar.
690 /// * `out_dtype` - The output tensor dtype.
691 ///
692 /// # Returns
693 ///
694 /// A boolean tensor with the result of the comparison.
695 fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
696
697 /// Greater than or equal comparison of two tensors.
698 ///
699 /// # Arguments
700 ///
701 /// * `lhs` - The left-hand side tensor.
702 /// * `rhs` - The right-hand side tensor.
703 /// * `out_dtype` - The output tensor dtype.
704 ///
705 /// # Returns
706 ///
707 /// A boolean tensor with the result of the comparison.
708 fn float_greater_equal(
709 lhs: FloatTensor<B>,
710 rhs: FloatTensor<B>,
711 out_dtype: BoolDType,
712 ) -> BoolTensor<B>;
713
714 /// Greater than or equal comparison of a tensor and a scalar.
715 ///
716 /// # Arguments
717 ///
718 /// * `lhs` - The left-hand side tensor.
719 /// * `rhs` - The right-hand side scalar.
720 /// * `out_dtype` - The output tensor dtype.
721 ///
722 /// # Returns
723 ///
724 /// A boolean tensor with the result of the comparison.
725 fn float_greater_equal_elem(
726 lhs: FloatTensor<B>,
727 rhs: Scalar,
728 out_dtype: BoolDType,
729 ) -> BoolTensor<B>;
730
731 /// Less than comparison of two tensors.
732 ///
733 /// # Arguments
734 ///
735 /// * `lhs` - The left-hand side tensor.
736 /// * `rhs` - The right-hand side tensor.
737 /// * `out_dtype` - The output tensor dtype.
738 ///
739 /// # Returns
740 ///
741 /// A boolean tensor with the result of the comparison.
742 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
743 -> BoolTensor<B>;
744
745 /// Less than comparison of a tensor and a scalar.
746 ///
747 /// # Arguments
748 ///
749 /// * `lhs` - The left-hand side tensor.
750 /// * `rhs` - The right-hand side scalar.
751 /// * `out_dtype` - The output tensor dtype.
752 ///
753 /// # Returns
754 ///
755 /// A boolean tensor with the result of the comparison.
756 fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
757
758 /// Less than or equal comparison of two tensors.
759 ///
760 /// # Arguments
761 ///
762 /// * `lhs` - The left-hand side tensor.
763 /// * `rhs` - The right-hand side tensor.
764 /// * `out_dtype` - The output tensor dtype.
765 ///
766 /// # Returns
767 ///
768 /// A boolean tensor with the result of the comparison.
769 fn float_lower_equal(
770 lhs: FloatTensor<B>,
771 rhs: FloatTensor<B>,
772 out_dtype: BoolDType,
773 ) -> BoolTensor<B>;
774
775 /// Less than or equal comparison of a tensor and a scalar.
776 ///
777 /// # Arguments
778 ///
779 /// * `lhs` - The left-hand side tensor.
780 /// * `rhs` - The right-hand side scalar.
781 /// * `out_dtype` - The output tensor dtype.
782 ///
783 /// # Returns
784 ///
785 /// A boolean tensor with the result of the comparison.
786 fn float_lower_equal_elem(
787 lhs: FloatTensor<B>,
788 rhs: Scalar,
789 out_dtype: BoolDType,
790 ) -> BoolTensor<B>;
791
792 /// Detaches a tensor from the computation graph.
793 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
794 // Should only be overridden by autodiff backends.
795 tensor
796 }
797
798 /// Sets the `require_grad` flag of a tensor.
799 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
800 // Should only be overridden by autodiff backends.
801 tensor
802 }
803
804 /// Returns the `require_grad` flag of a tensor.
805 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
806 // Should only be overridden by autodiff backends.
807 false
808 }
809
810 /// Sum of all elements in a tensor.
811 ///
812 /// # Arguments
813 ///
814 /// * `tensor` - The tensor to sum.
815 ///
816 /// # Returns
817 ///
818 /// A scalar tensor with the sum of all elements in `tensor`.
819 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
820
821 /// Sum of all elements in a tensor along a dimension.
822 ///
823 /// # Arguments
824 ///
825 /// * `tensor` - The tensor to sum.
826 /// * `dim` - The dimension along which to sum.
827 ///
828 /// # Returns
829 ///
830 /// A tensor with the sum of all elements in `tensor` along `dim`.
831 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
832
833 /// Product of all elements in a tensor.
834 ///
835 /// # Arguments
836 ///
837 /// * `tensor` - The tensor to product.
838 ///
839 /// # Returns
840 ///
841 /// A scalar tensor with the product of all elements in `tensor`.
842 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
843 // Product of all elements in a tensor
844 B::float_exp(B::float_sum(B::float_log(tensor)))
845 }
846
847 /// Product of all elements in a tensor along a dimension.
848 ///
849 /// # Arguments
850 ///
851 /// * `tensor` - The tensor to product.
852 ///
853 /// # Returns
854 ///
855 /// A tensor with the product of all elements in `tensor` along `dim`.
856 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
857 // Product of all elements in a tensor along a dimension
858 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
859 }
860
861 /// Mean of all elements in a tensor.
862 ///
863 /// # Arguments
864 ///
865 /// * `tensor` - The tensor to mean.
866 ///
867 /// # Returns
868 ///
869 /// A scalar tensor with the mean of all elements in `tensor`.
870 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
871 let num_elems = tensor.shape().num_elements() as f32;
872 B::float_div_scalar(B::float_sum(tensor), num_elems.into())
873 }
874
875 /// Mean of all elements in a tensor along a dimension.
876 ///
877 /// # Arguments
878 ///
879 /// * `tensor` - The tensor to mean.
880 /// * `dim` - The dimension along which to mean.
881 ///
882 /// # Returns
883 ///
884 /// A tensor with the mean of all elements in `tensor` along `dim`.
885 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
886
887 /// Computes the cumulative sum of elements along a dimension.
888 ///
889 /// # Arguments
890 ///
891 /// * `tensor` - The tensor to compute the cumulative sum of.
892 /// * `dim` - The dimension along which to compute the cumulative sum.
893 ///
894 /// # Returns
895 ///
896 /// A tensor with the same shape where each element is the cumulative sum
897 /// of all elements up to and including that position along the dimension.
898 fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
899
900 /// Computes the cumulative product of elements along a dimension.
901 ///
902 /// # Arguments
903 ///
904 /// * `tensor` - The tensor to compute the cumulative product of.
905 /// * `dim` - The dimension along which to compute the cumulative product.
906 ///
907 /// # Returns
908 ///
909 /// A tensor with the same shape where each element is the cumulative product
910 /// of all elements up to and including that position along the dimension.
911 fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
912
913 /// Computes the cumulative minimum of elements along a dimension.
914 ///
915 /// # Arguments
916 ///
917 /// * `tensor` - The tensor to compute the cumulative minimum of.
918 /// * `dim` - The dimension along which to compute the cumulative minimum.
919 ///
920 /// # Returns
921 ///
922 /// A tensor with the same shape where each element is the minimum
923 /// of all elements up to and including that position along the dimension.
924 fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
925
926 /// Computes the cumulative maximum of elements along a dimension.
927 ///
928 /// # Arguments
929 ///
930 /// * `tensor` - The tensor to compute the cumulative maximum of.
931 /// * `dim` - The dimension along which to compute the cumulative maximum.
932 ///
933 /// # Returns
934 ///
935 /// A tensor with the same shape where each element is the maximum
936 /// of all elements up to and including that position along the dimension.
937 fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
938
939 /// Converts a tensor to another floating point data type.
940 ///
941 /// # Arguments
942 ///
943 /// * `tensor` - The tensor to convert.
944 /// * `dtype` - The target data type.
945 ///
946 /// # Returns
947 ///
948 /// A tensor with the same values as `tensor` but in the target floating point data type.
949 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
950
951 /// Returns a new tensor with exponential values.
952 ///
953 /// # Arguments
954 ///
955 /// * `tensor` - The tensor to exponentiate.
956 ///
957 /// # Returns
958 ///
959 /// A tensor with the same shape as `tensor` with exponential values.
960 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
961
962 /// Returns a new tensor with natural logarithm values.
963 ///
964 /// # Arguments
965 ///
966 /// * `tensor` - The tensor to take the logarithm of.
967 ///
968 /// # Returns
969 ///
970 /// A tensor with the same shape as `tensor` with natural logarithm values.
971 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
972
973 /// Returns a new tensor with logarithm values of (1 + Xi).
974 ///
975 /// # Arguments
976 ///
977 /// * `tensor` - The tensor to take the logarithm of.
978 ///
979 /// # Returns
980 ///
981 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
982 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
983
984 /// Element-wise power with a FloatTensor.
985 ///
986 /// # Arguments
987 ///
988 /// * `lhs` - The left-hand side tensor.
989 /// * `rhs` - The right-hand side tensor.
990 ///
991 /// # Returns
992 ///
993 /// The elements of `lhs` raised to the power of the elements of `rhs`.
994 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
995
996 /// Element-wise power with an IntTensor.
997 ///
998 /// # Arguments
999 ///
1000 /// * `lhs` - The left-hand side tensor.
1001 /// * `rhs` - The right-hand side floatTensor.
1002 ///
1003 /// # Returns
1004 ///
1005 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
1006 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
1007 let dtype = lhs.dtype();
1008 Self::float_powf(lhs, B::int_into_float(rhs, dtype.into()))
1009 }
1010
1011 /// Raises a tensor to the power of an int scalar.
1012 ///
1013 /// # Backend Implementors Note
1014 ///
1015 /// A number of common exponent cases can be implemented with operations
1016 /// which are much cheaper than generic exponentiation.
1017 ///
1018 /// This (`Backend` impl overridable) operation handles generic optimizations
1019 /// for several common integer exponent cases; and then dispatches to
1020 /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
1021 /// operation to handle the generic case.
1022 ///
1023 /// # Arguments
1024 ///
1025 /// * `lhs` - The left-hand side tensor.
1026 /// * `rhs` - The right-hand side scalar.
1027 ///
1028 /// # Returns
1029 ///
1030 /// The elements of `lhs` raised to the value of `rhs`.
1031 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1032 match rhs.elem::<i64>() {
1033 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
1034 1 => lhs,
1035 2 => B::float_mul(lhs.clone(), lhs),
1036 -1 => Self::float_recip(lhs),
1037 -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
1038 _ => Self::float_powi_scalar_impl(lhs, rhs),
1039 }
1040 }
1041
1042 /// Raises a tensor to the power of an int scalar.
1043 ///
1044 /// # Backend Implementors Note
1045 ///
1046 /// This is the generic implementation of integer exponentiation
1047 /// called by [`Self::float_powi_scalar`] in the fallback case.
1048 ///
1049 /// As a general rule, this should not be called directly.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `lhs` - The left-hand side tensor.
1054 /// * `rhs` - The right-hand side scalar.
1055 ///
1056 /// # Returns
1057 ///
1058 /// The elements of `lhs` raised to the value of `rhs`.
1059 fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1060 // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
1061 Self::float_powf_scalar_impl(lhs, rhs)
1062 }
1063
1064 /// Returns a new tensor with values raised to the power of float `value`.
1065 ///
1066 /// # Backend Implementors Note
1067 ///
1068 /// This (`Backend` impl overridable) operation dispatches integer exponentiation
1069 /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
1070 /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
1071 /// operation to handle the generic case.
1072 ///
1073 /// # Arguments
1074 ///
1075 /// * `tensor` - The tensor to exponentiate.
1076 /// * `value` - The exponent.
1077 ///
1078 /// # Returns
1079 ///
1080 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1081 fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {
1082 if let Some(exp) = value.try_as_integer() {
1083 Self::float_powi_scalar(tensor, exp)
1084 } else {
1085 Self::float_powf_scalar_impl(tensor, value)
1086 }
1087 }
1088
1089 /// Returns a new tensor with values raised to the power of float `value`.
1090 ///
1091 /// # Backend Implementors Note
1092 ///
1093 /// This is the generic implementation of integer exponentiation
1094 /// called by [`Self::float_powf_scalar`] in the fallback case.
1095 ///
1096 /// This is the minimal required support a `Backend` must implement
1097 /// for exponentiation.
1098 ///
1099 /// As a general rule, this should not be called directly.
1100 ///
1101 /// # Arguments
1102 ///
1103 /// * `tensor` - The tensor to exponentiate.
1104 /// * `value` - The exponent.
1105 ///
1106 /// # Returns
1107 ///
1108 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1109 fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;
1110
1111 /// Returns a new tensor with square root values.
1112 ///
1113 /// # Arguments
1114 ///
1115 /// * `tensor` - The tensor to take the square root of.
1116 ///
1117 /// # Returns
1118 ///
1119 /// A tensor with the same shape as `tensor` with square root values.
1120 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1121
1122 /// Returns a new tensor with absolute values.
1123 ///
1124 /// # Arguments
1125 ///
1126 /// * `tensor` - The tensor to take absolute value of.
1127 ///
1128 /// # Returns
1129 ///
1130 /// A tensor with the same shape as `tensor` with absolute values.
1131 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1132
1133 /// Returns a new tensor with cosine values.
1134 ///
1135 /// # Arguments
1136 ///
1137 /// * `tensor` - The tensor to take the cosine of.
1138 ///
1139 /// # Returns
1140 ///
1141 /// A tensor with the same shape as `tensor` with cosine values.
1142 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1143
1144 /// Returns a new tensor with sine values.
1145 ///
1146 /// # Arguments
1147 ///
1148 /// * `tensor` - The tensor to take the sine of.
1149 ///
1150 /// # Returns
1151 ///
1152 /// A tensor with the same shape as `tensor` with sine values.
1153 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1154
1155 /// Returns a new tensor with tangent values.
1156 ///
1157 /// # Arguments
1158 ///
1159 /// * `tensor` - The tensor to take the tangent of.
1160 ///
1161 /// # Returns
1162 ///
1163 /// A tensor with the same shape as `tensor` with tangent values.
1164 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1165
1166 /// Returns a new tensor with hyperbolic cosine values.
1167 ///
1168 /// # Arguments
1169 ///
1170 /// * `tensor` - The tensor to take the hyperbolic cosine of.
1171 ///
1172 /// # Returns
1173 ///
1174 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1175 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1176
1177 /// Returns a new tensor with hyperbolic sine values.
1178 ///
1179 /// # Arguments
1180 ///
1181 /// * `tensor` - The tensor to take the hyperbolic sine of.
1182 ///
1183 /// # Returns
1184 ///
1185 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1186 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1187
1188 /// Returns a new tensor with hyperbolic tangent values.
1189 ///
1190 /// # Arguments
1191 ///
1192 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1193 ///
1194 /// # Returns
1195 ///
1196 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1197 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1198
1199 /// Returns a new tensor with inverse cosine values.
1200 ///
1201 /// # Arguments
1202 ///
1203 /// * `tensor` - The input tensor.
1204 ///
1205 /// # Returns
1206 ///
1207 /// A tensor with the same shape as `tensor` with inverse cosine values.
1208 fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1209
1210 /// Returns a new tensor with inverse hyperbolic cosine values.
1211 ///
1212 /// # Arguments
1213 ///
1214 /// * `tensor` - The input tensor.
1215 ///
1216 /// # Returns
1217 ///
1218 /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
1219 fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1220
1221 /// Returns a new tensor with inverse sine values.
1222 ///
1223 /// # Arguments
1224 ///
1225 /// * `tensor` - The input tensor.
1226 ///
1227 /// # Returns
1228 ///
1229 /// A tensor with the same shape as `tensor` with inverse sine values.
1230 fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1231
1232 /// Returns a new tensor with inverse hyperbolic sine values.
1233 ///
1234 /// # Arguments
1235 ///
1236 /// * `tensor` - The input tensor.
1237 ///
1238 /// # Returns
1239 ///
1240 /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1241 fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1242
1243 /// Returns a new tensor with the inverse tangent values.
1244 ///
1245 /// # Arguments
1246 ///
1247 /// * `tensor` - The input tensor.
1248 ///
1249 /// # Returns
1250 ///
1251 /// A tensor with the same shape as `tensor` with the inverse tangent values.
1252 fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1253
1254 /// Returns a new tensor with the inverse hyperbolic tangent values.
1255 ///
1256 /// # Arguments
1257 ///
1258 /// * `tensor` - The input tensor.
1259 ///
1260 /// # Returns
1261 ///
1262 /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.
1263 fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1264
1265 /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1266 ///
1267 /// # Arguments
1268 ///
1269 /// * `lhs` - The tensor with y coordinates.
1270 /// * `rhs` - The tensor with x coordinates.
1271 ///
1272 /// # Returns
1273 ///
1274 /// A tensor with the four-quadrant inverse tangent values.
1275 fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
1276
1277 /// Returns a new tensor with rounded values.
1278 ///
1279 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1280 /// strategy, with halfway cases rounded to the nearest even integer value.
1281 ///
1282 /// # Arguments
1283 ///
1284 /// * `tensor` - The tensor to be rounded.
1285 ///
1286 /// # Returns
1287 ///
1288 /// A tensor with the same shape as `tensor` with rounded values.
1289 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1290
1291 /// Returns a new tensor with floored values.
1292 ///
1293 /// # Arguments
1294 ///
1295 /// * `tensor` - The tensor to be floored.
1296 ///
1297 /// # Returns
1298 ///
1299 /// A tensor with the same shape as `tensor` with floored values.
1300 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1301
1302 /// Returns a new tensor with ceiled values.
1303 ///
1304 /// # Arguments
1305 ///
1306 /// * `tensor` - The tensor to be ceiled.
1307 ///
1308 /// # Returns
1309 ///
1310 /// A tensor with the same shape as `tensor` with ceiled values.
1311 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1312
1313 /// Returns a new tensor with truncated values.
1314 ///
1315 /// # Arguments
1316 ///
1317 /// * `tensor` - The tensor to be truncated.
1318 ///
1319 /// # Returns
1320 ///
1321 /// A tensor with the same shape as `tensor` with truncated values.
1322 fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1323
1324 /// Returns a new tensor with the error function values.
1325 ///
1326 /// # Arguments
1327 ///
1328 /// * `tensor` - The tensor to take the error function of.
1329 ///
1330 /// # Returns
1331 ///
1332 /// A tensor with the same shape as `tensor` with error function values.
1333 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1334
1335 /// Concatenates tensors along a dimension.
1336 ///
1337 /// # Arguments
1338 ///
1339 /// * `tensors` - The tensors to concatenate.
1340 /// * `dim` - The dimension along which to concatenate.
1341 ///
1342 /// # Returns
1343 ///
1344 /// A tensor with the concatenated tensors along `dim`.
1345 ///
1346 /// # Note
1347 ///
1348 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1349 /// high-level tensor API and will not be passed to this method. Backend implementations do
1350 /// not need to handle empty tensors.
1351 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1352 cat_with_slice_assign::<B, Float>(
1353 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1354 dim,
1355 )
1356 .tensor()
1357 }
1358
1359 /// Gets the indices of the maximum elements of a tensor along an axis.
1360 ///
1361 /// # Arguments
1362 ///
1363 /// * `tensor` - The tensor to get the maximum elements of.
1364 /// * `dim` - The dimension along which to get the maximum elements.
1365 /// * `out_dtype` - The output tensor dtype.
1366 ///
1367 /// # Returns
1368 ///
1369 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1370 fn float_argmax(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1371
1372 /// Gets the indices of the k maximum elements of a tensor along an axis.
1373 /// if two elements are equals, it will be ordered by lowest indices
1374 ///
1375 /// # Arguments
1376 ///
1377 /// * `tensor` - The tensor to get the maximum elements of.
1378 /// * `dim` - The dimension along which to get the maximum elements.
1379 /// * `k` - number of maximum elements
1380 /// * `out_dtype` - The output tensor dtype.
1381 ///
1382 /// # Returns
1383 ///
1384 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1385 fn float_argtopk(
1386 tensor: FloatTensor<B>,
1387 dim: usize,
1388 k: usize,
1389 out_dtype: IntDType,
1390 ) -> IntTensor<B>;
1391
1392 /// Gets the values of the k maximum elements of a tensor along an axis.
1393 ///
1394 /// # Arguments
1395 ///
1396 /// * `tensor` - The tensor to get the maximum elements of.
1397 /// * `dim` - The dimension along which to get the maximum elements.
1398 /// * `k` - number of maximum elements
1399 /// * `out_dtype` - The output tensor dtype.
1400 ///
1401 /// # Returns
1402 ///
1403 /// A tensor with the values of the maximum elements of `tensor` along `dim`.
1404 fn float_topk(tensor: FloatTensor<B>, dim: usize, k: usize) -> FloatTensor<B> {
1405 let device = Self::float_device(&tensor);
1406 let dtype = get_device_settings::<B>(&device).int_dtype;
1407 let k_indices = B::int_arange(0..k as i64, &device, dtype);
1408 Self::float_select(Self::float_sort(tensor, dim, true), dim, k_indices)
1409 }
1410
1411 /// Gets the indices of the minimum elements of a tensor along an axis.
1412 ///
1413 /// # Arguments
1414 ///
1415 /// * `tensor` - The tensor to get the minimum elements of.
1416 /// * `dim` - The dimension along which to get the minimum elements.
1417 /// * `out_dtype` - The output tensor dtype.
1418 ///
1419 /// # Returns
1420 ///
1421 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1422 fn float_argmin(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1423
1424 /// Gets the maximum element of a tensor.
1425 ///
1426 /// # Arguments
1427 ///
1428 /// * `tensor` - The tensor to get the maximum elements of.
1429 ///
1430 /// # Returns
1431 ///
1432 /// A tensor with the maximum element of `tensor`.
1433 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1434 let shape = tensor.shape();
1435 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1436
1437 B::float_max_dim(tensor, 0)
1438 }
1439
1440 /// Gets the maximum elements of a tensor along an axis.
1441 ///
1442 /// # Arguments
1443 ///
1444 /// * `tensor` - The tensor to get the maximum elements of.
1445 /// * `dim` - The dimension along which to get the maximum elements.
1446 ///
1447 /// # Returns
1448 ///
1449 /// A tensor with the maximum elements of `tensor` along `dim`.
1450 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1451 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1452 let index = B::float_argmax(tensor.clone(), dim, dtype);
1453
1454 B::float_gather(dim, tensor, index)
1455 }
1456
1457 /// Gets the maximum elements of a tensor along an axis and their indices.
1458 ///
1459 /// # Arguments
1460 ///
1461 /// * `tensor` - The tensor to get the maximum elements of.
1462 /// * `dim` - The dimension along which to get the maximum elements.
1463 /// * `indices_dtype` - The indices tensor dtype.
1464 ///
1465 /// # Returns
1466 ///
1467 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1468 fn float_max_dim_with_indices(
1469 tensor: FloatTensor<B>,
1470 dim: usize,
1471 indices_dtype: IntDType,
1472 ) -> (FloatTensor<B>, IntTensor<B>) {
1473 let index = B::float_argmax(tensor.clone(), dim, indices_dtype);
1474 let values = B::float_gather(dim, tensor, index.clone());
1475
1476 (values, index)
1477 }
1478
1479 /// Gets the minimum element of a tensor.
1480 ///
1481 /// # Arguments
1482 ///
1483 /// * `tensor` - The tensor to get the minimum elements of.
1484 ///
1485 /// # Returns
1486 ///
1487 /// A tensor with the minimum element of `tensor`.
1488 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1489 let shape = tensor.shape();
1490 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1491
1492 B::float_min_dim(tensor, 0)
1493 }
1494
1495 /// Gets the minimum elements of a tensor along an axis.
1496 ///
1497 /// # Arguments
1498 ///
1499 /// * `tensor` - The tensor to get the minimum elements of.
1500 /// * `dim` - The dimension along which to get the minimum elements.
1501 ///
1502 /// # Returns
1503 ///
1504 /// A tensor with the minimum elements of `tensor` along `dim`.
1505 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1506 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1507 let index = B::float_argmin(tensor.clone(), dim, dtype);
1508
1509 B::float_gather(dim, tensor, index)
1510 }
1511
1512 /// Gets the minimum elements of a tensor along an axis and their indices.
1513 ///
1514 /// # Arguments
1515 ///
1516 /// * `tensor` - The tensor to get the minimum elements of.
1517 /// * `dim` - The dimension along which to get the minimum elements.
1518 /// * `indices_dtype` - The indices tensor dtype.
1519 ///
1520 /// # Returns
1521 ///
1522 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1523 fn float_min_dim_with_indices(
1524 tensor: FloatTensor<B>,
1525 dim: usize,
1526 indices_dtype: IntDType,
1527 ) -> (FloatTensor<B>, IntTensor<B>) {
1528 let index = B::float_argmin(tensor.clone(), dim, indices_dtype);
1529 let values = B::float_gather(dim, tensor, index.clone());
1530
1531 (values, index)
1532 }
1533
1534 /// Gets the maximum absolute element of a tensor.
1535 ///
1536 /// # Arguments
1537 ///
1538 /// * `tensor` - The tensor to get the maximum elements of.
1539 ///
1540 /// # Returns
1541 ///
1542 /// A tensor with the maximum element of `tensor`.
1543 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1544 let shape = tensor.shape();
1545 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1546
1547 B::float_max_abs_dim(tensor, 0)
1548 }
1549
1550 /// Gets the maximum absolute elements of a tensor along an axis.
1551 ///
1552 /// # Arguments
1553 ///
1554 /// * `tensor` - The tensor to get the maximum elements of.
1555 /// * `dim` - The dimension along which to get the maximum elements.
1556 ///
1557 /// # Returns
1558 ///
1559 /// A tensor with the maximum elements of `tensor` along `dim`.
1560 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1561 B::float_max_dim(B::float_abs(tensor), dim)
1562 }
1563
1564 /// Tests if any element in the float `tensor` evaluates to True.
1565 ///
1566 /// # Arguments
1567 ///
1568 /// * `tensor` - The tensor to test.
1569 /// * `out_dtype` - The output tensor dtype.
1570 ///
1571 /// # Returns
1572 ///
1573 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1574 fn float_any(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1575 let float_dtype = tensor.dtype();
1576 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1577 let bool_tensor = B::bool_not(bool_tensor);
1578 let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1579 B::float_greater_elem(sum, 0f32.into(), out_dtype)
1580 }
1581
1582 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1583 ///
1584 /// # Arguments
1585 ///
1586 /// * `tensor` - The tensor to test.
1587 /// * `dim` - The axis along which to test.
1588 /// * `out_dtype` - The output tensor dtype.
1589 ///
1590 /// # Returns
1591 ///
1592 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1593 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1594 /// input evaluates to True, False otherwise.
1595 fn float_any_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1596 let float_dtype = tensor.dtype();
1597 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1598 let bool_tensor = B::bool_not(bool_tensor);
1599 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1600 B::float_greater_elem(sum, 0f32.into(), out_dtype)
1601 }
1602
1603 /// Tests if all elements in the float `tensor` evaluate to True.
1604 ///
1605 /// # Arguments
1606 ///
1607 /// * `tensor` - The tensor to test.
1608 /// * `out_dtype` - The output tensor dtype.
1609 ///
1610 /// # Returns
1611 ///
1612 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1613 /// evaluate to True, False otherwise.
1614 fn float_all(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1615 let float_dtype = tensor.dtype();
1616 let num_elems = tensor.shape().num_elements() as f32;
1617 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1618 let bool_tensor = B::bool_not(bool_tensor);
1619 let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1620 B::float_equal_elem(sum, num_elems.into(), out_dtype)
1621 }
1622
1623 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1624 ///
1625 /// # Arguments
1626 ///
1627 /// * `tensor` - The tensor to test.
1628 /// * `dim` - The axis along which to test.
1629 /// * `out_dtype` - The output tensor dtype.
1630 ///
1631 /// # Returns
1632 ///
1633 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1634 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1635 /// evaluates to True, False otherwise.
1636 fn float_all_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1637 let float_dtype = tensor.dtype();
1638 let num_elems = tensor.shape()[dim] as f32;
1639 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1640 let bool_tensor = B::bool_not(bool_tensor);
1641 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1642 B::float_equal_elem(sum, num_elems.into(), out_dtype)
1643 }
1644
1645 /// Returns the signs of the float `tensor`.
1646 ///
1647 /// # Arguments
1648 ///
1649 /// * `tensor` - The tensor to extract the signs from.
1650 ///
1651 /// # Returns
1652 ///
1653 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1654 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1655 let device = B::float_device(&tensor);
1656 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
1657 let zeros = B::float_zeros(tensor.shape(), &device, tensor.dtype().into());
1658 let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
1659 let greater_than_zero = B::float_greater_elem(tensor, 0f32.into(), bool_dtype);
1660
1661 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());
1662 result = B::float_mask_fill(result, greater_than_zero, 1f32.into());
1663 result
1664 }
1665
1666 /// Broadcasts the float `tensor` to the given `shape`.
1667 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1668
1669 /// Sort the elements of the input `tensor` by value in along a given dimension.
1670 ///
1671 /// This sort is unstable (i.e., may reorder equal elements).
1672 ///
1673 /// # Arguments
1674 ///
1675 /// * `tensor` - The input tensor.
1676 /// * `dim` - The axis along which to sort.
1677 /// * `descending` - The sorting order.
1678 ///
1679 /// # Returns
1680 ///
1681 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1682 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1683 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1684 }
1685
1686 /// Sort the elements of the input `tensor` by value in along a given dimension.
1687 ///
1688 /// This sort is unstable (i.e., may reorder equal elements).
1689 ///
1690 /// # Arguments
1691 ///
1692 /// * `tensor` - The input tensor.
1693 /// * `dim` - The axis along which to sort.
1694 /// * `descending` - The sorting order.
1695 /// * `indices_dtype` - The indices tensor dtype.
1696 ///
1697 /// # Returns
1698 ///
1699 /// A tensor with the same shape as the input tensor and corresponding indices, where
1700 /// the elements are sorted by value and the indices map back to the original input tensor.
1701 fn float_sort_with_indices(
1702 tensor: FloatTensor<B>,
1703 dim: usize,
1704 descending: bool,
1705 indices_dtype: IntDType,
1706 ) -> (FloatTensor<B>, IntTensor<B>) {
1707 let (values, indices) = sort_with_indices::<B, Float>(
1708 TensorPrimitive::Float(tensor),
1709 dim,
1710 descending,
1711 indices_dtype,
1712 );
1713 (values.tensor(), indices)
1714 }
1715
1716 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1717 ///
1718 /// This sort is unstable (i.e., may reorder equal elements).
1719 ///
1720 /// # Arguments
1721 ///
1722 /// * `tensor` - The input tensor.
1723 /// * `dim` - The axis along which to sort.
1724 /// * `descending` - The sorting order.
1725 /// * `out_dtype` - The output tensor dtype.
1726 ///
1727 /// # Returns
1728 ///
1729 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1730 fn float_argsort(
1731 tensor: FloatTensor<B>,
1732 dim: usize,
1733 descending: bool,
1734 out_dtype: IntDType,
1735 ) -> IntTensor<B> {
1736 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending, out_dtype)
1737 }
1738
1739 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1740 /// using the given locations in [-1, 1].
1741 ///
1742 /// # Arguments
1743 ///
1744 /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
1745 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1746 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1747 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1748 ///
1749 /// # Returns
1750 ///
1751 /// A tensor with shape (N, C, H_out, W_out)
1752 fn float_grid_sample_2d(
1753 tensor: FloatTensor<B>,
1754 grid: FloatTensor<B>,
1755 options: GridSampleOptions,
1756 ) -> FloatTensor<B> {
1757 // TODO: default impl should get int default dtype
1758 float_grid_sample_2d_ref::<B>(tensor, grid, options)
1759 }
1760
1761 /// Unfold windows along a dimension.
1762 ///
1763 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1764 /// where windows are advanced by `step` at each index.
1765 ///
1766 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1767 ///
1768 /// # Arguments
1769 ///
1770 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1771 /// * `dim` - the selected dim.
1772 /// * `size` - the size of each unfolded window.
1773 /// * `step` - the step between each window.
1774 ///
1775 /// # Returns
1776 ///
1777 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1778 fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1779 -> FloatTensor<B>;
1780
1781 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1782 ///
1783 /// # Returns
1784 ///
1785 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1786 fn float_is_nan(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1787 // Check if the input tensor is NaN by comparing it to itself
1788 // NaN is the only value that is not equal to itself
1789 B::float_not_equal(tensor.clone(), tensor, out_dtype)
1790 }
1791
1792 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1793 ///
1794 /// # Returns
1795 ///
1796 /// A boolean tensor where `true` indicates that the value is infinite
1797 fn float_is_inf(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1798 B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into(), out_dtype)
1799 }
1800}