burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2use crate::backend::ExecutionError;
3
4pub use burn_backend::tensor::BasicOps;
5
6use alloc::vec::Vec;
7
8use alloc::format;
9use alloc::string::String;
10use alloc::vec;
11
12use burn_std::stub::RwLock;
13use core::iter::repeat;
14use core::{fmt::Debug, ops::Range};
15use serde::{Deserialize, Deserializer};
16
17use crate::IndexingUpdateOp;
18use crate::{AsIndex, Slice, SliceArg, wrap_index};
19use crate::{
20 Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,
21 backend::Backend, check,
22};
23use crate::{DType, Element};
24use crate::{cast::ToElement, check::TensorCheck};
25use serde::{Serialize, Serializer};
26
27/// A tensor with a given backend, shape and data type.
28///
29/// # Indexing
30/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
31/// or [`select`](Tensor::select) for numeric types.
32///
33/// ## Example
34///
35/// ```rust
36/// use burn_tensor::backend::Backend;
37/// use burn_tensor::Tensor;
38/// use burn_tensor::Int;
39///
40/// fn example<B: Backend>() {
41/// let device = Default::default();
42///
43/// let tensor = Tensor::<B, 2>::from_data(
44/// [
45/// [3.0, 4.9, 2.0],
46/// [2.0, 1.9, 3.0],
47/// [6.0, 1.5, 7.0],
48/// [3.0, 4.9, 9.0],
49/// ],
50/// &device,
51/// );
52///
53/// // Slice the tensor to get the second and third rows:
54/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
55/// // The resulting tensor will have dimensions [2, 3].
56/// let slice = tensor.clone().slice([1..3]);
57/// println!("{slice}");
58///
59/// // Slice the tensor to get the first two rows and the first 2 columns:
60/// // [[3.0, 4.9], [2.0, 1.9]]
61/// // The resulting tensor will have dimensions [2, 2].
62/// let slice = tensor.clone().slice([0..2, 0..2]);
63/// println!("{slice}");
64///
65/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
66/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
67/// // The resulting tensor will have dimensions [4, 2]
68/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
69/// let indexed = tensor.select(1, indices);
70/// println!("{indexed}");
71/// }
72/// ```
73#[derive(new, Clone, Debug)]
74pub struct Tensor<B, const D: usize, K = Float>
75where
76 B: Backend,
77 K: TensorKind<B>,
78{
79 pub(crate) primitive: K::Primitive,
80}
81
82impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
83where
84 B: Backend,
85 K: BasicOps<B>,
86 T: Into<TensorData>,
87{
88 fn from(value: T) -> Self {
89 Tensor::from_data(value.into(), &Default::default())
90 }
91}
92
93impl<B, const D: usize, K> Tensor<B, D, K>
94where
95 B: Backend,
96 K: BasicOps<B>,
97 K::Elem: Element,
98{
99 /// Executes an operation on the tensor and modifies its value.
100 ///
101 /// # Notes
102 ///
103 /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
104 /// no other reference pointing to the same tensor.
105 ///
106 /// Wrapping operations with inplace is not an optimization, it's mainly there if you
107 /// want to mutate a tensor by using owned operations. A plausible usage would be to
108 /// update the weights of a mutable model reference.
109 pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
110 let mut tensor_owned = Tensor::empty([0; D], &self.device());
111 core::mem::swap(&mut tensor_owned, self);
112
113 let mut tensor_new = func(tensor_owned);
114 core::mem::swap(&mut tensor_new, self);
115 }
116
117 /// Converts the tensor into a primitive tensor.
118 pub fn into_primitive(self) -> K::Primitive {
119 self.primitive
120 }
121
122 /// Converts from a primitive tensor into a tensor.
123 pub fn from_primitive(tensor: K::Primitive) -> Self {
124 Self::new(tensor)
125 }
126
127 /// Returns the number of dimensions of the tensor.
128 pub fn rank(&self) -> usize {
129 self.primitive.rank()
130 }
131
132 /// Returns the tensor primitive data type.
133 ///
134 /// # Note
135 /// Some element types are encoded in different primitive types depending on the backend
136 /// (e.g., bool could be encoded as `u8` or `u32`).
137 pub fn dtype(&self) -> DType {
138 self.primitive.dtype()
139 }
140
141 /// Create an empty tensor of the given shape.
142 ///
143 /// # Arguments
144 ///
145 /// - `shape`: The shape of the tensor.
146 /// - `device`: The device where the tensor will be created.
147 ///
148 /// # Example
149 /// ```rust
150 /// use burn_tensor::backend::Backend;
151 /// use burn_tensor::Tensor;
152 ///
153 /// fn example<B: Backend>() {
154 /// let device = Default::default();
155 /// // Create an empty tensor with dimensions [2, 3, 4].
156 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
157 /// }
158 /// ```
159 pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
160 let shape = shape.into();
161 check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
162 Self::new(K::empty(shape, device, K::Elem::dtype()))
163 }
164
165 /// Create a tensor of the given shape where each element is zero.
166 ///
167 /// # Example
168 ///
169 /// ```rust
170 /// use burn_tensor::backend::Backend;
171 /// use burn_tensor::{Tensor, Shape};
172 ///
173 /// fn example<B: Backend>() {
174 /// let device = B::Device::default();
175 /// let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
176 /// println!("{tensor}");
177 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
178 /// }
179 /// ```
180 pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
181 let shape = shape.into();
182 check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
183 Self::new(K::zeros(shape, device, K::Elem::dtype()))
184 }
185
186 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
187 ///
188 /// # Example
189 ///
190 /// ```rust
191 /// use burn_tensor::backend::Backend;
192 /// use burn_tensor::{Tensor, Shape};
193 ///
194 /// fn example<B: Backend>() {
195 /// let device = B::Device::default();
196 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
197 /// let tensor = tensor.zeros_like();
198 /// println!("{tensor}");
199 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
200 /// }
201 /// ```
202 pub fn zeros_like(&self) -> Self {
203 Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
204 }
205
206 /// Create a tensor of the given shape where each element is one.
207 ///
208 /// # Example
209 ///
210 /// ```rust
211 /// use burn_tensor::backend::Backend;
212 /// use burn_tensor::{Tensor, Shape};
213 ///
214 /// fn example<B: Backend>() {
215 /// let device = B::Device::default();
216 /// let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
217 /// println!("{tensor}");
218 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
219 /// }
220 /// ```
221 pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
222 let shape = shape.into();
223 check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
224 Self::new(K::ones(shape, device, K::Elem::dtype()))
225 }
226
227 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
228 ///
229 /// # Example
230 ///
231 /// ```rust
232 /// use burn_tensor::backend::Backend;
233 /// use burn_tensor::{Tensor, Shape};
234 ///
235 /// fn example<B: Backend>() {
236 /// let device = B::Device::default();
237 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
238 /// let tensor = tensor.ones_like();
239 /// println!("{tensor}");
240 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
241 /// }
242 /// ```
243 pub fn ones_like(&self) -> Self {
244 Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
245 }
246
247 /// Create a tensor of the given shape where each element is equal to the provided value.
248 ///
249 /// # Example
250 ///
251 /// ```rust
252 /// use burn_tensor::backend::Backend;
253 /// use burn_tensor::{Tensor, Shape};
254 ///
255 /// fn example<B: Backend>() {
256 /// let device = B::Device::default();
257 /// let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
258 /// println!("{tensor}");
259 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
260 /// }
261 /// ```
262 pub fn full<S: Into<Shape>, E: ElementConversion>(
263 shape: S,
264 fill_value: E,
265 device: &B::Device,
266 ) -> Self {
267 let shape = shape.into();
268 check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
269 Self::new(K::full(shape, fill_value, device, K::Elem::dtype()))
270 }
271
272 /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
273 /// filled with the provided value.
274 ///
275 /// # Example
276 ///
277 /// ```rust
278 /// use burn_tensor::backend::Backend;
279 /// use burn_tensor::{Tensor, Shape};
280 ///
281 /// fn example<B: Backend>() {
282 /// let device = B::Device::default();
283 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
284 /// let tensor = tensor.full_like(5.0);
285 /// println!("{tensor}");
286 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
287 /// }
288 /// ```
289 pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
290 Self::new(K::full(
291 self.shape(),
292 fill_value,
293 &self.device(),
294 self.dtype(),
295 ))
296 }
297
298 /// Returns the dimensions of the current tensor.
299 ///
300 /// # Example
301 /// ```rust
302 /// use burn_tensor::backend::Backend;
303 /// use burn_tensor::Tensor;
304 ///
305 /// fn example<B: Backend>() {
306 /// let device = Default::default();
307 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
308 /// let dims = tensor.dims(); // [2, 3, 4]
309 /// println!("{dims:?}");
310 /// }
311 /// ```
312 pub fn dims(&self) -> [usize; D] {
313 Self::shape(self).dims()
314 }
315
316 /// Returns the shape of the current tensor.
317 ///
318 /// # Example
319 /// ```rust
320 /// use burn_tensor::backend::Backend;
321 /// use burn_tensor::Tensor;
322 ///
323 /// fn example<B: Backend>() {
324 /// let device = Default::default();
325 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
326 /// // Shape { dims: [2, 3, 4] }
327 /// let shape = tensor.shape();
328 /// }
329 /// ```
330 pub fn shape(&self) -> Shape {
331 self.primitive.shape()
332 }
333
334 /// Reshape the tensor to have the given shape.
335 ///
336 /// The tensor has the same data and number of elements as the input.
337 ///
338 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
339 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
340 ///
341 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
342 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
343 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
344 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
345 /// with [1, 3, 4] dimensions to [1, 12].
346 ///
347 /// # Arguments
348 /// - `shape`: The new shape of the tensor.
349 ///
350 /// # Panics
351 /// - If the tensor contains more than one `-1` in the shape.
352 /// - If the tensor contains values that are not positive (other than -1).
353 /// - If the shape does not match the number of elements of the original shape.
354 ///
355 /// # Example
356 ///
357 /// ```rust
358 /// use burn_tensor::backend::Backend;
359 /// use burn_tensor::Tensor;
360 ///
361 /// fn example<B: Backend>() {
362 /// let device = Default::default();
363 /// // Create a tensor with dimensions [2, 3, 4]
364 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
365 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
366 /// let reshaped = tensor.reshape([2, -1]);
367 /// println!("{reshaped}");
368 /// }
369 /// ```
370 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
371 // Convert reshape args to shape
372 let shape = shape.into_shape(&self);
373 Tensor::new(K::reshape(self.primitive, shape))
374 }
375
376 /// Transpose the tensor.
377 ///
378 /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
379 /// applied on the last two dimensions. For example, the transpose of a tensor with shape
380 /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
381 ///
382 /// See also [`permute`](Tensor::permute).
383 ///
384 /// # Arguments
385 ///
386 /// * `tensor` - The tensor to transpose.
387 ///
388 /// # Returns
389 ///
390 /// The transposed tensor.
391 ///
392 /// # Example
393 ///
394 /// ```rust
395 /// use burn_tensor::backend::Backend;
396 /// use burn_tensor::Tensor;
397 ///
398 /// fn example<B: Backend>() {
399 /// let device = Default::default();
400 /// // Create a 2D tensor of shape [2, 3]
401 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
402 ///
403 /// // Transpose the tensor:
404 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
405 /// // The resulting tensor will have dimensions [3, 2].
406 /// let transposed = tensor.transpose();
407 /// println!("{transposed}");
408 /// }
409 /// ```
410 pub fn transpose(self) -> Tensor<B, D, K> {
411 Tensor::new(K::transpose(self.primitive))
412 }
413
414 /// Alias for `transpose`.
415 #[inline(always)]
416 pub fn t(self) -> Tensor<B, D, K> {
417 self.transpose()
418 }
419
420 /// Swaps two dimensions of a tensor.
421 ///
422 /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
423 ///
424 /// # Arguments
425 ///
426 /// * `tensor` - The tensor to swap the dimensions of.
427 /// * `dim1` - The first dimension to swap, supports negative indexing.
428 /// * `dim2` - The second dimension to swap, supports negative indexing.
429 ///
430 /// # Returns
431 ///
432 /// The tensor with the dimensions swapped.
433 ///
434 /// # Panics
435 ///
436 /// When dimensions are out of bounds.
437 ///
438 /// # Example
439 ///
440 /// ```rust
441 /// use burn_tensor::backend::Backend;
442 /// use burn_tensor::Tensor;
443 ///
444 /// fn example<B: Backend>() {
445 /// let device = Default::default();
446 /// // Create a 2D tensor of shape [2, 3]
447 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
448 ///
449 /// // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
450 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
451 /// // The resulting tensor will have dimensions [3, 2].
452 /// let swapped = tensor.swap_dims(0, -1);
453 /// println!("{swapped}");
454 /// }
455 /// ```
456 pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
457 where
458 Dim1: AsIndex,
459 Dim2: AsIndex,
460 {
461 let dim1 = dim1.expect_dim_index(D);
462 let dim2 = dim2.expect_dim_index(D);
463 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
464 if dim1 == dim2 {
465 self
466 } else {
467 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
468 }
469 }
470
471 /// Permute the dimensions of the tensor.
472 ///
473 /// This is a no-op when the resolved `axes` match the current order.
474 ///
475 /// # Arguments
476 ///
477 /// * `axes` - The new order of the dimensions. The length of the axes
478 /// must be equal to the number of dimensions of the tensor.
479 /// The values must be unique and in the range of the number of dimensions.
480 /// The values can be negative, in which case they are used as an offset from the end.
481 ///
482 /// # Returns
483 ///
484 /// The tensor with the dimensions permuted.
485 ///
486 /// # Example
487 ///
488 /// ```rust
489 /// use burn_tensor::backend::Backend;
490 /// use burn_tensor::Tensor;
491 ///
492 /// fn example<B: Backend>() {
493 /// let device = Default::default();
494 /// // Create a 2D tensor of shape [3, 2]
495 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
496 ///
497 /// // Permute the dimensions 1 and 0:
498 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
499 /// // The resulting tensor will have dimensions [3, 2].
500 /// let permuted = tensor.permute([1, 0]);
501 /// println!("{permuted}");
502 /// }
503 /// ```
504 pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
505 where
506 Dim: AsIndex,
507 {
508 let mut no_op = true;
509 let mut fixed_axes = [0; D];
510 for (i, axis) in axes.into_iter().enumerate() {
511 let dim = axis.expect_dim_index(D);
512 no_op &= dim == i;
513 fixed_axes[i] = dim;
514 }
515
516 if no_op {
517 self
518 } else {
519 check!(TensorCheck::permute(fixed_axes));
520 Tensor::new(K::permute(self.primitive, &fixed_axes))
521 }
522 }
523
524 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
525 ///
526 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
527 /// at the positions not specified in destination.
528 ///
529 /// # Arguments
530 ///
531 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
532 /// The values can be negative, in which case they are used as an offset from the end.
533 ///
534 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
535 ///
536 /// # Panics
537 ///
538 /// - If the source and destination dimensions are not of the same length.
539 /// - If the source and destination vectors contain duplicate values.
540 /// - If the source and destination vectors contain values that are out of bounds.
541 ///
542 /// # Returns
543 ///
544 /// The tensor with the dimensions moved.
545 ///
546 /// # Example
547 ///
548 /// ```rust
549 /// use burn_tensor::backend::Backend;
550 /// use burn_tensor::Tensor;
551 ///
552 /// fn example<B: Backend>() {
553 /// let device = Default::default();
554 /// // Create a 3D tensor of shape [3, 2, 1]
555 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
556 ///
557 /// // Move the dimensions 0 and 1:
558 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
559 /// // The resulting tensor will have dimensions [2, 3, 1].
560 /// let moved = tensor.movedim(1, 0);
561 /// println!("{moved}");
562 /// }
563 /// ```
564 ///
565 /// # Note
566 ///
567 /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
568 /// for it
569 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
570 let source_dims = src.into_dim_vec::<D>();
571 let destination_dims = dst.into_dim_vec::<D>();
572
573 check!(TensorCheck::movedim_args_length(
574 &source_dims,
575 &destination_dims
576 ));
577
578 let mut m = [-1; D];
579 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
580 m[d] = s as isize;
581 }
582 let mut axes: [isize; D] = [0; D];
583 let mut source_i = 0;
584 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
585 *item = if m[dest_i] != -1 {
586 m[dest_i]
587 } else {
588 while source_dims.contains(&source_i) {
589 source_i += 1;
590 }
591 let result = source_i as isize;
592 source_i += 1;
593 result
594 };
595 }
596
597 self.permute(axes)
598 }
599
600 /// Reverse the order of elements in the tensor along the given dimensions.
601 ///
602 /// # Arguments
603 ///
604 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
605 /// The values can be negative, in which case they are used as an offset from the end.
606 ///
607 /// # Returns
608 ///
609 /// The tensor with the axes flipped.
610 ///
611 /// # Example
612 ///
613 /// ```rust
614 /// use burn_tensor::backend::Backend;
615 /// use burn_tensor::Tensor;
616 ///
617 /// fn example<B: Backend>() {
618 /// let device = Default::default();
619 /// // Create a 2D tensor with dimensions [4, 3]
620 /// let tensor = Tensor::<B, 2>::from_data(
621 /// [
622 /// [3.0, 4.9, 2.0],
623 /// [2.0, 1.9, 3.0],
624 /// [4.0, 5.9, 8.0],
625 /// [1.4, 5.8, 6.0],
626 /// ],
627 /// &device,
628 /// );
629 ///
630 /// // Flip the elements in dimensions 0 and 1:
631 /// // [[6.0, 5.8, 1.4],
632 /// // [8.0, 5.9, 4.0],
633 /// // [3.0, 1.9, 2.0],
634 /// // [2.0, 4.9, 3.0]]
635 /// // The resulting tensor will have dimensions [4, 3].
636 /// let flipped = tensor.flip([0, 1]);
637 /// println!("{flipped}");
638 /// }
639 /// ```
640 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
641 // Convert the axes to usize and handle negative values without using vector
642 let mut transformed_axes: [usize; N] = [0; N];
643 for (i, &x) in axes.iter().enumerate() {
644 transformed_axes[i] = if x < 0 {
645 (D as isize + x) as usize
646 } else {
647 x as usize
648 };
649 }
650
651 // Check if the axes are valid
652 check!(TensorCheck::flip(D, &transformed_axes));
653
654 Tensor::new(K::flip(self.primitive, &transformed_axes))
655 }
656
657 /// Flatten the tensor along a given range of dimensions.
658 ///
659 /// This function collapses the specified range of dimensions into a single dimension,
660 /// effectively flattening the tensor in that range.
661 ///
662 /// # Arguments
663 ///
664 /// - `start_dim`: The starting dimension of the range to be flattened,
665 /// supports negative indexing.
666 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
667 /// supports negative indexing.
668 ///
669 /// # Type Parameters
670 ///
671 /// - `D2`: The resulting number of dimensions in the flattened tensor.
672 ///
673 /// # Returns
674 ///
675 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
676 ///
677 /// # Example
678 ///
679 /// ```rust
680 ///
681 /// use burn_tensor::backend::Backend;
682 /// use burn_tensor::{Tensor, Shape};
683 ///
684 /// fn example<B: Backend>() {
685 /// let device = Default::default();
686 /// // Create a 3D tensor with dimensions [2, 3, 4]
687 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
688 ///
689 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
690 /// // The resulting tensor will have dimensions [2, 12]
691 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
692 /// println!("{flattened}");
693 /// }
694 /// ```
695 pub fn flatten<const D2: usize>(
696 self,
697 start_dim: impl AsIndex,
698 end_dim: impl AsIndex,
699 ) -> Tensor<B, D2, K> {
700 let start_dim = start_dim.expect_dim_index(D);
701 let end_dim = end_dim.expect_dim_index(D);
702 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
703 let new_shape = self.shape().flatten_dims(start_dim, end_dim);
704
705 Tensor::new(K::reshape(self.primitive, new_shape))
706 }
707
708 /// Squeeze the tensor along all dimensions, removing dimensions
709 /// of size one, and effectively reducing the rank of the tensor.
710 ///
711 /// # Type Parameters
712 ///
713 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
714 ///
715 /// # Returns
716 ///
717 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
718 ///
719 /// # Example
720 ///
721 /// ```rust
722 ///
723 /// use burn_tensor::backend::Backend;
724 /// use burn_tensor::{Tensor, Shape};
725 ///
726 /// fn example<B: Backend>() {
727 /// let device = Default::default();
728 /// // Create a 4D tensor with dimensions [1, 3, 1, 3]
729 /// let tensor = Tensor::<B, 4>::from_data(
730 /// [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
731 /// &device,
732 /// );
733 ///
734 /// // Squeeze the tensor dimensions.
735 /// // The resulting tensor will have dimensions [3, 3].
736 /// let squeezed = tensor.squeeze::<2>();
737 /// println!("{squeezed}");
738 /// }
739 /// ```
740 pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
741 let new_dims = self
742 .shape()
743 .dims
744 .iter()
745 .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
746 .collect::<Vec<_>>();
747 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
748
749 Tensor::new(K::reshape(self.primitive, new_dims.into()))
750 }
751
752 /// Squeeze the tensor along the given dimension, removing the specified dimension
753 /// of size one, and effectively reducing the rank of the tensor by one.
754 ///
755 /// # Arguments
756 ///
757 /// - `dim`: The dimension to be squeezed.
758 ///
759 /// # Type Parameters
760 ///
761 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
762 ///
763 /// # Panics
764 ///
765 /// If the size in the squeezed dimension is not 1.
766 ///
767 /// # Returns
768 ///
769 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
770 ///
771 /// # Example
772 ///
773 /// ```rust
774 ///
775 /// use burn_tensor::backend::Backend;
776 /// use burn_tensor::{Tensor, Shape};
777 ///
778 /// fn example<B: Backend>() {
779 /// let device = Default::default();
780 /// // Create a 3D tensor with dimensions [3, 1, 3]
781 /// let tensor = Tensor::<B, 3>::from_data(
782 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
783 /// &device,
784 /// );
785 ///
786 /// // Squeeze the dimension 1.
787 /// // The resulting tensor will have dimensions [3, 3].
788 /// let squeezed = tensor.squeeze_dim::<2>(1);
789 /// println!("{squeezed}");
790 /// }
791 /// ```
792 pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
793 check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
794
795 let current_dims = self.shape().dims;
796 let mut new_dims: [usize; D2] = [0; D2];
797
798 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
799 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
800
801 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
802 Tensor::new(K::reshape(self.primitive, new_dims.into()))
803 }
804
805 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
806 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
807 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
808 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
809 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
810 /// from the back.
811 ///
812 /// # Arguments
813 ///
814 /// - `dims`: The dimension(s) to be squeezed.
815 ///
816 /// # Type Parameters
817 ///
818 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
819 ///
820 /// # Returns
821 ///
822 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
823 ///
824 /// # Example
825 ///
826 /// ```rust
827 ///
828 /// use burn_tensor::backend::Backend;
829 /// use burn_tensor::{Tensor, Shape};
830 ///
831 /// fn example<B: Backend>() {
832 /// let device = Default::default();
833 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
834 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
835 ///
836 /// // Squeeze the dimensions 1 and 3.
837 /// // The resulting tensor will have dimensions [2, 4].
838 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
839 /// println!("{squeezed}");
840 /// }
841 /// ```
842 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
843 let current_dims = self.shape().dims;
844 let mut dim_indices: Vec<usize>;
845
846 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
847 if dims.is_empty() {
848 dim_indices = current_dims
849 .iter()
850 .enumerate()
851 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
852 .collect();
853 } else {
854 // If negative dims, count from the back
855 dim_indices = dims
856 .iter()
857 .map(|&d| {
858 if d < 0 {
859 (current_dims.len() as isize + d) as usize
860 } else {
861 d as usize
862 }
863 })
864 .collect();
865 }
866
867 // Sort indices and remove duplicates
868 dim_indices.sort_unstable();
869 dim_indices.dedup();
870
871 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
872 check!(TensorCheck::squeeze_dims_input::<D2>(
873 &dim_indices,
874 ¤t_dims
875 ));
876
877 // Calculate new dimensions
878 let mut new_dims = Vec::new();
879 for (index, &dim_size) in current_dims.iter().enumerate() {
880 // Exclude the dimension if it's explicitly marked for squeezing
881 if dim_indices.contains(&index) {
882 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
883 continue;
884 }
885 new_dims.push(dim_size);
886 }
887
888 // Check that after squeezing, we still respect the D2 size
889 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
890
891 Tensor::new(K::reshape(self.primitive, new_dims.into()))
892 }
893
894 /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
895 ///
896 /// # Type Parameters
897 ///
898 /// - `D2`: The resulting number of dimensions in the unsqueezed tensor.
899 ///
900 /// # Panics
901 ///
902 /// If the output size `D2` is smaller than the current number of dimensions.
903 ///
904 /// # Returns
905 ///
906 /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
907 ///
908 /// # Example
909 ///
910 /// ```rust
911 /// use burn_tensor::backend::Backend;
912 /// use burn_tensor::{Tensor, Shape};
913 ///
914 /// fn example<B: Backend>() {
915 /// let device = Default::default();
916 /// // Create a 2D tensor with dimensions [3, 3]
917 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
918 /// // Unsqueeze the tensor up to 4 dimensions.
919 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
920 /// let unsqueezed = tensor.unsqueeze::<4>();
921 /// println!("{unsqueezed}");
922 /// }
923 /// ```
924 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
925 check!(TensorCheck::unsqueeze::<D, D2>());
926
927 let mut dims = [1; D2];
928 let num_ones = D2 - D;
929 let shape = self.shape();
930
931 dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
932
933 let shape = Shape::new(dims);
934 self.reshape(shape)
935 }
936
937 /// Creates a new tensor with a dimension of size one inserted at the specified position.
938 ///
939 /// # Example
940 ///
941 /// ```rust
942 /// use burn_tensor::backend::Backend;
943 /// use burn_tensor::{Tensor, Shape};
944 ///
945 /// fn example<B: Backend>() {
946 /// let device = Default::default();
947 /// // Create a 2D tensor with dimensions [3, 3]
948 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
949 /// // Unsqueeze the dimension 1.
950 /// // The resulting tensor will have dimensions [3, 1, 3].
951 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
952 /// println!("{unsqueezed}");
953 /// }
954 /// ```
955 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
956 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
957
958 let mut dims = [1; D2];
959 let shape = self.shape();
960
961 dims[0..dim].copy_from_slice(&shape[0..dim]);
962
963 if dim < D {
964 dims[dim] = 1;
965 dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
966 } else {
967 dims[dim] = 1;
968 }
969
970 let shape = Shape::new(dims);
971 self.reshape(shape)
972 }
973
974 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
975 /// The indices can be negative, in which case they are counted from the last to the first dimension.
976 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
977 /// is the number of duplicates.
978 /// # Example
979 ///
980 /// ```rust
981 /// use burn_tensor::backend::Backend;
982 /// use burn_tensor::{Tensor, Shape};
983 ///
984 /// fn example<B: Backend>() {
985 /// let device = Default::default();
986 /// // Create a 3D tensor with dimensions [3, 4, 5]
987 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
988 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
989 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
990 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
991 /// println!("{unsqueezed}");
992 /// }
993 /// ```
994 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
995 let mut new_dims = [1; D2];
996 let old_dims = self.shape().dims;
997 //for checking if the dimension is in the acceptable range
998
999 //part 1: convert the negative indices to positive
1000 let mut neg_offset = D2;
1001 let mut dim_indices = axes
1002 .iter()
1003 .map(|d| {
1004 // check if the dimension is in the acceptable range
1005 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
1006 (if *d < 0 {
1007 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
1008 d + neg_offset as isize + 1
1009 } else {
1010 *d
1011 }) as usize
1012 })
1013 .collect::<Vec<usize>>();
1014
1015 //sort the indices
1016 dim_indices.sort_unstable();
1017
1018 //Now use this to copy the chunks of the dims
1019 let mut prev_idx: usize = 0;
1020 let mut current_left_b: usize = 0;
1021 let mut current_right_b: usize = 0;
1022 let mut offset: usize = 0;
1023 dim_indices.iter().for_each(|d| {
1024 //check if there is space for at least one dimension
1025 if prev_idx < *d {
1026 current_right_b = *d - offset;
1027 //copy the chunks of the dims
1028 if current_right_b < D {
1029 new_dims[prev_idx..*d]
1030 .copy_from_slice(&old_dims[current_left_b..current_right_b]);
1031 } else {
1032 new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
1033 }
1034 prev_idx = *d + 1;
1035 //offset is equal to the number of extracted elements from the original shape
1036 offset += current_right_b - current_left_b;
1037 current_left_b = current_right_b;
1038 } else {
1039 //it's sorted so the only reason this would happen
1040 //is if multiple indices are the same
1041 prev_idx += 1;
1042 }
1043 });
1044 //copy over anything past the index of the last new dimension
1045 if current_left_b < D {
1046 new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
1047 }
1048
1049 //lastly, create the shape and reshape
1050 let shape = Shape::new(new_dims);
1051 self.reshape(shape)
1052 }
1053
1054 /// Roll operation along a specific dimension; wrapping around the elements.
1055 ///
1056 /// ## Parameters
1057 ///
1058 /// - `shift`: The roll extent; supports negative values and wraps around.
1059 /// - `dim`: The dimension to roll; supports negative indexing.
1060 ///
1061 /// ## Returns
1062 ///
1063 /// A new tensor with the specified dimension rolled by the given shift amount.
1064 pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
1065 where
1066 Shift: AsIndex,
1067 Dim: AsIndex,
1068 {
1069 let dim = dim.expect_dim_index(D);
1070 let size = self.shape().dims[dim];
1071 if size == 0 {
1072 // If the dimension is empty, return the tensor as is.
1073 return self;
1074 }
1075
1076 let shift = wrap_index(shift, size);
1077 if shift == 0 {
1078 // If the shift is zero, return the tensor as is.
1079 return self;
1080 }
1081
1082 self.unchecked_roll_dim(shift, dim)
1083 }
1084
1085 /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1086 ///
1087 /// ## Parameters
1088 ///
1089 /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1090 /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1091 ///
1092 /// ## Returns
1093 ///
1094 /// A new tensor with the specified dimension rolled by the given shift amount.
1095 #[inline(always)]
1096 fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1097 #[cfg(debug_assertions)]
1098 {
1099 let size = self.shape().dims[dim];
1100 assert!(
1101 0 < shift && shift < size,
1102 "Expected: 0 < shift < size: found shift={shift}, size={size}",
1103 );
1104 assert!(
1105 dim < self.shape().num_dims(),
1106 "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1107 );
1108 }
1109
1110 Tensor::cat(
1111 vec![
1112 self.clone().slice_dim(dim, shift..),
1113 self.slice_dim(dim, ..shift),
1114 ],
1115 dim,
1116 )
1117 }
1118
1119 /// Roll operation.
1120 ///
1121 /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1122 ///
1123 /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1124 ///
1125 /// ## Parameters
1126 ///
1127 /// - `shifts`: A slice of shifts corresponding to each dimension;
1128 /// supports negative values and wraps around.
1129 /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1130 ///
1131 /// ## Returns
1132 ///
1133 /// A new tensor with the specified dimensions rolled by the given shifts.
1134 pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1135 where
1136 Shift: AsIndex,
1137 Dim: AsIndex,
1138 {
1139 assert_eq!(
1140 dims.len(),
1141 shifts.len(),
1142 "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1143 );
1144
1145 // This is a fair amount of complexity, which could be replaced
1146 // by a simple canonicalization of `dims` and wrapping of `shifts`.
1147 // The work is done here to ensure that any roll operation
1148 // which could be a no-op is a no-op; simplifying the accounting
1149 // needed by backend-specific implementations of the inner roll op.
1150
1151 let item_count = dims.len();
1152
1153 let shape = self.shape().dims;
1154
1155 // Accumulate the effective shifts for each dimension.
1156 let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1157 for i in 0..item_count {
1158 let dim = dims[i].expect_dim_index(D);
1159 accumulated_shifts[dim] += shifts[i].index();
1160 }
1161
1162 // Do this after we've checked the validity of `dims` and `shifts`.
1163 if self.shape().num_elements() == 0 {
1164 // If the tensor is empty, return it as is.
1165 return self;
1166 }
1167
1168 // Wrap the accumulated shifts, and filter out empty dimensions.
1169 let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1170 let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1171 for dim in 0..shape.len() {
1172 // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1173 let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1174 if shift == 0 {
1175 continue;
1176 }
1177
1178 effective_dims.push(dim);
1179 effective_shifts.push(shift);
1180 }
1181
1182 // If no shifts are needed, return the original tensor.
1183 if effective_shifts.is_empty() {
1184 return self;
1185 }
1186
1187 // At this point:
1188 // - `dims` contains the effective dimensions to roll, in index order,
1189 // - `shifts` contains the effective usize shifts for each dimension.
1190 // - Every shift is non-zero, and less than the size of the corresponding dimension.
1191 self.unchecked_roll(&effective_shifts, &effective_dims)
1192 }
1193
1194 /// `roll` internal implementation.
1195 ///
1196 /// ## Parameters
1197 ///
1198 /// - `shifts`: A slice of shifts corresponding to each dimension;
1199 /// must be non-empty, the same length as `dims`, and all ``1..<size>``.
1200 /// - `dims`: A slice of dimensions to roll; must be non-empty;
1201 /// the same length as `shifts`, and must not contain repeats.
1202 ///
1203 /// ## Panics
1204 ///
1205 /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1206 ///
1207 /// ## Returns
1208 ///
1209 /// A new tensor with the specified dimensions rolled by the given shifts.
1210 #[inline(always)]
1211 fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1212 #[cfg(debug_assertions)]
1213 {
1214 assert!(!shifts.is_empty());
1215 assert_eq!(
1216 shifts.len(),
1217 dims.len(),
1218 "Shifts and dimensions must align; found {} shifts and {} dims",
1219 shifts.len(),
1220 dims.len()
1221 );
1222
1223 let mut unique_dims = dims.to_vec();
1224 unique_dims.dedup();
1225
1226 assert_eq!(
1227 unique_dims.len(),
1228 dims.len(),
1229 "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1230 unique_dims.len(),
1231 dims.len()
1232 )
1233 }
1234
1235 let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1236
1237 if dims.len() == 1 {
1238 x
1239 } else {
1240 x.unchecked_roll(&shifts[1..], &dims[1..])
1241 }
1242 }
1243
1244 /// Returns a tensor containing the elements selected from the given slices.
1245 ///
1246 /// This method provides flexible tensor slicing with support for various range types,
1247 /// negative indices, and stepped slicing. The method accepts both single slices and
1248 /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1249 ///
1250 /// # Arguments
1251 ///
1252 /// * `slices` - Can be:
1253 /// - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1254 /// - An array of ranges (e.g., `[0..2, 1..4]`)
1255 /// - The [`s!`] macro output for advanced slicing with steps
1256 /// - a `&Vec<Slice>` or `&[Slice]`
1257 ///
1258 /// # Behavior
1259 ///
1260 /// - Supports partial and full slicing in any number of dimensions
1261 /// - Handles negative indices by wrapping from the end (-1 is the last element)
1262 /// - Automatically clamps ranges that exceed tensor dimensions
1263 /// - Supports stepped slicing for selecting every nth element
1264 /// - Negative steps reverse the selection order
1265 ///
1266 /// # Panics
1267 ///
1268 /// - If the number of slices exceeds the tensor's dimensions
1269 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1270 /// - If a step is zero
1271 ///
1272 /// # Examples
1273 ///
1274 /// ```rust
1275 /// use burn_tensor::backend::Backend;
1276 /// use burn_tensor::{Tensor, Shape, s};
1277 ///
1278 /// fn example<B: Backend>() {
1279 /// let device = B::Device::default();
1280 ///
1281 /// // Single dimension slicing - no brackets needed!
1282 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1283 /// let slice = tensor.clone().slice(2..8); // Simple range
1284 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1285 ///
1286 /// // Using s! macro for single dimension with step
1287 /// let slice = tensor.clone().slice(s![0..10;2]); // Every 2nd element
1288 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1289 ///
1290 /// // Reverse a dimension with negative step
1291 /// let slice = tensor.slice(s![..;-1]); // Reverse entire tensor
1292 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1293 ///
1294 /// // Multi-dimensional slicing
1295 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1296 ///
1297 /// // Array syntax for simple ranges
1298 /// let slice = tensor.clone().slice([1..3, 2..5]);
1299 /// assert_eq!(slice.dims(), [2, 3]);
1300 ///
1301 /// // Advanced multi-dimensional with s! macro
1302 /// let slice = tensor.clone().slice(s![0..4;2, ..;-1]); // Every 2nd row, reverse columns
1303 /// assert_eq!(slice.dims(), [2, 6]);
1304 ///
1305 /// // Complex 3D example with mixed slice types
1306 /// let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1307 /// let slice = tensor.slice(s![1..3, ..;2, -3..]); // Rows 1-2, every 2nd col, last 3 depth
1308 /// assert_eq!(slice.dims(), [2, 3, 3]);
1309 ///
1310 /// // Using negative indices
1311 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1312 /// let slice = tensor.slice(s![-2.., ..-1]); // Last 2 rows, all but last column
1313 /// assert_eq!(slice.dims(), [2, 5]);
1314 /// }
1315 /// ```
1316 ///
1317 /// # See Also
1318 ///
1319 /// - [`s!`] - The recommended macro for creating complex slice specifications
1320 /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1321 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1322 /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1323 ///
1324 /// [`s!`]: crate::s!
1325 pub fn slice<S>(self, slices: S) -> Self
1326 where
1327 S: SliceArg,
1328 {
1329 let shape = self.shape();
1330 let slices = slices.into_slices(&shape);
1331
1332 // Validate slices
1333 check!(TensorCheck::slice::<D>(&shape, &slices));
1334
1335 // Calculate output shape and check for empty slices
1336 let mut output_dims = shape.dims.clone();
1337 for (dim, slice) in slices.iter().enumerate() {
1338 output_dims[dim] = slice.output_size(shape.dims[dim]);
1339 }
1340
1341 // Return empty tensor if any dimension is 0 (empty slice)
1342 if output_dims.contains(&0) {
1343 return Self::empty(output_dims, &self.device());
1344 }
1345 Self::new(K::slice(self.primitive, &slices))
1346 }
1347
1348 /// Assigns values to a slice of the tensor and returns the updated tensor.
1349 ///
1350 /// This method supports advanced slicing with steps, including negative steps for reverse
1351 /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1352 /// providing powerful syntax for complex patterns.
1353 ///
1354 /// # Arguments
1355 ///
1356 /// * `slices` - Slice specification (same format as `slice` method)
1357 /// * `values` - Tensor with values to assign (must match slice dimensions)
1358 ///
1359 /// # Panics
1360 ///
1361 /// - If slices exceed tensor dimensions
1362 /// - If values dimensions don't match the selected slice shape
1363 /// - If a step is zero
1364 ///
1365 /// # Examples
1366 ///
1367 /// ```rust
1368 /// use burn_tensor::backend::Backend;
1369 /// use burn_tensor::{Tensor, s};
1370 ///
1371 /// fn example<B: Backend>() {
1372 /// let device = B::Device::default();
1373 ///
1374 /// // Simple assignment to a sub-region
1375 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1376 /// let values = Tensor::<B, 2>::ones([2, 3], &device);
1377 /// tensor = tensor.slice_assign([1..3, 2..5], values);
1378 /// // Now tensor[1..3, 2..5] contains ones
1379 ///
1380 /// // Single dimension assignment with step
1381 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1382 /// let values = Tensor::<B, 1>::ones([5], &device);
1383 /// tensor = tensor.slice_assign(s![0..10;2], values);
1384 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1385 ///
1386 /// // Reverse assignment with negative step
1387 /// let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1388 /// let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1389 /// tensor = tensor.slice_assign(s![..;-1], values);
1390 /// // Assigns in reverse: [14, 13, 12, 11, 10]
1391 ///
1392 /// // Complex multi-dimensional assignment
1393 /// let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1394 /// let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1395 /// tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1396 /// // Assigns to every 2nd row, every 2nd column, last 3 in depth
1397 ///
1398 /// // Mixed syntax example
1399 /// let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1400 /// let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1401 /// tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1402 /// // Creates a checkerboard pattern with ones
1403 /// }
1404 /// ```
1405 ///
1406 /// # See Also
1407 ///
1408 /// - [`s!`] - The recommended macro for creating complex slice specifications
1409 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1410 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1411 ///
1412 /// [`s!`]: crate::s!
1413 pub fn slice_assign<S>(self, slices: S, values: Self) -> Self
1414 where
1415 S: SliceArg,
1416 {
1417 let shape = self.shape();
1418 let slices = slices.into_slices(&shape);
1419
1420 // Check if any slice produces 0 elements (empty assignment).
1421 // Empty assignments are no-ops and would cause issues in backend implementations.
1422 let is_empty_assignment = slices
1423 .iter()
1424 .enumerate()
1425 .any(|(i, slice)| slice.output_size(shape.dims[i]) == 0);
1426
1427 if is_empty_assignment {
1428 return self;
1429 }
1430
1431 check!(TensorCheck::slice_assign::<D>(
1432 &shape,
1433 &values.shape(),
1434 &slices
1435 ));
1436
1437 Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1438 }
1439
1440 /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1441 ///
1442 /// Like other slice methods, accepts both single slices and arrays. However, this method
1443 /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1444 /// with a constant tensor for stepped patterns.
1445 ///
1446 /// # Arguments
1447 ///
1448 /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1449 /// * `value` - The value to fill the slice with
1450 ///
1451 /// # Panics
1452 ///
1453 /// - If slices exceed tensor dimensions
1454 /// - If any slice has a step != 1 (not yet supported)
1455 ///
1456 /// # Examples
1457 ///
1458 /// ```rust
1459 /// use burn_tensor::backend::Backend;
1460 /// use burn_tensor::{Tensor, s};
1461 ///
1462 /// fn example<B: Backend>() {
1463 /// let device = B::Device::default();
1464 ///
1465 /// // Simple fill for a single dimension
1466 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1467 /// tensor = tensor.slice_fill(2..5, 1.0);
1468 /// // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1469 ///
1470 /// // Multi-dimensional fill
1471 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1472 /// tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1473 /// // Fills the rectangle at rows 1-2, columns 2-4 with -1
1474 ///
1475 /// // Using negative indices
1476 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1477 /// tensor = tensor.slice_fill(-3.., 2.0);
1478 /// // Fills the last 3 elements with 2.0
1479 ///
1480 /// // Complex multi-dimensional example
1481 /// let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1482 /// tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1483 /// // Sets rows 1-2, all columns, last 2 in depth to 0
1484 ///
1485 /// // Stepped slicing is supported
1486 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1487 /// tensor = tensor.slice_fill(s![0..10;2], 1.0);
1488 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1489 /// }
1490 /// ```
1491 ///
1492 /// # See Also
1493 ///
1494 /// - [`s!`] - The macro for creating slice specifications with steps
1495 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1496 /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1497 ///
1498 /// [`s!`]: crate::s!
1499 pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self
1500 where
1501 S: SliceArg,
1502 {
1503 let shape = self.shape();
1504 let slices = slices.into_slices(&shape);
1505
1506 check!(TensorCheck::slice::<D>(&shape, &slices));
1507
1508 let slice_shape = shape.slice(&slices).unwrap();
1509 let value = Tensor::<B, 1, K>::from_data_dtype(
1510 [value.elem::<K::Elem>()],
1511 &self.device(),
1512 self.dtype(),
1513 );
1514 let value = value.expand(slice_shape);
1515 self.slice_assign(&slices, value)
1516 }
1517
1518 /// Returns a new tensor with the specified dimension sliced.
1519 ///
1520 /// # Arguments
1521 ///
1522 /// * `dim`: The dimension to slice.
1523 /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1524 /// slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1525 ///
1526 /// # Returns
1527 ///
1528 /// A new tensor with the specified dimension sliced.
1529 ///
1530 /// # Panics
1531 ///
1532 /// If the slice is out of bounds for the specified dimension.
1533 ///
1534 /// # Examples
1535 ///
1536 /// ```rust
1537 /// # use burn_tensor::{Tensor, s};
1538 /// # use burn_tensor::backend::Backend;
1539 /// #
1540 /// # fn example<B: Backend>() {
1541 /// # let device = B::Device::default();
1542 /// let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1543 ///
1544 /// // Simple range slicing
1545 /// let sliced = tensor.clone().slice_dim(1, 1..3);
1546 /// assert_eq!(sliced.shape().dims, [3, 2, 5]);
1547 ///
1548 /// // Slicing with step - take every 2nd element
1549 /// let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1550 /// assert_eq!(sliced.shape().dims, [3, 4, 3]); // Takes indices 0, 2, 4
1551 ///
1552 /// // Reverse slicing with negative step
1553 /// let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1554 /// assert_eq!(sliced.shape().dims, [3, 4, 5]); // Reverses dimension 1
1555 ///
1556 /// // Select from index 2 with step 3
1557 /// let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1558 /// assert_eq!(sliced.shape().dims, [1, 4, 5]); // Takes only index 2
1559 ///
1560 /// // Select single index (reduces dimension to size 1)
1561 /// let sliced = tensor.slice_dim(0, 1);
1562 /// assert_eq!(sliced.shape().dims, [1, 4, 5]);
1563 /// # }
1564 /// ```
1565 ///
1566 /// # See Also
1567 ///
1568 /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1569 /// - [`s!`] - The macro for creating complex slice specifications
1570 ///
1571 /// [`s!`]: crate::s!
1572 pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1573 where
1574 S: Into<Slice>,
1575 {
1576 check!(TensorCheck::check_dim::<D>(dim));
1577 let slice: Slice = slice.into();
1578
1579 let mut slices = vec![Slice::full(); D];
1580 slices[dim] = slice;
1581
1582 self.slice(&slices)
1583 }
1584
1585 /// Returns the device of the current tensor.
1586 pub fn device(&self) -> B::Device {
1587 K::device(&self.primitive)
1588 }
1589
1590 /// Move the tensor to the given device.
1591 pub fn to_device(self, device: &B::Device) -> Self {
1592 Self::new(K::to_device(self.primitive, device))
1593 }
1594
1595 /// Select tensor elements along the given dimension corresponding to the given indices.
1596 ///
1597 /// # Arguments
1598 ///
1599 /// * `dim` - The dimension to select from. Supports negative indexing.
1600 /// * `indices` - The indices of the elements to select.
1601 ///
1602 /// # Example
1603 ///
1604 /// ```rust
1605 /// use burn_tensor::backend::Backend;
1606 /// use burn_tensor::{Tensor, Int};
1607 ///
1608 /// fn example<B: Backend>() {
1609 /// let device = B::Device::default();
1610 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1611 /// let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1612 /// let tensor = tensor.select(0, indices);
1613 /// println!("{tensor}");
1614 /// // [[1.0, -2.0, 3.0]]
1615 /// }
1616 /// ```
1617 pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1618 let dim = dim.expect_dim_index(D);
1619 check!(TensorCheck::select::<D>(dim));
1620 Self::new(K::select(self.primitive, dim, indices.primitive))
1621 }
1622
1623 /// Assign the selected elements along the given dimension corresponding to the given indices
1624 /// from the value tensor to the original tensor using sum reduction.
1625 ///
1626 /// # Note
1627 /// For booleans, the sum operator is logical or.
1628 ///
1629 /// # Arguments
1630 ///
1631 /// * `dim` - The dimension along which to select. Supports negative indexing.
1632 /// * `indices` - The indices to select from the tensor.
1633 /// * `values` - The values to assign to the selected indices.
1634 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1635 ///
1636 /// # Example
1637 ///
1638 /// Example using a 3D tensor:
1639 ///
1640 /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1641 /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1642 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1643 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1644 ///
1645 /// # Warning
1646 ///
1647 /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1648 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1649 pub fn select_assign(
1650 self,
1651 dim: impl AsIndex,
1652 indices: Tensor<B, 1, Int>,
1653 values: Tensor<B, D, K>,
1654 update: IndexingUpdateOp,
1655 ) -> Self {
1656 let dim = dim.expect_dim_index(D);
1657 check!(TensorCheck::select_assign::<D>(
1658 dim,
1659 &indices.shape(),
1660 &values.shape()
1661 ));
1662
1663 Self::new(K::select_assign(
1664 self.primitive,
1665 dim,
1666 indices.primitive,
1667 values.primitive,
1668 update,
1669 ))
1670 }
1671
1672 /// Update the given tensor with the value tensor where the mask is true.
1673 ///
1674 /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1675 /// a scalar.
1676 ///
1677 /// # Example
1678 ///
1679 /// ```rust
1680 /// use burn_tensor::backend::Backend;
1681 /// use burn_tensor::{Tensor, Shape, Bool};
1682 ///
1683 /// fn example<B: Backend>() {
1684 /// let device = B::Device::default();
1685 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1686 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1687 /// let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1688 /// let tensor = tensor.mask_where(mask, value);
1689 /// println!("{tensor}");
1690 /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1691 /// }
1692 /// ```
1693 pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1694 Self::new(K::mask_where(
1695 self.primitive,
1696 mask.primitive,
1697 value.primitive,
1698 ))
1699 }
1700
1701 /// Update the given tensor with the value where the mask is true.
1702 ///
1703 /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1704 /// a tensor.
1705 ///
1706 /// # Example
1707 ///
1708 /// ```rust
1709 /// use burn_tensor::backend::Backend;
1710 /// use burn_tensor::{Tensor, Shape, Bool};
1711 ///
1712 /// fn example<B: Backend>() {
1713 /// let device = B::Device::default();
1714 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1715 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1716 /// let tensor = tensor.mask_fill(mask, 3.0);
1717 /// println!("{tensor}");
1718 /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1719 /// }
1720 /// ```
1721 pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1722 Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
1723 }
1724
1725 /// Gather tensor elements corresponding to the given indices from the specified dim.
1726 ///
1727 /// Example using a 3D tensor:
1728 ///
1729 /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1730 /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1731 /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1732 ///
1733 /// # Notes
1734 ///
1735 /// The index tensor should have the same shape as the original tensor except for the dim
1736 /// specified.
1737 ///
1738 /// # Warning
1739 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1740 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1741 pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1742 check!(TensorCheck::gather::<D>(
1743 dim,
1744 &self.shape(),
1745 &indices.shape()
1746 ));
1747
1748 Self::new(K::gather(dim, self.primitive, indices.primitive))
1749 }
1750
1751 /// Assign the gathered elements corresponding to the given indices along the specified dimension
1752 /// from the value tensor to the original tensor using sum reduction.
1753 ///
1754 /// Example using a 3D tensor:
1755 ///
1756 /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1757 /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1758 /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1759 ///
1760 /// # Arguments
1761 /// * `dim` - The axis along which to scatter elements.
1762 /// * `indices` - The indices of the elements to scatter.
1763 /// * `values` - The values to scatter into the tensor.
1764 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1765 ///
1766 /// # Notes
1767 ///
1768 /// The index tensor should have the same shape as the original tensor except for the specified
1769 /// dimension. The value and index tensors should have the same shape.
1770 ///
1771 /// Other references to the input tensor will not be modified by this operation.
1772 ///
1773 /// # Warning
1774 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1775 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1776 pub fn scatter(
1777 self,
1778 dim: usize,
1779 indices: Tensor<B, D, Int>,
1780 values: Self,
1781 update: IndexingUpdateOp,
1782 ) -> Self {
1783 check!(TensorCheck::scatter::<D>(
1784 dim,
1785 &self.shape(),
1786 &indices.shape(),
1787 &values.shape()
1788 ));
1789
1790 Self::new(K::scatter(
1791 dim,
1792 self.primitive,
1793 indices.primitive,
1794 values.primitive,
1795 update,
1796 ))
1797 }
1798
1799 /// Converts the data of the current tensor.
1800 ///
1801 /// # Note
1802 ///
1803 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1804 /// tensors at once. This may improve laziness, especially if executed on a different
1805 /// thread in native environments.
1806 pub fn into_data(self) -> TensorData {
1807 self.try_into_data().expect(
1808 "Error while reading data: use `try_into_data` instead to catch the error at runtime",
1809 )
1810 }
1811
1812 /// Converts the data of the current tensor and returns any error that might have occurred since the
1813 /// last time the device was synchronized.
1814 ///
1815 /// # Note
1816 ///
1817 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1818 /// tensors at once. This may improve laziness, especially if executed on a different
1819 /// thread in native environments.
1820 pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {
1821 crate::try_read_sync(self.into_data_async()).expect(
1822 "Failed to read tensor data synchronously.
1823 This can happen on platforms that don't support blocking futures like WASM.
1824 If possible, try using into_data_async instead.",
1825 )
1826 }
1827
1828 /// Converts the data of the current tensor.
1829 ///
1830 /// # Note
1831 ///
1832 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1833 /// tensors at once. This may improve laziness, especially if executed on a different
1834 /// thread in native environments.
1835 pub fn to_data(&self) -> TensorData {
1836 self.clone().into_data()
1837 }
1838
1839 /// Returns the data of the current tensor.
1840 pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {
1841 K::into_data_async(self.primitive).await
1842 }
1843
1844 /// Returns the data of the current tensor.
1845 pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {
1846 self.clone().into_data_async().await
1847 }
1848
1849 /// Create a tensor from the given data on the given device.
1850 pub fn from_data<T>(data: T, device: &B::Device) -> Self
1851 where
1852 T: Into<TensorData>,
1853 {
1854 let data = data.into();
1855 check!(TensorCheck::creation_ops::<D>(
1856 "From Data",
1857 data.shape.as_slice()
1858 ));
1859 Self::new(K::from_data(data, device))
1860 }
1861
1862 /// Create a tensor from the given data on the given device enforcing the given data type.
1863 pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
1864 where
1865 T: Into<TensorData>,
1866 {
1867 let data = data.into();
1868 check!(TensorCheck::creation_ops::<D>(
1869 "From Data",
1870 data.shape.as_slice()
1871 ));
1872 Self::new(K::from_data_dtype(data, device, dtype))
1873 }
1874
1875 /// Repeat the tensor along the given dimension.
1876 ///
1877 /// The output tensor has the same shape, except along the given dimension.
1878 ///
1879 /// # Arguments
1880 /// - `dim`: The dimension to repeat.
1881 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1882 ///
1883 /// # Returns
1884 ///
1885 /// A new tensor with the given dimension repeated `times` times.
1886 ///
1887 /// # Example
1888 ///
1889 /// ```rust
1890 /// use burn_tensor::backend::Backend;
1891 /// use burn_tensor::Tensor;
1892 ///
1893 /// fn example<B: Backend>() {
1894 /// let device = Default::default();
1895 /// // Create a 2D tensor with dimensions [3, 2]
1896 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1897 ///
1898 /// // Repeat the tensor along the dimension 0 twice.
1899 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1900 /// // The resulting tensor will have dimensions [6, 2].
1901 /// let repeated = tensor.repeat_dim(0, 2);
1902 /// println!("{repeated}");
1903 /// }
1904 /// ```
1905 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1906 if times > 0 {
1907 Self::new(K::repeat_dim(self.primitive, dim, times))
1908 } else {
1909 let shape = self.shape().repeat(dim, times).unwrap();
1910 Self::empty(shape, &self.device())
1911 }
1912 }
1913
1914 /// Repeat the tensor along the given dimensions.
1915 /// # Arguments
1916 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1917 ///
1918 /// # Returns
1919 ///
1920 /// A new tensor with the given dimensions repeated `times` times.
1921 ///
1922 /// # Panics
1923 ///
1924 /// If `sizes` contains more elements than the number of dimensions.
1925 ///
1926 /// # Example
1927 ///
1928 /// ```rust
1929 ///
1930 /// use burn_tensor::backend::Backend;
1931 /// use burn_tensor::Tensor;
1932 ///
1933 /// fn example<B: Backend>() {
1934 /// let device = Default::default();
1935 /// // Create a 2D tensor with dimensions [3, 2]
1936 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1937 ///
1938 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1939 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1940 /// // The resulting tensor will have dimensions [6, 2].
1941 /// let repeated = tensor.repeat(&[2, 1]);
1942 /// }
1943 /// ```
1944 pub fn repeat(self, sizes: &[usize]) -> Self {
1945 if sizes.contains(&0) {
1946 let mut shape = self.shape();
1947 for (dim, ×) in sizes.iter().enumerate() {
1948 shape = shape.repeat(dim, times).unwrap();
1949 }
1950
1951 return Self::empty(shape, &self.device());
1952 }
1953
1954 let mut tensor = self;
1955 for (dim, ×) in sizes.iter().enumerate() {
1956 if times > 1 {
1957 tensor = tensor.repeat_dim(dim, times);
1958 }
1959 }
1960 tensor
1961 }
1962
1963 /// Applies element-wise equal comparison.
1964 ///
1965 /// # Returns
1966 /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1967 ///
1968 /// # Panics
1969 ///
1970 /// If the two tensors don't have the same shape.
1971 ///
1972 /// # Example
1973 ///
1974 /// ```rust
1975 /// use burn_tensor::backend::Backend;
1976 /// use burn_tensor::Tensor;
1977 ///
1978 /// fn example<B: Backend>() {
1979 /// let device = Default::default();
1980 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1981 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1982 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
1983 /// // [[false, true], [true, true], [true, true]]
1984 /// let equal = t1.equal(t2);
1985 /// println!("{equal}");
1986 /// }
1987 /// ```
1988 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1989 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1990 Tensor::new(K::equal(self.primitive, other.primitive))
1991 }
1992
1993 /// Applies element-wise non-equality comparison.
1994 ///
1995 /// # Returns
1996 /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1997 ///
1998 /// # Panics
1999 ///
2000 /// If the two tensors don't have the same shape.
2001 ///
2002 /// # Example
2003 ///
2004 /// ```rust
2005 /// use burn_tensor::backend::Backend;
2006 /// use burn_tensor::Tensor;
2007 ///
2008 /// fn example<B: Backend>() {
2009 /// let device = Default::default();
2010 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2011 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2012 /// // Compare the elements of the two 2D tensors for inequality.
2013 /// // [[true, false], [false, false], [false, false]]
2014 /// let not_equal = t1.not_equal(t2);
2015 /// println!("{not_equal}");
2016 /// }
2017 /// ```
2018 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
2019 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
2020 Tensor::new(K::not_equal(self.primitive, other.primitive))
2021 }
2022
2023 /// Applies element wise equal comparison and returns a boolean tensor.
2024 ///
2025 /// # Arguments
2026 ///
2027 /// * `other` - The element to compare.
2028 ///
2029 /// # Example
2030 ///
2031 /// ```rust
2032 /// use burn_tensor::backend::Backend;
2033 /// use burn_tensor::{Tensor, Shape};
2034 ///
2035 /// fn example<B: Backend>() {
2036 /// let device = B::Device::default();
2037 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2038 /// let tensor = tensor.equal_elem(3.0);
2039 /// println!("{tensor}");
2040 /// // [[false, false, true], [false, false, false]]
2041 /// }
2042 /// ```
2043 pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2044 Tensor::new(K::equal_elem(self.primitive, other.elem()))
2045 }
2046
2047 /// Applies element wise non-equality comparison and returns a boolean tensor.
2048 ///
2049 /// # Arguments
2050 ///
2051 /// * `other` - The element to compare.
2052 ///
2053 /// # Example
2054 ///
2055 /// ```rust
2056 /// use burn_tensor::backend::Backend;
2057 /// use burn_tensor::{Tensor, Shape};
2058 ///
2059 /// fn example<B: Backend>() {
2060 /// let device = B::Device::default();
2061 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2062 /// let tensor = tensor.not_equal_elem(3.0);
2063 /// println!("{tensor}");
2064 /// // [[true, true, false], [true, true, true]]
2065 /// }
2066 /// ```
2067 pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2068 Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
2069 }
2070
2071 /// Concatenates all tensors into a new one along the given dimension.
2072 ///
2073 /// # Panics
2074 ///
2075 /// - If `dim` is higher than the rank.
2076 /// - If `tensors` is an empty vector.
2077 /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
2078 ///
2079 /// # Example
2080 ///
2081 /// ```rust
2082 /// use burn_tensor::backend::Backend;
2083 /// use burn_tensor::Tensor;
2084 ///
2085 /// fn example<B: Backend>() {
2086 /// let device = Default::default();
2087 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
2088 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2089 ///
2090 /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
2091 /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
2092 /// // The resulting tensor will have shape [2, 7].
2093 /// let concat = Tensor::cat(vec![t1, t2], 1);
2094 /// println!("{concat}");
2095 /// }
2096 /// ```
2097 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
2098 check!(TensorCheck::cat(&tensors, dim));
2099
2100 // Filter out tensors with size 0 along the concatenation dimension.
2101 // Empty tensors don't contribute to the output and would cause issues
2102 // in backend implementations (e.g., division by zero in slice_assign).
2103 // Safety: TensorCheck::cat ensures tensors is non-empty
2104 let first_tensor = tensors.first().unwrap();
2105 let device = first_tensor.device();
2106 let mut shape = first_tensor.shape();
2107
2108 let non_empty_primitives: Vec<_> = tensors
2109 .into_iter()
2110 .filter(|t| t.shape().dims[dim] > 0)
2111 .map(|t| t.primitive)
2112 .collect();
2113
2114 // If all tensors were empty, return an empty tensor with size 0 on concat dim
2115 if non_empty_primitives.is_empty() {
2116 shape.dims[dim] = 0;
2117 return Self::empty(shape, &device);
2118 }
2119
2120 Self::new(K::cat(non_empty_primitives, dim))
2121 }
2122
2123 /// Concatenates all tensors into a new one along a new dimension.
2124 ///
2125 /// # Panics
2126 ///
2127 /// - If all tensors don't have the same shape.
2128 /// - If given dimension is not with range of 0..D2
2129 ///
2130 /// # Example
2131 ///
2132 /// ```rust
2133 /// use burn_tensor::backend::Backend;
2134 /// use burn_tensor::Tensor;
2135 ///
2136 /// fn example<B: Backend>() {
2137 /// let device = Default::default();
2138 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2139 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2140 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2141 ///
2142 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
2143 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
2144 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
2145 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
2146 /// // The resulting tensor will have shape [3, 2, 3].
2147 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
2148 /// println!("{stacked}");
2149 /// }
2150 /// ```
2151 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
2152 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
2153 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
2154 Tensor::<B, D2, K>::cat(tensors, dim)
2155 }
2156
2157 /// Iterate over slices of tensors alongside a given dimension.
2158 ///
2159 /// # Panics
2160 ///
2161 /// If given dimension is greater than or equal to tensor rank.
2162 ///
2163 /// # Returns
2164 ///
2165 /// A tensor iterator.
2166 ///
2167 /// # Example
2168 ///
2169 /// ```rust
2170 /// use burn_tensor::backend::Backend;
2171 /// use burn_tensor::Tensor;
2172 /// fn example<B: Backend>() {
2173 /// let device = Default::default();
2174 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2175 /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
2176 /// let iter = tensor.iter_dim(0);
2177 /// for (i,tensor) in iter.enumerate() {
2178 /// println!("Tensor {}: {}", i, tensor);
2179 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
2180 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
2181 /// }
2182 /// }
2183 /// ```
2184 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
2185 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
2186 DimIter::new(self, dim)
2187 }
2188
2189 /// Returns a new tensor with the given dimension narrowed to the given range.
2190 ///
2191 /// # Panics
2192 ///
2193 /// - If the dimension is greater than the number of dimensions of the tensor.
2194 /// - If the given range exceeds the number of elements on the given dimension.
2195 ///
2196 /// # Returns
2197 ///
2198 /// A new tensor with the given dimension narrowed to the given range.
2199 ///
2200 /// # Example
2201 ///
2202 /// ```rust
2203 /// use burn_tensor::backend::Backend;
2204 /// use burn_tensor::Tensor;
2205 ///
2206 /// fn example<B: Backend>() {
2207 /// let device = Default::default();
2208 /// // Create a 2D tensor with dimensions [4, 3]
2209 /// let tensor = Tensor::<B, 2>::from_data(
2210 /// [
2211 /// [3.0, 4.9, 2.0],
2212 /// [2.0, 1.9, 3.0],
2213 /// [6.0, 1.5, 7.0],
2214 /// [3.0, 4.9, 9.0],
2215 /// ],
2216 /// &device,
2217 /// );
2218 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
2219 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
2220 /// // The resulting tensor will have dimensions [3, 3].
2221 /// let narrowed = tensor.narrow(0, 1, 3);
2222 /// println!("{narrowed}");
2223 /// }
2224 /// ```
2225 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
2226 check!(TensorCheck::dim_ops::<D>("narrow", dim));
2227 check!(TensorCheck::narrow(&self, dim, start, length));
2228 let dims = self.dims();
2229
2230 let ranges: [Range<usize>; D] = dims
2231 .iter()
2232 .enumerate()
2233 .map(|(i, d)| {
2234 if i == dim {
2235 start..(start + length)
2236 } else {
2237 0..*d
2238 }
2239 })
2240 .collect::<Vec<_>>()
2241 .try_into()
2242 .unwrap();
2243
2244 Self::slice(self, ranges)
2245 }
2246
2247 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
2248 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2249 ///
2250 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2251 /// Otherwise all chunks will be of equal size except for the last one.
2252 ///
2253 /// # Panics
2254 ///
2255 /// If the dimension is greater than the number of dimensions of the tensor.
2256 ///
2257 /// # Returns
2258 /// A vector of tensors.
2259 ///
2260 /// # Example
2261 ///
2262 /// ```rust
2263 /// use burn_tensor::backend::Backend;
2264 /// use burn_tensor::Tensor;
2265 ///
2266 /// fn example<B: Backend>() {
2267 /// let device = Default::default();
2268 /// // Create a 2D tensor with dimensions [4, 3]
2269 /// let tensor = Tensor::<B, 2>::from_data(
2270 /// [
2271 /// [3.0, 4.9, 2.0],
2272 /// [2.0, 1.9, 3.0],
2273 /// [6.0, 1.5, 7.0],
2274 /// [3.0, 4.9, 9.0],
2275 /// ],
2276 /// &device,
2277 /// );
2278 /// // Split the tensor along the dimension 1 into 2 chunks.
2279 /// // The first chuck will have shape [4, 2]:
2280 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
2281 /// // The second chunk will have shape [4, 1]:
2282 /// // [[2.0], [3.0], [7.0], [9.0]]
2283 /// let chunks = tensor.chunk(2, 1);
2284 /// println!("{chunks:?}");
2285 /// }
2286 /// ```
2287 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
2288 check!(TensorCheck::dim_ops::<D>("chunk", dim));
2289 let size = self.shape().dims[dim];
2290 if size < chunks {
2291 return (0..size)
2292 .map(|i| Self::narrow(self.clone(), dim, i, 1))
2293 .collect();
2294 }
2295
2296 let mut tensors = Vec::with_capacity(chunks);
2297 let mut sum_chunk_size = 0;
2298 if size.is_multiple_of(chunks) {
2299 let chunk_size = size / chunks;
2300 for _ in 0..chunks {
2301 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2302 sum_chunk_size += chunk_size;
2303 }
2304 } else {
2305 let chunk_size = (size / chunks) + 1; // assumes not divisible
2306 for _ in 0..chunks - 1 {
2307 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2308 sum_chunk_size += chunk_size;
2309 }
2310 let remainder = size % chunk_size;
2311 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
2312 }
2313
2314 tensors
2315 }
2316
2317 /// Splits the tensor into chunks of a specified size along a given dimension.
2318 /// Each chunk is a view of the original tensor.
2319 ///
2320 /// If the tensor size along the given dimension is not divisible by `split_size`,
2321 /// then the last chunk will be smaller.
2322 ///
2323 /// # Panics
2324 ///
2325 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2326 ///
2327 /// # Returns
2328 ///
2329 /// A vector of tensors.
2330 ///
2331 /// # Example
2332 /// ```rust
2333 /// use burn_tensor::backend::Backend;
2334 /// use burn_tensor::Tensor;
2335 ///
2336 /// fn example<B: Backend>() {
2337 /// let device = Default::default();
2338 /// // Create a 1D tensor with 5 elements
2339 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2340 /// // Split the tensor into chunks of size 2 along dimension 0
2341 /// let chunks = tensor.split(2, 0);
2342 /// // The result is a vector of tensors:
2343 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2344 /// println!("{:?}", chunks);
2345 /// }
2346 /// ```
2347 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2348 check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2349 let size = self.shape().dims[dim];
2350 let mut tensors = Vec::new();
2351
2352 let mut start = 0;
2353 while start < size {
2354 let length = usize::min(split_size, size - start);
2355 tensors.push(Self::narrow(self.clone(), dim, start, length));
2356 start += length;
2357 }
2358
2359 tensors
2360 }
2361
2362 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2363 /// Each chunk is a view of the original tensor.
2364 ///
2365 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2366 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2367 ///
2368 /// # Panics
2369 ///
2370 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2371 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2372 ///
2373 /// # Returns
2374 ///
2375 /// A vector of tensors.
2376 ///
2377 /// # Example
2378 /// ```rust
2379 /// use burn_tensor::backend::Backend;
2380 /// use burn_tensor::Tensor;
2381 ///
2382 /// fn example<B: Backend>() {
2383 /// let device = Default::default();
2384 /// // Create a 1D tensor with 5 elements
2385 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2386 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
2387 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2388 /// // The result is a vector of tensors:
2389 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2390 /// println!("{:?}", chunks);
2391 /// }
2392 /// ```
2393 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2394 check!(TensorCheck::split_with_sizes::<D>(
2395 &self.shape(),
2396 &split_sizes,
2397 dim
2398 ));
2399 let mut tensors = Vec::new();
2400
2401 let mut start = 0;
2402 for length in split_sizes {
2403 if length == 0 {
2404 continue;
2405 }
2406 tensors.push(Self::narrow(self.clone(), dim, start, length));
2407 start += length;
2408 }
2409
2410 tensors
2411 }
2412
2413 /// Tests if any element in the `tensor` evaluates to True.
2414 ///
2415 /// # Arguments
2416 ///
2417 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2418 ///
2419 /// # Returns
2420 ///
2421 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2422 /// evaluates to True, False otherwise.
2423 ///
2424 /// # Example
2425 ///
2426 /// ```rust
2427 /// use burn_tensor::backend::Backend;
2428 /// use burn_tensor::{Tensor, Bool};
2429 ///
2430 /// fn example<B: Backend>() {
2431 /// let device = Default::default();
2432 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2433 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2434 ///
2435 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2436 /// let any_tensor = tensor.any();
2437 /// println!("{}", any_tensor);
2438 /// // Tensor { data: [true], ... }
2439 ///
2440 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2441 /// let any_tensor_two = tensor_two.any();
2442 /// println!("{}", any_tensor_two);
2443 /// // Tensor { data: [false], ... }
2444 /// }
2445 /// ```
2446 pub fn any(self) -> Tensor<B, 1, Bool> {
2447 Tensor::new(K::any(self.primitive))
2448 }
2449
2450 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2451 ///
2452 /// # Arguments
2453 ///
2454 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2455 /// * `dim` - The axis along which to test.
2456 ///
2457 /// # Returns
2458 ///
2459 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2460 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2461 /// evaluates to True, False otherwise.
2462 ///
2463 /// # Example
2464 ///
2465 /// ```rust
2466 /// use burn_tensor::backend::Backend;
2467 /// use burn_tensor::{Tensor, Bool};
2468 ///
2469 /// fn example<B: Backend>() {
2470 /// let device = Default::default();
2471 /// let tensor =
2472 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2473 /// // Check if any element in the tensor evaluates to True along the dimension 1.
2474 /// // [[true], [true]],
2475 /// let any_dim = tensor.clone().any_dim(1);
2476 /// println!("{any_dim}");
2477 /// }
2478 /// ```
2479 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2480 Tensor::new(K::any_dim(self.primitive, dim))
2481 }
2482
2483 /// Tests if all elements in the `tensor` evaluate to True.
2484 ///
2485 /// # Arguments
2486 ///
2487 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2488 ///
2489 /// # Returns
2490 ///
2491 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2492 /// evaluate to True, False otherwise.
2493 ///
2494 /// # Example
2495 ///
2496 /// ```rust
2497 /// use burn_tensor::backend::Backend;
2498 /// use burn_tensor::{Tensor, Bool};
2499 ///
2500 /// fn example<B: Backend>() {
2501 /// let device = Default::default();
2502 /// let tensor =
2503 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2504 /// // Check if all elements in the tensor evaluate to True (which is not the case).
2505 /// // [false]
2506 /// let all = tensor.all();
2507 /// println!("{all}");
2508 /// }
2509 /// ```
2510 pub fn all(self) -> Tensor<B, 1, Bool> {
2511 Tensor::new(K::all(self.primitive))
2512 }
2513
2514 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2515 ///
2516 /// # Arguments
2517 ///
2518 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2519 /// * `dim` - The axis along which to test.
2520 ///
2521 /// # Returns
2522 ///
2523 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2524 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2525 /// evaluates to True, False otherwise.
2526 ///
2527 /// # Example
2528 ///
2529 /// ```rust
2530 /// use burn_tensor::backend::Backend;
2531 /// use burn_tensor::{Tensor, Bool};
2532 ///
2533 /// fn example<B: Backend>() {
2534 /// let device = Default::default();
2535 /// let tensor =
2536 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2537 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
2538 /// // [[true, true, false]]
2539 /// let all_dim = tensor.clone().all_dim(0);
2540 /// println!("{all_dim}");
2541 /// }
2542 /// ```
2543 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2544 Tensor::new(K::all_dim(self.primitive, dim))
2545 }
2546
2547 /// Convert the tensor into a scalar.
2548 ///
2549 /// # Panics
2550 ///
2551 /// - If the tensor doesn't have one element.
2552 /// - If the backend fails to read the tensor data synchronously.
2553 ///
2554 /// # Returns
2555 ///
2556 /// The scalar value of the tensor.
2557 ///
2558 /// # Example
2559 ///
2560 /// ```rust
2561 /// use burn_tensor::backend::Backend;
2562 /// use burn_tensor::Tensor;
2563 ///
2564 /// fn example<B: Backend>() {
2565 /// let device = Default::default();
2566 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2567 /// // Convert the tensor with a single element into a scalar.
2568 /// let scalar = tensor.into_scalar();
2569 /// println!("{scalar}");
2570 /// }
2571 /// ```
2572 pub fn into_scalar(self) -> K::Elem {
2573 crate::try_read_sync(self.into_scalar_async())
2574 .expect(
2575 "Failed to read tensor data synchronously. This can happen on platforms
2576 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2577 )
2578 .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime")
2579 }
2580
2581 /// Convert the tensor into a scalar and returns any error that might have occurred since the
2582 /// last time the device was synchronized.
2583 ///
2584 /// # Panics
2585 ///
2586 /// - If the tensor doesn't have one element.
2587 /// - If the backend fails to read the tensor data synchronously.
2588 ///
2589 /// # Returns
2590 ///
2591 /// The scalar value of the tensor.
2592 pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {
2593 crate::try_read_sync(self.into_scalar_async()).expect(
2594 "Failed to read tensor data synchronously. This can happen on platforms
2595 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2596 )
2597 }
2598
2599 /// Convert the tensor into a scalar.
2600 ///
2601 /// # Panics
2602 ///
2603 /// If the tensor doesn't have one element.
2604 pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {
2605 check!(TensorCheck::into_scalar::<D>(&self.shape()));
2606
2607 Ok(self.into_data_async().await?.iter().next().unwrap())
2608 }
2609
2610 /// Broadcast the tensor to the given shape.
2611 ///
2612 /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2613 /// (which can be inferred with `-1`).
2614 ///
2615 /// # Arguments
2616 ///
2617 /// * `shape` - The shape to broadcast the tensor to.
2618 /// Can contain -1 for dimensions that should be inferred.
2619 /// The number of elements in the shape must be greater or equal as
2620 /// the number of dimensions of the tensor.
2621 ///
2622 /// # Panics
2623 ///
2624 /// If the tensor cannot be broadcasted to the given shape.
2625 ///
2626 /// # Returns
2627 ///
2628 /// A new tensor with the given shape.
2629 ///
2630 /// # Example
2631 ///
2632 /// ```rust
2633 /// use burn_tensor::backend::Backend;
2634 /// use burn_tensor::Tensor;
2635 ///
2636 /// fn example<B: Backend>() {
2637 /// let device = Default::default();
2638 /// // Create a 2D tensor with dimensions [3, 1]
2639 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2640 /// // Expand the tensor to a new shape [3, 4]
2641 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2642 /// let expanded = tensor.expand([3, 4]);
2643 /// println!("{}", expanded);
2644 /// }
2645 /// ```
2646 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2647 let shape = shape.into_shape(&self.shape());
2648 check!(TensorCheck::expand::<D, D2>(
2649 "expand",
2650 &self.shape(),
2651 &shape,
2652 ));
2653
2654 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2655 }
2656
2657 /// Unfold windows along a dimension.
2658 ///
2659 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2660 /// where windows are advanced by `step` at each index.
2661 ///
2662 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2663 ///
2664 /// The new view will have the unfolded dimension replaced by two dimensions;
2665 /// one in the position of the original dimension, with size equal to the number of windows,
2666 /// and one appended to the right-most position, with size equal to `size`.
2667 ///
2668 /// # Warning
2669 ///
2670 /// For the `ndarray` and `candle` backends; this is not a view but a copy
2671 /// with duplicated data.
2672 ///
2673 /// # Arguments
2674 ///
2675 /// * `dim` - the dimension to unfold.
2676 /// * `size` - the size of each unfolded window.
2677 /// * `step` - the step between each window.
2678 ///
2679 /// # Returns
2680 ///
2681 /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2682 pub fn unfold<const D2: usize, I: AsIndex>(
2683 self,
2684 dim: I,
2685 size: usize,
2686 step: usize,
2687 ) -> Tensor<B, D2, K> {
2688 let dim = dim.expect_dim_index(D);
2689 check!(TensorCheck::unfold::<D, D2>(
2690 "unfold",
2691 &self.shape(),
2692 dim,
2693 size,
2694 step,
2695 ));
2696 Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2697 }
2698}
2699
2700/// Iterator given by (Tensor::iter_dim).
2701pub struct DimIter<B, const D: usize, K>
2702where
2703 B: Backend,
2704 K: BasicOps<B>,
2705{
2706 start: usize,
2707 end: usize,
2708 dim: usize,
2709 ranges: [Range<usize>; D],
2710 tensor: Tensor<B, D, K>,
2711}
2712
2713impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2714 type Item = Tensor<B, D, K>;
2715
2716 fn next(&mut self) -> Option<Self::Item> {
2717 if self.start >= self.end {
2718 return None;
2719 }
2720
2721 let mut ranges = self.ranges.clone();
2722 ranges[self.dim] = self.start..(self.start + 1);
2723
2724 let slice = self.tensor.clone().slice(ranges);
2725 self.start += 1;
2726
2727 Some(slice)
2728 }
2729}
2730
2731impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2732 fn next_back(&mut self) -> Option<Self::Item> {
2733 if self.start >= self.end {
2734 return None;
2735 }
2736
2737 let mut ranges = self.ranges.clone();
2738 ranges[self.dim] = (self.end - 1)..self.end;
2739
2740 let slice = self.tensor.clone().slice(ranges);
2741 self.end = self.end.saturating_sub(1);
2742
2743 Some(slice)
2744 }
2745}
2746
2747impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2748 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2749 let dims = tensor.dims();
2750 let ranges = dims
2751 .iter()
2752 .map(|&dim| 0..dim)
2753 .collect::<Vec<Range<usize>>>();
2754 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2755 Self {
2756 end: dims[dim],
2757 ranges,
2758 start: 0,
2759 dim,
2760 tensor,
2761 }
2762 }
2763}
2764
2765impl<B, const D: usize, K> Tensor<B, D, K>
2766where
2767 B: Backend,
2768 K: BasicOps<B>,
2769 <K as BasicOps<B>>::Elem: Debug,
2770{
2771 #[inline]
2772 fn push_newline_indent(acc: &mut String, indent: usize) {
2773 acc.push('\n');
2774 for _ in 0..indent {
2775 acc.push(' ');
2776 }
2777 }
2778 fn fmt_inner_tensor(
2779 &self,
2780 acc: &mut String,
2781 depth: usize,
2782 multi_index: &mut [usize],
2783 range: (usize, usize),
2784 precision: Option<usize>,
2785 ) {
2786 let (start, end) = range;
2787 for i in start..end {
2788 if i > 0 {
2789 acc.push_str(", ");
2790 }
2791 multi_index[depth] = i;
2792 let range: [Range<usize>; D] =
2793 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2794
2795 let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());
2796
2797 if let Some(Ok(data)) = data {
2798 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2799 match (precision, K::name()) {
2800 (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2801 (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2802 _ => acc.push_str(&format!("{elem:?}")),
2803 }
2804 } else {
2805 acc.push_str("<Tensor data not available>");
2806 }
2807 }
2808 }
2809
2810 fn fmt_outer_tensor(
2811 &self,
2812 acc: &mut String,
2813 depth: usize,
2814 multi_index: &mut [usize],
2815 print_options: &PrintOptions,
2816 summarize: bool,
2817 range: (usize, usize),
2818 ) {
2819 let (start, end) = range;
2820 for i in start..end {
2821 if i > start {
2822 acc.push(',');
2823 Self::push_newline_indent(acc, depth + 1);
2824 }
2825 acc.push('[');
2826 multi_index[depth] = i;
2827 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2828 acc.push(']');
2829 }
2830 }
2831
2832 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2833 ///
2834 /// This function is designed to work with tensors of any dimensionality.
2835 /// It traverses the tensor dimensions recursively, converting the elements
2836 /// to strings and appending them to the accumulator string with the
2837 /// appropriate formatting.
2838 ///
2839 /// # Arguments
2840 ///
2841 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2842 /// * `depth` - The current depth of the tensor dimensions being processed.
2843 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2844 fn display_recursive(
2845 &self,
2846 acc: &mut String,
2847 depth: usize,
2848 multi_index: &mut [usize],
2849 print_options: &PrintOptions,
2850 summarize: bool,
2851 ) {
2852 let edge_items = print_options.edge_items;
2853
2854 if depth == 0 {
2855 acc.push('[');
2856 }
2857
2858 if depth == self.dims().len() - 1 {
2859 // if we are at the innermost dimension, just push its elements into the accumulator
2860 if summarize && self.dims()[depth] > 2 * edge_items {
2861 // print the starting `edge_items` elements
2862 self.fmt_inner_tensor(
2863 acc,
2864 depth,
2865 multi_index,
2866 (0, edge_items),
2867 print_options.precision,
2868 );
2869 acc.push_str(", ...");
2870 // print the last `edge_items` elements
2871 self.fmt_inner_tensor(
2872 acc,
2873 depth,
2874 multi_index,
2875 (self.dims()[depth] - edge_items, self.dims()[depth]),
2876 print_options.precision,
2877 );
2878 } else {
2879 // print all the elements
2880 self.fmt_inner_tensor(
2881 acc,
2882 depth,
2883 multi_index,
2884 (0, self.dims()[depth]),
2885 print_options.precision,
2886 );
2887 }
2888 } else {
2889 // otherwise, iterate through the current dimension and recursively display the inner tensors
2890 if summarize && self.dims()[depth] > 2 * edge_items {
2891 self.fmt_outer_tensor(
2892 acc,
2893 depth,
2894 multi_index,
2895 print_options,
2896 summarize,
2897 (0, edge_items),
2898 );
2899
2900 acc.push(',');
2901 Self::push_newline_indent(acc, depth + 1);
2902 acc.push_str("...");
2903 Self::push_newline_indent(acc, depth + 1);
2904
2905 self.fmt_outer_tensor(
2906 acc,
2907 depth,
2908 multi_index,
2909 print_options,
2910 summarize,
2911 (self.dims()[depth] - edge_items, self.dims()[depth]),
2912 );
2913 } else {
2914 self.fmt_outer_tensor(
2915 acc,
2916 depth,
2917 multi_index,
2918 print_options,
2919 summarize,
2920 (0, self.dims()[depth]),
2921 );
2922 }
2923 }
2924
2925 if depth == 0 {
2926 acc.push(']');
2927 }
2928 }
2929}
2930
2931#[derive(Clone, Debug)]
2932/// Options for Tensor pretty printing
2933pub struct PrintOptions {
2934 /// number of elements to start summarizing tensor
2935 pub threshold: usize,
2936
2937 /// number of starting elements and ending elements to display
2938 pub edge_items: usize,
2939
2940 /// Precision for floating point numbers
2941 pub precision: Option<usize>,
2942}
2943
2944static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
2945
2946impl PrintOptions {
2947 /// Print options with default values
2948 pub const fn const_default() -> Self {
2949 Self {
2950 threshold: 1000,
2951 edge_items: 3,
2952 precision: None,
2953 }
2954 }
2955}
2956
2957impl Default for PrintOptions {
2958 fn default() -> Self {
2959 Self::const_default()
2960 }
2961}
2962
2963/// Set print options
2964pub fn set_print_options(options: PrintOptions) {
2965 let mut print_opts = PRINT_OPTS.write().unwrap();
2966 *print_opts = options;
2967}
2968
2969/// Pretty print tensors
2970impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
2971where
2972 B: Backend,
2973 B::IntElem: core::fmt::Display,
2974 K: BasicOps<B>,
2975 <K as BasicOps<B>>::Elem: Debug,
2976{
2977 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2978 writeln!(f, "Tensor {{")?;
2979
2980 {
2981 // Do not lock the mutex for the whole function
2982 let mut po = { PRINT_OPTS.read().unwrap().clone() };
2983
2984 // Override the precision if it is set from the formatter
2985 // This will be possible when the tensor is printed using the `{:.*}` syntax
2986 if let Some(precision) = f.precision() {
2987 po.precision = Some(precision);
2988 }
2989
2990 let mut acc = String::new();
2991 let mut multi_index = vec![0; D];
2992 let summarize = self.shape().num_elements() > po.threshold;
2993
2994 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
2995
2996 writeln!(f, " data:")?;
2997 write!(f, "{acc}")?;
2998 writeln!(f, ",")?;
2999 }
3000
3001 writeln!(f, " shape: {:?},", self.dims())?;
3002 writeln!(f, " device: {:?},", self.device())?;
3003 writeln!(f, " backend: {:?},", B::name(&self.device()))?;
3004 writeln!(f, " kind: {:?},", K::name())?;
3005
3006 let dtype = self.primitive.dtype();
3007
3008 writeln!(f, " dtype: {:?},", dtype.name())?;
3009 write!(f, "}}")
3010 }
3011}
3012
3013/// Trait used for movedim arguments
3014pub trait MovedimArgs {
3015 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3016 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3017}
3018
3019impl MovedimArgs for Vec<i32> {
3020 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3021 let set = self
3022 .iter()
3023 .map(|&dim| {
3024 if dim < 0 {
3025 (D as i32 + dim) as usize
3026 } else {
3027 dim as usize
3028 }
3029 })
3030 .collect::<Vec<usize>>();
3031 check!(TensorCheck::movedim_args_vec::<D>(&set));
3032
3033 set
3034 }
3035}
3036
3037impl MovedimArgs for Vec<usize> {
3038 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3039 check!(TensorCheck::movedim_args_vec::<D>(&self));
3040 self
3041 }
3042}
3043
3044impl MovedimArgs for usize {
3045 #[allow(clippy::vec_init_then_push)]
3046 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3047 check!(TensorCheck::movedim_args_usize::<D>(self));
3048
3049 let mut set = Vec::with_capacity(1);
3050 set.push(self);
3051
3052 set
3053 }
3054}
3055
3056impl MovedimArgs for i32 {
3057 #[allow(clippy::vec_init_then_push)]
3058 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3059 check!(TensorCheck::movedim_args_i32::<D>(self));
3060
3061 let dim = if self < 0 {
3062 (D as i32 + self) as usize
3063 } else {
3064 self as usize
3065 };
3066
3067 let mut set = Vec::with_capacity(1);
3068 set.push(dim);
3069
3070 set
3071 }
3072}
3073
3074/// Trait used for reshape arguments.
3075pub trait ReshapeArgs<const D2: usize> {
3076 /// Converts to a shape.
3077 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3078 self,
3079 tensor: &Tensor<B, D, K>,
3080 ) -> Shape;
3081}
3082
3083impl<const D2: usize> ReshapeArgs<D2> for Shape {
3084 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3085 self,
3086 tensor: &Tensor<B, D, K>,
3087 ) -> Shape {
3088 check!(TensorCheck::reshape_args_usize::<D, D2>(
3089 &tensor.shape(),
3090 &self
3091 ));
3092
3093 self
3094 }
3095}
3096impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3097 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3098 self,
3099 tensor: &Tensor<B, D, K>,
3100 ) -> Shape {
3101 let shape = Shape::from(self);
3102
3103 check!(TensorCheck::reshape_args_usize::<D, D2>(
3104 &tensor.shape(),
3105 &shape
3106 ));
3107
3108 shape
3109 }
3110}
3111
3112impl<const D2: usize> ReshapeArgs<D2> for [i64; D2] {
3113 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3114 self,
3115 tensor: &Tensor<B, D, K>,
3116 ) -> Shape {
3117 // Validate the reshape arguments
3118 check!(TensorCheck::reshape_args_i64(&self));
3119
3120 // Temporary shape
3121 let mut new_shape: [i64; D2] = [1; D2];
3122
3123 // We need to find the index of the 0 dimension and
3124 // replace it with the actual dimension value.
3125 for (i, &s) in self.iter().enumerate() {
3126 if s != 0 {
3127 new_shape[i] = s;
3128 } else {
3129 new_shape[i] = tensor.dims()[i] as i64;
3130 }
3131 }
3132
3133 // Find the index of the inferred dimension (-1)
3134 let infer_index = new_shape.iter().position(|x| x == &-1);
3135
3136 // Handle the case where the dimension is inferred (via -1)
3137 if let Some(index) = infer_index {
3138 // Handle the case where the dimension is inferred
3139 let mut product = 1;
3140 for (i, &s) in new_shape.iter().enumerate() {
3141 if i != index {
3142 product *= s;
3143 }
3144 }
3145 let product_current = tensor.shape().num_elements() as i64;
3146
3147 new_shape[index] = product_current / product;
3148
3149 // Check if the reshape is valid
3150 if product_current % product != 0 {
3151 panic!(
3152 "Cannot reshape tensor of shape {:?} to shape {:?}",
3153 tensor.shape(),
3154 new_shape
3155 );
3156 }
3157 };
3158
3159 // Convert each element to usize
3160 let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3161
3162 Shape::from(new_shape)
3163 }
3164}
3165
3166impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3167 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3168 self,
3169 tensor: &Tensor<B, D, K>,
3170 ) -> Shape {
3171 // Convert i32 array to i64 array and use existing implementation
3172 let i64_array: [i64; D2] = self.map(|x| x as i64);
3173 ReshapeArgs::into_shape(i64_array, tensor)
3174 }
3175}
3176
3177/// Trait used for broadcast arguments.
3178pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3179 /// Converts to a shape.
3180 fn into_shape(self, shape: &Shape) -> Shape;
3181}
3182
3183impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3184 fn into_shape(self, _shape: &Shape) -> Shape {
3185 self
3186 }
3187}
3188
3189impl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {
3190 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3191 fn into_shape(self, shape: &Shape) -> Shape {
3192 if self.len() < shape.num_dims() {
3193 panic!("Broadcast arguments must be greater than the number of dimensions");
3194 }
3195
3196 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3197 let new_shape: Vec<_> = self
3198 .iter()
3199 .rev()
3200 .map(|x| {
3201 let primitive = x.index();
3202 if primitive < -1 || primitive == 0 {
3203 panic!("Broadcast arguments must be positive or -1");
3204 }
3205 primitive
3206 })
3207 .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3208 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3209 .collect::<Vec<_>>()
3210 .into_iter()
3211 .rev()
3212 .collect();
3213
3214 if new_shape.contains(&0) {
3215 panic!("Cannot substitute -1 for a non-existing dimension");
3216 }
3217
3218 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3219
3220 Shape::from(new_shape)
3221 }
3222}
3223
3224impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3225where
3226 B: Backend,
3227 K: BasicOps<B>,
3228 K::Elem: Debug + Copy + Serialize,
3229{
3230 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3231 let data = self.to_data();
3232 data.serialize(serializer)
3233 }
3234}
3235
3236impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3237where
3238 B: Backend,
3239 K: BasicOps<B>,
3240 K::Elem: Debug + Copy + Deserialize<'de>,
3241{
3242 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3243 let tensor = Tensor::from_data(
3244 TensorData::deserialize(deserializer)?,
3245 &<B::Device as Default>::default(),
3246 );
3247 Ok(tensor)
3248 }
3249}
3250
3251#[cfg(test)]
3252mod tests {
3253 use crate::{Shape, s};
3254
3255 #[test]
3256 fn slice_range_single_dim_leading() {
3257 let shape = Shape::new([8, 4]);
3258
3259 // Half-open range
3260 let slices = shape.clone().into_slices([0..5]);
3261 assert_eq!(slices[0].to_range(8), 0..5);
3262 let slices = shape.clone().into_slices([-3..-1]);
3263 assert_eq!(slices[0].to_range(8), 5..7);
3264
3265 // Inclusive range
3266 let slices = shape.clone().into_slices([0..=4]);
3267 assert_eq!(slices[0].to_range(8), 0..5);
3268 let slices = shape.clone().into_slices([-2..=-1]);
3269 assert_eq!(slices[0].to_range(8), 6..8);
3270
3271 // Unbounded start
3272 let slices = shape.clone().into_slices([..3]);
3273 assert_eq!(slices[0].to_range(8), 0..3);
3274 let slices = shape.clone().into_slices([..-5]);
3275 assert_eq!(slices[0].to_range(8), 0..3);
3276
3277 // Unbounded end
3278 let slices = shape.clone().into_slices([5..]);
3279 assert_eq!(slices[0].to_range(8), 5..8);
3280 let slices = shape.clone().into_slices([-3..]);
3281 assert_eq!(slices[0].to_range(8), 5..8);
3282
3283 // Full range
3284 let slices = shape.into_slices([..]);
3285 assert_eq!(slices[0].to_range(8), 0..8);
3286 }
3287
3288 #[test]
3289 fn test_negative_slice_indices() {
3290 use crate::Slice;
3291
3292 // Test negative indices conversion
3293 let slice: Slice = (-3..-1).into();
3294 assert_eq!(slice.start, -3);
3295 assert_eq!(slice.end, Some(-1));
3296
3297 // Test to_range conversion with size 8
3298 let range = slice.to_range(8);
3299 assert_eq!(range, 5..7);
3300
3301 // Test with shape slice
3302 let shape = Shape::new([8, 4]);
3303 let result = shape.clone().into_slices([-3..-1]);
3304 assert_eq!(result[0].to_range(8), 5..7);
3305
3306 // Test more negative index cases
3307 let slice2: Slice = (-5..).into();
3308 assert_eq!(slice2.to_range(10), 5..10);
3309
3310 let slice3: Slice = (..-2).into();
3311 assert_eq!(slice3.to_range(10), 0..8);
3312
3313 // Test with s! macro - single dimension returns Slice directly
3314 let slice4 = s![-3..-1];
3315 assert_eq!(slice4.start, -3);
3316 assert_eq!(slice4.end, Some(-1));
3317 }
3318
3319 #[test]
3320 fn slice_range_multi_dim() {
3321 let shape = Shape::new([8, 4]);
3322
3323 // Multiple ways to provide ranges
3324 let slices = shape.clone().into_slices([0..5, 0..4]);
3325 assert_eq!(slices[0].to_range(8), 0..5);
3326 assert_eq!(slices[1].to_range(4), 0..4);
3327
3328 let slices = shape.clone().into_slices([0.., 0..]);
3329 assert_eq!(slices[0].to_range(8), 0..8);
3330 assert_eq!(slices[1].to_range(4), 0..4);
3331
3332 let slices = shape.clone().into_slices([0..=7, 0..=3]);
3333 assert_eq!(slices[0].to_range(8), 0..8);
3334 assert_eq!(slices[1].to_range(4), 0..4);
3335
3336 let slices = shape.clone().into_slices([0..5, 0..3]);
3337 assert_eq!(slices[0].to_range(8), 0..5);
3338 assert_eq!(slices[1].to_range(4), 0..3);
3339
3340 let slices = shape.into_slices([0.., 0..]);
3341 assert_eq!(slices[0].to_range(8), 0..8);
3342 assert_eq!(slices[1].to_range(4), 0..4);
3343 }
3344
3345 #[test]
3346 fn slice_range_multi_dim_index() {
3347 let shape = Shape::new([8, 4]);
3348
3349 // Indices (single integer) should also convert to correct range
3350 let slices = shape.clone().into_slices([0, 2]);
3351 assert_eq!(slices[0].to_range(8), 0..1);
3352 assert_eq!(slices[1].to_range(4), 2..3);
3353
3354 let slices = shape.into_slices([-1, -1]);
3355 assert_eq!(slices[0].to_range(8), 7..8);
3356 assert_eq!(slices[1].to_range(4), 3..4);
3357 }
3358
3359 #[test]
3360 fn slice_range_multi_dim_heterogeneous() {
3361 // Slice macro `s![]` can be used to provide different range types
3362 let shape = Shape::new([8, 4, 2]);
3363 let slice = s![0..5, .., -1];
3364 let slices = shape.into_slices(slice);
3365 assert_eq!(slices[0].to_range(8), 0..5);
3366 assert_eq!(slices[1].to_range(4), 0..4);
3367 assert_eq!(slices[2].to_range(2), 1..2);
3368
3369 let shape = Shape::new([8, 4, 2, 3]);
3370 let slice = s![..=4, 0..=3, .., -2..];
3371 let slices = shape.into_slices(slice);
3372 assert_eq!(slices[0].to_range(8), 0..5);
3373 assert_eq!(slices[1].to_range(4), 0..4);
3374 assert_eq!(slices[2].to_range(2), 0..2);
3375 assert_eq!(slices[3].to_range(3), 1..3);
3376
3377 let shape = Shape::new([3, 4]);
3378 let slice = s![1..-1, ..];
3379 let slices = shape.into_slices(slice);
3380 assert_eq!(slices[0].to_range(3), 1..2);
3381 assert_eq!(slices[1].to_range(4), 0..4);
3382 }
3383}