burn_tensor/tensor/ops/tensor.rs
1use super::cat::cat_with_slice_assign;
2use super::repeat_dim::repeat_with_slice_assign;
3use super::{BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor};
4use crate::tensor::cast::ToElement;
5use crate::{Distribution, ElementConversion, Float, TensorData, backend::Backend, tensor::Shape};
6use crate::{
7 FloatDType, TensorMetadata, TensorPrimitive, tensor::api::chunk, tensor::api::narrow,
8 tensor::api::split, tensor::api::split_with_sizes,
9};
10use alloc::vec::Vec;
11use core::future::Future;
12use core::ops::Range;
13
14use crate::{argsort, sort, sort_with_indices};
15
16/// Operations on float tensors.
17pub trait FloatTensorOps<B: Backend> {
18 /// Creates a new tensor from the data structure.
19 ///
20 /// # Arguments
21 ///
22 /// * `data` - The data structure.
23 /// * `device` - The device to create the tensor on.
24 ///
25 /// # Returns
26 ///
27 /// The tensor with the given data.
28 fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;
29
30 /// Creates a new tensor with random values.
31 ///
32 /// # Arguments
33 ///
34 /// * `shape` - The shape of the tensor.
35 /// * `distribution` - The distribution to sample from.
36 /// * `device` - The device to create the tensor on.
37 ///
38 /// # Returns
39 ///
40 /// The tensor with the given shape and random values.
41 fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)
42 -> FloatTensor<B>;
43
44 /// Creates a new tensor with zeros.
45 ///
46 /// # Arguments
47 ///
48 /// * `shape` - The shape of the tensor.
49 /// * `device` - The device to create the tensor on.
50 ///
51 /// # Returns
52 ///
53 /// The tensor with the given shape and zeros.
54 fn float_zeros(shape: Shape, device: &Device<B>) -> FloatTensor<B> {
55 Self::float_from_data(TensorData::zeros::<FloatElem<B>, _>(shape), device)
56 }
57
58 /// Creates a new tensor with ones.
59 ///
60 /// # Arguments
61 ///
62 /// * `shape` - The shape of the tensor.
63 /// * `device` - The device to create the tensor on.
64 ///
65 /// # Returns
66 ///
67 /// The tensor with the given shape and ones.
68 fn float_ones(shape: Shape, device: &Device<B>) -> FloatTensor<B> {
69 Self::float_from_data(TensorData::ones::<FloatElem<B>, _>(shape), device)
70 }
71
72 /// Creates a tensor filled with given value.
73 ///
74 /// # Arguments
75 ///
76 /// * `shape` - The shape of the tensor.
77 /// * `fill_value` - The value with which to fill the tensor.
78 /// * `device` - The device to create the tensor on.
79 ///
80 /// # Returns
81 ///
82 /// The tensor filled with given value
83 fn float_full(shape: Shape, fill_value: FloatElem<B>, device: &Device<B>) -> FloatTensor<B> {
84 Self::float_add_scalar(Self::float_zeros(shape, device), fill_value)
85 }
86
87 /// Converts the tensor to a data structure.
88 ///
89 /// # Arguments
90 ///
91 /// * `tensor` - The tensor.
92 ///
93 /// # Returns
94 ///
95 /// The data structure with the tensor's data.
96 fn float_into_data(tensor: FloatTensor<B>)
97 -> impl Future<Output = TensorData> + 'static + Send;
98
99 /// Gets the device of the tensor.
100 ///
101 /// # Arguments
102 ///
103 /// * `tensor` - The tensor.
104 ///
105 /// # Returns
106 ///
107 /// The device of the tensor.
108 fn float_device(tensor: &FloatTensor<B>) -> Device<B>;
109
110 /// Moves the tensor to the given device.
111 ///
112 /// # Arguments
113 ///
114 /// * `tensor` - The tensor.
115 /// * `device` - The device to move the tensor to.
116 ///
117 /// # Returns
118 ///
119 /// The tensor on the given device.
120 fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;
121
122 /// Converts float tensor to int tensor.
123 ///
124 /// # Arguments
125 ///
126 /// * `tensor` - The tensor.
127 ///
128 /// # Returns
129 ///
130 /// The int tensor with the same data as the float tensor.
131 fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;
132
133 /// Creates an empty tensor with the given shape.
134 ///
135 /// # Arguments
136 ///
137 /// * `shape` - The shape of the tensor.
138 /// * `device` - The device to create the tensor on.
139 ///
140 /// # Returns
141 ///
142 /// The empty tensor with the given shape.
143 fn float_empty(shape: Shape, device: &Device<B>) -> FloatTensor<B>;
144
145 /// Repeat the tensor along the given dimension.
146 ///
147 /// # Arguments
148 ///
149 /// * `tensor` - The tensor.
150 /// * `dim` - The dimension to repeat.
151 /// * `times` - The number of times to repeat the dimension.
152 ///
153 /// # Returns
154 ///
155 /// The tensor with the given dimension repeated.
156 fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {
157 repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()
158 }
159
160 /// Adds two tensors together.
161 ///
162 /// # Arguments
163 ///
164 /// * `lhs` - The left hand side tensor.
165 /// * `rhs` - The right hand side tensor.
166 ///
167 /// # Returns
168 ///
169 /// The result of adding the two tensors together.
170 fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
171
172 /// Adds a scalar to a tensor.
173 ///
174 /// # Arguments
175 ///
176 /// * `lhs` - The left hand side tensor.
177 /// * `rhs` - The right hand side scalar.
178 ///
179 /// # Returns
180 ///
181 /// The result of adding the scalar to the tensor.
182 fn float_add_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
183
184 /// Clamps a tensor under a minimum value.
185 ///
186 /// # Arguments
187 ///
188 /// * `tensor` - The tensor to clamp.
189 /// * `min` - The minimum value.
190 ///
191 /// # Returns
192 ///
193 /// The clamped tensor.
194 fn float_clamp_min(tensor: FloatTensor<B>, min: FloatElem<B>) -> FloatTensor<B> {
195 // Default implementation
196 let mask = Self::float_lower_elem(tensor.clone(), min);
197 B::float_mask_fill(tensor, mask, min)
198 }
199
200 /// Clamps a tensor over a maximum value.
201 ///
202 /// # Arguments
203 ///
204 /// * `tensor` - The tensor to clamp.
205 /// * `max` - The maximum value.
206 ///
207 /// # Returns
208 ///
209 /// The clamped tensor.
210 fn float_clamp_max(tensor: FloatTensor<B>, max: FloatElem<B>) -> FloatTensor<B> {
211 // Default implementation
212 let mask = Self::float_greater_elem(tensor.clone(), max);
213 B::float_mask_fill(tensor, mask, max)
214 }
215
216 /// Clamps a tensor between a minimum and maximum value.
217 ///
218 /// # Arguments
219 ///
220 /// * `tensor` - The tensor to clamp.
221 /// * `min` - The minimum value.
222 /// * `max` - The maximum value.
223 ///
224 /// # Returns
225 ///
226 /// The clamped tensor.
227 fn float_clamp(tensor: FloatTensor<B>, min: FloatElem<B>, max: FloatElem<B>) -> FloatTensor<B> {
228 // Default implementation
229 Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)
230 }
231
232 /// Subtracts two tensors.
233 ///
234 /// # Arguments
235 ///
236 /// * `lhs` - The left hand side tensor.
237 /// * `rhs` - The right hand side tensor.
238 ///
239 /// # Returns
240 ///
241 /// The result of subtracting the two tensors.
242 fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
243
244 /// Subtracts a scalar from a tensor.
245 ///
246 /// # Arguments
247 ///
248 /// * `lhs` - The left hand side tensor.
249 /// * `rhs` - The right hand side scalar.
250 ///
251 /// # Returns
252 ///
253 /// The result of subtracting the scalar from the tensor.
254 fn float_sub_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
255
256 /// Multiplies two tensors together element-wise.
257 fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
258
259 /// Multiplies a tensor by a scalar.
260 ///
261 /// # Arguments
262 ///
263 /// * `lhs` - The left hand side tensor.
264 /// * `rhs` - The right hand side scalar.
265 ///
266 /// # Returns
267 ///
268 /// The result of multiplying the tensor by the scalar.
269 fn float_mul_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
270
271 /// Divides two tensors element-wise.
272 ///
273 /// # Arguments
274 ///
275 /// * `lhs` - The left hand side tensor.
276 /// * `rhs` - The right hand side tensor.
277 ///
278 /// # Returns
279 ///
280 /// The result of dividing the two tensors.
281 fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
282
283 /// Divides a tensor by a scalar.
284 ///
285 /// # Arguments
286 ///
287 /// * `lhs` - The left hand side tensor.
288 /// * `rhs` - The right hand side scalar.
289 ///
290 /// # Returns
291 ///
292 /// The result of dividing the tensor by the scalar.
293 fn float_div_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
294
295 /// Computes the remainder of division between two tensors element-wise.
296 ///
297 /// # Arguments
298 ///
299 /// * `lhs` - The left hand side tensor.
300 /// * `rhs` - The right hand side tensor.
301 ///
302 /// # Returns
303 ///
304 /// The element-wise remainder when dividing `lhs` by `rhs`.
305 fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
306
307 /// Computes the modulus of a tensor given a scalar.
308 ///
309 /// # Arguments
310 /// * `lhs` - The left hand side tensor.
311 /// * `rhs` - The right hand side scalar.
312 ///
313 /// # Returns
314 ///
315 /// The result of applying the modulus of the scalar to the tensor.
316 fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> FloatTensor<B>;
317
318 /// Multiplies two tensors together using matrix multiplication.
319 ///
320 /// # Arguments
321 ///
322 /// * `lhs` - The left hand side tensor.
323 /// * `rhs` - The right hand side tensor.
324 ///
325 /// # Returns
326 ///
327 /// The result of multiplying the two tensors together using matrix multiplication.
328 fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
329
330 /// Negates a tensor element-wise.
331 fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {
332 Self::float_mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
333 }
334
335 /// Calculates the reciprocals element-wise
336 fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;
337
338 /// Transposes a tensor.
339 ///
340 /// # Arguments
341 ///
342 /// * `tensor` - The tensor to transpose.
343 ///
344 /// # Returns
345 ///
346 /// The transposed tensor.
347 fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {
348 let ndims = tensor.shape().num_dims();
349 Self::float_swap_dims(tensor, ndims - 2, ndims - 1)
350 }
351
352 /// Swaps two dimensions of a tensor.
353 ///
354 /// # Arguments
355 ///
356 /// * `tensor` - The tensor to swap the dimensions of.
357 /// * `dim1` - The first dimension to swap.
358 /// * `dim2` - The second dimension to swap.
359 ///
360 /// # Returns
361 ///
362 /// The tensor with the dimensions swapped.
363 fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;
364
365 /// Permutes the dimensions of a tensor.
366 ///
367 /// # Arguments
368 ///
369 /// * `tensor` - The tensor to permute the dimensions of.
370 /// * `axes` - The new order of the dimensions.
371 /// # Returns
372 ///
373 /// The tensor with the dimensions permuted.
374 fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
375
376 /// Reverse the order of elements in a tensor along the given axes.
377 ///
378 /// # Arguments
379 ///
380 /// * `tensor` - The tensor to reverse.
381 /// * `axes` - The axes to reverse.
382 ///
383 /// The tensor with the elements reversed.
384 fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;
385
386 /// Reshapes a tensor.
387 ///
388 /// # Arguments
389 ///
390 /// * `tensor` - The tensor to reshape.
391 /// * `shape` - The new shape of the tensor.
392 ///
393 /// # Returns
394 ///
395 /// The tensor with the new shape.
396 fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
397
398 /// Gather elements from a tensor.
399 ///
400 /// # Arguments
401 ///
402 /// * `dim` - The dimension to gather from.
403 /// * `tensor` - The tensor to gather from.
404 /// * `indices` - The indices to gather.
405 ///
406 /// # Returns
407 ///
408 /// The gathered elements.
409 fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;
410
411 /// Scatter elements into a tensor.
412 ///
413 /// # Arguments
414 ///
415 /// * `dim` - The dimension to scatter into.
416 /// * `tensor` - The tensor to scatter into.
417 /// * `indices` - The indices to scatter into.
418 /// * `value` - The value to scatter.
419 ///
420 /// # Returns
421 ///
422 /// The tensor with the scattered elements.
423 fn float_scatter(
424 dim: usize,
425 tensor: FloatTensor<B>,
426 indices: IntTensor<B>,
427 value: FloatTensor<B>,
428 ) -> FloatTensor<B>;
429
430 /// Select tensor elements along the given dimension corresponding for the given indices.
431 ///
432 /// # Arguments
433 ///
434 /// * `tensor` - The tensor to select from.
435 /// * `dim` - The dimension to select from.
436 /// * `indices` - The indices to select.
437 ///
438 /// # Returns
439 ///
440 /// The selected elements.
441 fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;
442
443 /// Assign the selected elements along the given dimension corresponding for the given indices
444 /// to the given value.
445 ///
446 /// # Arguments
447 ///
448 /// * `tensor` - The tensor to select from.
449 /// * `dim` - The dimension to select from.
450 /// * `indices` - The indices to select.
451 /// * `value` - The value to assign.
452 ///
453 /// # Returns
454 ///
455 /// The tensor with the selected elements assigned to the given value.
456 fn float_select_assign(
457 tensor: FloatTensor<B>,
458 dim: usize,
459 indices: IntTensor<B>,
460 value: FloatTensor<B>,
461 ) -> FloatTensor<B>;
462
463 /// Select tensor elements corresponding for the given ranges.
464 ///
465 /// # Arguments
466 ///
467 /// * `tensor` - The tensor to select from.
468 /// * `ranges` - The ranges to select.
469 ///
470 /// # Returns
471 ///
472 /// The selected elements in a new tensor.
473 fn float_slice(tensor: FloatTensor<B>, ranges: &[Range<usize>]) -> FloatTensor<B>;
474
475 /// Assign the selected elements corresponding for the given ranges to the given value.
476 ///
477 /// # Arguments
478 ///
479 /// * `tensor` - The tensor to select from.
480 /// * `ranges` - The ranges to select.
481 /// * `value` - The value to assign.
482 ///
483 /// # Returns
484 ///
485 /// The tensor with the selected elements assigned to the given value.
486 fn float_slice_assign(
487 tensor: FloatTensor<B>,
488 ranges: &[Range<usize>],
489 value: FloatTensor<B>,
490 ) -> FloatTensor<B>;
491
492 /// Update the given tensor with the value tensor where the mask is true.
493 ///
494 /// # Arguments
495 ///
496 /// * `tensor` - The tensor to select from.
497 /// * `mask` - The boolean mask to select with.
498 /// * `value` - The value to assign to the selected elements from the value tensor.
499 ///
500 /// # Returns
501 ///
502 /// The tensor with the selected elements assigned to the given value.
503 fn float_mask_where(
504 tensor: FloatTensor<B>,
505 mask: BoolTensor<B>,
506 value: FloatTensor<B>,
507 ) -> FloatTensor<B>;
508
509 /// Update the given tensor with the value where the mask is true.
510 ///
511 /// # Arguments
512 ///
513 /// * `tensor` - The tensor to select from.
514 /// * `mask` - The boolean mask to select with.
515 /// * `value` - The value to assign to the selected elements.
516 ///
517 /// # Returns
518 ///
519 /// The tensor with the selected elements assigned to the given value.
520 fn float_mask_fill(
521 tensor: FloatTensor<B>,
522 mask: BoolTensor<B>,
523 value: FloatElem<B>,
524 ) -> FloatTensor<B>;
525
526 /// Equal comparison of two tensors.
527 ///
528 /// # Arguments
529 ///
530 /// * `lhs` - The left hand side tensor.
531 /// * `rhs` - The right hand side tensor.
532 ///
533 /// # Returns
534 ///
535 /// A boolean tensor with the result of the comparison.
536 fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
537
538 /// Element-wise non-equality comparison.
539 ///
540 /// # Arguments
541 ///
542 /// * `lhs` - The left hand side tensor.
543 /// * `rhs` - The right hand side tensor.
544 ///
545 /// # Returns
546 ///
547 /// A boolean tensor with the result of the comparison.
548 fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {
549 let equal_tensor = B::float_equal(lhs, rhs);
550 B::bool_not(equal_tensor)
551 }
552
553 /// Equal comparison of a tensor and a scalar.
554 ///
555 /// # Arguments
556 ///
557 /// * `lhs` - The left hand side tensor.
558 /// * `rhs` - The right hand side scalar.
559 ///
560 /// # Returns
561 ///
562 /// A boolean tensor with the result of the comparison.
563 fn float_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
564
565 /// Element-wise non-equality comparison with a scalar.
566 ///
567 /// # Arguments
568 ///
569 /// * `lhs` - The left hand side tensor.
570 /// * `rhs` - The right hand side scalar.
571 ///
572 /// # Returns
573 ///
574 /// A boolean tensor with the result of the comparison.
575 fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B> {
576 let equal_tensor = B::float_equal_elem(lhs, rhs);
577 B::bool_not(equal_tensor)
578 }
579
580 /// Greater than comparison of two tensors.
581 ///
582 /// # Arguments
583 ///
584 /// * `lhs` - The left hand side tensor.
585 /// * `rhs` - The right hand side tensor.
586 ///
587 /// # Returns
588 ///
589 /// A boolean tensor with the result of the comparison.
590 fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
591
592 /// Greater than comparison of a tensor and a scalar.
593 ///
594 /// # Arguments
595 ///
596 /// * `lhs` - The left hand side tensor.
597 /// * `rhs` - The right hand side scalar.
598 ///
599 /// # Returns
600 ///
601 /// A boolean tensor with the result of the comparison.
602 fn float_greater_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
603
604 /// Greater than or equal comparison of two tensors.
605 ///
606 /// # Arguments
607 ///
608 /// * `lhs` - The left hand side tensor.
609 /// * `rhs` - The right hand side tensor.
610 ///
611 /// # Returns
612 ///
613 /// A boolean tensor with the result of the comparison.
614 fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
615
616 /// Greater than or equal comparison of a tensor and a scalar.
617 ///
618 /// # Arguments
619 ///
620 /// * `lhs` - The left hand side tensor.
621 /// * `rhs` - The right hand side scalar.
622 ///
623 /// # Returns
624 ///
625 /// A boolean tensor with the result of the comparison.
626 fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
627
628 /// Less than comparison of two tensors.
629 ///
630 /// # Arguments
631 ///
632 /// * `lhs` - The left hand side tensor.
633 /// * `rhs` - The right hand side tensor.
634 ///
635 /// # Returns
636 ///
637 /// A boolean tensor with the result of the comparison.
638 fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
639
640 /// Less than comparison of a tensor and a scalar.
641 ///
642 /// # Arguments
643 ///
644 /// * `lhs` - The left hand side tensor.
645 /// * `rhs` - The right hand side scalar.
646 ///
647 /// # Returns
648 ///
649 /// A boolean tensor with the result of the comparison.
650 fn float_lower_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
651
652 /// Less than or equal comparison of two tensors.
653 ///
654 /// # Arguments
655 ///
656 /// * `lhs` - The left hand side tensor.
657 /// * `rhs` - The right hand side tensor.
658 ///
659 /// # Returns
660 ///
661 /// A boolean tensor with the result of the comparison.
662 fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;
663
664 /// Less than or equal comparison of a tensor and a scalar.
665 ///
666 /// # Arguments
667 ///
668 /// * `lhs` - The left hand side tensor.
669 /// * `rhs` - The right hand side scalar.
670 ///
671 /// # Returns
672 ///
673 /// A boolean tensor with the result of the comparison.
674 fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: FloatElem<B>) -> BoolTensor<B>;
675
676 /// Detaches a tensor from the computation graph.
677 fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {
678 // Should only be overridden by autodiff backends.
679 tensor
680 }
681
682 /// Sets the `require_grad` flag of a tensor.
683 fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {
684 // Should only be overridden by autodiff backends.
685 tensor
686 }
687
688 /// Returns the `require_grad` flag of a tensor.
689 fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {
690 // Should only be overridden by autodiff backends.
691 false
692 }
693
694 /// Sum of all elements in a tensor.
695 ///
696 /// # Arguments
697 ///
698 /// * `tensor` - The tensor to sum.
699 ///
700 /// # Returns
701 ///
702 /// A scalar tensor with the sum of all elements in `tensor`.
703 fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;
704
705 /// Sum of all elements in a tensor along a dimension.
706 ///
707 /// # Arguments
708 ///
709 /// * `tensor` - The tensor to sum.
710 /// * `dim` - The dimension along which to sum.
711 ///
712 /// # Returns
713 ///
714 /// A tensor with the sum of all elements in `tensor` along `dim`.
715 fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
716
717 /// Product of all elements in a tensor.
718 ///
719 /// # Arguments
720 ///
721 /// * `tensor` - The tensor to product.
722 ///
723 /// # Returns
724 ///
725 /// A scalar tensor with the product of all elements in `tensor`.
726 fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {
727 // Product of all elements in a tensor
728 B::float_exp(B::float_sum(B::float_log(tensor)))
729 }
730
731 /// Product of all elements in a tensor along a dimension.
732 ///
733 /// # Arguments
734 ///
735 /// * `tensor` - The tensor to product.
736 ///
737 /// # Returns
738 ///
739 /// A tensor with the product of all elements in `tensor` along `dim`.
740 fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
741 // Product of all elements in a tensor along a dimension
742 B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))
743 }
744
745 /// Mean of all elements in a tensor.
746 ///
747 /// # Arguments
748 ///
749 /// * `tensor` - The tensor to mean.
750 ///
751 /// # Returns
752 ///
753 /// A scalar tensor with the mean of all elements in `tensor`.
754 fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {
755 let num_elems = tensor.shape().num_elements();
756 B::float_div_scalar(B::float_sum(tensor), (num_elems as i64).elem())
757 }
758
759 /// Mean of all elements in a tensor along a dimension.
760 ///
761 /// # Arguments
762 ///
763 /// * `tensor` - The tensor to mean.
764 /// * `dim` - The dimension along which to mean.
765 ///
766 /// # Returns
767 ///
768 /// A tensor with the mean of all elements in `tensor` along `dim`.
769 fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;
770
771 /// Converts a tensor to another floating point data type.
772 ///
773 /// # Arguments
774 ///
775 /// * `tensor` - The tensor to convert.
776 /// * `dtype` - The target data type.
777 ///
778 /// # Returns
779 ///
780 /// A tensor with the same values as `tensor` but in the target floating point data type.
781 fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;
782
783 /// Returns a new tensor with exponential values.
784 ///
785 /// # Arguments
786 ///
787 /// * `tensor` - The tensor to exponentiate.
788 ///
789 /// # Returns
790 ///
791 /// A tensor with the same shape as `tensor` with exponential values.
792 fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;
793
794 /// Returns a new tensor with natural logarithm values.
795 ///
796 /// # Arguments
797 ///
798 /// * `tensor` - The tensor to take the logarithm of.
799 ///
800 /// # Returns
801 ///
802 /// A tensor with the same shape as `tensor` with natural logarithm values.
803 fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;
804
805 /// Returns a new tensor with logarithm values of (1 + Xi).
806 ///
807 /// # Arguments
808 ///
809 /// * `tensor` - The tensor to take the logarithm of.
810 ///
811 /// # Returns
812 ///
813 /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
814 fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;
815
816 /// Element-wise power with a FloatTensor.
817 ///
818 /// # Arguments
819 ///
820 /// * `lhs` - The left hand side tensor.
821 /// * `rhs` - The right hand side tensor.
822 ///
823 /// # Returns
824 ///
825 /// The elements of `lhs` raised to the power of the elements of `rhs`.
826 fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;
827
828 /// Element-wise power with an IntTensor.
829 ///
830 /// # Arguments
831 ///
832 /// * `lhs` - The left hand side tensor.
833 /// * `rhs` - The right hand side floatTensor.
834 ///
835 /// # Returns
836 ///
837 /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.
838 fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {
839 Self::float_powf(lhs, B::int_into_float(rhs))
840 }
841
842 /// raises a tensor to the power of an int scalar.
843 ///
844 /// # Arguments
845 ///
846 /// * `lhs` - The left hand side tensor.
847 /// * `rhs` - The right hand side scalar.
848 ///
849 /// # Returns
850 ///
851 /// The elements of `lhs` raised to the value of `rhs`.
852 fn float_powi_scalar(lhs: FloatTensor<B>, rhs: IntElem<B>) -> FloatTensor<B> {
853 Self::float_powf_scalar(lhs, rhs.to_f32())
854 }
855
856 /// Returns a new tensor with values raised to the power of float `value`.
857 ///
858 /// # Arguments
859 ///
860 /// * `tensor` - The tensor to exponentiate.
861 /// * `value` - The exponent.
862 ///
863 /// # Returns
864 ///
865 /// A tensor with the same shape as `tensor` with values raised to the power of `value`.
866 fn float_powf_scalar(tensor: FloatTensor<B>, value: f32) -> FloatTensor<B>;
867
868 /// Returns a new tensor with square root values.
869 ///
870 /// # Arguments
871 ///
872 /// * `tensor` - The tensor to take the square root of.
873 ///
874 /// # Returns
875 ///
876 /// A tensor with the same shape as `tensor` with square root values.
877 fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;
878
879 /// Returns a new tensor with absolute values.
880 ///
881 /// # Arguments
882 ///
883 /// * `tensor` - The tensor to take absolute value of.
884 ///
885 /// # Returns
886 ///
887 /// A tensor with the same shape as `tensor` with absolute values.
888 fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;
889
890 /// Returns a new tensor with cosine values.
891 ///
892 /// # Arguments
893 ///
894 /// * `tensor` - The tensor to take the cosine of.
895 ///
896 /// # Returns
897 ///
898 /// A tensor with the same shape as `tensor` with cosine values.
899 fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;
900
901 /// Returns a new tensor with sine values.
902 ///
903 /// # Arguments
904 ///
905 /// * `tensor` - The tensor to take the sine of.
906 ///
907 /// # Returns
908 ///
909 /// A tensor with the same shape as `tensor` with sine values.
910 fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;
911
912 /// Returns a new tensor with tangent values.
913 ///
914 /// # Arguments
915 ///
916 /// * `tensor` - The tensor to take the tangent of.
917 ///
918 /// # Returns
919 ///
920 /// A tensor with the same shape as `tensor` with tangent values.
921 fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B> {
922 let sin = B::float_sin(tensor.clone());
923 let cos = B::float_cos(tensor);
924 B::float_div(sin, cos)
925 }
926
927 /// Returns a new tensor with hyperbolic cosine values.
928 ///
929 /// # Arguments
930 ///
931 /// * `tensor` - The tensor to take the hyperbolic cosine of.
932 ///
933 /// # Returns
934 ///
935 /// A tensor with the same shape as `tensor` with hyperbolic cosine values.
936 fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B> {
937 // cosh = ( e^x + e^(-x) ) / 2
938 let e_x = B::float_exp(tensor.clone());
939 let e_neg_x = B::float_exp(B::float_neg(tensor));
940 let num = B::float_add(e_x, e_neg_x); // e^x + e^(-x)
941 B::float_div_scalar(num, 2.0.elem())
942 }
943
944 /// Returns a new tensor with hyperbolic sine values.
945 ///
946 /// # Arguments
947 ///
948 /// * `tensor` - The tensor to take the hyperbolic sine of.
949 ///
950 /// # Returns
951 ///
952 /// A tensor with the same shape as `tensor` with hyperbolic sine values.
953 fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B> {
954 // sinh = ( e^x - e^(-x) ) / 2
955 let e_x = B::float_exp(tensor.clone());
956 let e_neg_x = B::float_exp(B::float_neg(tensor));
957 let num = B::float_sub(e_x, e_neg_x); // e^x - e^(-x)
958 B::float_div_scalar(num, 2.0.elem())
959 }
960
961 /// Returns a new tensor with hyperbolic tangent values.
962 ///
963 /// # Arguments
964 ///
965 /// * `tensor` - The tensor to take the hyperbolic tangent of.
966 ///
967 /// # Returns
968 ///
969 /// A tensor with the same shape as `tensor` with hyperbolic tangent values.
970 fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B> {
971 let sinh = B::float_sinh(tensor.clone());
972 let cosh = B::float_cosh(tensor);
973 B::float_div(sinh, cosh)
974 }
975
976 /// Returns a new tensor with rounded values.
977 ///
978 /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
979 /// strategy, with halfway cases rounded to the nearest even integer value.
980 ///
981 /// # Arguments
982 ///
983 /// * `tensor` - The tensor to be rounded.
984 ///
985 /// # Returns
986 ///
987 /// A tensor with the same shape as `tensor` with rounded values.
988 fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
989
990 /// Returns a new tensor with floored values.
991 ///
992 /// # Arguments
993 ///
994 /// * `tensor` - The tensor to be floored.
995 ///
996 /// # Returns
997 ///
998 /// A tensor with the same shape as `tensor` with floored values.
999 fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
1000
1001 /// Returns a new tensor with ceiled values.
1002 ///
1003 /// # Arguments
1004 ///
1005 /// * `tensor` - The tensor to be ceiled.
1006 ///
1007 /// # Returns
1008 ///
1009 /// A tensor with the same shape as `tensor` with ceiled values.
1010 fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
1011
1012 /// Returns a new tensor with the error function values.
1013 ///
1014 /// # Arguments
1015 ///
1016 /// * `tensor` - The tensor to take the error function of.
1017 ///
1018 /// # Returns
1019 ///
1020 /// A tensor with the same shape as `tensor` with error function values.
1021 fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;
1022
1023 /// Concatenates tensors along a dimension.
1024 ///
1025 /// # Arguments
1026 ///
1027 /// * `tensors` - The tensors to concatenate.
1028 /// * `dim` - The dimension along which to concatenate.
1029 ///
1030 /// # Returns
1031 ///
1032 /// A tensor with the concatenated tensors along `dim`.
1033 fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {
1034 cat_with_slice_assign::<B, Float>(
1035 tensors.into_iter().map(TensorPrimitive::Float).collect(),
1036 dim,
1037 )
1038 .tensor()
1039 }
1040
1041 /// Gets the indices of the maximum elements of a tensor along an axis.
1042 ///
1043 /// # Arguments
1044 ///
1045 /// * `tensor` - The tensor to get the maximum elements of.
1046 /// * `dim` - The dimension along which to get the maximum elements.
1047 ///
1048 /// # Returns
1049 ///
1050 /// A tensor with the indices of the maximum elements of `tensor` along `dim`.
1051 fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1052
1053 /// Gets the indices of the minimum elements of a tensor along an axis.
1054 ///
1055 /// # Arguments
1056 ///
1057 /// * `tensor` - The tensor to get the minimum elements of.
1058 /// * `dim` - The dimension along which to get the minimum elements.
1059 ///
1060 /// # Returns
1061 ///
1062 /// A tensor with the indices of the minimum elements of `tensor` along `dim`.
1063 fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;
1064
1065 /// Gets the maximum element of a tensor.
1066 ///
1067 /// # Arguments
1068 ///
1069 /// * `tensor` - The tensor to get the maximum elements of.
1070 ///
1071 /// # Returns
1072 ///
1073 /// A tensor with the maximum element of `tensor`.
1074 fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {
1075 let shape = tensor.shape();
1076 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1077
1078 B::float_max_dim(tensor, 0)
1079 }
1080
1081 /// Gets the maximum elements of a tensor along an axis.
1082 ///
1083 /// # Arguments
1084 ///
1085 /// * `tensor` - The tensor to get the maximum elements of.
1086 /// * `dim` - The dimension along which to get the maximum elements.
1087 ///
1088 /// # Returns
1089 ///
1090 /// A tensor with the maximum elements of `tensor` along `dim`.
1091 fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1092 let index = B::float_argmax(tensor.clone(), dim);
1093
1094 B::float_gather(dim, tensor, index)
1095 }
1096
1097 /// Gets the maximum elements of a tensor along an axis and their indices.
1098 ///
1099 /// # Arguments
1100 ///
1101 /// * `tensor` - The tensor to get the maximum elements of.
1102 /// * `dim` - The dimension along which to get the maximum elements.
1103 ///
1104 /// # Returns
1105 ///
1106 /// A tuple with the maximum elements of `tensor` along `dim` and their indices.
1107 fn float_max_dim_with_indices(
1108 tensor: FloatTensor<B>,
1109 dim: usize,
1110 ) -> (FloatTensor<B>, IntTensor<B>) {
1111 let index = B::float_argmax(tensor.clone(), dim);
1112 let values = B::float_gather(dim, tensor, index.clone());
1113
1114 (values, index)
1115 }
1116
1117 /// Gets the minimum element of a tensor.
1118 ///
1119 /// # Arguments
1120 ///
1121 /// * `tensor` - The tensor to get the minimum elements of.
1122 ///
1123 /// # Returns
1124 ///
1125 /// A tensor with the minimum element of `tensor`.
1126 fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {
1127 let shape = tensor.shape();
1128 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1129
1130 B::float_min_dim(tensor, 0)
1131 }
1132
1133 /// Gets the minimum elements of a tensor along an axis.
1134 ///
1135 /// # Arguments
1136 ///
1137 /// * `tensor` - The tensor to get the minimum elements of.
1138 /// * `dim` - The dimension along which to get the minimum elements.
1139 ///
1140 /// # Returns
1141 ///
1142 /// A tensor with the minimum elements of `tensor` along `dim`.
1143 fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1144 let index = B::float_argmin(tensor.clone(), dim);
1145
1146 B::float_gather(dim, tensor, index)
1147 }
1148
1149 /// Gets the minimum elements of a tensor along an axis and their indices.
1150 ///
1151 /// # Arguments
1152 ///
1153 /// * `tensor` - The tensor to get the minimum elements of.
1154 /// * `dim` - The dimension along which to get the minimum elements.
1155 ///
1156 /// # Returns
1157 ///
1158 /// A tuple with the minimum elements of `tensor` along `dim` and their indices.
1159 fn float_min_dim_with_indices(
1160 tensor: FloatTensor<B>,
1161 dim: usize,
1162 ) -> (FloatTensor<B>, IntTensor<B>) {
1163 let index = B::float_argmin(tensor.clone(), dim);
1164 let values = B::float_gather(dim, tensor, index.clone());
1165
1166 (values, index)
1167 }
1168
1169 /// Gets the maximum absolute element of a tensor.
1170 ///
1171 /// # Arguments
1172 ///
1173 /// * `tensor` - The tensor to get the maximum elements of.
1174 ///
1175 /// # Returns
1176 ///
1177 /// A tensor with the maximum element of `tensor`.
1178 fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {
1179 let shape = tensor.shape();
1180 let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));
1181
1182 B::float_max_abs_dim(tensor, 0)
1183 }
1184
1185 /// Gets the maximum absolute elements of a tensor along an axis.
1186 ///
1187 /// # Arguments
1188 ///
1189 /// * `tensor` - The tensor to get the maximum elements of.
1190 /// * `dim` - The dimension along which to get the maximum elements.
1191 ///
1192 /// # Returns
1193 ///
1194 /// A tensor with the maximum elements of `tensor` along `dim`.
1195 fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
1196 B::float_max_dim(B::float_abs(tensor), dim)
1197 }
1198
1199 /// Returns a new tensor with the given dimension narrowed to the given range.
1200 ///
1201 /// # Arguments
1202 ///
1203 /// * `dim` - The dimension along which the tensor will be narrowed.
1204 /// * `start` - The starting point of the given range.
1205 /// * `length` - The ending point of the given range.
1206 /// # Panics
1207 ///
1208 /// - If the dimension is greater than the number of dimensions of the tensor.
1209 /// - If the given range exceeds the number of elements on the given dimension.
1210 ///
1211 /// # Returns
1212 ///
1213 /// A new tensor with the given dimension narrowed to the given range.
1214 fn float_narrow(
1215 tensor: FloatTensor<B>,
1216 dim: usize,
1217 start: usize,
1218 length: usize,
1219 ) -> FloatTensor<B> {
1220 narrow::<B, Float>(TensorPrimitive::Float(tensor), dim, start, length).tensor()
1221 }
1222
1223 /// Split the tensor along the given dimension into chunks.
1224 ///
1225 /// # Arguments
1226 ///
1227 /// * `tensor` - The tensor.
1228 /// * `chunks` - The number of chunks to be produced
1229 /// * `times` - The dimension along which the tensor will be split.
1230 ///
1231 /// # Returns
1232 ///
1233 /// A vector of tensors
1234 fn float_chunk(tensor: FloatTensor<B>, chunks: usize, dim: usize) -> Vec<FloatTensor<B>> {
1235 chunk::<B, Float>(TensorPrimitive::Float(tensor), chunks, dim)
1236 .into_iter()
1237 .map(|t| t.tensor())
1238 .collect()
1239 }
1240
1241 /// Split the tensor along the given dimension into chunks of `split_size`.
1242 ///
1243 /// # Arguments
1244 ///
1245 /// * `tensor` - The tensor.
1246 /// * `split_size` - The size of a single chunk.
1247 /// * `times` - The dimension along which the tensor will be split.
1248 ///
1249 /// # Returns
1250 ///
1251 /// A vector of tensors.
1252 fn float_split(tensor: FloatTensor<B>, split_size: usize, dim: usize) -> Vec<FloatTensor<B>> {
1253 split::<B, Float>(TensorPrimitive::Float(tensor), split_size, dim)
1254 .into_iter()
1255 .map(|t| t.tensor())
1256 .collect()
1257 }
1258
1259 /// Split the tensor along the given dimension into chunks with sizes in
1260 /// `dim` according to `split_sizes`.
1261 ///
1262 /// # Arguments
1263 ///
1264 /// * `tensor` - The tensor.
1265 /// * `split_sizes` - Vector of sizes for each chunk.
1266 /// * `times` - The dimension along which the tensor will be split.
1267 ///
1268 /// # Returns
1269 ///
1270 /// A vector of tensors.
1271 fn float_split_with_sizes(
1272 tensor: FloatTensor<B>,
1273 split_sizes: Vec<usize>,
1274 dim: usize,
1275 ) -> Vec<FloatTensor<B>> {
1276 split_with_sizes::<B, Float>(TensorPrimitive::Float(tensor), split_sizes, dim)
1277 .into_iter()
1278 .map(|t| t.tensor())
1279 .collect()
1280 }
1281
1282 /// Tests if any element in the float `tensor` evaluates to True.
1283 ///
1284 /// # Arguments
1285 ///
1286 /// * `tensor` - The tensor to test.
1287 ///
1288 /// # Returns
1289 ///
1290 /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.
1291 fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {
1292 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1293 let bool_tensor = B::bool_not(bool_tensor);
1294 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1295 B::float_greater_elem(sum, 0.0f32.elem())
1296 }
1297
1298 /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.
1299 ///
1300 /// # Arguments
1301 ///
1302 /// * `tensor` - The tensor to test.
1303 /// * `dim` - The axis along which to test.
1304 ///
1305 /// # Returns
1306 ///
1307 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1308 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the
1309 /// input evaluates to True, False otherwise.
1310 fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1311 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1312 let bool_tensor = B::bool_not(bool_tensor);
1313 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1314 B::float_greater_elem(sum, 0.0f32.elem())
1315 }
1316
1317 /// Tests if all elements in the float `tensor` evaluate to True.
1318 ///
1319 /// # Arguments
1320 ///
1321 /// * `tensor` - The tensor to test.
1322 ///
1323 /// # Returns
1324 ///
1325 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1326 /// evaluate to True, False otherwise.
1327 fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {
1328 let num_elems = tensor.shape().num_elements();
1329 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1330 let bool_tensor = B::bool_not(bool_tensor);
1331 let sum = B::float_sum(B::bool_into_float(bool_tensor));
1332 B::float_equal_elem(sum, (num_elems as f32).elem())
1333 }
1334
1335 /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.
1336 ///
1337 /// # Arguments
1338 ///
1339 /// * `tensor` - The tensor to test.
1340 /// * `dim` - The axis along which to test.
1341 ///
1342 /// # Returns
1343 ///
1344 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1345 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1346 /// evaluates to True, False otherwise.
1347 fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {
1348 let num_elems = tensor.shape().dims[dim];
1349 let bool_tensor = B::float_equal_elem(tensor, 0.0f32.elem());
1350 let bool_tensor = B::bool_not(bool_tensor);
1351 let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);
1352 B::float_equal_elem(sum, (num_elems as f32).elem())
1353 }
1354
1355 /// Returns the signs of the float `tensor`.
1356 ///
1357 /// # Arguments
1358 ///
1359 /// * `tensor` - The tensor to extract the signs from.
1360 ///
1361 /// # Returns
1362 ///
1363 /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.
1364 fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {
1365 let zeros = B::float_zeros(tensor.shape(), &B::float_device(&tensor));
1366 let less_than_zero = B::float_lower_elem(tensor.clone(), 0.0f32.elem());
1367 let greater_than_zero = B::float_greater_elem(tensor, 0.0f32.elem());
1368
1369 let mut result = B::float_mask_fill(zeros, less_than_zero, (-1.0f32).elem());
1370 result = B::float_mask_fill(result, greater_than_zero, 1.0f32.elem());
1371 result
1372 }
1373
1374 /// Broadcasts the float `tensor` to the given `shape`.
1375 fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;
1376
1377 /// Sort the elements of the input `tensor` by value in along a given dimension.
1378 ///
1379 /// This sort is unstable (i.e., may reorder equal elements).
1380 ///
1381 /// # Arguments
1382 ///
1383 /// * `tensor` - The input tensor.
1384 /// * `dim` - The axis along which to sort.
1385 /// * `descending` - The sorting order.
1386 ///
1387 /// # Returns
1388 ///
1389 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
1390 fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {
1391 sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()
1392 }
1393
1394 /// Sort the elements of the input `tensor` by value in along a given dimension.
1395 ///
1396 /// This sort is unstable (i.e., may reorder equal elements).
1397 ///
1398 /// # Arguments
1399 ///
1400 /// * `tensor` - The input tensor.
1401 /// * `dim` - The axis along which to sort.
1402 /// * `descending` - The sorting order.
1403 ///
1404 /// # Returns
1405 ///
1406 /// A tensor with the same shape as the input tensor and corresponding indices, where
1407 /// the elements are sorted by value and the indices map back to the original input tensor.
1408 fn float_sort_with_indices(
1409 tensor: FloatTensor<B>,
1410 dim: usize,
1411 descending: bool,
1412 ) -> (FloatTensor<B>, IntTensor<B>) {
1413 let (values, indices) =
1414 sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);
1415 (values.tensor(), indices)
1416 }
1417
1418 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
1419 ///
1420 /// This sort is unstable (i.e., may reorder equal elements).
1421 ///
1422 /// # Arguments
1423 ///
1424 /// * `tensor` - The input tensor.
1425 /// * `dim` - The axis along which to sort.
1426 /// * `descending` - The sorting order.
1427 ///
1428 /// # Returns
1429 ///
1430 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
1431 fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {
1432 argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)
1433 }
1434}