burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2use crate::backend::ExecutionError;
3use crate::check::unwrap_shape_reshape;
4
5use burn_backend::Scalar;
6pub use burn_backend::tensor::BasicOps;
7
8use alloc::vec::Vec;
9
10use alloc::format;
11use alloc::string::String;
12use alloc::vec;
13
14use burn_std::{SliceOps, stub::RwLock};
15use core::iter::repeat;
16use core::{fmt::Debug, ops::Range};
17use serde::{Deserialize, Deserializer};
18
19use crate::{AsIndex, Slice, SliceArg, wrap_index};
20use crate::{
21 Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,
22 backend::Backend, check,
23};
24use crate::{DType, Element};
25use crate::{IndexingUpdateOp, TensorCreationOptions};
26use crate::{cast::ToElement, check::TensorCheck};
27use serde::{Serialize, Serializer};
28
29/// A tensor with a given backend, shape and data type.
30///
31/// # Indexing
32/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
33/// or [`select`](Tensor::select) for numeric types.
34///
35/// ## Example
36///
37/// ```rust
38/// use burn_tensor::backend::Backend;
39/// use burn_tensor::Tensor;
40/// use burn_tensor::Int;
41///
42/// fn example<B: Backend>() {
43/// let device = Default::default();
44///
45/// let tensor = Tensor::<B, 2>::from_data(
46/// [
47/// [3.0, 4.9, 2.0],
48/// [2.0, 1.9, 3.0],
49/// [6.0, 1.5, 7.0],
50/// [3.0, 4.9, 9.0],
51/// ],
52/// &device,
53/// );
54///
55/// // Slice the tensor to get the second and third rows:
56/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
57/// // The resulting tensor will have dimensions [2, 3].
58/// let slice = tensor.clone().slice([1..3]);
59/// println!("{slice}");
60///
61/// // Slice the tensor to get the first two rows and the first 2 columns:
62/// // [[3.0, 4.9], [2.0, 1.9]]
63/// // The resulting tensor will have dimensions [2, 2].
64/// let slice = tensor.clone().slice([0..2, 0..2]);
65/// println!("{slice}");
66///
67/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
68/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
69/// // The resulting tensor will have dimensions [4, 2]
70/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
71/// let indexed = tensor.select(1, indices);
72/// println!("{indexed}");
73/// }
74/// ```
75#[derive(new, Clone, Debug)]
76pub struct Tensor<B, const D: usize, K = Float>
77where
78 B: Backend,
79 K: TensorKind<B>,
80{
81 pub(crate) primitive: K::Primitive,
82}
83
84impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
85where
86 B: Backend,
87 K: BasicOps<B>,
88 T: Into<TensorData>,
89{
90 fn from(value: T) -> Self {
91 Tensor::from_data(value.into(), &Default::default())
92 }
93}
94
95impl<B, const D: usize, K> Tensor<B, D, K>
96where
97 B: Backend,
98 K: BasicOps<B>,
99 K::Elem: Element,
100{
101 /// Executes an operation on the tensor and modifies its value.
102 ///
103 /// # Notes
104 ///
105 /// This won't necessarily reuse the same tensor data/buffer, but it should if there is
106 /// no other reference pointing to the same tensor.
107 ///
108 /// Wrapping operations with inplace is not an optimization, it's mainly there if you
109 /// want to mutate a tensor by using owned operations. A plausible usage would be to
110 /// update the weights of a mutable model reference.
111 pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {
112 let mut tensor_owned = Tensor::empty([0; D], &self.device());
113 core::mem::swap(&mut tensor_owned, self);
114
115 let mut tensor_new = func(tensor_owned);
116 core::mem::swap(&mut tensor_new, self);
117 }
118
119 /// Converts the tensor into a primitive tensor.
120 pub fn into_primitive(self) -> K::Primitive {
121 self.primitive
122 }
123
124 /// Converts from a primitive tensor into a tensor.
125 pub fn from_primitive(tensor: K::Primitive) -> Self {
126 Self::new(tensor)
127 }
128
129 /// Returns the number of dimensions of the tensor.
130 pub fn rank(&self) -> usize {
131 self.primitive.rank()
132 }
133
134 /// Returns the tensor primitive data type.
135 ///
136 /// # Note
137 /// Some element types are encoded in different primitive types depending on the backend
138 /// (e.g., bool could be encoded as `u8` or `u32`).
139 pub fn dtype(&self) -> DType {
140 self.primitive.dtype()
141 }
142
143 /// Create an empty tensor of the given shape.
144 ///
145 /// # Arguments
146 ///
147 /// - `shape`: The shape of the tensor.
148 /// - `device`: The device where the tensor will be created.
149 ///
150 /// # Example
151 /// ```rust
152 /// use burn_tensor::backend::Backend;
153 /// use burn_tensor::Tensor;
154 ///
155 /// fn example<B: Backend>() {
156 /// let device = Default::default();
157 /// // Create an empty tensor with dimensions [2, 3, 4].
158 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
159 /// }
160 /// ```
161 pub fn empty<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
162 let opt = options.into();
163 let shape = shape.into();
164 let dtype = opt.resolve_dtype::<K>();
165 check!(TensorCheck::creation_ops::<D>("Empty", &shape));
166 Self::new(K::empty(shape, &opt.device, dtype))
167 }
168
169 /// Create a tensor of the given shape where each element is zero.
170 ///
171 /// # Example
172 ///
173 /// ```rust
174 /// use burn_tensor::backend::Backend;
175 /// use burn_tensor::{Tensor, Shape};
176 ///
177 /// fn example<B: Backend>() {
178 /// let device = B::Device::default();
179 /// let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
180 /// println!("{tensor}");
181 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
182 /// }
183 /// ```
184 pub fn zeros<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
185 let opt = options.into();
186 let shape = shape.into();
187 let dtype = opt.resolve_dtype::<K>();
188 check!(TensorCheck::creation_ops::<D>("Zeros", &shape));
189 Self::new(K::zeros(shape, &opt.device, dtype))
190 }
191
192 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
193 ///
194 /// # Example
195 ///
196 /// ```rust
197 /// use burn_tensor::backend::Backend;
198 /// use burn_tensor::{Tensor, Shape};
199 ///
200 /// fn example<B: Backend>() {
201 /// let device = B::Device::default();
202 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
203 /// let tensor = tensor.zeros_like();
204 /// println!("{tensor}");
205 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
206 /// }
207 /// ```
208 pub fn zeros_like(&self) -> Self {
209 Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
210 }
211
212 /// Create a tensor of the given shape where each element is one.
213 ///
214 /// # Example
215 ///
216 /// ```rust
217 /// use burn_tensor::backend::Backend;
218 /// use burn_tensor::{Tensor, Shape};
219 ///
220 /// fn example<B: Backend>() {
221 /// let device = B::Device::default();
222 /// let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
223 /// println!("{tensor}");
224 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
225 /// }
226 /// ```
227 pub fn ones<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {
228 let opt = options.into();
229 let shape = shape.into();
230 let dtype = opt.resolve_dtype::<K>();
231 check!(TensorCheck::creation_ops::<D>("Ones", &shape));
232 Self::new(K::ones(shape, &opt.device, dtype))
233 }
234
235 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
236 ///
237 /// # Example
238 ///
239 /// ```rust
240 /// use burn_tensor::backend::Backend;
241 /// use burn_tensor::{Tensor, Shape};
242 ///
243 /// fn example<B: Backend>() {
244 /// let device = B::Device::default();
245 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
246 /// let tensor = tensor.ones_like();
247 /// println!("{tensor}");
248 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
249 /// }
250 /// ```
251 pub fn ones_like(&self) -> Self {
252 Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
253 }
254
255 /// Create a tensor of the given shape where each element is equal to the provided value.
256 ///
257 /// # Example
258 ///
259 /// ```rust
260 /// use burn_tensor::backend::Backend;
261 /// use burn_tensor::{Tensor, Shape};
262 ///
263 /// fn example<B: Backend>() {
264 /// let device = B::Device::default();
265 /// let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
266 /// println!("{tensor}");
267 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
268 /// }
269 /// ```
270 pub fn full<S: Into<Shape>, E: ElementConversion>(
271 shape: S,
272 fill_value: E,
273 options: impl Into<TensorCreationOptions<B>>,
274 ) -> Self {
275 let opt = options.into();
276 let shape = shape.into();
277 let dtype = opt.resolve_dtype::<K>();
278 check!(TensorCheck::creation_ops::<D>("Full", &shape));
279 Self::new(K::full(
280 shape,
281 Scalar::new(fill_value, &dtype),
282 &opt.device,
283 dtype,
284 ))
285 }
286
287 /// Returns a new tensor with the same shape, dtype, and device as the current tensor,
288 /// filled with the provided value.
289 ///
290 /// # Example
291 ///
292 /// ```rust
293 /// use burn_tensor::backend::Backend;
294 /// use burn_tensor::{Tensor, Shape};
295 ///
296 /// fn example<B: Backend>() {
297 /// let device = B::Device::default();
298 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
299 /// let tensor = tensor.full_like(5.0);
300 /// println!("{tensor}");
301 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
302 /// }
303 /// ```
304 pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
305 let dtype = self.dtype();
306 Self::new(K::full(
307 self.shape(),
308 Scalar::new(fill_value, &dtype),
309 &self.device(),
310 dtype,
311 ))
312 }
313
314 /// Returns the dimensions of the current tensor.
315 ///
316 /// # Example
317 /// ```rust
318 /// use burn_tensor::backend::Backend;
319 /// use burn_tensor::Tensor;
320 ///
321 /// fn example<B: Backend>() {
322 /// let device = Default::default();
323 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
324 /// let dims = tensor.dims(); // [2, 3, 4]
325 /// println!("{dims:?}");
326 /// }
327 /// ```
328 pub fn dims(&self) -> [usize; D] {
329 Self::shape(self).dims()
330 }
331
332 /// Returns the shape of the current tensor.
333 ///
334 /// # Example
335 /// ```rust
336 /// use burn_tensor::backend::Backend;
337 /// use burn_tensor::Tensor;
338 ///
339 /// fn example<B: Backend>() {
340 /// let device = Default::default();
341 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
342 /// // Shape { dims: [2, 3, 4] }
343 /// let shape = tensor.shape();
344 /// }
345 /// ```
346 pub fn shape(&self) -> Shape {
347 self.primitive.shape()
348 }
349
350 /// Reshape the tensor to have the given shape.
351 ///
352 /// The tensor has the same data and number of elements as the input.
353 ///
354 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
355 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
356 ///
357 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
358 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
359 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
360 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
361 /// with [1, 3, 4] dimensions to [1, 12].
362 ///
363 /// # Arguments
364 /// - `shape`: The new shape of the tensor.
365 ///
366 /// # Panics
367 /// - If the tensor contains more than one `-1` in the shape.
368 /// - If the tensor contains values that are not positive (other than -1).
369 /// - If the shape does not match the number of elements of the original shape.
370 ///
371 /// # Example
372 ///
373 /// ```rust
374 /// use burn_tensor::backend::Backend;
375 /// use burn_tensor::Tensor;
376 ///
377 /// fn example<B: Backend>() {
378 /// let device = Default::default();
379 /// // Create a tensor with dimensions [2, 3, 4]
380 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
381 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
382 /// let reshaped = tensor.reshape([2, -1]);
383 /// println!("{reshaped}");
384 /// }
385 /// ```
386 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
387 // Convert reshape args to shape
388 let shape = shape.into_shape::<D2>(self.shape());
389 Tensor::new(K::reshape(self.primitive, shape))
390 }
391
392 /// Transpose the tensor.
393 ///
394 /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
395 /// applied on the last two dimensions. For example, the transpose of a tensor with shape
396 /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
397 ///
398 /// See also [`permute`](Tensor::permute).
399 ///
400 /// # Arguments
401 ///
402 /// * `tensor` - The tensor to transpose.
403 ///
404 /// # Returns
405 ///
406 /// The transposed tensor.
407 ///
408 /// # Example
409 ///
410 /// ```rust
411 /// use burn_tensor::backend::Backend;
412 /// use burn_tensor::Tensor;
413 ///
414 /// fn example<B: Backend>() {
415 /// let device = Default::default();
416 /// // Create a 2D tensor of shape [2, 3]
417 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
418 ///
419 /// // Transpose the tensor:
420 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
421 /// // The resulting tensor will have dimensions [3, 2].
422 /// let transposed = tensor.transpose();
423 /// println!("{transposed}");
424 /// }
425 /// ```
426 pub fn transpose(self) -> Tensor<B, D, K> {
427 Tensor::new(K::transpose(self.primitive))
428 }
429
430 /// Alias for `transpose`.
431 #[inline(always)]
432 pub fn t(self) -> Tensor<B, D, K> {
433 self.transpose()
434 }
435
436 /// Swaps two dimensions of a tensor.
437 ///
438 /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.
439 ///
440 /// # Arguments
441 ///
442 /// * `tensor` - The tensor to swap the dimensions of.
443 /// * `dim1` - The first dimension to swap, supports negative indexing.
444 /// * `dim2` - The second dimension to swap, supports negative indexing.
445 ///
446 /// # Returns
447 ///
448 /// The tensor with the dimensions swapped.
449 ///
450 /// # Panics
451 ///
452 /// When dimensions are out of bounds.
453 ///
454 /// # Example
455 ///
456 /// ```rust
457 /// use burn_tensor::backend::Backend;
458 /// use burn_tensor::Tensor;
459 ///
460 /// fn example<B: Backend>() {
461 /// let device = Default::default();
462 /// // Create a 2D tensor of shape [2, 3]
463 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
464 ///
465 /// // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):
466 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
467 /// // The resulting tensor will have dimensions [3, 2].
468 /// let swapped = tensor.swap_dims(0, -1);
469 /// println!("{swapped}");
470 /// }
471 /// ```
472 pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>
473 where
474 Dim1: AsIndex,
475 Dim2: AsIndex,
476 {
477 let dim1 = dim1.expect_dim_index(D);
478 let dim2 = dim2.expect_dim_index(D);
479 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
480 if dim1 == dim2 {
481 self
482 } else {
483 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
484 }
485 }
486
487 /// Permute the dimensions of the tensor.
488 ///
489 /// This is a no-op when the resolved `axes` match the current order.
490 ///
491 /// # Arguments
492 ///
493 /// * `axes` - The new order of the dimensions. The length of the axes
494 /// must be equal to the number of dimensions of the tensor.
495 /// The values must be unique and in the range of the number of dimensions.
496 /// The values can be negative, in which case they are used as an offset from the end.
497 ///
498 /// # Returns
499 ///
500 /// The tensor with the dimensions permuted.
501 ///
502 /// # Example
503 ///
504 /// ```rust
505 /// use burn_tensor::backend::Backend;
506 /// use burn_tensor::Tensor;
507 ///
508 /// fn example<B: Backend>() {
509 /// let device = Default::default();
510 /// // Create a 2D tensor of shape [3, 2]
511 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
512 ///
513 /// // Permute the dimensions 1 and 0:
514 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
515 /// // The resulting tensor will have dimensions [3, 2].
516 /// let permuted = tensor.permute([1, 0]);
517 /// println!("{permuted}");
518 /// }
519 /// ```
520 pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>
521 where
522 Dim: AsIndex,
523 {
524 let mut no_op = true;
525 let mut fixed_axes = [0; D];
526 for (i, axis) in axes.into_iter().enumerate() {
527 let dim = axis.expect_dim_index(D);
528 no_op &= dim == i;
529 fixed_axes[i] = dim;
530 }
531
532 if no_op {
533 self
534 } else {
535 check!(TensorCheck::permute(fixed_axes));
536 Tensor::new(K::permute(self.primitive, &fixed_axes))
537 }
538 }
539
540 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
541 ///
542 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
543 /// at the positions not specified in destination.
544 ///
545 /// # Arguments
546 ///
547 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
548 /// The values can be negative, in which case they are used as an offset from the end.
549 ///
550 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
551 ///
552 /// # Panics
553 ///
554 /// - If the source and destination dimensions are not of the same length.
555 /// - If the source and destination vectors contain duplicate values.
556 /// - If the source and destination vectors contain values that are out of bounds.
557 ///
558 /// # Returns
559 ///
560 /// The tensor with the dimensions moved.
561 ///
562 /// # Example
563 ///
564 /// ```rust
565 /// use burn_tensor::backend::Backend;
566 /// use burn_tensor::Tensor;
567 ///
568 /// fn example<B: Backend>() {
569 /// let device = Default::default();
570 /// // Create a 3D tensor of shape [3, 2, 1]
571 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
572 ///
573 /// // Move the dimensions 0 and 1:
574 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
575 /// // The resulting tensor will have dimensions [2, 3, 1].
576 /// let moved = tensor.movedim(1, 0);
577 /// println!("{moved}");
578 /// }
579 /// ```
580 ///
581 /// # Note
582 ///
583 /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
584 /// for it
585 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
586 let source_dims = src.into_dim_vec::<D>();
587 let destination_dims = dst.into_dim_vec::<D>();
588
589 check!(TensorCheck::movedim_args_length(
590 &source_dims,
591 &destination_dims
592 ));
593
594 let mut m = [-1; D];
595 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
596 m[d] = s as isize;
597 }
598 let mut axes: [isize; D] = [0; D];
599 let mut source_i = 0;
600 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
601 *item = if m[dest_i] != -1 {
602 m[dest_i]
603 } else {
604 while source_dims.contains(&source_i) {
605 source_i += 1;
606 }
607 let result = source_i as isize;
608 source_i += 1;
609 result
610 };
611 }
612
613 self.permute(axes)
614 }
615
616 /// Reverse the order of elements in the tensor along the given dimensions.
617 ///
618 /// # Arguments
619 ///
620 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
621 /// The values can be negative, in which case they are used as an offset from the end.
622 ///
623 /// # Returns
624 ///
625 /// The tensor with the axes flipped.
626 ///
627 /// # Example
628 ///
629 /// ```rust
630 /// use burn_tensor::backend::Backend;
631 /// use burn_tensor::Tensor;
632 ///
633 /// fn example<B: Backend>() {
634 /// let device = Default::default();
635 /// // Create a 2D tensor with dimensions [4, 3]
636 /// let tensor = Tensor::<B, 2>::from_data(
637 /// [
638 /// [3.0, 4.9, 2.0],
639 /// [2.0, 1.9, 3.0],
640 /// [4.0, 5.9, 8.0],
641 /// [1.4, 5.8, 6.0],
642 /// ],
643 /// &device,
644 /// );
645 ///
646 /// // Flip the elements in dimensions 0 and 1:
647 /// // [[6.0, 5.8, 1.4],
648 /// // [8.0, 5.9, 4.0],
649 /// // [3.0, 1.9, 2.0],
650 /// // [2.0, 4.9, 3.0]]
651 /// // The resulting tensor will have dimensions [4, 3].
652 /// let flipped = tensor.flip([0, 1]);
653 /// println!("{flipped}");
654 /// }
655 /// ```
656 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
657 // Convert the axes to usize and handle negative values without using vector
658 let mut transformed_axes: [usize; N] = [0; N];
659 for (i, &x) in axes.iter().enumerate() {
660 transformed_axes[i] = if x < 0 {
661 (D as isize + x) as usize
662 } else {
663 x as usize
664 };
665 }
666
667 // Check if the axes are valid
668 check!(TensorCheck::flip(D, &transformed_axes));
669
670 Tensor::new(K::flip(self.primitive, &transformed_axes))
671 }
672
673 /// Flatten the tensor along a given range of dimensions.
674 ///
675 /// This function collapses the specified range of dimensions into a single dimension,
676 /// effectively flattening the tensor in that range.
677 ///
678 /// # Arguments
679 ///
680 /// - `start_dim`: The starting dimension of the range to be flattened,
681 /// supports negative indexing.
682 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
683 /// supports negative indexing.
684 ///
685 /// # Type Parameters
686 ///
687 /// - `D2`: The resulting number of dimensions in the flattened tensor.
688 ///
689 /// # Returns
690 ///
691 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
692 ///
693 /// # Example
694 ///
695 /// ```rust
696 ///
697 /// use burn_tensor::backend::Backend;
698 /// use burn_tensor::{Tensor, Shape};
699 ///
700 /// fn example<B: Backend>() {
701 /// let device = Default::default();
702 /// // Create a 3D tensor with dimensions [2, 3, 4]
703 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
704 ///
705 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
706 /// // The resulting tensor will have dimensions [2, 12]
707 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
708 /// println!("{flattened}");
709 /// }
710 /// ```
711 pub fn flatten<const D2: usize>(
712 self,
713 start_dim: impl AsIndex,
714 end_dim: impl AsIndex,
715 ) -> Tensor<B, D2, K> {
716 let start_dim = start_dim.expect_dim_index(D);
717 let end_dim = end_dim.expect_dim_index(D);
718 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
719 let new_shape = self.shape().flatten_dims(start_dim, end_dim);
720
721 Tensor::new(K::reshape(self.primitive, new_shape))
722 }
723
724 /// Squeeze the tensor along all dimensions, removing dimensions
725 /// of size one, and effectively reducing the rank of the tensor.
726 ///
727 /// # Type Parameters
728 ///
729 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
730 ///
731 /// # Returns
732 ///
733 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
734 ///
735 /// # Example
736 ///
737 /// ```rust
738 ///
739 /// use burn_tensor::backend::Backend;
740 /// use burn_tensor::{Tensor, Shape};
741 ///
742 /// fn example<B: Backend>() {
743 /// let device = Default::default();
744 /// // Create a 4D tensor with dimensions [1, 3, 1, 3]
745 /// let tensor = Tensor::<B, 4>::from_data(
746 /// [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],
747 /// &device,
748 /// );
749 ///
750 /// // Squeeze the tensor dimensions.
751 /// // The resulting tensor will have dimensions [3, 3].
752 /// let squeezed = tensor.squeeze::<2>();
753 /// println!("{squeezed}");
754 /// }
755 /// ```
756 pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {
757 let new_dims = self
758 .shape()
759 .iter()
760 .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })
761 .collect::<Vec<_>>();
762 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
763
764 Tensor::new(K::reshape(self.primitive, new_dims.into()))
765 }
766
767 /// Squeeze the tensor along the given dimension, removing the specified dimension
768 /// of size one, and effectively reducing the rank of the tensor by one.
769 ///
770 /// # Arguments
771 ///
772 /// - `dim`: The dimension to be squeezed.
773 ///
774 /// # Type Parameters
775 ///
776 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
777 ///
778 /// # Panics
779 ///
780 /// If the size in the squeezed dimension is not 1.
781 ///
782 /// # Returns
783 ///
784 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
785 ///
786 /// # Example
787 ///
788 /// ```rust
789 ///
790 /// use burn_tensor::backend::Backend;
791 /// use burn_tensor::{Tensor, Shape};
792 ///
793 /// fn example<B: Backend>() {
794 /// let device = Default::default();
795 /// // Create a 3D tensor with dimensions [3, 1, 3]
796 /// let tensor = Tensor::<B, 3>::from_data(
797 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
798 /// &device,
799 /// );
800 ///
801 /// // Squeeze the dimension 1.
802 /// // The resulting tensor will have dimensions [3, 3].
803 /// let squeezed = tensor.squeeze_dim::<2>(1);
804 /// println!("{squeezed}");
805 /// }
806 /// ```
807 pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
808 check!(TensorCheck::squeeze::<D2>(dim, &self.shape()));
809
810 let current_dims = self.shape();
811 let mut new_dims: [usize; D2] = [0; D2];
812
813 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
814 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
815
816 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
817 Tensor::new(K::reshape(self.primitive, new_dims.into()))
818 }
819
820 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
821 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
822 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
823 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
824 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
825 /// from the back.
826 ///
827 /// # Arguments
828 ///
829 /// - `dims`: The dimension(s) to be squeezed.
830 ///
831 /// # Type Parameters
832 ///
833 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
834 ///
835 /// # Returns
836 ///
837 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
838 ///
839 /// # Example
840 ///
841 /// ```rust
842 ///
843 /// use burn_tensor::backend::Backend;
844 /// use burn_tensor::{Tensor, Shape};
845 ///
846 /// fn example<B: Backend>() {
847 /// let device = Default::default();
848 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
849 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
850 ///
851 /// // Squeeze the dimensions 1 and 3.
852 /// // The resulting tensor will have dimensions [2, 4].
853 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
854 /// println!("{squeezed}");
855 /// }
856 /// ```
857 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
858 let current_dims = self.shape();
859 let mut dim_indices: Vec<usize>;
860
861 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
862 if dims.is_empty() {
863 dim_indices = current_dims
864 .iter()
865 .enumerate()
866 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
867 .collect();
868 } else {
869 // If negative dims, count from the back
870 dim_indices = dims
871 .iter()
872 .map(|&d| {
873 if d < 0 {
874 (current_dims.len() as isize + d) as usize
875 } else {
876 d as usize
877 }
878 })
879 .collect();
880 }
881
882 // Sort indices and remove duplicates
883 dim_indices.sort_unstable();
884 dim_indices.dedup();
885
886 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
887 check!(TensorCheck::squeeze_dims_input::<D2>(
888 &dim_indices,
889 ¤t_dims
890 ));
891
892 // Calculate new dimensions
893 let mut new_dims = Vec::new();
894 for (index, &dim_size) in current_dims.iter().enumerate() {
895 // Exclude the dimension if it's explicitly marked for squeezing
896 if dim_indices.contains(&index) {
897 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
898 continue;
899 }
900 new_dims.push(dim_size);
901 }
902
903 // Check that after squeezing, we still respect the D2 size
904 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
905
906 Tensor::new(K::reshape(self.primitive, new_dims.into()))
907 }
908
909 /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
910 ///
911 /// # Type Parameters
912 ///
913 /// - `D2`: The resulting number of dimensions in the unsqueezed tensor.
914 ///
915 /// # Panics
916 ///
917 /// If the output size `D2` is smaller than the current number of dimensions.
918 ///
919 /// # Returns
920 ///
921 /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
922 ///
923 /// # Example
924 ///
925 /// ```rust
926 /// use burn_tensor::backend::Backend;
927 /// use burn_tensor::{Tensor, Shape};
928 ///
929 /// fn example<B: Backend>() {
930 /// let device = Default::default();
931 /// // Create a 2D tensor with dimensions [3, 3]
932 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
933 /// // Unsqueeze the tensor up to 4 dimensions.
934 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
935 /// let unsqueezed = tensor.unsqueeze::<4>();
936 /// println!("{unsqueezed}");
937 /// }
938 /// ```
939 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
940 check!(TensorCheck::unsqueeze::<D, D2>());
941
942 let mut dims = [1; D2];
943 let num_ones = D2 - D;
944 let shape = self.shape();
945
946 dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);
947
948 let shape = Shape::new(dims);
949 self.reshape(shape)
950 }
951
952 /// Creates a new tensor with a dimension of size one inserted at the specified position.
953 ///
954 /// # Example
955 ///
956 /// ```rust
957 /// use burn_tensor::backend::Backend;
958 /// use burn_tensor::{Tensor, Shape};
959 ///
960 /// fn example<B: Backend>() {
961 /// let device = Default::default();
962 /// // Create a 2D tensor with dimensions [3, 3]
963 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
964 /// // Unsqueeze the dimension 1.
965 /// // The resulting tensor will have dimensions [3, 1, 3].
966 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
967 /// println!("{unsqueezed}");
968 /// }
969 /// ```
970 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
971 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
972
973 let mut dims = [1; D2];
974 let shape = self.shape();
975
976 dims[0..dim].copy_from_slice(&shape[0..dim]);
977
978 if dim < D {
979 dims[dim] = 1;
980 dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
981 } else {
982 dims[dim] = 1;
983 }
984
985 let shape = Shape::new(dims);
986 self.reshape(shape)
987 }
988
989 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
990 /// The indices can be negative, in which case they are counted from the last to the first dimension.
991 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
992 /// is the number of duplicates.
993 /// # Example
994 ///
995 /// ```rust
996 /// use burn_tensor::backend::Backend;
997 /// use burn_tensor::{Tensor, Shape};
998 ///
999 /// fn example<B: Backend>() {
1000 /// let device = Default::default();
1001 /// // Create a 3D tensor with dimensions [3, 4, 5]
1002 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
1003 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
1004 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
1005 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
1006 /// println!("{unsqueezed}");
1007 /// }
1008 /// ```
1009 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[impl AsIndex]) -> Tensor<B, D2, K> {
1010 let mut new_dims = [1; D2];
1011 let old_dims = self.shape();
1012 //for checking if the dimension is in the acceptable range
1013
1014 //part 1: convert the negative indices to positive
1015 let mut neg_offset = D2;
1016 let mut dim_indices = axes
1017 .iter()
1018 .map(|d| {
1019 let d = d.as_index();
1020 // check if the dimension is in the acceptable range
1021 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(d));
1022 (if d < 0 {
1023 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
1024 d + neg_offset as isize + 1
1025 } else {
1026 d
1027 }) as usize
1028 })
1029 .collect::<Vec<usize>>();
1030
1031 //sort the indices
1032 dim_indices.sort_unstable();
1033
1034 // Per the documented semantics, duplicate axes mean "insert N dims at that index".
1035 // After sorting, N insertions at position `i` logically occupy positions
1036 // `i, i+1, ..., i+N-1` in the output, so bump each duplicate to the next slot.
1037 // Example: sorted `[0, 0, 3]` becomes `[0, 1, 3]`, matching the intent of
1038 // "two 1s starting at index 0, plus one 1 at index 3".
1039 for i in 1..dim_indices.len() {
1040 if dim_indices[i] <= dim_indices[i - 1] {
1041 dim_indices[i] = dim_indices[i - 1] + 1;
1042 }
1043 }
1044
1045 // Re-validate after normalization: bumping duplicates forward can push the
1046 // last index past `D2 - 1` (e.g. `[2, 2]` targeting rank 3 normalizes to
1047 // `[2, 3]`). The per-axis check above only runs on pre-normalization values,
1048 // so we re-check here to surface a clear `TensorCheck` error instead of
1049 // letting the copy loop panic on an out-of-bounds `old_dims` read.
1050 for &dim_index in &dim_indices {
1051 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(dim_index as isize));
1052 }
1053
1054 // Loop over the entries/indices of the `new_dims` array.
1055 // When the current entry should be 1 from the unsqueeze operation, simply increment
1056 // the index for `dims_indices` to account for "adding" its entry to `new_dims`.
1057 // Otherwise, the dim from the current entry of `old_dims` should be copied to `new_dims`.
1058 let mut dim_indices_curr_idx = 0;
1059 let mut old_dims_curr_idx = 0;
1060 for new_dims_curr_idx in 0..D2 {
1061 // If all indices in `dim_indices` have been processed, then
1062 // simply copy all the remaining dims from `old_dims` to `new_dims`
1063 if dim_indices_curr_idx == dim_indices.len() {
1064 new_dims[new_dims_curr_idx..].copy_from_slice(&old_dims[old_dims_curr_idx..]);
1065 break;
1066 }
1067
1068 if new_dims_curr_idx == dim_indices[dim_indices_curr_idx] {
1069 dim_indices_curr_idx += 1;
1070 } else {
1071 new_dims[new_dims_curr_idx] = old_dims[old_dims_curr_idx];
1072 old_dims_curr_idx += 1;
1073 }
1074 }
1075
1076 //lastly, create the shape and reshape
1077 let shape = Shape::new(new_dims);
1078 self.reshape(shape)
1079 }
1080
1081 /// Roll operation along a specific dimension; wrapping around the elements.
1082 ///
1083 /// ## Parameters
1084 ///
1085 /// - `shift`: The roll extent; supports negative values and wraps around.
1086 /// - `dim`: The dimension to roll; supports negative indexing.
1087 ///
1088 /// ## Returns
1089 ///
1090 /// A new tensor with the specified dimension rolled by the given shift amount.
1091 pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self
1092 where
1093 Shift: AsIndex,
1094 Dim: AsIndex,
1095 {
1096 let dim = dim.expect_dim_index(D);
1097 let size = self.shape()[dim];
1098 if size == 0 {
1099 // If the dimension is empty, return the tensor as is.
1100 return self;
1101 }
1102
1103 let shift = wrap_index(shift, size);
1104 if shift == 0 {
1105 // If the shift is zero, return the tensor as is.
1106 return self;
1107 }
1108
1109 self.unchecked_roll_dim(shift, dim)
1110 }
1111
1112 /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.
1113 ///
1114 /// ## Parameters
1115 ///
1116 /// - `shift`: The number of positions to shift; must be (0 < shift < size).
1117 /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.
1118 ///
1119 /// ## Returns
1120 ///
1121 /// A new tensor with the specified dimension rolled by the given shift amount.
1122 #[inline(always)]
1123 fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {
1124 #[cfg(debug_assertions)]
1125 {
1126 let size = self.shape()[dim];
1127 assert!(
1128 0 < shift && shift < size,
1129 "Expected: 0 < shift < size: found shift={shift}, size={size}",
1130 );
1131 assert!(
1132 dim < self.shape().num_dims(),
1133 "Expected: dim < num_dims: found dim={dim}, num_dims={size}",
1134 );
1135 }
1136
1137 Tensor::cat(
1138 vec![
1139 self.clone().slice_dim(dim, shift..),
1140 self.slice_dim(dim, ..shift),
1141 ],
1142 dim,
1143 )
1144 }
1145
1146 /// Roll operation.
1147 ///
1148 /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.
1149 ///
1150 /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.
1151 ///
1152 /// ## Parameters
1153 ///
1154 /// - `shifts`: A slice of shifts corresponding to each dimension;
1155 /// supports negative values and wraps around.
1156 /// - `dims`: A slice of dimensions to roll; supports negative indexing.
1157 ///
1158 /// ## Returns
1159 ///
1160 /// A new tensor with the specified dimensions rolled by the given shifts.
1161 pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self
1162 where
1163 Shift: AsIndex,
1164 Dim: AsIndex,
1165 {
1166 assert_eq!(
1167 dims.len(),
1168 shifts.len(),
1169 "Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}",
1170 );
1171
1172 // This is a fair amount of complexity, which could be replaced
1173 // by a simple canonicalization of `dims` and wrapping of `shifts`.
1174 // The work is done here to ensure that any roll operation
1175 // which could be a no-op is a no-op; simplifying the accounting
1176 // needed by backend-specific implementations of the inner roll op.
1177
1178 let item_count = dims.len();
1179
1180 let shape = self.shape();
1181
1182 // Accumulate the effective shifts for each dimension.
1183 let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];
1184 for i in 0..item_count {
1185 let dim = dims[i].expect_dim_index(D);
1186 accumulated_shifts[dim] += shifts[i].as_index();
1187 }
1188
1189 // Do this after we've checked the validity of `dims` and `shifts`.
1190 if self.shape().num_elements() == 0 {
1191 // If the tensor is empty, return it as is.
1192 return self;
1193 }
1194
1195 // Wrap the accumulated shifts, and filter out empty dimensions.
1196 let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);
1197 let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);
1198 for dim in 0..shape.len() {
1199 // `wrap_index` should inline, and has a fast-exit path for zero shifts.
1200 let shift = wrap_index(accumulated_shifts[dim], shape[dim]);
1201 if shift == 0 {
1202 continue;
1203 }
1204
1205 effective_dims.push(dim);
1206 effective_shifts.push(shift);
1207 }
1208
1209 // If no shifts are needed, return the original tensor.
1210 if effective_shifts.is_empty() {
1211 return self;
1212 }
1213
1214 // At this point:
1215 // - `dims` contains the effective dimensions to roll, in index order,
1216 // - `shifts` contains the effective usize shifts for each dimension.
1217 // - Every shift is non-zero, and less than the size of the corresponding dimension.
1218 self.unchecked_roll(&effective_shifts, &effective_dims)
1219 }
1220
1221 /// `roll` internal implementation.
1222 ///
1223 /// ## Parameters
1224 ///
1225 /// - `shifts`: A slice of shifts corresponding to each dimension;
1226 /// must be non-empty, the same length as `dims`, and all ``1..<size>``.
1227 /// - `dims`: A slice of dimensions to roll; must be non-empty;
1228 /// the same length as `shifts`, and must not contain repeats.
1229 ///
1230 /// ## Panics
1231 ///
1232 /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.
1233 ///
1234 /// ## Returns
1235 ///
1236 /// A new tensor with the specified dimensions rolled by the given shifts.
1237 #[inline(always)]
1238 fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {
1239 #[cfg(debug_assertions)]
1240 {
1241 assert!(!shifts.is_empty());
1242 assert_eq!(
1243 shifts.len(),
1244 dims.len(),
1245 "Shifts and dimensions must align; found {} shifts and {} dims",
1246 shifts.len(),
1247 dims.len()
1248 );
1249
1250 let mut unique_dims = dims.to_vec();
1251 unique_dims.dedup();
1252
1253 assert_eq!(
1254 unique_dims.len(),
1255 dims.len(),
1256 "Dimensions must not contain repeats; found {} unique dims and {} total dims",
1257 unique_dims.len(),
1258 dims.len()
1259 )
1260 }
1261
1262 let x = self.unchecked_roll_dim(shifts[0], dims[0]);
1263
1264 if dims.len() == 1 {
1265 x
1266 } else {
1267 x.unchecked_roll(&shifts[1..], &dims[1..])
1268 }
1269 }
1270
1271 /// Returns a tensor containing the elements selected from the given slices.
1272 ///
1273 /// This method provides flexible tensor slicing with support for various range types,
1274 /// negative indices, and stepped slicing. The method accepts both single slices and
1275 /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.
1276 ///
1277 /// # Arguments
1278 ///
1279 /// * `slices` - Can be:
1280 /// - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)
1281 /// - An array of ranges (e.g., `[0..2, 1..4]`)
1282 /// - The [`s!`] macro output for advanced slicing with steps
1283 /// - a `&Vec<Slice>` or `&[Slice]`
1284 ///
1285 /// # Behavior
1286 ///
1287 /// - Supports partial and full slicing in any number of dimensions
1288 /// - Handles negative indices by wrapping from the end (-1 is the last element)
1289 /// - Automatically clamps ranges that exceed tensor dimensions
1290 /// - Supports stepped slicing for selecting every nth element
1291 /// - Negative steps reverse the selection order
1292 ///
1293 /// # Panics
1294 ///
1295 /// - If the number of slices exceeds the tensor's dimensions
1296 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step
1297 /// - If a step is zero
1298 ///
1299 /// # Examples
1300 ///
1301 /// ```rust
1302 /// use burn_tensor::backend::Backend;
1303 /// use burn_tensor::{Tensor, Shape, s};
1304 ///
1305 /// fn example<B: Backend>() {
1306 /// let device = B::Device::default();
1307 ///
1308 /// // Single dimension slicing - no brackets needed!
1309 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);
1310 /// let slice = tensor.clone().slice(2..8); // Simple range
1311 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);
1312 ///
1313 /// // Using s! macro for single dimension with step
1314 /// let slice = tensor.clone().slice(s![0..10;2]); // Every 2nd element
1315 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);
1316 ///
1317 /// // Reverse a dimension with negative step
1318 /// let slice = tensor.slice(s![..;-1]); // Reverse entire tensor
1319 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
1320 ///
1321 /// // Multi-dimensional slicing
1322 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1323 ///
1324 /// // Array syntax for simple ranges
1325 /// let slice = tensor.clone().slice([1..3, 2..5]);
1326 /// assert_eq!(slice.dims(), [2, 3]);
1327 ///
1328 /// // Advanced multi-dimensional with s! macro
1329 /// let slice = tensor.clone().slice(s![0..4;2, ..;-1]); // Every 2nd row, reverse columns
1330 /// assert_eq!(slice.dims(), [2, 6]);
1331 ///
1332 /// // Complex 3D example with mixed slice types
1333 /// let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);
1334 /// let slice = tensor.slice(s![1..3, ..;2, -3..]); // Rows 1-2, every 2nd col, last 3 depth
1335 /// assert_eq!(slice.dims(), [2, 3, 3]);
1336 ///
1337 /// // Using negative indices
1338 /// let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);
1339 /// let slice = tensor.slice(s![-2.., ..-1]); // Last 2 rows, all but last column
1340 /// assert_eq!(slice.dims(), [2, 5]);
1341 /// }
1342 /// ```
1343 ///
1344 /// # See Also
1345 ///
1346 /// - [`s!`] - The recommended macro for creating complex slice specifications
1347 /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice
1348 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1349 /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension
1350 ///
1351 /// [`s!`]: crate::s!
1352 pub fn slice<S>(self, slices: S) -> Self
1353 where
1354 S: SliceArg,
1355 {
1356 let shape = self.shape();
1357 let slices = slices.into_slices(&shape);
1358
1359 // Validate slices
1360 check!(TensorCheck::slice::<D>(&shape, &slices));
1361
1362 // Calculate output shape and check for empty slices
1363 let mut output_dims = shape.clone();
1364 for (dim, slice) in slices.iter().enumerate() {
1365 output_dims[dim] = slice.output_size(shape[dim]);
1366 }
1367
1368 // Return empty tensor if any dimension is 0 (empty slice)
1369 if output_dims.contains(&0) {
1370 return Self::empty(output_dims, &self.device());
1371 }
1372 Self::new(K::slice(self.primitive, &slices))
1373 }
1374
1375 /// Assigns values to a slice of the tensor and returns the updated tensor.
1376 ///
1377 /// This method supports advanced slicing with steps, including negative steps for reverse
1378 /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro
1379 /// providing powerful syntax for complex patterns.
1380 ///
1381 /// # Arguments
1382 ///
1383 /// * `slices` - Slice specification (same format as `slice` method)
1384 /// * `values` - Tensor with values to assign (must match slice dimensions)
1385 ///
1386 /// # Panics
1387 ///
1388 /// - If slices exceed tensor dimensions
1389 /// - If values dimensions don't match the selected slice shape
1390 /// - If a step is zero
1391 ///
1392 /// # Examples
1393 ///
1394 /// ```rust
1395 /// use burn_tensor::backend::Backend;
1396 /// use burn_tensor::{Tensor, s};
1397 ///
1398 /// fn example<B: Backend>() {
1399 /// let device = B::Device::default();
1400 ///
1401 /// // Simple assignment to a sub-region
1402 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1403 /// let values = Tensor::<B, 2>::ones([2, 3], &device);
1404 /// tensor = tensor.slice_assign([1..3, 2..5], values);
1405 /// // Now tensor[1..3, 2..5] contains ones
1406 ///
1407 /// // Single dimension assignment with step
1408 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1409 /// let values = Tensor::<B, 1>::ones([5], &device);
1410 /// tensor = tensor.slice_assign(s![0..10;2], values);
1411 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1412 ///
1413 /// // Reverse assignment with negative step
1414 /// let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1415 /// let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);
1416 /// tensor = tensor.slice_assign(s![..;-1], values);
1417 /// // Assigns in reverse: [14, 13, 12, 11, 10]
1418 ///
1419 /// // Complex multi-dimensional assignment
1420 /// let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);
1421 /// let values = Tensor::<B, 3>::ones([2, 3, 3], &device);
1422 /// tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);
1423 /// // Assigns to every 2nd row, every 2nd column, last 3 in depth
1424 ///
1425 /// // Mixed syntax example
1426 /// let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);
1427 /// let pattern = Tensor::<B, 2>::ones([4, 4], &device);
1428 /// tensor = tensor.slice_assign(s![..;2, ..;2], pattern);
1429 /// // Creates a checkerboard pattern with ones
1430 /// }
1431 /// ```
1432 ///
1433 /// # See Also
1434 ///
1435 /// - [`s!`] - The recommended macro for creating complex slice specifications
1436 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1437 /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value
1438 ///
1439 /// [`s!`]: crate::s!
1440 pub fn slice_assign<S>(self, slices: S, values: Self) -> Self
1441 where
1442 S: SliceArg,
1443 {
1444 let shape = self.shape();
1445 let slices = slices.into_slices(&shape);
1446
1447 // Check if any slice produces 0 elements (empty assignment).
1448 // Empty assignments are no-ops and would cause issues in backend implementations.
1449 let is_empty_assignment = slices
1450 .iter()
1451 .enumerate()
1452 .any(|(i, slice)| slice.output_size(shape[i]) == 0);
1453
1454 if is_empty_assignment {
1455 return self;
1456 }
1457
1458 check!(TensorCheck::slice_assign::<D>(
1459 &shape,
1460 &values.shape(),
1461 &slices
1462 ));
1463
1464 Self::new(K::slice_assign(self.primitive, &slices, values.primitive))
1465 }
1466
1467 /// Fills a slice of the tensor with a constant value and returns the updated tensor.
1468 ///
1469 /// Like other slice methods, accepts both single slices and arrays. However, this method
1470 /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)
1471 /// with a constant tensor for stepped patterns.
1472 ///
1473 /// # Arguments
1474 ///
1475 /// * `slices` - Slice specification (same format as `slice` method, but no steps)
1476 /// * `value` - The value to fill the slice with
1477 ///
1478 /// # Panics
1479 ///
1480 /// - If slices exceed tensor dimensions
1481 /// - If any slice has a step != 1 (not yet supported)
1482 ///
1483 /// # Examples
1484 ///
1485 /// ```rust
1486 /// use burn_tensor::backend::Backend;
1487 /// use burn_tensor::{Tensor, s};
1488 ///
1489 /// fn example<B: Backend>() {
1490 /// let device = B::Device::default();
1491 ///
1492 /// // Simple fill for a single dimension
1493 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1494 /// tensor = tensor.slice_fill(2..5, 1.0);
1495 /// // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]
1496 ///
1497 /// // Multi-dimensional fill
1498 /// let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);
1499 /// tensor = tensor.slice_fill([1..3, 2..5], -1.0);
1500 /// // Fills the rectangle at rows 1-2, columns 2-4 with -1
1501 ///
1502 /// // Using negative indices
1503 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1504 /// tensor = tensor.slice_fill(-3.., 2.0);
1505 /// // Fills the last 3 elements with 2.0
1506 ///
1507 /// // Complex multi-dimensional example
1508 /// let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);
1509 /// tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);
1510 /// // Sets rows 1-2, all columns, last 2 in depth to 0
1511 ///
1512 /// // Stepped slicing is supported
1513 /// let mut tensor = Tensor::<B, 1>::zeros([10], &device);
1514 /// tensor = tensor.slice_fill(s![0..10;2], 1.0);
1515 /// // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
1516 /// }
1517 /// ```
1518 ///
1519 /// # See Also
1520 ///
1521 /// - [`s!`] - The macro for creating slice specifications with steps
1522 /// - [`slice`](Self::slice) - Extract a slice from a tensor
1523 /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice
1524 ///
1525 /// [`s!`]: crate::s!
1526 pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self
1527 where
1528 S: SliceArg,
1529 {
1530 let shape = self.shape();
1531 let slices = slices.into_slices(&shape);
1532
1533 check!(TensorCheck::slice::<D>(&shape, &slices));
1534
1535 let slice_shape = shape.slice(&slices).unwrap();
1536 let value =
1537 Tensor::<B, 1, K>::from_data([value.elem::<K::Elem>()], (&self.device(), self.dtype()));
1538 let value = value.expand(slice_shape);
1539 self.slice_assign(&slices, value)
1540 }
1541
1542 /// Returns a new tensor with the specified dimension sliced.
1543 ///
1544 /// # Arguments
1545 ///
1546 /// * `dim`: The dimension to slice.
1547 /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),
1548 /// slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.
1549 ///
1550 /// # Returns
1551 ///
1552 /// A new tensor with the specified dimension sliced.
1553 ///
1554 /// # Panics
1555 ///
1556 /// If the slice is out of bounds for the specified dimension.
1557 ///
1558 /// # Examples
1559 ///
1560 /// ```rust
1561 /// # use burn_tensor::{Tensor, s};
1562 /// # use burn_tensor::backend::Backend;
1563 /// #
1564 /// # fn example<B: Backend>() {
1565 /// # let device = B::Device::default();
1566 /// let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);
1567 ///
1568 /// // Simple range slicing
1569 /// let sliced = tensor.clone().slice_dim(1, 1..3);
1570 /// assert_eq!(sliced.shape().as_slice(), [3, 2, 5]);
1571 ///
1572 /// // Slicing with step - take every 2nd element
1573 /// let sliced = tensor.clone().slice_dim(2, s![0..5;2]);
1574 /// assert_eq!(sliced.shape().as_slice(), [3, 4, 3]); // Takes indices 0, 2, 4
1575 ///
1576 /// // Reverse slicing with negative step
1577 /// let sliced = tensor.clone().slice_dim(1, s![..;-1]);
1578 /// assert_eq!(sliced.shape().as_slice(), [3, 4, 5]); // Reverses dimension 1
1579 ///
1580 /// // Select from index 2 with step 3
1581 /// let sliced = tensor.clone().slice_dim(0, s![2..;3]);
1582 /// assert_eq!(sliced.shape().as_slice(), [1, 4, 5]); // Takes only index 2
1583 ///
1584 /// // Select single index (reduces dimension to size 1)
1585 /// let sliced = tensor.slice_dim(0, 1);
1586 /// assert_eq!(sliced.shape().as_slice(), [1, 4, 5]);
1587 /// # }
1588 /// ```
1589 ///
1590 /// # See Also
1591 ///
1592 /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously
1593 /// - [`s!`] - The macro for creating complex slice specifications
1594 ///
1595 /// [`s!`]: crate::s!
1596 pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self
1597 where
1598 S: Into<Slice>,
1599 {
1600 check!(TensorCheck::check_dim::<D>(dim));
1601 let slice: Slice = slice.into();
1602
1603 let mut slices = vec![Slice::full(); D];
1604 slices[dim] = slice;
1605
1606 self.slice(&slices)
1607 }
1608
1609 /// Returns the device of the current tensor.
1610 pub fn device(&self) -> B::Device {
1611 K::device(&self.primitive)
1612 }
1613
1614 /// Move the tensor to the given device.
1615 pub fn to_device(self, device: &B::Device) -> Self {
1616 Self::new(K::to_device(self.primitive, device))
1617 }
1618
1619 /// Select tensor elements along the given dimension corresponding to the given indices.
1620 ///
1621 /// # Arguments
1622 ///
1623 /// * `dim` - The dimension to select from. Supports negative indexing.
1624 /// * `indices` - The indices of the elements to select.
1625 ///
1626 /// # Example
1627 ///
1628 /// ```rust
1629 /// use burn_tensor::backend::Backend;
1630 /// use burn_tensor::{Tensor, Int};
1631 ///
1632 /// fn example<B: Backend>() {
1633 /// let device = B::Device::default();
1634 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);
1635 /// let indices = Tensor::<B, 1, Int>::from_data([0], &device);
1636 /// let tensor = tensor.select(0, indices);
1637 /// println!("{tensor}");
1638 /// // [[1.0, -2.0, 3.0]]
1639 /// }
1640 /// ```
1641 pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {
1642 let dim = dim.expect_dim_index(D);
1643 check!(TensorCheck::select::<D>(dim));
1644 Self::new(K::select(self.primitive, dim, indices.primitive))
1645 }
1646
1647 /// Assign the selected elements along the given dimension corresponding to the given indices
1648 /// from the value tensor to the original tensor using sum reduction.
1649 ///
1650 /// # Note
1651 /// For booleans, the sum operator is logical or.
1652 ///
1653 /// # Arguments
1654 ///
1655 /// * `dim` - The dimension along which to select. Supports negative indexing.
1656 /// * `indices` - The indices to select from the tensor.
1657 /// * `values` - The values to assign to the selected indices.
1658 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1659 ///
1660 /// # Example
1661 ///
1662 /// Example using a 3D tensor:
1663 ///
1664 /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1665 /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1666 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1667 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`
1668 ///
1669 /// # Warning
1670 ///
1671 /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1672 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1673 pub fn select_assign(
1674 self,
1675 dim: impl AsIndex,
1676 indices: Tensor<B, 1, Int>,
1677 values: Tensor<B, D, K>,
1678 update: IndexingUpdateOp,
1679 ) -> Self {
1680 let dim = dim.expect_dim_index(D);
1681 check!(TensorCheck::select_assign::<D>(
1682 dim,
1683 &indices.shape(),
1684 &values.shape()
1685 ));
1686
1687 Self::new(K::select_assign(
1688 self.primitive,
1689 dim,
1690 indices.primitive,
1691 values.primitive,
1692 update,
1693 ))
1694 }
1695
1696 /// Update the given tensor with the value tensor where the mask is true.
1697 ///
1698 /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1699 /// a scalar.
1700 ///
1701 /// # Example
1702 ///
1703 /// ```rust
1704 /// use burn_tensor::backend::Backend;
1705 /// use burn_tensor::{Tensor, Shape, Bool};
1706 ///
1707 /// fn example<B: Backend>() {
1708 /// let device = B::Device::default();
1709 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1710 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1711 /// let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1712 /// let tensor = tensor.mask_where(mask, value);
1713 /// println!("{tensor}");
1714 /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1715 /// }
1716 /// ```
1717 pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1718 Self::new(K::mask_where(
1719 self.primitive,
1720 mask.primitive,
1721 value.primitive,
1722 ))
1723 }
1724
1725 /// Update the given tensor with the value where the mask is true.
1726 ///
1727 /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1728 /// a tensor.
1729 ///
1730 /// # Example
1731 ///
1732 /// ```rust
1733 /// use burn_tensor::backend::Backend;
1734 /// use burn_tensor::{Tensor, Shape, Bool};
1735 ///
1736 /// fn example<B: Backend>() {
1737 /// let device = B::Device::default();
1738 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1739 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1740 /// let tensor = tensor.mask_fill(mask, 3.0);
1741 /// println!("{tensor}");
1742 /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1743 /// }
1744 /// ```
1745 pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1746 let value = Scalar::new(value, &self.dtype());
1747 Self::new(K::mask_fill(self.primitive, mask.primitive, value))
1748 }
1749
1750 /// Gather tensor elements corresponding to the given indices from the specified dim.
1751 ///
1752 /// Example using a 3D tensor:
1753 ///
1754 /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1755 /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1756 /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1757 ///
1758 /// # Notes
1759 ///
1760 /// The index tensor should have the same shape as the original tensor except for the dim
1761 /// specified.
1762 ///
1763 /// # Warning
1764 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1765 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1766 pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1767 check!(TensorCheck::gather::<D>(
1768 dim,
1769 &self.shape(),
1770 &indices.shape()
1771 ));
1772
1773 Self::new(K::gather(dim, self.primitive, indices.primitive))
1774 }
1775
1776 /// Assign the gathered elements corresponding to the given indices along the specified dimension
1777 /// from the value tensor to the original tensor using sum reduction.
1778 ///
1779 /// Example using a 3D tensor:
1780 ///
1781 /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1782 /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1783 /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1784 ///
1785 /// # Arguments
1786 /// * `dim` - The axis along which to scatter elements.
1787 /// * `indices` - The indices of the elements to scatter.
1788 /// * `values` - The values to scatter into the tensor.
1789 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1790 ///
1791 /// # Notes
1792 ///
1793 /// The index tensor should have the same shape as the original tensor except for the specified
1794 /// dimension. The value and index tensors should have the same shape.
1795 ///
1796 /// Other references to the input tensor will not be modified by this operation.
1797 ///
1798 /// # Warning
1799 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1800 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1801 ///
1802 /// # Panics
1803 /// If the `update` is not `IndexingUpdateOp::Add`. Other operations are currently not implemented.
1804 pub fn scatter(
1805 self,
1806 dim: usize,
1807 indices: Tensor<B, D, Int>,
1808 values: Self,
1809 update: IndexingUpdateOp,
1810 ) -> Self {
1811 check!(TensorCheck::scatter::<D>(
1812 dim,
1813 &self.shape(),
1814 &indices.shape(),
1815 &values.shape()
1816 ));
1817
1818 Self::new(K::scatter(
1819 dim,
1820 self.primitive,
1821 indices.primitive,
1822 values.primitive,
1823 update,
1824 ))
1825 }
1826
1827 /// Multi-dimensional scatter: update the tensor at locations given by `indices` using the specified `update` operation.
1828 ///
1829 /// The size of `indices`'s last axis (call it `K`) indexes the leading `K` dims of `self`;
1830 /// the batch shape `indices.shape[0..M-1]` is preserved. `values` has shape
1831 /// `indices.shape[0..M-1] ++ self.shape[K..D]`. Constraints: `K <= D` and `M >= 1`.
1832 ///
1833 /// # Arguments
1834 /// * `indices` - The indices of the elements to scatter.
1835 /// * `values` - The values to scatter into the tensor.
1836 /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).
1837 ///
1838 /// # Note
1839 ///
1840 /// When `indices` contains duplicate entries, behavior varies by operation:
1841 /// - For `Add`, accumulation is supported, though results may be non-deterministic on GPU
1842 /// backends.
1843 /// - For other operations (`Assign`, `Mul`, `Min`, `Max`), duplicate indices result in
1844 /// undefined behavior for both the forward result and the backward gradients.
1845 ///
1846 /// For deterministic results and correct gradient calculation across all operations,
1847 /// `indices` should contain unique entries.
1848 ///
1849 /// # Warning
1850 ///
1851 /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1852 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1853 pub fn scatter_nd<const M: usize, const DV: usize>(
1854 self,
1855 indices: Tensor<B, M, Int>,
1856 values: Tensor<B, DV, K>,
1857 update: IndexingUpdateOp,
1858 ) -> Self {
1859 check!(TensorCheck::scatter_nd::<D, M, DV>(
1860 &self.shape(),
1861 &indices.shape(),
1862 &values.shape()
1863 ));
1864 Self::new(K::scatter_nd(
1865 self.primitive,
1866 indices.primitive,
1867 values.primitive,
1868 update,
1869 ))
1870 }
1871
1872 /// Multi-dimensional gather: collect slices from `self` at multi-index locations
1873 /// specified by `indices`.
1874 ///
1875 /// The size of `indices`'s last axis (call it `K`) indexes the leading `K` dims of `self`;
1876 /// the batch shape `indices.shape[0..M-1]` is preserved. The output has shape
1877 /// `indices.shape[0..M-1] ++ self.shape[K..D]`. Constraints: `K <= D` and `M >= 1`.
1878 ///
1879 /// # Warning
1880 ///
1881 /// Not all backends have runtime bound checks for the indices, so make sure they are valid.
1882 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1883 pub fn gather_nd<const M: usize, const DV: usize>(
1884 self,
1885 indices: Tensor<B, M, Int>,
1886 ) -> Tensor<B, DV, K> {
1887 check!(TensorCheck::gather_nd::<D, M, DV>(&indices.shape()));
1888 Tensor::new(K::gather_nd(self.primitive, indices.primitive))
1889 }
1890
1891 /// Converts the data of the current tensor.
1892 ///
1893 /// # Note
1894 ///
1895 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1896 /// tensors at once. This may improve laziness, especially if executed on a different
1897 /// thread in native environments.
1898 pub fn into_data(self) -> TensorData {
1899 self.try_into_data().expect(
1900 "Error while reading data: use `try_into_data` instead to catch the error at runtime",
1901 )
1902 }
1903
1904 /// Converts the data of the current tensor and returns any error that might have occurred since the
1905 /// last time the device was synchronized.
1906 ///
1907 /// # Note
1908 ///
1909 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1910 /// tensors at once. This may improve laziness, especially if executed on a different
1911 /// thread in native environments.
1912 pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {
1913 crate::try_read_sync(self.into_data_async()).expect(
1914 "Failed to read tensor data synchronously.
1915 This can happen on platforms that don't support blocking futures like WASM.
1916 If possible, try using into_data_async instead.",
1917 )
1918 }
1919
1920 /// Converts the data of the current tensor.
1921 ///
1922 /// # Note
1923 ///
1924 /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple
1925 /// tensors at once. This may improve laziness, especially if executed on a different
1926 /// thread in native environments.
1927 pub fn to_data(&self) -> TensorData {
1928 self.clone().into_data()
1929 }
1930
1931 /// Returns the data of the current tensor.
1932 pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {
1933 K::into_data_async(self.primitive).await
1934 }
1935
1936 /// Returns the data of the current tensor.
1937 pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {
1938 self.clone().into_data_async().await
1939 }
1940
1941 /// Create a tensor from the given data on the given device.
1942 pub fn from_data<T>(data: T, options: impl Into<TensorCreationOptions<B>>) -> Self
1943 where
1944 T: Into<TensorData>,
1945 {
1946 let data = data.into();
1947 check!(TensorCheck::creation_ops::<D>(
1948 "From Data",
1949 data.shape.as_slice()
1950 ));
1951
1952 // Use the given dtype when provided, otherwise default device dtype
1953 let opt = options.into();
1954 let dtype = opt.resolve_dtype::<K>();
1955 Self::new(K::from_data(data, &opt.device, dtype))
1956 }
1957
1958 /// Repeat the tensor along the given dimension.
1959 ///
1960 /// The output tensor has the same shape, except along the given dimension.
1961 ///
1962 /// # Arguments
1963 /// - `dim`: The dimension to repeat.
1964 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1965 ///
1966 /// # Returns
1967 ///
1968 /// A new tensor with the given dimension repeated `times` times.
1969 ///
1970 /// # Example
1971 ///
1972 /// ```rust
1973 /// use burn_tensor::backend::Backend;
1974 /// use burn_tensor::Tensor;
1975 ///
1976 /// fn example<B: Backend>() {
1977 /// let device = Default::default();
1978 /// // Create a 2D tensor with dimensions [3, 2]
1979 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1980 ///
1981 /// // Repeat the tensor along the dimension 0 twice.
1982 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1983 /// // The resulting tensor will have dimensions [6, 2].
1984 /// let repeated = tensor.repeat_dim(0, 2);
1985 /// println!("{repeated}");
1986 /// }
1987 /// ```
1988 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1989 if times > 0 {
1990 Self::new(K::repeat_dim(self.primitive, dim, times))
1991 } else {
1992 let shape = self.shape().repeat(dim, times).unwrap();
1993 Self::empty(shape, &self.device())
1994 }
1995 }
1996
1997 /// Repeat the tensor along the given dimensions.
1998 /// # Arguments
1999 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
2000 ///
2001 /// # Returns
2002 ///
2003 /// A new tensor with the given dimensions repeated `times` times.
2004 ///
2005 /// # Panics
2006 ///
2007 /// If `sizes` contains more elements than the number of dimensions.
2008 ///
2009 /// # Example
2010 ///
2011 /// ```rust
2012 ///
2013 /// use burn_tensor::backend::Backend;
2014 /// use burn_tensor::Tensor;
2015 ///
2016 /// fn example<B: Backend>() {
2017 /// let device = Default::default();
2018 /// // Create a 2D tensor with dimensions [3, 2]
2019 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2020 ///
2021 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
2022 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
2023 /// // The resulting tensor will have dimensions [6, 2].
2024 /// let repeated = tensor.repeat(&[2, 1]);
2025 /// }
2026 /// ```
2027 pub fn repeat(self, sizes: &[usize]) -> Self {
2028 if sizes.contains(&0) {
2029 let mut shape = self.shape();
2030 for (dim, ×) in sizes.iter().enumerate() {
2031 shape = shape.repeat(dim, times).unwrap();
2032 }
2033
2034 return Self::empty(shape, &self.device());
2035 }
2036
2037 let mut tensor = self;
2038 for (dim, ×) in sizes.iter().enumerate() {
2039 if times > 1 {
2040 tensor = tensor.repeat_dim(dim, times);
2041 }
2042 }
2043 tensor
2044 }
2045
2046 /// Applies element-wise equal comparison.
2047 ///
2048 /// # Returns
2049 /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
2050 ///
2051 /// # Panics
2052 ///
2053 /// If the two tensors don't have the same shape.
2054 ///
2055 /// # Example
2056 ///
2057 /// ```rust
2058 /// use burn_tensor::backend::Backend;
2059 /// use burn_tensor::Tensor;
2060 ///
2061 /// fn example<B: Backend>() {
2062 /// let device = Default::default();
2063 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2064 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2065 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
2066 /// // [[false, true], [true, true], [true, true]]
2067 /// let equal = t1.equal(t2);
2068 /// println!("{equal}");
2069 /// }
2070 /// ```
2071 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
2072 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
2073 Tensor::new(K::equal(self.primitive, other.primitive))
2074 }
2075
2076 /// Applies element-wise non-equality comparison.
2077 ///
2078 /// # Returns
2079 /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
2080 ///
2081 /// # Panics
2082 ///
2083 /// If the two tensors don't have the same shape.
2084 ///
2085 /// # Example
2086 ///
2087 /// ```rust
2088 /// use burn_tensor::backend::Backend;
2089 /// use burn_tensor::Tensor;
2090 ///
2091 /// fn example<B: Backend>() {
2092 /// let device = Default::default();
2093 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2094 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
2095 /// // Compare the elements of the two 2D tensors for inequality.
2096 /// // [[true, false], [false, false], [false, false]]
2097 /// let not_equal = t1.not_equal(t2);
2098 /// println!("{not_equal}");
2099 /// }
2100 /// ```
2101 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
2102 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
2103 Tensor::new(K::not_equal(self.primitive, other.primitive))
2104 }
2105
2106 /// Applies element wise equal comparison and returns a boolean tensor.
2107 ///
2108 /// # Arguments
2109 ///
2110 /// * `other` - The element to compare.
2111 ///
2112 /// # Example
2113 ///
2114 /// ```rust
2115 /// use burn_tensor::backend::Backend;
2116 /// use burn_tensor::{Tensor, Shape};
2117 ///
2118 /// fn example<B: Backend>() {
2119 /// let device = B::Device::default();
2120 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2121 /// let tensor = tensor.equal_elem(3.0);
2122 /// println!("{tensor}");
2123 /// // [[false, false, true], [false, false, false]]
2124 /// }
2125 /// ```
2126 pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2127 let other = Scalar::new(other, &self.dtype());
2128 Tensor::new(K::equal_elem(self.primitive, other))
2129 }
2130
2131 /// Applies element wise non-equality comparison and returns a boolean tensor.
2132 ///
2133 /// # Arguments
2134 ///
2135 /// * `other` - The element to compare.
2136 ///
2137 /// # Example
2138 ///
2139 /// ```rust
2140 /// use burn_tensor::backend::Backend;
2141 /// use burn_tensor::{Tensor, Shape};
2142 ///
2143 /// fn example<B: Backend>() {
2144 /// let device = B::Device::default();
2145 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2146 /// let tensor = tensor.not_equal_elem(3.0);
2147 /// println!("{tensor}");
2148 /// // [[true, true, false], [true, true, true]]
2149 /// }
2150 /// ```
2151 pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
2152 let other = Scalar::new(other, &self.dtype());
2153 Tensor::new(K::not_equal_elem(self.primitive, other))
2154 }
2155
2156 /// Concatenates all tensors into a new one along the given dimension.
2157 ///
2158 /// # Panics
2159 ///
2160 /// - If `dim` is higher than the rank.
2161 /// - If `tensors` is an empty vector.
2162 /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
2163 ///
2164 /// # Example
2165 ///
2166 /// ```rust
2167 /// use burn_tensor::backend::Backend;
2168 /// use burn_tensor::Tensor;
2169 ///
2170 /// fn example<B: Backend>() {
2171 /// let device = Default::default();
2172 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
2173 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2174 ///
2175 /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
2176 /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
2177 /// // The resulting tensor will have shape [2, 7].
2178 /// let concat = Tensor::cat(vec![t1, t2], 1);
2179 /// println!("{concat}");
2180 /// }
2181 /// ```
2182 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
2183 check!(TensorCheck::cat(&tensors, dim));
2184
2185 // Filter out tensors with size 0 along the concatenation dimension.
2186 // Empty tensors don't contribute to the output and would cause issues
2187 // in backend implementations (e.g., division by zero in slice_assign).
2188 // Safety: TensorCheck::cat ensures tensors is non-empty
2189 let first_tensor = tensors.first().unwrap();
2190 let device = first_tensor.device();
2191 let mut shape = first_tensor.shape();
2192
2193 let non_empty_primitives: Vec<_> = tensors
2194 .into_iter()
2195 .filter(|t| t.shape()[dim] > 0)
2196 .map(|t| t.primitive)
2197 .collect();
2198
2199 // If all tensors were empty, return an empty tensor with size 0 on concat dim
2200 if non_empty_primitives.is_empty() {
2201 shape[dim] = 0;
2202 return Self::empty(shape, &device);
2203 }
2204
2205 Self::new(K::cat(non_empty_primitives, dim))
2206 }
2207
2208 /// Concatenates all tensors into a new one along a new dimension.
2209 ///
2210 /// # Panics
2211 ///
2212 /// - If all tensors don't have the same shape.
2213 /// - If given dimension is not with range of 0..D2
2214 ///
2215 /// # Example
2216 ///
2217 /// ```rust
2218 /// use burn_tensor::backend::Backend;
2219 /// use burn_tensor::Tensor;
2220 ///
2221 /// fn example<B: Backend>() {
2222 /// let device = Default::default();
2223 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2224 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2225 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
2226 ///
2227 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
2228 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
2229 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
2230 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
2231 /// // The resulting tensor will have shape [3, 2, 3].
2232 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
2233 /// println!("{stacked}");
2234 /// }
2235 /// ```
2236 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
2237 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
2238 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
2239 Tensor::<B, D2, K>::cat(tensors, dim)
2240 }
2241
2242 /// Iterate over slices of tensors alongside a given dimension.
2243 ///
2244 /// # Panics
2245 ///
2246 /// If given dimension is greater than or equal to tensor rank.
2247 ///
2248 /// # Returns
2249 ///
2250 /// A tensor iterator.
2251 ///
2252 /// # Example
2253 ///
2254 /// ```rust
2255 /// use burn_tensor::backend::Backend;
2256 /// use burn_tensor::Tensor;
2257 /// fn example<B: Backend>() {
2258 /// let device = Default::default();
2259 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
2260 /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
2261 /// let iter = tensor.iter_dim(0);
2262 /// for (i,tensor) in iter.enumerate() {
2263 /// println!("Tensor {}: {}", i, tensor);
2264 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
2265 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
2266 /// }
2267 /// }
2268 /// ```
2269 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
2270 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
2271 DimIter::new(self, dim)
2272 }
2273
2274 /// Returns a new tensor with the given dimension narrowed to the given range.
2275 ///
2276 /// # Panics
2277 ///
2278 /// - If the dimension is greater than the number of dimensions of the tensor.
2279 /// - If the given range exceeds the number of elements on the given dimension.
2280 ///
2281 /// # Returns
2282 ///
2283 /// A new tensor with the given dimension narrowed to the given range.
2284 ///
2285 /// # Example
2286 ///
2287 /// ```rust
2288 /// use burn_tensor::backend::Backend;
2289 /// use burn_tensor::Tensor;
2290 ///
2291 /// fn example<B: Backend>() {
2292 /// let device = Default::default();
2293 /// // Create a 2D tensor with dimensions [4, 3]
2294 /// let tensor = Tensor::<B, 2>::from_data(
2295 /// [
2296 /// [3.0, 4.9, 2.0],
2297 /// [2.0, 1.9, 3.0],
2298 /// [6.0, 1.5, 7.0],
2299 /// [3.0, 4.9, 9.0],
2300 /// ],
2301 /// &device,
2302 /// );
2303 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
2304 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
2305 /// // The resulting tensor will have dimensions [3, 3].
2306 /// let narrowed = tensor.narrow(0, 1, 3);
2307 /// println!("{narrowed}");
2308 /// }
2309 /// ```
2310 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
2311 check!(TensorCheck::dim_ops::<D>("narrow", dim));
2312 check!(TensorCheck::narrow(&self, dim, start, length));
2313 let dims = self.dims();
2314
2315 let ranges: [Range<usize>; D] = dims
2316 .iter()
2317 .enumerate()
2318 .map(|(i, d)| {
2319 if i == dim {
2320 start..(start + length)
2321 } else {
2322 0..*d
2323 }
2324 })
2325 .collect::<Vec<_>>()
2326 .try_into()
2327 .unwrap();
2328
2329 Self::slice(self, ranges)
2330 }
2331
2332 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
2333 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2334 ///
2335 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2336 /// Otherwise all chunks will be of equal size except for the last one.
2337 ///
2338 /// # Panics
2339 ///
2340 /// If the dimension is greater than the number of dimensions of the tensor.
2341 ///
2342 /// # Returns
2343 /// A vector of tensors.
2344 ///
2345 /// # Example
2346 ///
2347 /// ```rust
2348 /// use burn_tensor::backend::Backend;
2349 /// use burn_tensor::Tensor;
2350 ///
2351 /// fn example<B: Backend>() {
2352 /// let device = Default::default();
2353 /// // Create a 2D tensor with dimensions [4, 3]
2354 /// let tensor = Tensor::<B, 2>::from_data(
2355 /// [
2356 /// [3.0, 4.9, 2.0],
2357 /// [2.0, 1.9, 3.0],
2358 /// [6.0, 1.5, 7.0],
2359 /// [3.0, 4.9, 9.0],
2360 /// ],
2361 /// &device,
2362 /// );
2363 /// // Split the tensor along the dimension 1 into 2 chunks.
2364 /// // The first chuck will have shape [4, 2]:
2365 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
2366 /// // The second chunk will have shape [4, 1]:
2367 /// // [[2.0], [3.0], [7.0], [9.0]]
2368 /// let chunks = tensor.chunk(2, 1);
2369 /// println!("{chunks:?}");
2370 /// }
2371 /// ```
2372 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
2373 check!(TensorCheck::dim_ops::<D>("chunk", dim));
2374 let size = self.shape()[dim];
2375 if size < chunks {
2376 return (0..size)
2377 .map(|i| Self::narrow(self.clone(), dim, i, 1))
2378 .collect();
2379 }
2380
2381 let mut tensors = Vec::with_capacity(chunks);
2382 let mut sum_chunk_size = 0;
2383 if size.is_multiple_of(chunks) {
2384 let chunk_size = size / chunks;
2385 for _ in 0..chunks {
2386 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2387 sum_chunk_size += chunk_size;
2388 }
2389 } else {
2390 let chunk_size = (size / chunks) + 1; // assumes not divisible
2391 for _ in 0..chunks - 1 {
2392 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));
2393 sum_chunk_size += chunk_size;
2394 }
2395 let remainder = size % chunk_size;
2396 tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));
2397 }
2398
2399 tensors
2400 }
2401
2402 /// Splits the tensor into chunks of a specified size along a given dimension.
2403 /// Each chunk is a view of the original tensor.
2404 ///
2405 /// If the tensor size along the given dimension is not divisible by `split_size`,
2406 /// then the last chunk will be smaller.
2407 ///
2408 /// # Panics
2409 ///
2410 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
2411 ///
2412 /// # Returns
2413 ///
2414 /// A vector of tensors.
2415 ///
2416 /// # Example
2417 /// ```rust
2418 /// use burn_tensor::backend::Backend;
2419 /// use burn_tensor::Tensor;
2420 ///
2421 /// fn example<B: Backend>() {
2422 /// let device = Default::default();
2423 /// // Create a 1D tensor with 5 elements
2424 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2425 /// // Split the tensor into chunks of size 2 along dimension 0
2426 /// let chunks = tensor.split(2, 0);
2427 /// // The result is a vector of tensors:
2428 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
2429 /// println!("{:?}", chunks);
2430 /// }
2431 /// ```
2432 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
2433 check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));
2434 let size = self.shape()[dim];
2435 let mut tensors = Vec::new();
2436
2437 let mut start = 0;
2438 while start < size {
2439 let length = usize::min(split_size, size - start);
2440 tensors.push(Self::narrow(self.clone(), dim, start, length));
2441 start += length;
2442 }
2443
2444 tensors
2445 }
2446
2447 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2448 /// Each chunk is a view of the original tensor.
2449 ///
2450 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2451 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2452 ///
2453 /// # Panics
2454 ///
2455 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
2456 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2457 ///
2458 /// # Returns
2459 ///
2460 /// A vector of tensors.
2461 ///
2462 /// # Example
2463 /// ```rust
2464 /// use burn_tensor::backend::Backend;
2465 /// use burn_tensor::Tensor;
2466 ///
2467 /// fn example<B: Backend>() {
2468 /// let device = Default::default();
2469 /// // Create a 1D tensor with 5 elements
2470 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
2471 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
2472 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
2473 /// // The result is a vector of tensors:
2474 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
2475 /// println!("{:?}", chunks);
2476 /// }
2477 /// ```
2478 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
2479 check!(TensorCheck::split_with_sizes::<D>(
2480 &self.shape(),
2481 &split_sizes,
2482 dim
2483 ));
2484 let mut tensors = Vec::new();
2485
2486 let mut start = 0;
2487 for length in split_sizes {
2488 if length == 0 {
2489 continue;
2490 }
2491 tensors.push(Self::narrow(self.clone(), dim, start, length));
2492 start += length;
2493 }
2494
2495 tensors
2496 }
2497
2498 /// Tests if any element in the `tensor` evaluates to True.
2499 ///
2500 /// # Arguments
2501 ///
2502 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2503 ///
2504 /// # Returns
2505 ///
2506 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
2507 /// evaluates to True, False otherwise.
2508 ///
2509 /// # Example
2510 ///
2511 /// ```rust
2512 /// use burn_tensor::backend::Backend;
2513 /// use burn_tensor::{Tensor, Bool};
2514 ///
2515 /// fn example<B: Backend>() {
2516 /// let device = Default::default();
2517 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
2518 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
2519 ///
2520 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2521 /// let any_tensor = tensor.any();
2522 /// println!("{}", any_tensor);
2523 /// // Tensor { data: [true], ... }
2524 ///
2525 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
2526 /// let any_tensor_two = tensor_two.any();
2527 /// println!("{}", any_tensor_two);
2528 /// // Tensor { data: [false], ... }
2529 /// }
2530 /// ```
2531 pub fn any(self) -> Tensor<B, 1, Bool> {
2532 Tensor::new(K::any(self.primitive))
2533 }
2534
2535 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
2536 ///
2537 /// # Arguments
2538 ///
2539 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2540 /// * `dim` - The axis along which to test.
2541 ///
2542 /// # Returns
2543 ///
2544 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2545 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
2546 /// evaluates to True, False otherwise.
2547 ///
2548 /// # Example
2549 ///
2550 /// ```rust
2551 /// use burn_tensor::backend::Backend;
2552 /// use burn_tensor::{Tensor, Bool};
2553 ///
2554 /// fn example<B: Backend>() {
2555 /// let device = Default::default();
2556 /// let tensor =
2557 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
2558 /// // Check if any element in the tensor evaluates to True along the dimension 1.
2559 /// // [[true], [true]],
2560 /// let any_dim = tensor.clone().any_dim(1);
2561 /// println!("{any_dim}");
2562 /// }
2563 /// ```
2564 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2565 Tensor::new(K::any_dim(self.primitive, dim))
2566 }
2567
2568 /// Tests if all elements in the `tensor` evaluate to True.
2569 ///
2570 /// # Arguments
2571 ///
2572 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2573 ///
2574 /// # Returns
2575 ///
2576 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
2577 /// evaluate to True, False otherwise.
2578 ///
2579 /// # Example
2580 ///
2581 /// ```rust
2582 /// use burn_tensor::backend::Backend;
2583 /// use burn_tensor::{Tensor, Bool};
2584 ///
2585 /// fn example<B: Backend>() {
2586 /// let device = Default::default();
2587 /// let tensor =
2588 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
2589 /// // Check if all elements in the tensor evaluate to True (which is not the case).
2590 /// // [false]
2591 /// let all = tensor.all();
2592 /// println!("{all}");
2593 /// }
2594 /// ```
2595 pub fn all(self) -> Tensor<B, 1, Bool> {
2596 Tensor::new(K::all(self.primitive))
2597 }
2598
2599 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2600 ///
2601 /// # Arguments
2602 ///
2603 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
2604 /// * `dim` - The axis along which to test.
2605 ///
2606 /// # Returns
2607 ///
2608 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
2609 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
2610 /// evaluates to True, False otherwise.
2611 ///
2612 /// # Example
2613 ///
2614 /// ```rust
2615 /// use burn_tensor::backend::Backend;
2616 /// use burn_tensor::{Tensor, Bool};
2617 ///
2618 /// fn example<B: Backend>() {
2619 /// let device = Default::default();
2620 /// let tensor =
2621 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
2622 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
2623 /// // [[true, true, false]]
2624 /// let all_dim = tensor.clone().all_dim(0);
2625 /// println!("{all_dim}");
2626 /// }
2627 /// ```
2628 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
2629 Tensor::new(K::all_dim(self.primitive, dim))
2630 }
2631
2632 /// Convert the tensor into a scalar.
2633 ///
2634 /// # Panics
2635 ///
2636 /// - If the tensor doesn't have one element.
2637 /// - If the backend fails to read the tensor data synchronously.
2638 ///
2639 /// # Returns
2640 ///
2641 /// The scalar value of the tensor.
2642 ///
2643 /// # Example
2644 ///
2645 /// ```rust
2646 /// use burn_tensor::backend::Backend;
2647 /// use burn_tensor::Tensor;
2648 ///
2649 /// fn example<B: Backend>() {
2650 /// let device = Default::default();
2651 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
2652 /// // Convert the tensor with a single element into a scalar.
2653 /// let scalar = tensor.into_scalar();
2654 /// println!("{scalar}");
2655 /// }
2656 /// ```
2657 pub fn into_scalar(self) -> K::Elem {
2658 crate::try_read_sync(self.into_scalar_async())
2659 .expect(
2660 "Failed to read tensor data synchronously. This can happen on platforms
2661 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2662 )
2663 .expect("Error while reading data: use `try_into_scalar` instead to catch the error at runtime")
2664 }
2665
2666 /// Convert the tensor into a scalar and returns any error that might have occurred since the
2667 /// last time the device was synchronized.
2668 ///
2669 /// # Panics
2670 ///
2671 /// - If the tensor doesn't have one element.
2672 /// - If the backend fails to read the tensor data synchronously.
2673 ///
2674 /// # Returns
2675 ///
2676 /// The scalar value of the tensor.
2677 pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {
2678 crate::try_read_sync(self.into_scalar_async()).expect(
2679 "Failed to read tensor data synchronously. This can happen on platforms
2680 that don't support blocking futures like WASM. Try into_scalar_async instead.",
2681 )
2682 }
2683
2684 /// Convert the tensor into a scalar.
2685 ///
2686 /// # Panics
2687 ///
2688 /// If the tensor doesn't have one element.
2689 pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {
2690 check!(TensorCheck::into_scalar::<D>(&self.shape()));
2691
2692 Ok(self.into_data_async().await?.iter().next().unwrap())
2693 }
2694
2695 /// Broadcast the tensor to the given shape.
2696 ///
2697 /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
2698 /// (which can be inferred with `-1`).
2699 ///
2700 /// # Arguments
2701 ///
2702 /// * `shape` - The shape to broadcast the tensor to.
2703 /// Can contain -1 for dimensions that should be inferred.
2704 /// The number of elements in the shape must be greater or equal as
2705 /// the number of dimensions of the tensor.
2706 ///
2707 /// # Panics
2708 ///
2709 /// If the tensor cannot be broadcasted to the given shape.
2710 ///
2711 /// # Returns
2712 ///
2713 /// A new tensor with the given shape.
2714 ///
2715 /// # Example
2716 ///
2717 /// ```rust
2718 /// use burn_tensor::backend::Backend;
2719 /// use burn_tensor::Tensor;
2720 ///
2721 /// fn example<B: Backend>() {
2722 /// let device = Default::default();
2723 /// // Create a 2D tensor with dimensions [3, 1]
2724 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
2725 /// // Expand the tensor to a new shape [3, 4]
2726 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
2727 /// let expanded = tensor.expand([3, 4]);
2728 /// println!("{}", expanded);
2729 /// }
2730 /// ```
2731 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
2732 let shape = shape.into_shape(&self.shape());
2733 check!(TensorCheck::expand::<D, D2>(
2734 "expand",
2735 &self.shape(),
2736 &shape,
2737 ));
2738
2739 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
2740 }
2741
2742 /// Unfold windows along a dimension.
2743 ///
2744 /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
2745 /// where windows are advanced by `step` at each index.
2746 ///
2747 /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
2748 ///
2749 /// The new view will have the unfolded dimension replaced by two dimensions;
2750 /// one in the position of the original dimension, with size equal to the number of windows,
2751 /// and one appended to the right-most position, with size equal to `size`.
2752 ///
2753 /// # Warning
2754 ///
2755 /// For the `ndarray` backend; this is not a view but a copy
2756 /// with duplicated data.
2757 ///
2758 /// # Arguments
2759 ///
2760 /// * `dim` - the dimension to unfold.
2761 /// * `size` - the size of each unfolded window.
2762 /// * `step` - the step between each window.
2763 ///
2764 /// # Returns
2765 ///
2766 /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
2767 pub fn unfold<const D2: usize, I: AsIndex>(
2768 self,
2769 dim: I,
2770 size: usize,
2771 step: usize,
2772 ) -> Tensor<B, D2, K> {
2773 let dim = dim.expect_dim_index(D);
2774 check!(TensorCheck::unfold::<D, D2>(
2775 "unfold",
2776 &self.shape(),
2777 dim,
2778 size,
2779 step,
2780 ));
2781 Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))
2782 }
2783}
2784
2785/// Iterator given by (Tensor::iter_dim).
2786pub struct DimIter<B, const D: usize, K>
2787where
2788 B: Backend,
2789 K: BasicOps<B>,
2790{
2791 start: usize,
2792 end: usize,
2793 dim: usize,
2794 ranges: [Range<usize>; D],
2795 tensor: Tensor<B, D, K>,
2796}
2797
2798impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
2799 type Item = Tensor<B, D, K>;
2800
2801 fn next(&mut self) -> Option<Self::Item> {
2802 if self.start >= self.end {
2803 return None;
2804 }
2805
2806 let mut ranges = self.ranges.clone();
2807 ranges[self.dim] = self.start..(self.start + 1);
2808
2809 let slice = self.tensor.clone().slice(ranges);
2810 self.start += 1;
2811
2812 Some(slice)
2813 }
2814}
2815
2816impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
2817 fn next_back(&mut self) -> Option<Self::Item> {
2818 if self.start >= self.end {
2819 return None;
2820 }
2821
2822 let mut ranges = self.ranges.clone();
2823 ranges[self.dim] = (self.end - 1)..self.end;
2824
2825 let slice = self.tensor.clone().slice(ranges);
2826 self.end = self.end.saturating_sub(1);
2827
2828 Some(slice)
2829 }
2830}
2831
2832impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
2833 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
2834 let dims = tensor.dims();
2835 let ranges = dims
2836 .iter()
2837 .map(|&dim| 0..dim)
2838 .collect::<Vec<Range<usize>>>();
2839 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
2840 Self {
2841 end: dims[dim],
2842 ranges,
2843 start: 0,
2844 dim,
2845 tensor,
2846 }
2847 }
2848}
2849
2850impl<B, const D: usize, K> Tensor<B, D, K>
2851where
2852 B: Backend,
2853 K: BasicOps<B>,
2854 <K as BasicOps<B>>::Elem: Debug,
2855{
2856 #[inline]
2857 fn push_newline_indent(acc: &mut String, indent: usize) {
2858 acc.push('\n');
2859 for _ in 0..indent {
2860 acc.push(' ');
2861 }
2862 }
2863 fn fmt_inner_tensor(
2864 &self,
2865 acc: &mut String,
2866 depth: usize,
2867 multi_index: &mut [usize],
2868 range: (usize, usize),
2869 precision: Option<usize>,
2870 ) {
2871 let (start, end) = range;
2872 for i in start..end {
2873 if i > 0 {
2874 acc.push_str(", ");
2875 }
2876 multi_index[depth] = i;
2877 let range: [Range<usize>; D] =
2878 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
2879
2880 let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());
2881
2882 if let Some(Ok(data)) = data {
2883 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
2884 match (precision, K::name()) {
2885 (Some(p), "Float") => acc.push_str(&format!("{elem:.p$}")),
2886 (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
2887 _ => acc.push_str(&format!("{elem:?}")),
2888 }
2889 } else {
2890 acc.push_str("<Tensor data not available>");
2891 }
2892 }
2893 }
2894
2895 fn fmt_outer_tensor(
2896 &self,
2897 acc: &mut String,
2898 depth: usize,
2899 multi_index: &mut [usize],
2900 print_options: &PrintOptions,
2901 summarize: bool,
2902 range: (usize, usize),
2903 ) {
2904 let (start, end) = range;
2905 for i in start..end {
2906 if i > start {
2907 acc.push(',');
2908 Self::push_newline_indent(acc, depth + 1);
2909 }
2910 acc.push('[');
2911 multi_index[depth] = i;
2912 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
2913 acc.push(']');
2914 }
2915 }
2916
2917 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
2918 ///
2919 /// This function is designed to work with tensors of any dimensionality.
2920 /// It traverses the tensor dimensions recursively, converting the elements
2921 /// to strings and appending them to the accumulator string with the
2922 /// appropriate formatting.
2923 ///
2924 /// # Arguments
2925 ///
2926 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
2927 /// * `depth` - The current depth of the tensor dimensions being processed.
2928 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
2929 fn display_recursive(
2930 &self,
2931 acc: &mut String,
2932 depth: usize,
2933 multi_index: &mut [usize],
2934 print_options: &PrintOptions,
2935 summarize: bool,
2936 ) {
2937 let edge_items = print_options.edge_items;
2938
2939 if depth == 0 {
2940 acc.push('[');
2941 }
2942
2943 if depth == self.dims().len() - 1 {
2944 // if we are at the innermost dimension, just push its elements into the accumulator
2945 if summarize && self.dims()[depth] > 2 * edge_items {
2946 // print the starting `edge_items` elements
2947 self.fmt_inner_tensor(
2948 acc,
2949 depth,
2950 multi_index,
2951 (0, edge_items),
2952 print_options.precision,
2953 );
2954 acc.push_str(", ...");
2955 // print the last `edge_items` elements
2956 self.fmt_inner_tensor(
2957 acc,
2958 depth,
2959 multi_index,
2960 (self.dims()[depth] - edge_items, self.dims()[depth]),
2961 print_options.precision,
2962 );
2963 } else {
2964 // print all the elements
2965 self.fmt_inner_tensor(
2966 acc,
2967 depth,
2968 multi_index,
2969 (0, self.dims()[depth]),
2970 print_options.precision,
2971 );
2972 }
2973 } else {
2974 // otherwise, iterate through the current dimension and recursively display the inner tensors
2975 if summarize && self.dims()[depth] > 2 * edge_items {
2976 self.fmt_outer_tensor(
2977 acc,
2978 depth,
2979 multi_index,
2980 print_options,
2981 summarize,
2982 (0, edge_items),
2983 );
2984
2985 acc.push(',');
2986 Self::push_newline_indent(acc, depth + 1);
2987 acc.push_str("...");
2988 Self::push_newline_indent(acc, depth + 1);
2989
2990 self.fmt_outer_tensor(
2991 acc,
2992 depth,
2993 multi_index,
2994 print_options,
2995 summarize,
2996 (self.dims()[depth] - edge_items, self.dims()[depth]),
2997 );
2998 } else {
2999 self.fmt_outer_tensor(
3000 acc,
3001 depth,
3002 multi_index,
3003 print_options,
3004 summarize,
3005 (0, self.dims()[depth]),
3006 );
3007 }
3008 }
3009
3010 if depth == 0 {
3011 acc.push(']');
3012 }
3013 }
3014}
3015
3016#[derive(Clone, Debug)]
3017/// Options for Tensor pretty printing
3018pub struct PrintOptions {
3019 /// number of elements to start summarizing tensor
3020 pub threshold: usize,
3021
3022 /// number of starting elements and ending elements to display
3023 pub edge_items: usize,
3024
3025 /// Precision for floating point numbers
3026 pub precision: Option<usize>,
3027}
3028
3029static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
3030
3031impl PrintOptions {
3032 /// Print options with default values
3033 pub const fn const_default() -> Self {
3034 Self {
3035 threshold: 1000,
3036 edge_items: 3,
3037 precision: None,
3038 }
3039 }
3040}
3041
3042impl Default for PrintOptions {
3043 fn default() -> Self {
3044 Self::const_default()
3045 }
3046}
3047
3048/// Set print options
3049pub fn set_print_options(options: PrintOptions) {
3050 let mut print_opts = PRINT_OPTS.write().unwrap();
3051 *print_opts = options;
3052}
3053
3054/// Pretty print tensors
3055impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
3056where
3057 B: Backend,
3058 B::IntElem: core::fmt::Display,
3059 K: BasicOps<B>,
3060 <K as BasicOps<B>>::Elem: Debug,
3061{
3062 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3063 writeln!(f, "Tensor {{")?;
3064
3065 {
3066 // Do not lock the mutex for the whole function
3067 let mut po = { PRINT_OPTS.read().unwrap().clone() };
3068
3069 // Override the precision if it is set from the formatter
3070 // This will be possible when the tensor is printed using the `{:.*}` syntax
3071 if let Some(precision) = f.precision() {
3072 po.precision = Some(precision);
3073 }
3074
3075 let mut acc = String::new();
3076 let mut multi_index = vec![0; D];
3077 let summarize = self.shape().num_elements() > po.threshold;
3078
3079 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
3080
3081 writeln!(f, " data:")?;
3082 write!(f, "{acc}")?;
3083 writeln!(f, ",")?;
3084 }
3085
3086 writeln!(f, " shape: {:?},", self.dims())?;
3087 writeln!(f, " device: {:?},", self.device())?;
3088 writeln!(f, " backend: {:?},", B::name(&self.device()))?;
3089 writeln!(f, " kind: {:?},", K::name())?;
3090
3091 let dtype = self.primitive.dtype();
3092
3093 writeln!(f, " dtype: {:?},", dtype.name())?;
3094 write!(f, "}}")
3095 }
3096}
3097
3098/// Trait used for movedim arguments
3099pub trait MovedimArgs {
3100 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
3101 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
3102}
3103
3104impl MovedimArgs for Vec<i32> {
3105 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3106 let set = self
3107 .iter()
3108 .map(|&dim| {
3109 if dim < 0 {
3110 (D as i32 + dim) as usize
3111 } else {
3112 dim as usize
3113 }
3114 })
3115 .collect::<Vec<usize>>();
3116 check!(TensorCheck::movedim_args_vec::<D>(&set));
3117
3118 set
3119 }
3120}
3121
3122impl MovedimArgs for Vec<usize> {
3123 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3124 check!(TensorCheck::movedim_args_vec::<D>(&self));
3125 self
3126 }
3127}
3128
3129impl MovedimArgs for usize {
3130 #[allow(clippy::vec_init_then_push)]
3131 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3132 check!(TensorCheck::movedim_args_usize::<D>(self));
3133
3134 let mut set = Vec::with_capacity(1);
3135 set.push(self);
3136
3137 set
3138 }
3139}
3140
3141impl MovedimArgs for i32 {
3142 #[allow(clippy::vec_init_then_push)]
3143 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
3144 check!(TensorCheck::movedim_args_i32::<D>(self));
3145
3146 let dim = if self < 0 {
3147 (D as i32 + self) as usize
3148 } else {
3149 self as usize
3150 };
3151
3152 let mut set = Vec::with_capacity(1);
3153 set.push(dim);
3154
3155 set
3156 }
3157}
3158
3159/// Trait used for reshape arguments.
3160pub trait ReshapeArgs<const D2: usize>: Debug {
3161 /// Converts to a shape.
3162 fn into_shape<const D: usize>(self, source: Shape) -> Shape;
3163}
3164
3165impl<const D2: usize, I: AsIndex> ReshapeArgs<D2> for [I; D2] {
3166 fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3167 unwrap_shape_reshape(source.reshape(self))
3168 }
3169}
3170
3171impl<const D2: usize> ReshapeArgs<D2> for Shape {
3172 fn into_shape<const D: usize>(self, source: Shape) -> Shape {
3173 unwrap_shape_reshape(source.reshape(self))
3174 }
3175}
3176
3177/// Trait used for broadcast arguments.
3178pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3179 /// Converts to a shape.
3180 fn into_shape(self, shape: &Shape) -> Shape;
3181}
3182
3183impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3184 fn into_shape(self, _shape: &Shape) -> Shape {
3185 self
3186 }
3187}
3188
3189impl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {
3190 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3191 fn into_shape(self, shape: &Shape) -> Shape {
3192 if self.len() < shape.num_dims() {
3193 panic!("Broadcast arguments must be greater than the number of dimensions");
3194 }
3195
3196 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3197 let new_shape: Vec<_> = self
3198 .iter()
3199 .rev()
3200 .map(|x| {
3201 let primitive = x.as_index();
3202 if primitive < -1 || primitive == 0 {
3203 panic!("Broadcast arguments must be positive or -1");
3204 }
3205 primitive
3206 })
3207 .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3208 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3209 .collect::<Vec<_>>()
3210 .into_iter()
3211 .rev()
3212 .collect();
3213
3214 if new_shape.contains(&0) {
3215 panic!("Cannot substitute -1 for a non-existing dimension");
3216 }
3217
3218 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3219
3220 Shape::from(new_shape)
3221 }
3222}
3223
3224impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3225where
3226 B: Backend,
3227 K: BasicOps<B>,
3228 K::Elem: Debug + Copy + Serialize,
3229{
3230 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3231 let data = self.to_data();
3232 data.serialize(serializer)
3233 }
3234}
3235
3236impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3237where
3238 B: Backend,
3239 K: BasicOps<B>,
3240 K::Elem: Debug + Copy + Deserialize<'de>,
3241{
3242 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3243 let tensor = Tensor::from_data(
3244 TensorData::deserialize(deserializer)?,
3245 &<B::Device as Default>::default(),
3246 );
3247 Ok(tensor)
3248 }
3249}
3250
3251#[cfg(test)]
3252mod tests {
3253 use burn_std::SliceOps;
3254
3255 use crate::{Shape, s};
3256
3257 #[test]
3258 fn slice_range_single_dim_leading() {
3259 let shape = Shape::new([8, 4]);
3260
3261 // Half-open range
3262 let slices = shape.clone().into_slices([0..5]);
3263 assert_eq!(slices[0].to_range(8), 0..5);
3264 let slices = shape.clone().into_slices([-3..-1]);
3265 assert_eq!(slices[0].to_range(8), 5..7);
3266
3267 // Inclusive range
3268 let slices = shape.clone().into_slices([0..=4]);
3269 assert_eq!(slices[0].to_range(8), 0..5);
3270 let slices = shape.clone().into_slices([-2..=-1]);
3271 assert_eq!(slices[0].to_range(8), 6..8);
3272
3273 // Unbounded start
3274 let slices = shape.clone().into_slices([..3]);
3275 assert_eq!(slices[0].to_range(8), 0..3);
3276 let slices = shape.clone().into_slices([..-5]);
3277 assert_eq!(slices[0].to_range(8), 0..3);
3278
3279 // Unbounded end
3280 let slices = shape.clone().into_slices([5..]);
3281 assert_eq!(slices[0].to_range(8), 5..8);
3282 let slices = shape.clone().into_slices([-3..]);
3283 assert_eq!(slices[0].to_range(8), 5..8);
3284
3285 // Full range
3286 let slices = shape.into_slices([..]);
3287 assert_eq!(slices[0].to_range(8), 0..8);
3288 }
3289
3290 #[test]
3291 fn test_negative_slice_indices() {
3292 use crate::Slice;
3293
3294 // Test negative indices conversion
3295 let slice: Slice = (-3..-1).into();
3296 assert_eq!(slice.start, -3);
3297 assert_eq!(slice.end, Some(-1));
3298
3299 // Test to_range conversion with size 8
3300 let range = slice.to_range(8);
3301 assert_eq!(range, 5..7);
3302
3303 // Test with shape slice
3304 let shape = Shape::new([8, 4]);
3305 let result = shape.clone().into_slices([-3..-1]);
3306 assert_eq!(result[0].to_range(8), 5..7);
3307
3308 // Test more negative index cases
3309 let slice2: Slice = (-5..).into();
3310 assert_eq!(slice2.to_range(10), 5..10);
3311
3312 let slice3: Slice = (..-2).into();
3313 assert_eq!(slice3.to_range(10), 0..8);
3314
3315 // Test with s! macro - single dimension returns Slice directly
3316 let slice4 = s![-3..-1];
3317 assert_eq!(slice4.start, -3);
3318 assert_eq!(slice4.end, Some(-1));
3319 }
3320
3321 #[test]
3322 fn slice_range_multi_dim() {
3323 let shape = Shape::new([8, 4]);
3324
3325 // Multiple ways to provide ranges
3326 let slices = shape.clone().into_slices([0..5, 0..4]);
3327 assert_eq!(slices[0].to_range(8), 0..5);
3328 assert_eq!(slices[1].to_range(4), 0..4);
3329
3330 let slices = shape.clone().into_slices([0.., 0..]);
3331 assert_eq!(slices[0].to_range(8), 0..8);
3332 assert_eq!(slices[1].to_range(4), 0..4);
3333
3334 let slices = shape.clone().into_slices([0..=7, 0..=3]);
3335 assert_eq!(slices[0].to_range(8), 0..8);
3336 assert_eq!(slices[1].to_range(4), 0..4);
3337
3338 let slices = shape.clone().into_slices([0..5, 0..3]);
3339 assert_eq!(slices[0].to_range(8), 0..5);
3340 assert_eq!(slices[1].to_range(4), 0..3);
3341
3342 let slices = shape.into_slices([0.., 0..]);
3343 assert_eq!(slices[0].to_range(8), 0..8);
3344 assert_eq!(slices[1].to_range(4), 0..4);
3345 }
3346
3347 #[test]
3348 fn slice_range_multi_dim_index() {
3349 let shape = Shape::new([8, 4]);
3350
3351 // Indices (single integer) should also convert to correct range
3352 let slices = shape.clone().into_slices([0, 2]);
3353 assert_eq!(slices[0].to_range(8), 0..1);
3354 assert_eq!(slices[1].to_range(4), 2..3);
3355
3356 let slices = shape.into_slices([-1, -1]);
3357 assert_eq!(slices[0].to_range(8), 7..8);
3358 assert_eq!(slices[1].to_range(4), 3..4);
3359 }
3360
3361 #[test]
3362 fn slice_range_multi_dim_heterogeneous() {
3363 // Slice macro `s![]` can be used to provide different range types
3364 let shape = Shape::new([8, 4, 2]);
3365 let slice = s![0..5, .., -1];
3366 let slices = shape.into_slices(slice);
3367 assert_eq!(slices[0].to_range(8), 0..5);
3368 assert_eq!(slices[1].to_range(4), 0..4);
3369 assert_eq!(slices[2].to_range(2), 1..2);
3370
3371 let shape = Shape::new([8, 4, 2, 3]);
3372 let slice = s![..=4, 0..=3, .., -2..];
3373 let slices = shape.into_slices(slice);
3374 assert_eq!(slices[0].to_range(8), 0..5);
3375 assert_eq!(slices[1].to_range(4), 0..4);
3376 assert_eq!(slices[2].to_range(2), 0..2);
3377 assert_eq!(slices[3].to_range(3), 1..3);
3378
3379 let shape = Shape::new([3, 4]);
3380 let slice = s![1..-1, ..];
3381 let slices = shape.into_slices(slice);
3382 assert_eq!(slices[0].to_range(3), 1..2);
3383 assert_eq!(slices[1].to_range(4), 0..4);
3384 }
3385}