burn_backend/backend/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::grid_sample::float_grid_sample_2d_ref;
3use super::repeat_dim::repeat_with_slice_assign;
4use super::sort::{argsort, sort, sort_with_indices};
5use crate::ops::GridSampleOptions;
6use crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};
7use crate::{Backend, Distribution, TensorData, get_device_settings};
8use crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};
9use alloc::vec::Vec;
10use burn_std::{BoolDType, FloatDType, IntDType, Shape, Slice};
11
12/// Operations on float tensors.
13pub trait FloatTensorOps<B: Backend> {
14 /// Creates a new tensor from the data structure.
15 ///
16 /// # Arguments
17 ///
18 /// * `data` - The data structure.
19 /// * `device` - The device to create the tensor on.
20 ///
21 /// # Returns
22 ///
23 /// The tensor with the given data.
24 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
25
26 /// Creates a new tensor with random values.
27 ///
28 /// # Arguments
29 ///
30 /// * `shape` - The shape of the tensor.
31 /// * `distribution` - The distribution to sample from.
32 /// * `device` - The device to create the tensor on.
33 /// * `dtype` - The target data type.
34 ///
35 /// # Returns
36 ///
37 /// The tensor with the given shape and random values.
38 fn float_random(
39 shape: Shape,
40 distribution: Distribution,
41 device: &Device<B>,
42 dtype: FloatDType,
43 ) -> FloatTensor<B>;
44
45 /// Creates a new tensor with zeros.
46 ///
47 /// # Arguments
48 ///
49 /// * `shape` - The shape of the tensor.
50 /// * `device` - The device to create the tensor on.
51 /// * `dtype` - The target data type.
52 ///
53 /// # Returns
54 ///
55 /// The tensor with the given shape and zeros.
56 fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
57 Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)
58 }
59
60 /// Creates a new tensor with ones.
61 ///
62 /// # Arguments
63 ///
64 /// * `shape` - The shape of the tensor.
65 /// * `device` - The device to create the tensor on.
66 /// * `dtype` - The target data type.
67 ///
68 /// # Returns
69 ///
70 /// The tensor with the given shape and ones.
71 fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {
72 Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)
73 }
74
75 /// Creates a tensor filled with given value.
76 ///
77 /// # Arguments
78 ///
79 /// * `shape` - The shape of the tensor.
80 /// * `fill_value` - The value with which to fill the tensor.
81 /// * `device` - The device to create the tensor on.
82 /// * `dtype` - The target data type.
83 ///
84 /// # Returns
85 ///
86 /// The tensor filled with given value
87 fn float_full(
88 shape: Shape,
89 fill_value: Scalar,
90 device: &Device<B>,
91 dtype: FloatDType,
92 ) -> FloatTensor<B> {
93 Self::float_from_data(
94 TensorData::full_dtype(shape, fill_value, dtype.into()),
95 device,
96 )
97 }
98
99 /// Converts the tensor to a data structure.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The data structure with the tensor's data.
108 fn float_into_data(
109 tensor: FloatTensor<B>,
110 ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;
111
112 /// Gets the device of the tensor.
113 ///
114 /// # Arguments
115 ///
116 /// * `tensor` - The tensor.
117 ///
118 /// # Returns
119 ///
120 /// The device of the tensor.
121 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
122
123 /// Moves the tensor to the given device.
124 ///
125 /// # Arguments
126 ///
127 /// * `tensor` - The tensor.
128 /// * `device` - The device to move the tensor to.
129 ///
130 /// # Returns
131 ///
132 /// The tensor on the given device.
133 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
134
135 /// Converts float tensor to int tensor.
136 ///
137 /// # Arguments
138 ///
139 /// * `tensor` - The tensor.
140 /// * `out_dtype` - The output tensor dtype.
141 ///
142 /// # Returns
143 ///
144 /// The int tensor with the same data as the float tensor.
145 fn float_into_int(tensor: FloatTensor<B>, out_dtype: IntDType) -> IntTensor<B>;
146
147 /// Creates an empty tensor with the given shape.
148 ///
149 /// # Arguments
150 ///
151 /// * `shape` - The shape of the tensor.
152 /// * `device` - The device to create the tensor on.
153 /// * `dtype` - The target data type.
154 ///
155 /// # Returns
156 ///
157 /// The empty tensor with the given shape.
158 fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;
159
160 /// Repeat the tensor along the given dimension.
161 ///
162 /// # Arguments
163 ///
164 /// * `tensor` - The tensor.
165 /// * `dim` - The dimension to repeat.
166 /// * `times` - The number of times to repeat the dimension.
167 ///
168 /// # Returns
169 ///
170 /// The tensor with the given dimension repeated.
171 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
172 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
173 }
174
175 /// Adds two tensors together.
176 ///
177 /// # Arguments
178 ///
179 /// * `lhs` - The left-hand side tensor.
180 /// * `rhs` - The right-hand side tensor.
181 ///
182 /// # Returns
183 ///
184 /// The result of adding the two tensors together.
185 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
186
187 /// Adds a scalar to a tensor.
188 ///
189 /// # Arguments
190 ///
191 /// * `lhs` - The left-hand side tensor.
192 /// * `rhs` - The right-hand side scalar.
193 ///
194 /// # Returns
195 ///
196 /// The result of adding the scalar to the tensor.
197 fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
198
199 /// Clamps a tensor under a minimum value.
200 ///
201 /// # Arguments
202 ///
203 /// * `tensor` - The tensor to clamp.
204 /// * `min` - The minimum value.
205 ///
206 /// # Returns
207 ///
208 /// The clamped tensor.
209 fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {
210 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
211 let mask = Self::float_lower_elem(tensor.clone(), min, dtype);
212 B::float_mask_fill(tensor, mask, min)
213 }
214
215 /// Clamps a tensor over a maximum value.
216 ///
217 /// # Arguments
218 ///
219 /// * `tensor` - The tensor to clamp.
220 /// * `max` - The maximum value.
221 ///
222 /// # Returns
223 ///
224 /// The clamped tensor.
225 fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {
226 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
227 let mask = Self::float_greater_elem(tensor.clone(), max, dtype);
228 B::float_mask_fill(tensor, mask, max)
229 }
230
231 /// Clamps a tensor between a minimum and maximum value.
232 ///
233 /// # Arguments
234 ///
235 /// * `tensor` - The tensor to clamp.
236 /// * `min` - The minimum value.
237 /// * `max` - The maximum value.
238 ///
239 /// # Returns
240 ///
241 /// The clamped tensor.
242 fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {
243 // Default implementation
244 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
245 }
246
247 /// Subtracts two tensors.
248 ///
249 /// # Arguments
250 ///
251 /// * `lhs` - The left-hand side tensor.
252 /// * `rhs` - The right-hand side tensor.
253 ///
254 /// # Returns
255 ///
256 /// The result of subtracting the two tensors.
257 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259 /// Subtracts a scalar from a tensor.
260 ///
261 /// # Arguments
262 ///
263 /// * `lhs` - The left-hand side tensor.
264 /// * `rhs` - The right-hand side scalar.
265 ///
266 /// # Returns
267 ///
268 /// The result of subtracting the scalar from the tensor.
269 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
270
271 /// Multiplies two tensors together element-wise.
272 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
273
274 /// Multiplies a tensor by a scalar.
275 ///
276 /// # Arguments
277 ///
278 /// * `lhs` - The left-hand side tensor.
279 /// * `rhs` - The right-hand side scalar.
280 ///
281 /// # Returns
282 ///
283 /// The result of multiplying the tensor by the scalar.
284 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
285
286 /// Divides two tensors element-wise.
287 ///
288 /// # Arguments
289 ///
290 /// * `lhs` - The left-hand side tensor.
291 /// * `rhs` - The right-hand side tensor.
292 ///
293 /// # Returns
294 ///
295 /// The result of dividing the two tensors.
296 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
297
298 /// Divides a tensor by a scalar.
299 ///
300 /// # Arguments
301 ///
302 /// * `lhs` - The left-hand side tensor.
303 /// * `rhs` - The right-hand side scalar.
304 ///
305 /// # Returns
306 ///
307 /// The result of dividing the tensor by the scalar.
308 fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
309
310 /// Computes the remainder of division between two tensors element-wise.
311 ///
312 /// # Arguments
313 ///
314 /// * `lhs` - The left-hand side tensor.
315 /// * `rhs` - The right-hand side tensor.
316 ///
317 /// # Returns
318 ///
319 /// The element-wise remainder when dividing `lhs` by `rhs`.
320 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
321
322 /// Computes the modulus of a tensor given a scalar.
323 ///
324 /// # Arguments
325 /// * `lhs` - The left-hand side tensor.
326 /// * `rhs` - The right-hand side scalar.
327 ///
328 /// # Returns
329 ///
330 /// The result of applying the modulus of the scalar to the tensor.
331 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;
332
333 /// Multiplies two tensors together using matrix multiplication.
334 ///
335 /// # Arguments
336 ///
337 /// * `lhs` - The left-hand side tensor.
338 /// * `rhs` - The right-hand side tensor.
339 ///
340 /// # Returns
341 ///
342 /// The result of multiplying the two tensors together using matrix multiplication.
343 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
344
345 /// Computes the cross product of two tensors along a given dimension.
346 ///
347 /// # Arguments
348 ///
349 /// * `lhs` - The left-hand side tensor.
350 /// * `rhs` - The right-hand side tensor.
351 /// * `dim` - The dimension to compute the cross product along.
352 ///
353 /// # Returns
354 ///
355 /// The cross product of the two tensors.
356 fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
357
358 /// Negates a tensor element-wise.
359 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
360 Self::float_mul_scalar(tensor, (-1f32).into())
361 }
362
363 /// Calculates the reciprocals element-wise
364 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
365
366 /// Transposes a tensor.
367 ///
368 /// # Arguments
369 ///
370 /// * `tensor` - The tensor to transpose.
371 ///
372 /// # Returns
373 ///
374 /// The transposed tensor.
375 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
376 let ndims = tensor.shape().num_dims();
377 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
378 }
379
380 /// Swaps two dimensions of a tensor.
381 ///
382 /// # Arguments
383 ///
384 /// * `tensor` - The tensor to swap the dimensions of.
385 /// * `dim1` - The first dimension to swap.
386 /// * `dim2` - The second dimension to swap.
387 ///
388 /// # Returns
389 ///
390 /// The tensor with the dimensions swapped.
391 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
392
393 /// Permutes the dimensions of a tensor.
394 ///
395 /// # Arguments
396 ///
397 /// * `tensor` - The tensor to permute the dimensions of.
398 /// * `axes` - The new order of the dimensions.
399 /// # Returns
400 ///
401 /// The tensor with the dimensions permuted.
402 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
403
404 /// Reverse the order of elements in a tensor along the given axes.
405 ///
406 /// # Arguments
407 ///
408 /// * `tensor` - The tensor to reverse.
409 /// * `axes` - The axes to reverse.
410 ///
411 /// The tensor with the elements reversed.
412 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
413
414 /// Reshapes a tensor.
415 ///
416 /// # Arguments
417 ///
418 /// * `tensor` - The tensor to reshape.
419 /// * `shape` - The new shape of the tensor.
420 ///
421 /// # Returns
422 ///
423 /// The tensor with the new shape.
424 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
425
426 /// Gather elements from a tensor.
427 ///
428 /// # Arguments
429 ///
430 /// * `dim` - The dimension to gather from.
431 /// * `tensor` - The tensor to gather from.
432 /// * `indices` - The indices to gather.
433 ///
434 /// # Returns
435 ///
436 /// The gathered elements.
437 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
438
439 /// Scatter elements into a tensor using sum reduction.
440 ///
441 /// # Arguments
442 ///
443 /// * `dim` - The dimension to scatter into.
444 /// * `tensor` - The tensor to scatter into.
445 /// * `indices` - The indices to scatter into.
446 /// * `value` - The value to scatter.
447 ///
448 /// # Returns
449 ///
450 /// The tensor with the scattered elements.
451 fn float_scatter_add(
452 dim: usize,
453 tensor: FloatTensor<B>,
454 indices: IntTensor<B>,
455 value: FloatTensor<B>,
456 ) -> FloatTensor<B>;
457
458 /// Multi-dimensional scatter: update `data` at locations specified by `indices` with `values`.
459 ///
460 /// # Arguments
461 ///
462 /// * `data` - The tensor to scatter into.
463 /// * `indices` - An M-dimensional integer tensor whose last dimension indexes into `data`.
464 /// * `values` - The values to scatter.
465 /// * `reduction` - How to combine with existing values.
466 ///
467 /// # Returns
468 ///
469 /// The tensor with scattered values.
470 fn float_scatter_nd(
471 _data: FloatTensor<B>,
472 _indices: IntTensor<B>,
473 _values: FloatTensor<B>,
474 _reduction: crate::tensor::IndexingUpdateOp,
475 ) -> FloatTensor<B> {
476 unimplemented!("float_scatter_nd is not implemented for this backend")
477 }
478
479 /// Multi-dimensional gather: collect slices from `data` at locations specified by `indices`.
480 ///
481 /// # Arguments
482 ///
483 /// * `data` - The tensor to gather from.
484 /// * `indices` - An M-dimensional integer tensor whose last dimension indexes into `data`.
485 ///
486 /// # Returns
487 ///
488 /// The gathered tensor.
489 fn float_gather_nd(_data: FloatTensor<B>, _indices: IntTensor<B>) -> FloatTensor<B> {
490 unimplemented!("float_gather_nd is not implemented for this backend")
491 }
492
493 /// Select tensor elements along the given dimension corresponding for the given indices.
494 ///
495 /// # Arguments
496 ///
497 /// * `tensor` - The tensor to select from.
498 /// * `dim` - The dimension to select from.
499 /// * `indices` - The indices to select.
500 ///
501 /// # Returns
502 ///
503 /// The selected elements.
504 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
505
506 /// Assign the selected elements along the given dimension corresponding for the given indices
507 /// to the given value using sum reduction.
508 ///
509 /// # Arguments
510 ///
511 /// * `tensor` - The tensor to select from.
512 /// * `dim` - The dimension to select from.
513 /// * `indices` - The indices to select.
514 /// * `value` - The value to assign.
515 ///
516 /// # Returns
517 ///
518 /// The tensor with the selected elements assigned to the given value.
519 fn float_select_add(
520 tensor: FloatTensor<B>,
521 dim: usize,
522 indices: IntTensor<B>,
523 value: FloatTensor<B>,
524 ) -> FloatTensor<B>;
525
526 /// Select tensor elements corresponding to the given slices.
527 ///
528 /// # Arguments
529 ///
530 /// * `tensor` - The tensor to select from.
531 /// * `slices` - The slices specifying ranges and steps for each dimension.
532 ///
533 /// # Returns
534 ///
535 /// The selected elements in a new tensor.
536 ///
537 /// # Note
538 ///
539 /// Empty slices (where start >= end) are handled at the high-level tensor API and will not
540 /// be passed to this method. Backend implementations do not need to handle empty slices.
541 fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;
542
543 /// Assign the selected elements corresponding to the given slices to the given value.
544 ///
545 /// # Arguments
546 ///
547 /// * `tensor` - The tensor to select from.
548 /// * `ranges` - The ranges to select.
549 /// * `value` - The value to assign.
550 ///
551 /// # Returns
552 ///
553 /// The tensor with the selected elements assigned to the given value.
554 ///
555 /// # Note
556 ///
557 /// Empty slice assignments (where any slice range produces 0 elements) are handled at the
558 /// high-level tensor API and will not be passed to this method. Backend implementations do
559 /// not need to handle empty slice assignments.
560 fn float_slice_assign(
561 tensor: FloatTensor<B>,
562 slices: &[Slice],
563 value: FloatTensor<B>,
564 ) -> FloatTensor<B>;
565
566 /// Update the given tensor with the value tensor where the mask is true.
567 ///
568 /// # Arguments
569 ///
570 /// * `tensor` - The tensor to select from.
571 /// * `mask` - The boolean mask to select with.
572 /// * `value` - The value to assign to the selected elements from the value tensor.
573 ///
574 /// # Returns
575 ///
576 /// The tensor with the selected elements assigned to the given value.
577 fn float_mask_where(
578 tensor: FloatTensor<B>,
579 mask: BoolTensor<B>,
580 value: FloatTensor<B>,
581 ) -> FloatTensor<B>;
582
583 /// Update the given tensor with the value where the mask is true.
584 ///
585 /// # Arguments
586 ///
587 /// * `tensor` - The tensor to select from.
588 /// * `mask` - The boolean mask to select with.
589 /// * `value` - The value to assign to the selected elements.
590 ///
591 /// # Returns
592 ///
593 /// The tensor with the selected elements assigned to the given value.
594 fn float_mask_fill(
595 tensor: FloatTensor<B>,
596 mask: BoolTensor<B>,
597 value: Scalar,
598 ) -> FloatTensor<B>;
599
600 /// Equal comparison of two tensors.
601 ///
602 /// # Arguments
603 ///
604 /// * `lhs` - The left-hand side tensor.
605 /// * `rhs` - The right-hand side tensor.
606 /// * `out_dtype` - The output tensor dtype.
607 ///
608 /// # Returns
609 ///
610 /// A boolean tensor with the result of the comparison.
611 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
612 -> BoolTensor<B>;
613
614 /// Element-wise non-equality comparison.
615 ///
616 /// # Arguments
617 ///
618 /// * `lhs` - The left-hand side tensor.
619 /// * `rhs` - The right-hand side tensor.
620 /// * `out_dtype` - The output tensor dtype.
621 ///
622 /// # Returns
623 ///
624 /// A boolean tensor with the result of the comparison.
625 fn float_not_equal(
626 lhs: FloatTensor<B>,
627 rhs: FloatTensor<B>,
628 out_dtype: BoolDType,
629 ) -> BoolTensor<B> {
630 let equal_tensor = B::float_equal(lhs, rhs, out_dtype);
631 B::bool_not(equal_tensor)
632 }
633
634 /// Equal comparison of a tensor and a scalar.
635 ///
636 /// # Arguments
637 ///
638 /// * `lhs` - The left-hand side tensor.
639 /// * `rhs` - The right-hand side scalar.
640 /// * `out_dtype` - The output tensor dtype.
641 ///
642 /// # Returns
643 ///
644 /// A boolean tensor with the result of the comparison.
645 fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
646
647 /// Element-wise non-equality comparison with a scalar.
648 ///
649 /// # Arguments
650 ///
651 /// * `lhs` - The left-hand side tensor.
652 /// * `rhs` - The right-hand side scalar.
653 /// * `out_dtype` - The output tensor dtype.
654 ///
655 /// # Returns
656 ///
657 /// A boolean tensor with the result of the comparison.
658 fn float_not_equal_elem(
659 lhs: FloatTensor<B>,
660 rhs: Scalar,
661 out_dtype: BoolDType,
662 ) -> BoolTensor<B> {
663 let equal_tensor = B::float_equal_elem(lhs, rhs, out_dtype);
664 B::bool_not(equal_tensor)
665 }
666
667 /// Greater than comparison of two tensors.
668 ///
669 /// # Arguments
670 ///
671 /// * `lhs` - The left-hand side tensor.
672 /// * `rhs` - The right-hand side tensor.
673 /// * `out_dtype` - The output tensor dtype.
674 ///
675 /// # Returns
676 ///
677 /// A boolean tensor with the result of the comparison.
678 fn float_greater(
679 lhs: FloatTensor<B>,
680 rhs: FloatTensor<B>,
681 out_dtype: BoolDType,
682 ) -> BoolTensor<B>;
683
684 /// Greater than comparison of a tensor and a scalar.
685 ///
686 /// # Arguments
687 ///
688 /// * `lhs` - The left-hand side tensor.
689 /// * `rhs` - The right-hand side scalar.
690 /// * `out_dtype` - The output tensor dtype.
691 ///
692 /// # Returns
693 ///
694 /// A boolean tensor with the result of the comparison.
695 fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
696
697 /// Greater than or equal comparison of two tensors.
698 ///
699 /// # Arguments
700 ///
701 /// * `lhs` - The left-hand side tensor.
702 /// * `rhs` - The right-hand side tensor.
703 /// * `out_dtype` - The output tensor dtype.
704 ///
705 /// # Returns
706 ///
707 /// A boolean tensor with the result of the comparison.
708 fn float_greater_equal(
709 lhs: FloatTensor<B>,
710 rhs: FloatTensor<B>,
711 out_dtype: BoolDType,
712 ) -> BoolTensor<B>;
713
714 /// Greater than or equal comparison of a tensor and a scalar.
715 ///
716 /// # Arguments
717 ///
718 /// * `lhs` - The left-hand side tensor.
719 /// * `rhs` - The right-hand side scalar.
720 /// * `out_dtype` - The output tensor dtype.
721 ///
722 /// # Returns
723 ///
724 /// A boolean tensor with the result of the comparison.
725 fn float_greater_equal_elem(
726 lhs: FloatTensor<B>,
727 rhs: Scalar,
728 out_dtype: BoolDType,
729 ) -> BoolTensor<B>;
730
731 /// Less than comparison of two tensors.
732 ///
733 /// # Arguments
734 ///
735 /// * `lhs` - The left-hand side tensor.
736 /// * `rhs` - The right-hand side tensor.
737 /// * `out_dtype` - The output tensor dtype.
738 ///
739 /// # Returns
740 ///
741 /// A boolean tensor with the result of the comparison.
742 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>, out_dtype: BoolDType)
743 -> BoolTensor<B>;
744
745 /// Less than comparison of a tensor and a scalar.
746 ///
747 /// # Arguments
748 ///
749 /// * `lhs` - The left-hand side tensor.
750 /// * `rhs` - The right-hand side scalar.
751 /// * `out_dtype` - The output tensor dtype.
752 ///
753 /// # Returns
754 ///
755 /// A boolean tensor with the result of the comparison.
756 fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<B>;
757
758 /// Less than or equal comparison of two tensors.
759 ///
760 /// # Arguments
761 ///
762 /// * `lhs` - The left-hand side tensor.
763 /// * `rhs` - The right-hand side tensor.
764 /// * `out_dtype` - The output tensor dtype.
765 ///
766 /// # Returns
767 ///
768 /// A boolean tensor with the result of the comparison.
769 fn float_lower_equal(
770 lhs: FloatTensor<B>,
771 rhs: FloatTensor<B>,
772 out_dtype: BoolDType,
773 ) -> BoolTensor<B>;
774
775 /// Less than or equal comparison of a tensor and a scalar.
776 ///
777 /// # Arguments
778 ///
779 /// * `lhs` - The left-hand side tensor.
780 /// * `rhs` - The right-hand side scalar.
781 /// * `out_dtype` - The output tensor dtype.
782 ///
783 /// # Returns
784 ///
785 /// A boolean tensor with the result of the comparison.
786 fn float_lower_equal_elem(
787 lhs: FloatTensor<B>,
788 rhs: Scalar,
789 out_dtype: BoolDType,
790 ) -> BoolTensor<B>;
791
792 /// Detaches a tensor from the computation graph.
793 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
794 // Should only be overridden by autodiff backends.
795 tensor
796 }
797
798 /// Sets the `require_grad` flag of a tensor.
799 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
800 // Should only be overridden by autodiff backends.
801 tensor
802 }
803
804 /// Returns the `require_grad` flag of a tensor.
805 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
806 // Should only be overridden by autodiff backends.
807 false
808 }
809
810 /// Sum of all elements in a tensor.
811 ///
812 /// # Arguments
813 ///
814 /// * `tensor` - The tensor to sum.
815 ///
816 /// # Returns
817 ///
818 /// A scalar tensor with the sum of all elements in `tensor`.
819 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
820
821 /// Sum of all elements in a tensor along a dimension.
822 ///
823 /// # Arguments
824 ///
825 /// * `tensor` - The tensor to sum.
826 /// * `dim` - The dimension along which to sum.
827 ///
828 /// # Returns
829 ///
830 /// A tensor with the sum of all elements in `tensor` along `dim`.
831 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
832
833 /// Product of all elements in a tensor.
834 ///
835 /// # Arguments
836 ///
837 /// * `tensor` - The tensor to product.
838 ///
839 /// # Returns
840 ///
841 /// A scalar tensor with the product of all elements in `tensor`.
842 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
843 // Product of all elements in a tensor
844 B::float_exp(B::float_sum(B::float_log(tensor)))
845 }
846
847 /// Product of all elements in a tensor along a dimension.
848 ///
849 /// # Arguments
850 ///
851 /// * `tensor` - The tensor to product.
852 ///
853 /// # Returns
854 ///
855 /// A tensor with the product of all elements in `tensor` along `dim`.
856 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
857 // Product of all elements in a tensor along a dimension
858 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
859 }
860
861 /// Mean of all elements in a tensor.
862 ///
863 /// # Arguments
864 ///
865 /// * `tensor` - The tensor to mean.
866 ///
867 /// # Returns
868 ///
869 /// A scalar tensor with the mean of all elements in `tensor`.
870 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
871 let num_elems = tensor.shape().num_elements() as f32;
872 B::float_div_scalar(B::float_sum(tensor), num_elems.into())
873 }
874
875 /// Mean of all elements in a tensor along a dimension.
876 ///
877 /// # Arguments
878 ///
879 /// * `tensor` - The tensor to mean.
880 /// * `dim` - The dimension along which to mean.
881 ///
882 /// # Returns
883 ///
884 /// A tensor with the mean of all elements in `tensor` along `dim`.
885 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
886
887 /// Computes the cumulative sum of elements along a dimension.
888 ///
889 /// # Arguments
890 ///
891 /// * `tensor` - The tensor to compute the cumulative sum of.
892 /// * `dim` - The dimension along which to compute the cumulative sum.
893 ///
894 /// # Returns
895 ///
896 /// A tensor with the same shape where each element is the cumulative sum
897 /// of all elements up to and including that position along the dimension.
898 fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
899
900 /// Computes the cumulative product of elements along a dimension.
901 ///
902 /// # Arguments
903 ///
904 /// * `tensor` - The tensor to compute the cumulative product of.
905 /// * `dim` - The dimension along which to compute the cumulative product.
906 ///
907 /// # Returns
908 ///
909 /// A tensor with the same shape where each element is the cumulative product
910 /// of all elements up to and including that position along the dimension.
911 fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
912
913 /// Computes the cumulative minimum of elements along a dimension.
914 ///
915 /// # Arguments
916 ///
917 /// * `tensor` - The tensor to compute the cumulative minimum of.
918 /// * `dim` - The dimension along which to compute the cumulative minimum.
919 ///
920 /// # Returns
921 ///
922 /// A tensor with the same shape where each element is the minimum
923 /// of all elements up to and including that position along the dimension.
924 fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
925
926 /// Computes the cumulative maximum of elements along a dimension.
927 ///
928 /// # Arguments
929 ///
930 /// * `tensor` - The tensor to compute the cumulative maximum of.
931 /// * `dim` - The dimension along which to compute the cumulative maximum.
932 ///
933 /// # Returns
934 ///
935 /// A tensor with the same shape where each element is the maximum
936 /// of all elements up to and including that position along the dimension.
937 fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
938
939 /// Converts a tensor to another floating point data type.
940 ///
941 /// # Arguments
942 ///
943 /// * `tensor` - The tensor to convert.
944 /// * `dtype` - The target data type.
945 ///
946 /// # Returns
947 ///
948 /// A tensor with the same values as `tensor` but in the target floating point data type.
949 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
950
951 /// Returns a new tensor with exponential values.
952 ///
953 /// # Arguments
954 ///
955 /// * `tensor` - The tensor to exponentiate.
956 ///
957 /// # Returns
958 ///
959 /// A tensor with the same shape as `tensor` with exponential values.
960 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
961
962 /// Returns a new tensor with natural logarithm values.
963 ///
964 /// # Arguments
965 ///
966 /// * `tensor` - The tensor to take the logarithm of.
967 ///
968 /// # Returns
969 ///
970 /// A tensor with the same shape as `tensor` with natural logarithm values.
971 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
972
973 /// Returns a new tensor with logarithm values of (1 + Xi).
974 ///
975 /// # Arguments
976 ///
977 /// * `tensor` - The tensor to take the logarithm of.
978 ///
979 /// # Returns
980 ///
981 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
982 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
983
984 /// Element-wise power with a FloatTensor.
985 ///
986 /// # Arguments
987 ///
988 /// * `lhs` - The left-hand side tensor.
989 /// * `rhs` - The right-hand side tensor.
990 ///
991 /// # Returns
992 ///
993 /// The elements of `lhs` raised to the power of the elements of `rhs`.
994 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
995
996 /// Element-wise power with an IntTensor.
997 ///
998 /// # Arguments
999 ///
1000 /// * `lhs` - The left-hand side tensor.
1001 /// * `rhs` - The right-hand side floatTensor.
1002 ///
1003 /// # Returns
1004 ///
1005 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
1006 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
1007 let dtype = lhs.dtype();
1008 Self::float_powf(lhs, B::int_into_float(rhs, dtype.into()))
1009 }
1010
1011 /// Raises a tensor to the power of an int scalar.
1012 ///
1013 /// # Backend Implementors Note
1014 ///
1015 /// A number of common exponent cases can be implemented with operations
1016 /// which are much cheaper than generic exponentiation.
1017 ///
1018 /// This (`Backend` impl overridable) operation handles generic optimizations
1019 /// for several common integer exponent cases; and then dispatches to
1020 /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]
1021 /// operation to handle the generic case.
1022 ///
1023 /// # Arguments
1024 ///
1025 /// * `lhs` - The left-hand side tensor.
1026 /// * `rhs` - The right-hand side scalar.
1027 ///
1028 /// # Returns
1029 ///
1030 /// The elements of `lhs` raised to the value of `rhs`.
1031 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1032 match rhs.elem::<i64>() {
1033 0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),
1034 1 => lhs,
1035 2 => B::float_mul(lhs.clone(), lhs),
1036 -1 => Self::float_recip(lhs),
1037 -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),
1038 _ => Self::float_powi_scalar_impl(lhs, rhs),
1039 }
1040 }
1041
1042 /// Raises a tensor to the power of an int scalar.
1043 ///
1044 /// # Backend Implementors Note
1045 ///
1046 /// This is the generic implementation of integer exponentiation
1047 /// called by [`Self::float_powi_scalar`] in the fallback case.
1048 ///
1049 /// As a general rule, this should not be called directly.
1050 ///
1051 /// # Arguments
1052 ///
1053 /// * `lhs` - The left-hand side tensor.
1054 /// * `rhs` - The right-hand side scalar.
1055 ///
1056 /// # Returns
1057 ///
1058 /// The elements of `lhs` raised to the value of `rhs`.
1059 fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {
1060 // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.
1061 Self::float_powf_scalar_impl(lhs, rhs)
1062 }
1063
1064 /// Returns a new tensor with values raised to the power of float `value`.
1065 ///
1066 /// # Backend Implementors Note
1067 ///
1068 /// This (`Backend` impl overridable) operation dispatches integer exponentiation
1069 /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to
1070 /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]
1071 /// operation to handle the generic case.
1072 ///
1073 /// # Arguments
1074 ///
1075 /// * `tensor` - The tensor to exponentiate.
1076 /// * `value` - The exponent.
1077 ///
1078 /// # Returns
1079 ///
1080 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1081 fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {
1082 if let Some(exp) = value.try_as_integer() {
1083 Self::float_powi_scalar(tensor, exp)
1084 } else {
1085 Self::float_powf_scalar_impl(tensor, value)
1086 }
1087 }
1088
1089 /// Returns a new tensor with values raised to the power of float `value`.
1090 ///
1091 /// # Backend Implementors Note
1092 ///
1093 /// This is the generic implementation of integer exponentiation
1094 /// called by [`Self::float_powf_scalar`] in the fallback case.
1095 ///
1096 /// This is the minimal required support a `Backend` must implement
1097 /// for exponentiation.
1098 ///
1099 /// As a general rule, this should not be called directly.
1100 ///
1101 /// # Arguments
1102 ///
1103 /// * `tensor` - The tensor to exponentiate.
1104 /// * `value` - The exponent.
1105 ///
1106 /// # Returns
1107 ///
1108 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
1109 fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;
1110
1111 /// Returns a new tensor with square root values.
1112 ///
1113 /// # Arguments
1114 ///
1115 /// * `tensor` - The tensor to take the square root of.
1116 ///
1117 /// # Returns
1118 ///
1119 /// A tensor with the same shape as `tensor` with square root values.
1120 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
1121
1122 /// Returns a new tensor with absolute values.
1123 ///
1124 /// # Arguments
1125 ///
1126 /// * `tensor` - The tensor to take absolute value of.
1127 ///
1128 /// # Returns
1129 ///
1130 /// A tensor with the same shape as `tensor` with absolute values.
1131 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
1132
1133 /// Returns a new tensor with cosine values.
1134 ///
1135 /// # Arguments
1136 ///
1137 /// * `tensor` - The tensor to take the cosine of.
1138 ///
1139 /// # Returns
1140 ///
1141 /// A tensor with the same shape as `tensor` with cosine values.
1142 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1143
1144 /// Returns a new tensor with sine values.
1145 ///
1146 /// # Arguments
1147 ///
1148 /// * `tensor` - The tensor to take the sine of.
1149 ///
1150 /// # Returns
1151 ///
1152 /// A tensor with the same shape as `tensor` with sine values.
1153 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1154
1155 /// Returns a new tensor with tangent values.
1156 ///
1157 /// # Arguments
1158 ///
1159 /// * `tensor` - The tensor to take the tangent of.
1160 ///
1161 /// # Returns
1162 ///
1163 /// A tensor with the same shape as `tensor` with tangent values.
1164 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1165
1166 /// Returns a new tensor with hyperbolic cosine values.
1167 ///
1168 /// # Arguments
1169 ///
1170 /// * `tensor` - The tensor to take the hyperbolic cosine of.
1171 ///
1172 /// # Returns
1173 ///
1174 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
1175 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1176
1177 /// Returns a new tensor with hyperbolic sine values.
1178 ///
1179 /// # Arguments
1180 ///
1181 /// * `tensor` - The tensor to take the hyperbolic sine of.
1182 ///
1183 /// # Returns
1184 ///
1185 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
1186 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1187
1188 /// Returns a new tensor with hyperbolic tangent values.
1189 ///
1190 /// # Arguments
1191 ///
1192 /// * `tensor` - The tensor to take the hyperbolic tangent of.
1193 ///
1194 /// # Returns
1195 ///
1196 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
1197 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1198
1199 /// Returns a new tensor with inverse cosine values.
1200 ///
1201 /// # Arguments
1202 ///
1203 /// * `tensor` - The input tensor.
1204 ///
1205 /// # Returns
1206 ///
1207 /// A tensor with the same shape as `tensor` with inverse cosine values.
1208 fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;
1209
1210 /// Returns a new tensor with inverse hyperbolic cosine values.
1211 ///
1212 /// # Arguments
1213 ///
1214 /// * `tensor` - The input tensor.
1215 ///
1216 /// # Returns
1217 ///
1218 /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.
1219 fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1220
1221 /// Returns a new tensor with inverse sine values.
1222 ///
1223 /// # Arguments
1224 ///
1225 /// * `tensor` - The input tensor.
1226 ///
1227 /// # Returns
1228 ///
1229 /// A tensor with the same shape as `tensor` with inverse sine values.
1230 fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;
1231
1232 /// Returns a new tensor with inverse hyperbolic sine values.
1233 ///
1234 /// # Arguments
1235 ///
1236 /// * `tensor` - The input tensor.
1237 ///
1238 /// # Returns
1239 ///
1240 /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.
1241 fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1242
1243 /// Returns a new tensor with the inverse tangent values.
1244 ///
1245 /// # Arguments
1246 ///
1247 /// * `tensor` - The input tensor.
1248 ///
1249 /// # Returns
1250 ///
1251 /// A tensor with the same shape as `tensor` with the inverse tangent values.
1252 fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;
1253
1254 /// Returns a new tensor with the inverse hyperbolic tangent values.
1255 ///
1256 /// # Arguments
1257 ///
1258 /// * `tensor` - The input tensor.
1259 ///
1260 /// # Returns
1261 ///
1262 /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.
1263 fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
1264
1265 /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.
1266 ///
1267 /// # Arguments
1268 ///
1269 /// * `lhs` - The tensor with y coordinates.
1270 /// * `rhs` - The tensor with x coordinates.
1271 ///
1272 /// # Returns
1273 ///
1274 /// A tensor with the four-quadrant inverse tangent values.
1275 fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
1276
1277 /// Returns a new tensor with rounded values.
1278 ///
1279 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
1280 /// strategy, with halfway cases rounded to the nearest even integer value.
1281 ///
1282 /// # Arguments
1283 ///
1284 /// * `tensor` - The tensor to be rounded.
1285 ///
1286 /// # Returns
1287 ///
1288 /// A tensor with the same shape as `tensor` with rounded values.
1289 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
1290
1291 /// Returns a new tensor with floored values.
1292 ///
1293 /// # Arguments
1294 ///
1295 /// * `tensor` - The tensor to be floored.
1296 ///
1297 /// # Returns
1298 ///
1299 /// A tensor with the same shape as `tensor` with floored values.
1300 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1301
1302 /// Returns a new tensor with ceiled values.
1303 ///
1304 /// # Arguments
1305 ///
1306 /// * `tensor` - The tensor to be ceiled.
1307 ///
1308 /// # Returns
1309 ///
1310 /// A tensor with the same shape as `tensor` with ceiled values.
1311 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1312
1313 /// Returns a new tensor with truncated values.
1314 ///
1315 /// # Arguments
1316 ///
1317 /// * `tensor` - The tensor to be truncated.
1318 ///
1319 /// # Returns
1320 ///
1321 /// A tensor with the same shape as `tensor` with truncated values.
1322 fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;
1323
1324 /// Returns a new tensor with the error function values.
1325 ///
1326 /// # Arguments
1327 ///
1328 /// * `tensor` - The tensor to take the error function of.
1329 ///
1330 /// # Returns
1331 ///
1332 /// A tensor with the same shape as `tensor` with error function values.
1333 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1334
1335 /// Concatenates tensors along a dimension.
1336 ///
1337 /// # Arguments
1338 ///
1339 /// * `tensors` - The tensors to concatenate.
1340 /// * `dim` - The dimension along which to concatenate.
1341 ///
1342 /// # Returns
1343 ///
1344 /// A tensor with the concatenated tensors along `dim`.
1345 ///
1346 /// # Note
1347 ///
1348 /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the
1349 /// high-level tensor API and will not be passed to this method. Backend implementations do
1350 /// not need to handle empty tensors.
1351 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1352 cat_with_slice_assign::<B, Float>(
1353 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1354 dim,
1355 )
1356 .tensor()
1357 }
1358
1359 /// Gets the indices of the maximum elements of a tensor along an axis.
1360 ///
1361 /// # Arguments
1362 ///
1363 /// * `tensor` - The tensor to get the maximum elements of.
1364 /// * `dim` - The dimension along which to get the maximum elements.
1365 /// * `out_dtype` - The output tensor dtype.
1366 ///
1367 /// # Returns
1368 ///
1369 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1370 fn float_argmax(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1371
1372 /// Gets the indices of the k maximum elements of a tensor along an axis.
1373 /// if two elements are equals, it will be ordered by lowest indices
1374 ///
1375 /// # Arguments
1376 ///
1377 /// * `tensor` - The tensor to get the maximum elements of.
1378 /// * `dim` - The dimension along which to get the maximum elements.
1379 /// * `k` - number of maximum elements
1380 /// * `out_dtype` - The output tensor dtype.
1381 ///
1382 /// # Returns
1383 ///
1384 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1385 fn float_argtopk(
1386 tensor: FloatTensor<B>,
1387 dim: usize,
1388 k: usize,
1389 out_dtype: IntDType,
1390 ) -> IntTensor<B>;
1391
1392 /// Gets the values of the k maximum elements of a tensor along an axis.
1393 ///
1394 /// # Arguments
1395 ///
1396 /// * `tensor` - The tensor to get the maximum elements of.
1397 /// * `dim` - The dimension along which to get the maximum elements.
1398 /// * `k` - number of maximum elements
1399 /// * `out_dtype` - The output tensor dtype.
1400 ///
1401 /// # Returns
1402 ///
1403 /// A tensor with the values of the maximum elements of `tensor` along `dim`.
1404 fn float_topk(tensor: FloatTensor<B>, dim: usize, k: usize) -> FloatTensor<B>;
1405
1406 /// Gets the indices of the minimum elements of a tensor along an axis.
1407 ///
1408 /// # Arguments
1409 ///
1410 /// * `tensor` - The tensor to get the minimum elements of.
1411 /// * `dim` - The dimension along which to get the minimum elements.
1412 /// * `out_dtype` - The output tensor dtype.
1413 ///
1414 /// # Returns
1415 ///
1416 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1417 fn float_argmin(tensor: FloatTensor<B>, dim: usize, out_dtype: IntDType) -> IntTensor<B>;
1418
1419 /// Gets the maximum element of a tensor.
1420 ///
1421 /// # Arguments
1422 ///
1423 /// * `tensor` - The tensor to get the maximum elements of.
1424 ///
1425 /// # Returns
1426 ///
1427 /// A tensor with the maximum element of `tensor`.
1428 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1429 let shape = tensor.shape();
1430 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1431
1432 B::float_max_dim(tensor, 0)
1433 }
1434
1435 /// Gets the maximum elements of a tensor along an axis.
1436 ///
1437 /// # Arguments
1438 ///
1439 /// * `tensor` - The tensor to get the maximum elements of.
1440 /// * `dim` - The dimension along which to get the maximum elements.
1441 ///
1442 /// # Returns
1443 ///
1444 /// A tensor with the maximum elements of `tensor` along `dim`.
1445 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1446 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1447 let index = B::float_argmax(tensor.clone(), dim, dtype);
1448
1449 B::float_gather(dim, tensor, index)
1450 }
1451
1452 /// Gets the maximum elements of a tensor along an axis and their indices.
1453 ///
1454 /// # Arguments
1455 ///
1456 /// * `tensor` - The tensor to get the maximum elements of.
1457 /// * `dim` - The dimension along which to get the maximum elements.
1458 /// * `indices_dtype` - The indices tensor dtype.
1459 ///
1460 /// # Returns
1461 ///
1462 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1463 fn float_max_dim_with_indices(
1464 tensor: FloatTensor<B>,
1465 dim: usize,
1466 indices_dtype: IntDType,
1467 ) -> (FloatTensor<B>, IntTensor<B>) {
1468 let index = B::float_argmax(tensor.clone(), dim, indices_dtype);
1469 let values = B::float_gather(dim, tensor, index.clone());
1470
1471 (values, index)
1472 }
1473
1474 /// Gets the minimum element of a tensor.
1475 ///
1476 /// # Arguments
1477 ///
1478 /// * `tensor` - The tensor to get the minimum elements of.
1479 ///
1480 /// # Returns
1481 ///
1482 /// A tensor with the minimum element of `tensor`.
1483 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1484 let shape = tensor.shape();
1485 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1486
1487 B::float_min_dim(tensor, 0)
1488 }
1489
1490 /// Gets the minimum elements of a tensor along an axis.
1491 ///
1492 /// # Arguments
1493 ///
1494 /// * `tensor` - The tensor to get the minimum elements of.
1495 /// * `dim` - The dimension along which to get the minimum elements.
1496 ///
1497 /// # Returns
1498 ///
1499 /// A tensor with the minimum elements of `tensor` along `dim`.
1500 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1501 let dtype = get_device_settings::<B>(&B::float_device(&tensor)).int_dtype;
1502 let index = B::float_argmin(tensor.clone(), dim, dtype);
1503
1504 B::float_gather(dim, tensor, index)
1505 }
1506
1507 /// Gets the minimum elements of a tensor along an axis and their indices.
1508 ///
1509 /// # Arguments
1510 ///
1511 /// * `tensor` - The tensor to get the minimum elements of.
1512 /// * `dim` - The dimension along which to get the minimum elements.
1513 /// * `indices_dtype` - The indices tensor dtype.
1514 ///
1515 /// # Returns
1516 ///
1517 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1518 fn float_min_dim_with_indices(
1519 tensor: FloatTensor<B>,
1520 dim: usize,
1521 indices_dtype: IntDType,
1522 ) -> (FloatTensor<B>, IntTensor<B>) {
1523 let index = B::float_argmin(tensor.clone(), dim, indices_dtype);
1524 let values = B::float_gather(dim, tensor, index.clone());
1525
1526 (values, index)
1527 }
1528
1529 /// Gets the maximum absolute element of a tensor.
1530 ///
1531 /// # Arguments
1532 ///
1533 /// * `tensor` - The tensor to get the maximum elements of.
1534 ///
1535 /// # Returns
1536 ///
1537 /// A tensor with the maximum element of `tensor`.
1538 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1539 let shape = tensor.shape();
1540 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1541
1542 B::float_max_abs_dim(tensor, 0)
1543 }
1544
1545 /// Gets the maximum absolute elements of a tensor along an axis.
1546 ///
1547 /// # Arguments
1548 ///
1549 /// * `tensor` - The tensor to get the maximum elements of.
1550 /// * `dim` - The dimension along which to get the maximum elements.
1551 ///
1552 /// # Returns
1553 ///
1554 /// A tensor with the maximum elements of `tensor` along `dim`.
1555 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1556 B::float_max_dim(B::float_abs(tensor), dim)
1557 }
1558
1559 /// Tests if any element in the float `tensor` evaluates to True.
1560 ///
1561 /// # Arguments
1562 ///
1563 /// * `tensor` - The tensor to test.
1564 /// * `out_dtype` - The output tensor dtype.
1565 ///
1566 /// # Returns
1567 ///
1568 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1569 fn float_any(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1570 let float_dtype = tensor.dtype();
1571 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1572 let bool_tensor = B::bool_not(bool_tensor);
1573 let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1574 B::float_greater_elem(sum, 0f32.into(), out_dtype)
1575 }
1576
1577 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1578 ///
1579 /// # Arguments
1580 ///
1581 /// * `tensor` - The tensor to test.
1582 /// * `dim` - The axis along which to test.
1583 /// * `out_dtype` - The output tensor dtype.
1584 ///
1585 /// # Returns
1586 ///
1587 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1588 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1589 /// input evaluates to True, False otherwise.
1590 fn float_any_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1591 let float_dtype = tensor.dtype();
1592 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1593 let bool_tensor = B::bool_not(bool_tensor);
1594 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1595 B::float_greater_elem(sum, 0f32.into(), out_dtype)
1596 }
1597
1598 /// Tests if all elements in the float `tensor` evaluate to True.
1599 ///
1600 /// # Arguments
1601 ///
1602 /// * `tensor` - The tensor to test.
1603 /// * `out_dtype` - The output tensor dtype.
1604 ///
1605 /// # Returns
1606 ///
1607 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1608 /// evaluate to True, False otherwise.
1609 fn float_all(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1610 let float_dtype = tensor.dtype();
1611 let num_elems = tensor.shape().num_elements() as f32;
1612 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1613 let bool_tensor = B::bool_not(bool_tensor);
1614 let sum = B::float_sum(B::bool_into_float(bool_tensor, float_dtype.into()));
1615 B::float_equal_elem(sum, num_elems.into(), out_dtype)
1616 }
1617
1618 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1619 ///
1620 /// # Arguments
1621 ///
1622 /// * `tensor` - The tensor to test.
1623 /// * `dim` - The axis along which to test.
1624 /// * `out_dtype` - The output tensor dtype.
1625 ///
1626 /// # Returns
1627 ///
1628 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1629 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1630 /// evaluates to True, False otherwise.
1631 fn float_all_dim(tensor: FloatTensor<B>, dim: usize, out_dtype: BoolDType) -> BoolTensor<B> {
1632 let float_dtype = tensor.dtype();
1633 let num_elems = tensor.shape()[dim] as f32;
1634 let bool_tensor = B::float_equal_elem(tensor, 0f32.into(), out_dtype);
1635 let bool_tensor = B::bool_not(bool_tensor);
1636 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor, float_dtype.into()), dim);
1637 B::float_equal_elem(sum, num_elems.into(), out_dtype)
1638 }
1639
1640 /// Returns the signs of the float `tensor`.
1641 ///
1642 /// # Arguments
1643 ///
1644 /// * `tensor` - The tensor to extract the signs from.
1645 ///
1646 /// # Returns
1647 ///
1648 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1649 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1650 let device = B::float_device(&tensor);
1651 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
1652 let zeros = B::float_zeros(tensor.shape(), &device, tensor.dtype().into());
1653 let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
1654 let greater_than_zero = B::float_greater_elem(tensor, 0f32.into(), bool_dtype);
1655
1656 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());
1657 result = B::float_mask_fill(result, greater_than_zero, 1f32.into());
1658 result
1659 }
1660
1661 /// Broadcasts the float `tensor` to the given `shape`.
1662 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1663
1664 /// Sort the elements of the input `tensor` by value in along a given dimension.
1665 ///
1666 /// This sort is unstable (i.e., may reorder equal elements).
1667 ///
1668 /// # Arguments
1669 ///
1670 /// * `tensor` - The input tensor.
1671 /// * `dim` - The axis along which to sort.
1672 /// * `descending` - The sorting order.
1673 ///
1674 /// # Returns
1675 ///
1676 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1677 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1678 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1679 }
1680
1681 /// Sort the elements of the input `tensor` by value in along a given dimension.
1682 ///
1683 /// This sort is unstable (i.e., may reorder equal elements).
1684 ///
1685 /// # Arguments
1686 ///
1687 /// * `tensor` - The input tensor.
1688 /// * `dim` - The axis along which to sort.
1689 /// * `descending` - The sorting order.
1690 /// * `indices_dtype` - The indices tensor dtype.
1691 ///
1692 /// # Returns
1693 ///
1694 /// A tensor with the same shape as the input tensor and corresponding indices, where
1695 /// the elements are sorted by value and the indices map back to the original input tensor.
1696 fn float_sort_with_indices(
1697 tensor: FloatTensor<B>,
1698 dim: usize,
1699 descending: bool,
1700 indices_dtype: IntDType,
1701 ) -> (FloatTensor<B>, IntTensor<B>) {
1702 let (values, indices) = sort_with_indices::<B, Float>(
1703 TensorPrimitive::Float(tensor),
1704 dim,
1705 descending,
1706 indices_dtype,
1707 );
1708 (values.tensor(), indices)
1709 }
1710
1711 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1712 ///
1713 /// This sort is unstable (i.e., may reorder equal elements).
1714 ///
1715 /// # Arguments
1716 ///
1717 /// * `tensor` - The input tensor.
1718 /// * `dim` - The axis along which to sort.
1719 /// * `descending` - The sorting order.
1720 /// * `out_dtype` - The output tensor dtype.
1721 ///
1722 /// # Returns
1723 ///
1724 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1725 fn float_argsort(
1726 tensor: FloatTensor<B>,
1727 dim: usize,
1728 descending: bool,
1729 out_dtype: IntDType,
1730 ) -> IntTensor<B> {
1731 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending, out_dtype)
1732 }
1733
1734 /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,
1735 /// using the given locations in [-1, 1].
1736 ///
1737 /// # Arguments
1738 ///
1739 /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
1740 /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
1741 /// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
1742 /// * `options` - Grid sampling options (mode, padding_mode, align_corners)
1743 ///
1744 /// # Returns
1745 ///
1746 /// A tensor with shape (N, C, H_out, W_out)
1747 fn float_grid_sample_2d(
1748 tensor: FloatTensor<B>,
1749 grid: FloatTensor<B>,
1750 options: GridSampleOptions,
1751 ) -> FloatTensor<B> {
1752 // TODO: default impl should get int default dtype
1753 float_grid_sample_2d_ref::<B>(tensor, grid, options)
1754 }
1755
1756 /// Unfold windows along a dimension.
1757 ///
1758 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
1759 /// where windows are advanced by `step` at each index.
1760 ///
1761 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
1762 ///
1763 /// # Arguments
1764 ///
1765 /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
1766 /// * `dim` - the selected dim.
1767 /// * `size` - the size of each unfolded window.
1768 /// * `step` - the step between each window.
1769 ///
1770 /// # Returns
1771 ///
1772 /// A tensor view with shape ``[pre=..., windows, size, post=...]``.
1773 fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)
1774 -> FloatTensor<B>;
1775
1776 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
1777 ///
1778 /// # Returns
1779 ///
1780 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
1781 fn float_is_nan(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1782 // Check if the input tensor is NaN by comparing it to itself
1783 // NaN is the only value that is not equal to itself
1784 B::float_not_equal(tensor.clone(), tensor, out_dtype)
1785 }
1786
1787 /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).
1788 ///
1789 /// # Returns
1790 ///
1791 /// A boolean tensor where `true` indicates that the value is infinite
1792 fn float_is_inf(tensor: FloatTensor<B>, out_dtype: BoolDType) -> BoolTensor<B> {
1793 B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into(), out_dtype)
1794 }
1795}