burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::future::Future;
11use core::iter::repeat;
12use core::{fmt::Debug, ops::Range};
13use serde::{Deserialize, Deserializer};
14
15use serde::{Serialize, Serializer};
16
17use crate::tensor::api::narrow::narrow;
18use crate::{
19 Bool, Float, Int, Shape, TensorData, TensorKind, backend::Backend, check, ops::Device,
20};
21use crate::{DType, Element, TensorPrimitive};
22use crate::{cast::ToElement, check::TensorCheck};
23
24use super::{Slice, TensorMetadata, Transaction};
25
26/// A tensor with a given backend, shape and data type.
27///
28/// # Indexing
29/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
30/// or [`select`](Tensor::select) for numeric types.
31///
32/// ## Example
33///
34/// ```rust
35/// use burn_tensor::backend::Backend;
36/// use burn_tensor::Tensor;
37/// use burn_tensor::Int;
38///
39/// fn example<B: Backend>() {
40/// let device = Default::default();
41///
42/// let tensor = Tensor::<B, 2>::from_data(
43/// [
44/// [3.0, 4.9, 2.0],
45/// [2.0, 1.9, 3.0],
46/// [6.0, 1.5, 7.0],
47/// [3.0, 4.9, 9.0],
48/// ],
49/// &device,
50/// );
51///
52/// // Slice the tensor to get the second and third rows:
53/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
54/// // The resulting tensor will have dimensions [2, 3].
55/// let slice = tensor.clone().slice([1..3]);
56/// println!("{slice}");
57///
58/// // Slice the tensor to get the first two rows and the first 2 columns:
59/// // [[3.0, 4.9], [2.0, 1.9]]
60/// // The resulting tensor will have dimensions [2, 2].
61/// let slice = tensor.clone().slice([0..2, 0..2]);
62/// println!("{slice}");
63///
64/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
65/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
66/// // The resulting tensor will have dimensions [4, 2]
67/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
68/// let indexed = tensor.select(1, indices);
69/// println!("{indexed}");
70/// }
71/// ```
72#[derive(new, Clone, Debug)]
73pub struct Tensor<B, const D: usize, K = Float>
74where
75 B: Backend,
76 K: TensorKind<B>,
77{
78 pub(crate) primitive: K::Primitive,
79}
80
81impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
82where
83 B: Backend,
84 K: BasicOps<B>,
85 T: Into<TensorData>,
86{
87 fn from(value: T) -> Self {
88 Tensor::from_data(value.into(), &Default::default())
89 }
90}
91
92impl<B, const D: usize, K> Tensor<B, D, K>
93where
94 B: Backend,
95 K: BasicOps<B>,
96{
97 /// Converts the tensor into a primitive tensor.
98 pub fn into_primitive(self) -> K::Primitive {
99 self.primitive
100 }
101
102 /// Converts from a primitive tensor into a tensor.
103 pub fn from_primitive(tensor: K::Primitive) -> Self {
104 Self::new(tensor)
105 }
106
107 /// Returns the tensor primitive data type.
108 ///
109 /// # Note
110 /// Some element types are encoded in different primitive types depending on the backend
111 /// (e.g., bool could be encoded as `u8` or `u32`).
112 pub fn dtype(&self) -> DType {
113 self.primitive.dtype()
114 }
115
116 /// Create an empty tensor of the given shape.
117 ///
118 /// # Arguments
119 ///
120 /// - shape: The shape of the tensor.
121 /// - device: The device where the tensor will be created.
122 ///
123 /// # Example
124 /// ```rust
125 /// use burn_tensor::backend::Backend;
126 /// use burn_tensor::Tensor;
127 ///
128 /// fn example<B: Backend>() {
129 /// let device = Default::default();
130 /// // Create an empty tensor with dimensions [2, 3, 4].
131 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
132 /// }
133 /// ```
134 pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
135 let shape = shape.into();
136 check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
137 Self::new(K::empty(shape, device))
138 }
139
140 /// Returns the dimensions of the current tensor.
141 ///
142 /// # Example
143 /// ```rust
144 /// use burn_tensor::backend::Backend;
145 /// use burn_tensor::Tensor;
146 ///
147 /// fn example<B: Backend>() {
148 /// let device = Default::default();
149 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
150 /// let dims = tensor.dims(); // [2, 3, 4]
151 /// println!("{dims:?}");
152 /// }
153 /// ```
154 pub fn dims(&self) -> [usize; D] {
155 Self::shape(self).dims()
156 }
157
158 /// Returns the shape of the current tensor.
159 ///
160 /// # Example
161 /// ```rust
162 /// use burn_tensor::backend::Backend;
163 /// use burn_tensor::Tensor;
164 ///
165 /// fn example<B: Backend>() {
166 /// let device = Default::default();
167 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
168 /// // Shape { dims: [2, 3, 4] }
169 /// let shape = tensor.shape();
170 /// }
171 /// ```
172 pub fn shape(&self) -> Shape {
173 self.primitive.shape()
174 }
175
176 /// Reshape the tensor to have the given shape.
177 ///
178 /// The tensor has the same data and number of elements as the input.
179 ///
180 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
181 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
182 ///
183 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
184 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
185 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
186 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
187 /// with [1, 3, 4] dimensions to [1, 12].
188 ///
189 /// # Arguments
190 /// - `shape`: The new shape of the tensor.
191 ///
192 /// # Panics
193 /// - If the tensor contains more than one `-1` in the shape.
194 /// - If the tensor contains values that are not positive (other than -1).
195 /// - If the shape does not match the number of elements of the original shape.
196 ///
197 /// # Example
198 ///
199 /// ```rust
200 /// use burn_tensor::backend::Backend;
201 /// use burn_tensor::Tensor;
202 ///
203 /// fn example<B: Backend>() {
204 /// let device = Default::default();
205 /// // Create a tensor with dimensions [2, 3, 4]
206 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
207 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
208 /// let reshaped = tensor.reshape([2, -1]);
209 /// println!("{reshaped}");
210 /// }
211 /// ```
212 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
213 // Convert reshape args to shape
214 let shape = shape.into_shape(&self);
215 Tensor::new(K::reshape(self.primitive, shape))
216 }
217
218 /// Transpose the tensor.
219 ///
220 /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is
221 /// applied on the last two dimensions. For example, the transpose of a tensor with shape
222 /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.
223 ///
224 /// See also [`permute`](Tensor::permute).
225 ///
226 /// # Arguments
227 ///
228 /// * `tensor` - The tensor to transpose.
229 ///
230 /// # Returns
231 ///
232 /// The transposed tensor.
233 ///
234 /// # Example
235 ///
236 /// ```rust
237 /// use burn_tensor::backend::Backend;
238 /// use burn_tensor::Tensor;
239 ///
240 /// fn example<B: Backend>() {
241 /// let device = Default::default();
242 /// // Create a 2D tensor of shape [2, 3]
243 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
244 ///
245 /// // Transpose the tensor:
246 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
247 /// // The resulting tensor will have dimensions [3, 2].
248 /// let transposed = tensor.transpose();
249 /// println!("{transposed}");
250 /// }
251 /// ```
252 pub fn transpose(self) -> Tensor<B, D, K> {
253 Tensor::new(K::transpose(self.primitive))
254 }
255
256 /// Swaps two dimensions of a tensor.
257 ///
258 /// # Arguments
259 ///
260 /// * `tensor` - The tensor to swap the dimensions of.
261 /// * `dim1` - The first dimension to swap.
262 /// * `dim2` - The second dimension to swap.
263 ///
264 /// # Returns
265 ///
266 /// The tensor with the dimensions swapped.
267 ///
268 /// # Example
269 ///
270 /// ```rust
271 /// use burn_tensor::backend::Backend;
272 /// use burn_tensor::Tensor;
273 ///
274 /// fn example<B: Backend>() {
275 /// let device = Default::default();
276 /// // Create a 2D tensor of shape [2, 3]
277 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
278 ///
279 /// // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`):
280 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
281 /// // The resulting tensor will have dimensions [3, 2].
282 /// let swapped = tensor.swap_dims(0, 1);
283 /// println!("{swapped}");
284 /// }
285 /// ```
286 pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
287 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
288 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
289 }
290
291 /// Permute the dimensions of the tensor.
292 ///
293 /// # Arguments
294 ///
295 /// * `axes` - The new order of the dimensions. The length of the axes
296 /// must be equal to the number of dimensions of the tensor.
297 /// The values must be unique and in the range of the number of dimensions.
298 /// The values can be negative, in which case they are used as an offset from the end.
299 ///
300 /// # Returns
301 ///
302 /// The tensor with the dimensions permuted.
303 ///
304 /// # Example
305 ///
306 /// ```rust
307 /// use burn_tensor::backend::Backend;
308 /// use burn_tensor::Tensor;
309 ///
310 /// fn example<B: Backend>() {
311 /// let device = Default::default();
312 /// // Create a 2D tensor of shape [3, 2]
313 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
314 ///
315 /// // Permute the dimensions 1 and 0:
316 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
317 /// // The resulting tensor will have dimensions [3, 2].
318 /// let permuted = tensor.permute([1, 0]);
319 /// println!("{permuted}");
320 /// }
321 /// ```
322 pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
323 // Convert the axes to usize and handle negative values without using vector
324 let mut transformed_axes: [usize; D] = [0; D];
325 for (i, &x) in axes.iter().enumerate() {
326 transformed_axes[i] = if x < 0 {
327 (D as isize + x) as usize
328 } else {
329 x as usize
330 };
331 }
332
333 // Check if the axes are valid after the transformation
334 check!(TensorCheck::permute(transformed_axes));
335
336 Tensor::new(K::permute(self.primitive, &transformed_axes))
337 }
338
339 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
340 ///
341 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
342 /// at the positions not specified in destination.
343 ///
344 /// # Arguments
345 ///
346 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
347 /// The values can be negative, in which case they are used as an offset from the end.
348 ///
349 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
350 ///
351 /// # Panics
352 ///
353 /// - If the source and destination dimensions are not of the same length.
354 /// - If the source and destination vectors contain duplicate values.
355 /// - If the source and destination vectors contain values that are out of bounds.
356 ///
357 /// # Returns
358 ///
359 /// The tensor with the dimensions moved.
360 ///
361 /// # Example
362 ///
363 /// ```rust
364 /// use burn_tensor::backend::Backend;
365 /// use burn_tensor::Tensor;
366 ///
367 /// fn example<B: Backend>() {
368 /// let device = Default::default();
369 /// // Create a 3D tensor of shape [3, 2, 1]
370 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
371 ///
372 /// // Move the dimensions 0 and 1:
373 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
374 /// // The resulting tensor will have dimensions [2, 3, 1].
375 /// let moved = tensor.movedim(1, 0);
376 /// println!("{moved}");
377 /// }
378 /// ```
379 ///
380 /// # Note
381 ///
382 /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
383 /// for it
384 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
385 let source_dims = src.into_dim_vec::<D>();
386 let destination_dims = dst.into_dim_vec::<D>();
387
388 check!(TensorCheck::movedim_args_length(
389 &source_dims,
390 &destination_dims
391 ));
392
393 let mut m = [-1; D];
394 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
395 m[d] = s as isize;
396 }
397 let mut axes: [isize; D] = [0; D];
398 let mut source_i = 0;
399 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
400 *item = if m[dest_i] != -1 {
401 m[dest_i]
402 } else {
403 while source_dims.contains(&source_i) {
404 source_i += 1;
405 }
406 let result = source_i as isize;
407 source_i += 1;
408 result
409 };
410 }
411
412 self.permute(axes)
413 }
414
415 /// Reverse the order of elements in the tensor along the given dimensions.
416 ///
417 /// # Arguments
418 ///
419 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
420 /// The values can be negative, in which case they are used as an offset from the end.
421 ///
422 /// # Returns
423 ///
424 /// The tensor with the axes flipped.
425 ///
426 /// # Example
427 ///
428 /// ```rust
429 /// use burn_tensor::backend::Backend;
430 /// use burn_tensor::Tensor;
431 ///
432 /// fn example<B: Backend>() {
433 /// let device = Default::default();
434 /// // Create a 2D tensor with dimensions [4, 3]
435 /// let tensor = Tensor::<B, 2>::from_data(
436 /// [
437 /// [3.0, 4.9, 2.0],
438 /// [2.0, 1.9, 3.0],
439 /// [4.0, 5.9, 8.0],
440 /// [1.4, 5.8, 6.0],
441 /// ],
442 /// &device,
443 /// );
444 ///
445 /// // Flip the elements in dimensions 0 and 1:
446 /// // [[6.0, 5.8, 1.4],
447 /// // [8.0, 5.9, 4.0],
448 /// // [3.0, 1.9, 2.0],
449 /// // [2.0, 4.9, 3.0]]
450 /// // The resulting tensor will have dimensions [4, 3].
451 /// let flipped = tensor.flip([0, 1]);
452 /// println!("{flipped}");
453 /// }
454 /// ```
455 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
456 // Convert the axes to usize and handle negative values without using vector
457 let mut transformed_axes: [usize; N] = [0; N];
458 for (i, &x) in axes.iter().enumerate() {
459 transformed_axes[i] = if x < 0 {
460 (D as isize + x) as usize
461 } else {
462 x as usize
463 };
464 }
465
466 // Check if the axes are valid
467 check!(TensorCheck::flip(D, &transformed_axes));
468
469 Tensor::new(K::flip(self.primitive, &transformed_axes))
470 }
471
472 /// Flatten the tensor along a given range of dimensions.
473 ///
474 /// This function collapses the specified range of dimensions into a single dimension,
475 /// effectively flattening the tensor in that range.
476 ///
477 /// # Arguments
478 ///
479 /// - `start_dim`: The starting dimension of the range to be flattened.
480 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
481 ///
482 /// # Type Parameters
483 ///
484 /// - `D2`: The resulting number of dimensions in the flattened tensor.
485 ///
486 /// # Returns
487 ///
488 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
489 ///
490 /// # Example
491 ///
492 /// ```rust
493 ///
494 /// use burn_tensor::backend::Backend;
495 /// use burn_tensor::{Tensor, Shape};
496 ///
497 /// fn example<B: Backend>() {
498 /// let device = Default::default();
499 /// // Create a 3D tensor with dimensions [2, 3, 4]
500 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
501 ///
502 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
503 /// // The resulting tensor will have dimensions [2, 12]
504 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
505 /// println!("{flattened}");
506 /// }
507 /// ```
508 pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
509 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
510
511 let current_dims = self.shape().dims;
512 let mut new_dims: [usize; D2] = [0; D2];
513 let mut flatten_dims = 1;
514
515 for i in current_dims[start_dim..=end_dim].iter() {
516 flatten_dims *= i;
517 }
518
519 new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]);
520 new_dims[start_dim] = flatten_dims;
521 new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]);
522
523 Tensor::new(K::reshape(self.primitive, new_dims.into()))
524 }
525
526 /// Squeeze the tensor along the given dimension, removing the specified dimension
527 /// of size one, and effectively reducing the rank of the tensor by one.
528 ///
529 /// # Arguments
530 ///
531 /// - `dim`: The dimension to be squeezed.
532 ///
533 /// # Type Parameters
534 ///
535 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
536 ///
537 /// # Panics
538 ///
539 /// If the size in the squeezed dimension is not 1.
540 ///
541 /// # Returns
542 ///
543 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
544 ///
545 /// # Example
546 ///
547 /// ```rust
548 ///
549 /// use burn_tensor::backend::Backend;
550 /// use burn_tensor::{Tensor, Shape};
551 ///
552 /// fn example<B: Backend>() {
553 /// let device = Default::default();
554 /// // Create a 3D tensor with dimensions [3, 1, 3]
555 /// let tensor = Tensor::<B, 3>::from_data(
556 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
557 /// &device,
558 /// );
559 ///
560 /// // Squeeze the dimension 1.
561 /// // The resulting tensor will have dimensions [3, 3].
562 /// let squeezed = tensor.squeeze::<2>(1);
563 /// println!("{squeezed}");
564 /// }
565 /// ```
566 pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
567 check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
568
569 let current_dims = self.shape().dims;
570 let mut new_dims: [usize; D2] = [0; D2];
571
572 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
573 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
574
575 Tensor::new(K::reshape(self.primitive, new_dims.into()))
576 }
577
578 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
579 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
580 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
581 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
582 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
583 /// from the back.
584 ///
585 /// # Arguments
586 ///
587 /// - `dims`: The dimension(s) to be squeezed.
588 ///
589 /// # Type Parameters
590 ///
591 /// - `D2`: The resulting number of dimensions in the squeezed tensor.
592 ///
593 /// # Returns
594 ///
595 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
596 ///
597 /// # Example
598 ///
599 /// ```rust
600 ///
601 /// use burn_tensor::backend::Backend;
602 /// use burn_tensor::{Tensor, Shape};
603 ///
604 /// fn example<B: Backend>() {
605 /// let device = Default::default();
606 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
607 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
608 ///
609 /// // Squeeze the dimensions 1 and 3.
610 /// // The resulting tensor will have dimensions [2, 4].
611 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
612 /// println!("{squeezed}");
613 /// }
614 /// ```
615 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
616 let current_dims = self.shape().dims;
617 let mut dim_indices: Vec<usize>;
618
619 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
620 if dims.is_empty() {
621 dim_indices = current_dims
622 .iter()
623 .enumerate()
624 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
625 .collect();
626 } else {
627 // If negative dims, count from the back
628 dim_indices = dims
629 .iter()
630 .map(|&d| {
631 if d < 0 {
632 (current_dims.len() as isize + d) as usize
633 } else {
634 d as usize
635 }
636 })
637 .collect();
638 }
639
640 // Sort indices and remove duplicates
641 dim_indices.sort_unstable();
642 dim_indices.dedup();
643
644 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
645 check!(TensorCheck::squeeze_dims_input::<D2>(
646 &dim_indices,
647 ¤t_dims
648 ));
649
650 // Calculate new dimensions
651 let mut new_dims = Vec::new();
652 for (index, &dim_size) in current_dims.iter().enumerate() {
653 // Exclude the dimension if it's explicitly marked for squeezing
654 if dim_indices.contains(&index) {
655 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
656 continue;
657 }
658 new_dims.push(dim_size);
659 }
660
661 // Check that after squeezing, we still respect the D2 size
662 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
663
664 Tensor::new(K::reshape(self.primitive, new_dims.into()))
665 }
666
667 /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.
668 ///
669 /// # Type Parameters
670 ///
671 /// - `D2`: The resulting number of dimensions in the unsqueezed tensor.
672 ///
673 /// # Panics
674 ///
675 /// If the output size `D2` is smaller than the current number of dimensions.
676 ///
677 /// # Returns
678 ///
679 /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.
680 ///
681 /// # Example
682 ///
683 /// ```rust
684 /// use burn_tensor::backend::Backend;
685 /// use burn_tensor::{Tensor, Shape};
686 ///
687 /// fn example<B: Backend>() {
688 /// let device = Default::default();
689 /// // Create a 2D tensor with dimensions [3, 3]
690 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
691 /// // Unsqueeze the tensor up to 4 dimensions.
692 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
693 /// let unsqueezed = tensor.unsqueeze::<4>();
694 /// println!("{unsqueezed}");
695 /// }
696 /// ```
697 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
698 check!(TensorCheck::unsqueeze::<D, D2>());
699
700 let mut dims = [1; D2];
701 let num_ones = D2 - D;
702 let shape = self.shape();
703
704 dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
705
706 let shape = Shape::new(dims);
707 self.reshape(shape)
708 }
709
710 /// Creates a new tensor with a dimension of size one inserted at the specified position.
711 ///
712 /// # Example
713 ///
714 /// ```rust
715 /// use burn_tensor::backend::Backend;
716 /// use burn_tensor::{Tensor, Shape};
717 ///
718 /// fn example<B: Backend>() {
719 /// let device = Default::default();
720 /// // Create a 2D tensor with dimensions [3, 3]
721 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
722 /// // Unsqueeze the dimension 1.
723 /// // The resulting tensor will have dimensions [3, 1, 3].
724 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
725 /// println!("{unsqueezed}");
726 /// }
727 /// ```
728 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
729 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
730
731 let mut dims = [1; D2];
732 let shape = self.shape();
733
734 dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
735
736 if dim < D {
737 dims[dim] = 1;
738 dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
739 } else {
740 dims[dim] = 1;
741 }
742
743 let shape = Shape::new(dims);
744 self.reshape(shape)
745 }
746
747 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
748 /// The indices can be negative, in which case they are counted from the last to the first dimension.
749 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
750 /// is the number of duplicates.
751 /// # Example
752 ///
753 /// ```rust
754 /// use burn_tensor::backend::Backend;
755 /// use burn_tensor::{Tensor, Shape};
756 ///
757 /// fn example<B: Backend>() {
758 /// let device = Default::default();
759 /// // Create a 3D tensor with dimensions [3, 4, 5]
760 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
761 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
762 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
763 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
764 /// println!("{unsqueezed}");
765 /// }
766 /// ```
767 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
768 let mut new_dims = [1; D2];
769 let old_dims = self.shape().dims;
770 //for checking if the dimension is in the acceptable range
771
772 //part 1: convert the negative indices to positive
773 let mut neg_offset = D2;
774 let mut dim_indices = axes
775 .iter()
776 .map(|d| {
777 // check if the dimension is in the acceptable range
778 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
779 (if *d < 0 {
780 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
781 d + neg_offset as isize + 1
782 } else {
783 *d
784 }) as usize
785 })
786 .collect::<Vec<usize>>();
787
788 //sort the indices
789 dim_indices.sort_unstable();
790
791 //Now use this to copy the chunks of the dims
792 let mut prev_idx: usize = 0;
793 let mut current_left_b: usize = 0;
794 let mut current_right_b: usize = 0;
795 let mut offset: usize = 0;
796 dim_indices.iter().for_each(|d| {
797 //check if there is space for at least one dimension
798 if prev_idx < *d {
799 current_right_b = *d - offset;
800 //copy the chunks of the dims
801 if current_right_b < D {
802 new_dims[prev_idx..*d]
803 .copy_from_slice(&old_dims[current_left_b..current_right_b]);
804 } else {
805 new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
806 }
807 prev_idx = *d + 1;
808 //offset is equal to the number of extracted elements from the original shape
809 offset += current_right_b - current_left_b;
810 current_left_b = current_right_b;
811 } else {
812 //it's sorted so the only reason this would happen
813 //is if multiple indices are the same
814 prev_idx += 1;
815 }
816 });
817 //copy over anything past the index of the last new dimension
818 if current_left_b < D {
819 new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
820 }
821
822 //lastly, create the shape and reshape
823 let shape = Shape::new(new_dims);
824 self.reshape(shape)
825 }
826
827 /// Returns a tensor containing the elements selected from the given ranges.
828 ///
829 /// For more complex indexing with different slice ranges, see also the slice
830 /// macro [`s!`](crate::s).
831 ///
832 /// # Arguments
833 ///
834 /// * `ranges` - A type implementing the `RangesArg` trait, which can be:
835 /// - A single range (slice the first dimension)
836 /// - A single index (slice the first dimension)
837 /// - An array of ranges
838 ///
839 /// # Behavior
840 ///
841 /// - Supports partial and full slicing in any number of dimensions.
842 /// - Missing ranges are treated as full slices if D > D2.
843 /// - Handles negative indices by wrapping around from the end of the dimension.
844 /// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
845 ///
846 /// # Panics
847 ///
848 /// - If the number of ranges provided exceeds the tensor's dimensions.
849 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
850 ///
851 /// # Examples
852 ///
853 /// ```rust
854 /// use burn_tensor::backend::Backend;
855 /// use burn_tensor::{Tensor, Shape, s};
856 ///
857 /// fn example<B: Backend>() {
858 /// let device = B::Device::default();
859 ///
860 /// // 1D slicing
861 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
862 /// let slice = tensor.slice([1..4]);
863 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
864 ///
865 /// // 2D slicing
866 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
867 /// let slice = tensor.slice([1..3, 0..2]);
868 /// assert_eq!(slice.dims(), [2, 2]);
869 ///
870 /// // Using negative indices
871 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
872 /// let slice = tensor.slice([1..-1]); // Equivalent to 1..4
873 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
874 ///
875 /// // Using the slice macro to select different ranges
876 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
877 /// let slice = tensor.slice(s![1.., ..]); // Select rows 1 and 2, all columns
878 /// assert_eq!(slice.dims(), [2, 4]);
879 ///
880 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..16, &device).reshape([2, 4, 2]);
881 /// let slice = tensor.slice(s![1.., 1..=3, -1]);
882 /// assert_eq!(slice.dims(), [1, 3, 1]);
883 /// }
884 /// ```
885 ///
886 /// # Note
887 ///
888 /// This function uses the `RangesArg` trait for flexible range specification. The trait
889 /// handles the conversion of various range formats and applies clamping and negative
890 /// index handling internally.
891 pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
892 let ranges = ranges.into_ranges(self.shape());
893
894 check!(TensorCheck::slice::<D, D2>(&self.shape(), &ranges));
895 Self::new(K::slice(self.primitive, &ranges))
896 }
897
898 /// Returns a copy of the current tensor with the selected elements changed to the new ones at
899 /// the selected indices.
900 ///
901 /// # Panics
902 ///
903 /// - If a range exceeds the number of elements on a dimension.
904 /// - If the given values don't match the given ranges.
905 ///
906 /// # Example
907 ///
908 /// ```rust
909 /// use burn_tensor::backend::Backend;
910 /// use burn_tensor::Tensor;
911 ///
912 /// fn example<B: Backend>() {
913 /// let device = B::Device::default();
914 /// let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
915 /// let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
916 /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
917 /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
918 /// }
919 /// ```
920 pub fn slice_assign<const D2: usize>(self, ranges: [Range<usize>; D2], values: Self) -> Self {
921 check!(TensorCheck::slice_assign::<D, D2>(
922 &self.shape(),
923 &values.shape(),
924 &ranges
925 ));
926 Self::new(K::slice_assign(self.primitive, &ranges, values.primitive))
927 }
928
929 /// Returns the device of the current tensor.
930 pub fn device(&self) -> B::Device {
931 K::device(&self.primitive)
932 }
933
934 /// Move the tensor to the given device.
935 pub fn to_device(self, device: &B::Device) -> Self {
936 Self::new(K::to_device(self.primitive, device))
937 }
938
939 /// Converts the data of the current tensor.
940 ///
941 /// # Note
942 ///
943 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
944 /// tensors at once. This may improve laziness, especially if executed on a different
945 /// thread in native environments.
946 pub fn into_data(self) -> TensorData {
947 crate::try_read_sync(self.into_data_async()).expect(
948 "Failed to read tensor data synchronously.
949 This can happen on platforms that don't support blocking futures like WASM.
950 If possible, try using into_data_async instead.",
951 )
952 }
953
954 /// Converts the data of the current tensor.
955 ///
956 /// # Note
957 ///
958 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
959 /// tensors at once. This may improve laziness, especially if executed on a different
960 /// thread in native environments.
961 pub fn to_data(&self) -> TensorData {
962 self.clone().into_data()
963 }
964
965 /// Returns the data of the current tensor.
966 pub async fn into_data_async(self) -> TensorData {
967 K::into_data_async(self.primitive).await
968 }
969
970 /// Returns the data of the current tensor.
971 pub async fn to_data_async(&self) -> TensorData {
972 self.clone().into_data_async().await
973 }
974
975 /// Create a tensor from the given data on the given device.
976 pub fn from_data<T>(data: T, device: &B::Device) -> Self
977 where
978 T: Into<TensorData>,
979 {
980 let data = data.into();
981 check!(TensorCheck::creation_ops::<D>(
982 "From Data",
983 data.shape.as_slice()
984 ));
985 Self::new(K::from_data(data, device))
986 }
987
988 /// Create a tensor from the given data on the given device enforcing the given data type.
989 pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self
990 where
991 T: Into<TensorData>,
992 {
993 let data = data.into();
994 check!(TensorCheck::creation_ops::<D>(
995 "From Data",
996 data.shape.as_slice()
997 ));
998 Self::new(K::from_data_dtype(data, device, dtype))
999 }
1000
1001 /// Repeat the tensor along the given dimension.
1002 ///
1003 /// The output tensor has the same shape, except along the given dimension.
1004 ///
1005 /// # Arguments
1006 /// - `dim`: The dimension to repeat.
1007 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
1008 ///
1009 /// # Returns
1010 ///
1011 /// A new tensor with the given dimension repeated `times` times.
1012 ///
1013 /// # Example
1014 ///
1015 /// ```rust
1016 /// use burn_tensor::backend::Backend;
1017 /// use burn_tensor::Tensor;
1018 ///
1019 /// fn example<B: Backend>() {
1020 /// let device = Default::default();
1021 /// // Create a 2D tensor with dimensions [3, 2]
1022 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1023 ///
1024 /// // Repeat the tensor along the dimension 0 twice.
1025 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1026 /// // The resulting tensor will have dimensions [6, 2].
1027 /// let repeated = tensor.repeat_dim(0, 2);
1028 /// println!("{repeated}");
1029 /// }
1030 /// ```
1031 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
1032 Self::new(K::repeat_dim(self.primitive, dim, times))
1033 }
1034
1035 /// Repeat the tensor along the given dimensions.
1036 /// # Arguments
1037 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
1038 ///
1039 /// # Returns
1040 ///
1041 /// A new tensor with the given dimensions repeated `times` times.
1042 ///
1043 /// # Panics
1044 ///
1045 /// If `sizes` contains more elements than the number of dimensions.
1046 ///
1047 /// # Example
1048 ///
1049 /// ```rust
1050 ///
1051 /// use burn_tensor::backend::Backend;
1052 /// use burn_tensor::Tensor;
1053 ///
1054 /// fn example<B: Backend>() {
1055 /// let device = Default::default();
1056 /// // Create a 2D tensor with dimensions [3, 2]
1057 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1058 ///
1059 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1060 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1061 /// // The resulting tensor will have dimensions [6, 2].
1062 /// let repeated = tensor.repeat(&[2, 1]);
1063 /// }
1064 /// ```
1065 pub fn repeat(self, sizes: &[usize]) -> Self {
1066 let mut tensor = self;
1067 for (dim, ×) in sizes.iter().enumerate() {
1068 if times > 1 {
1069 tensor = tensor.repeat_dim(dim, times);
1070 }
1071 }
1072 tensor
1073 }
1074
1075 /// Applies element-wise equal comparison.
1076 ///
1077 /// # Returns
1078 /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.
1079 ///
1080 /// # Panics
1081 ///
1082 /// If the two tensors don't have the same shape.
1083 ///
1084 /// # Example
1085 ///
1086 /// ```rust
1087 /// use burn_tensor::backend::Backend;
1088 /// use burn_tensor::Tensor;
1089 ///
1090 /// fn example<B: Backend>() {
1091 /// let device = Default::default();
1092 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1093 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1094 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
1095 /// // [[false, true], [true, true], [true, true]]
1096 /// let equal = t1.equal(t2);
1097 /// println!("{equal}");
1098 /// }
1099 /// ```
1100 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1101 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1102 Tensor::new(K::equal(self.primitive, other.primitive))
1103 }
1104
1105 /// Applies element-wise non-equality comparison.
1106 ///
1107 /// # Returns
1108 /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.
1109 ///
1110 /// # Panics
1111 ///
1112 /// If the two tensors don't have the same shape.
1113 ///
1114 /// # Example
1115 ///
1116 /// ```rust
1117 /// use burn_tensor::backend::Backend;
1118 /// use burn_tensor::Tensor;
1119 ///
1120 /// fn example<B: Backend>() {
1121 /// let device = Default::default();
1122 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1123 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1124 /// // Compare the elements of the two 2D tensors for inequality.
1125 /// // [[true, false], [false, false], [false, false]]
1126 /// let not_equal = t1.not_equal(t2);
1127 /// println!("{not_equal}");
1128 /// }
1129 /// ```
1130 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1131 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1132 Tensor::new(K::not_equal(self.primitive, other.primitive))
1133 }
1134
1135 /// Concatenates all tensors into a new one along the given dimension.
1136 ///
1137 /// # Panics
1138 ///
1139 /// - If `dim` is higher than the rank.
1140 /// - If `tensors` is an empty vector.
1141 /// - If all tensors don't have the same shape (the dimension `dim` is ignored).
1142 ///
1143 /// # Example
1144 ///
1145 /// ```rust
1146 /// use burn_tensor::backend::Backend;
1147 /// use burn_tensor::Tensor;
1148 ///
1149 /// fn example<B: Backend>() {
1150 /// let device = Default::default();
1151 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);
1152 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1153 ///
1154 /// // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.
1155 /// // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]
1156 /// // The resulting tensor will have shape [2, 7].
1157 /// let concat = Tensor::cat(vec![t1, t2], 1);
1158 /// println!("{concat}");
1159 /// }
1160 /// ```
1161 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1162 check!(TensorCheck::cat(&tensors, dim));
1163
1164 Self::new(K::cat(
1165 tensors.into_iter().map(|vector| vector.primitive).collect(),
1166 dim,
1167 ))
1168 }
1169
1170 /// Concatenates all tensors into a new one along a new dimension.
1171 ///
1172 /// # Panics
1173 ///
1174 /// - If all tensors don't have the same shape.
1175 /// - If given dimension is not with range of 0..D2
1176 ///
1177 /// # Example
1178 ///
1179 /// ```rust
1180 /// use burn_tensor::backend::Backend;
1181 /// use burn_tensor::Tensor;
1182 ///
1183 /// fn example<B: Backend>() {
1184 /// let device = Default::default();
1185 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1186 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1187 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1188 ///
1189 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1190 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1191 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1192 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1193 /// // The resulting tensor will have shape [3, 2, 3].
1194 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1195 /// println!("{stacked}");
1196 /// }
1197 /// ```
1198 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1199 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1200 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1201 Tensor::<B, D2, K>::cat(tensors, dim)
1202 }
1203
1204 /// Iterate over slices of tensors alongside a given dimension.
1205 ///
1206 /// # Panics
1207 ///
1208 /// If given dimension is greater than or equal to tensor rank.
1209 ///
1210 /// # Returns
1211 ///
1212 /// A tensor iterator.
1213 ///
1214 /// # Example
1215 ///
1216 /// ```rust
1217 /// use burn_tensor::backend::Backend;
1218 /// use burn_tensor::Tensor;
1219 /// fn example<B: Backend>() {
1220 /// let device = Default::default();
1221 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1222 /// // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.
1223 /// let iter = tensor.iter_dim(0);
1224 /// for (i,tensor) in iter.enumerate() {
1225 /// println!("Tensor {}: {}", i, tensor);
1226 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1227 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1228 /// }
1229 /// }
1230 /// ```
1231 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1232 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1233 DimIter::new(self, dim)
1234 }
1235
1236 /// Returns a new tensor with the given dimension narrowed to the given range.
1237 ///
1238 /// # Panics
1239 ///
1240 /// - If the dimension is greater than the number of dimensions of the tensor.
1241 /// - If the given range exceeds the number of elements on the given dimension.
1242 ///
1243 /// # Returns
1244 ///
1245 /// A new tensor with the given dimension narrowed to the given range.
1246 ///
1247 /// # Example
1248 ///
1249 /// ```rust
1250 /// use burn_tensor::backend::Backend;
1251 /// use burn_tensor::Tensor;
1252 ///
1253 /// fn example<B: Backend>() {
1254 /// let device = Default::default();
1255 /// // Create a 2D tensor with dimensions [4, 3]
1256 /// let tensor = Tensor::<B, 2>::from_data(
1257 /// [
1258 /// [3.0, 4.9, 2.0],
1259 /// [2.0, 1.9, 3.0],
1260 /// [6.0, 1.5, 7.0],
1261 /// [3.0, 4.9, 9.0],
1262 /// ],
1263 /// &device,
1264 /// );
1265 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1266 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1267 /// // The resulting tensor will have dimensions [3, 3].
1268 /// let narrowed = tensor.narrow(0, 1, 3);
1269 /// println!("{narrowed}");
1270 /// }
1271 /// ```
1272 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1273 check!(TensorCheck::dim_ops::<D>("narrow", dim));
1274 check!(TensorCheck::narrow(&self, dim, start, length));
1275 Self::new(narrow::<B, K>(self.primitive, dim, start, length))
1276 }
1277
1278 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1279 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1280 ///
1281 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1282 /// Otherwise all chunks will be of equal size except for the last one.
1283 ///
1284 /// # Panics
1285 ///
1286 /// If the dimension is greater than the number of dimensions of the tensor.
1287 ///
1288 /// # Returns
1289 /// A vector of tensors.
1290 ///
1291 /// # Example
1292 ///
1293 /// ```rust
1294 /// use burn_tensor::backend::Backend;
1295 /// use burn_tensor::Tensor;
1296 ///
1297 /// fn example<B: Backend>() {
1298 /// let device = Default::default();
1299 /// // Create a 2D tensor with dimensions [4, 3]
1300 /// let tensor = Tensor::<B, 2>::from_data(
1301 /// [
1302 /// [3.0, 4.9, 2.0],
1303 /// [2.0, 1.9, 3.0],
1304 /// [6.0, 1.5, 7.0],
1305 /// [3.0, 4.9, 9.0],
1306 /// ],
1307 /// &device,
1308 /// );
1309 /// // Split the tensor along the dimension 1 into 2 chunks.
1310 /// // The first chuck will have shape [4, 2]:
1311 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1312 /// // The second chunk will have shape [4, 1]:
1313 /// // [[2.0], [3.0], [7.0], [9.0]]
1314 /// let chunks = tensor.chunk(2, 1);
1315 /// println!("{chunks:?}");
1316 /// }
1317 /// ```
1318 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1319 check!(TensorCheck::dim_ops::<D>("chunk", dim));
1320 K::chunk(self.primitive, chunks, dim)
1321 .into_iter()
1322 .map(Self::new)
1323 .collect()
1324 }
1325
1326 /// Splits the tensor into chunks of a specified size along a given dimension.
1327 /// Each chunk is a view of the original tensor.
1328 ///
1329 /// If the tensor size along the given dimension is not divisible by `split_size`,
1330 /// then the last chunk will be smaller.
1331 ///
1332 /// # Panics
1333 ///
1334 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
1335 ///
1336 /// # Returns
1337 ///
1338 /// A vector of tensors.
1339 ///
1340 /// # Example
1341 /// ```rust
1342 /// use burn_tensor::backend::Backend;
1343 /// use burn_tensor::Tensor;
1344 ///
1345 /// fn example<B: Backend>() {
1346 /// let device = Default::default();
1347 /// // Create a 1D tensor with 5 elements
1348 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1349 /// // Split the tensor into chunks of size 2 along dimension 0
1350 /// let chunks = tensor.split(2, 0);
1351 /// // The result is a vector of tensors:
1352 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
1353 /// println!("{:?}", chunks);
1354 /// }
1355 /// ```
1356 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
1357 check!(TensorCheck::split::<D>(
1358 self.shape().dims.as_ref(),
1359 split_size,
1360 dim
1361 ));
1362 K::split(self.primitive, split_size, dim)
1363 .into_iter()
1364 .map(Self::new)
1365 .collect()
1366 }
1367
1368 /// Splits the tensor into chunks with the specified sizes along a given dimension.
1369 /// Each chunk is a view of the original tensor.
1370 ///
1371 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
1372 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
1373 ///
1374 /// # Panics
1375 ///
1376 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
1377 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
1378 ///
1379 /// # Returns
1380 ///
1381 /// A vector of tensors.
1382 ///
1383 /// # Example
1384 /// ```rust
1385 /// use burn_tensor::backend::Backend;
1386 /// use burn_tensor::Tensor;
1387 ///
1388 /// fn example<B: Backend>() {
1389 /// let device = Default::default();
1390 /// // Create a 1D tensor with 5 elements
1391 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1392 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
1393 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
1394 /// // The result is a vector of tensors:
1395 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
1396 /// println!("{:?}", chunks);
1397 /// }
1398 /// ```
1399 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
1400 check!(TensorCheck::split_with_sizes::<D>(
1401 self.shape().dims.as_ref(),
1402 &split_sizes,
1403 dim
1404 ));
1405 K::split_with_sizes(self.primitive, split_sizes, dim)
1406 .into_iter()
1407 .map(Self::new)
1408 .collect()
1409 }
1410
1411 /// Tests if any element in the `tensor` evaluates to True.
1412 ///
1413 /// # Arguments
1414 ///
1415 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1416 ///
1417 /// # Returns
1418 ///
1419 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
1420 /// evaluates to True, False otherwise.
1421 ///
1422 /// # Example
1423 ///
1424 /// ```rust
1425 /// use burn_tensor::backend::Backend;
1426 /// use burn_tensor::{Tensor, Bool};
1427 ///
1428 /// fn example<B: Backend>() {
1429 /// let device = Default::default();
1430 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
1431 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
1432 ///
1433 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1434 /// let any_tensor = tensor.any();
1435 /// println!("{}", any_tensor);
1436 /// // Tensor { data: [true], ... }
1437 ///
1438 /// // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.
1439 /// let any_tensor_two = tensor_two.any();
1440 /// println!("{}", any_tensor_two);
1441 /// // Tensor { data: [false], ... }
1442 /// }
1443 /// ```
1444 pub fn any(self) -> Tensor<B, 1, Bool> {
1445 Tensor::new(K::any(self.primitive))
1446 }
1447
1448 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
1449 ///
1450 /// # Arguments
1451 ///
1452 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1453 /// * `dim` - The axis along which to test.
1454 ///
1455 /// # Returns
1456 ///
1457 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1458 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1459 /// evaluates to True, False otherwise.
1460 ///
1461 /// # Example
1462 ///
1463 /// ```rust
1464 /// use burn_tensor::backend::Backend;
1465 /// use burn_tensor::{Tensor, Bool};
1466 ///
1467 /// fn example<B: Backend>() {
1468 /// let device = Default::default();
1469 /// let tensor =
1470 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
1471 /// // Check if any element in the tensor evaluates to True along the dimension 1.
1472 /// // [[true], [true]],
1473 /// let any_dim = tensor.clone().any_dim(1);
1474 /// println!("{any_dim}");
1475 /// }
1476 /// ```
1477 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1478 Tensor::new(K::any_dim(self.primitive, dim))
1479 }
1480
1481 /// Tests if all elements in the `tensor` evaluate to True.
1482 ///
1483 /// # Arguments
1484 ///
1485 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1486 ///
1487 /// # Returns
1488 ///
1489 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1490 /// evaluate to True, False otherwise.
1491 ///
1492 /// # Example
1493 ///
1494 /// ```rust
1495 /// use burn_tensor::backend::Backend;
1496 /// use burn_tensor::{Tensor, Bool};
1497 ///
1498 /// fn example<B: Backend>() {
1499 /// let device = Default::default();
1500 /// let tensor =
1501 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
1502 /// // Check if all elements in the tensor evaluate to True (which is not the case).
1503 /// // [false]
1504 /// let all = tensor.all();
1505 /// println!("{all}");
1506 /// }
1507 /// ```
1508 pub fn all(self) -> Tensor<B, 1, Bool> {
1509 Tensor::new(K::all(self.primitive))
1510 }
1511
1512 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1513 ///
1514 /// # Arguments
1515 ///
1516 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1517 /// * `dim` - The axis along which to test.
1518 ///
1519 /// # Returns
1520 ///
1521 /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis
1522 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1523 /// evaluates to True, False otherwise.
1524 ///
1525 /// # Example
1526 ///
1527 /// ```rust
1528 /// use burn_tensor::backend::Backend;
1529 /// use burn_tensor::{Tensor, Bool};
1530 ///
1531 /// fn example<B: Backend>() {
1532 /// let device = Default::default();
1533 /// let tensor =
1534 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
1535 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
1536 /// // [[true, true, false]]
1537 /// let all_dim = tensor.clone().all_dim(0);
1538 /// println!("{all_dim}");
1539 /// }
1540 /// ```
1541 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1542 Tensor::new(K::all_dim(self.primitive, dim))
1543 }
1544
1545 /// Convert the tensor into a scalar.
1546 ///
1547 /// # Panics
1548 ///
1549 /// - If the tensor doesn't have one element.
1550 /// - If the backend fails to read the tensor data synchronously.
1551 ///
1552 /// # Returns
1553 ///
1554 /// The scalar value of the tensor.
1555 ///
1556 /// # Example
1557 ///
1558 /// ```rust
1559 /// use burn_tensor::backend::Backend;
1560 /// use burn_tensor::Tensor;
1561 ///
1562 /// fn example<B: Backend>() {
1563 /// let device = Default::default();
1564 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
1565 /// // Convert the tensor with a single element into a scalar.
1566 /// let scalar = tensor.into_scalar();
1567 /// println!("{scalar}");
1568 /// }
1569 /// ```
1570 pub fn into_scalar(self) -> K::Elem {
1571 crate::try_read_sync(self.into_scalar_async()).expect(
1572 "Failed to read tensor data synchronously. This can happen on platforms
1573 that don't support blocking futures like WASM. Try into_scalar_async instead.",
1574 )
1575 }
1576
1577 /// Convert the tensor into a scalar.
1578 ///
1579 /// # Panics
1580 ///
1581 /// If the tensor doesn't have one element.
1582 pub async fn into_scalar_async(self) -> K::Elem {
1583 check!(TensorCheck::into_scalar::<D>(&self.shape()));
1584 let x = self.into_data_async().await.iter().next().unwrap();
1585 x
1586 }
1587
1588 /// Broadcast the tensor to the given shape.
1589 ///
1590 /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size
1591 /// (which can be inferred with `-1`).
1592 ///
1593 /// # Arguments
1594 ///
1595 /// * `shape` - The shape to broadcast the tensor to.
1596 /// Can contain -1 for dimensions that should be inferred.
1597 /// The number of elements in the shape must be greater or equal as
1598 /// the number of dimensions of the tensor.
1599 ///
1600 /// # Panics
1601 ///
1602 /// If the tensor cannot be broadcasted to the given shape.
1603 ///
1604 /// # Returns
1605 ///
1606 /// A new tensor with the given shape.
1607 ///
1608 /// # Example
1609 ///
1610 /// ```rust
1611 /// use burn_tensor::backend::Backend;
1612 /// use burn_tensor::Tensor;
1613 ///
1614 /// fn example<B: Backend>() {
1615 /// let device = Default::default();
1616 /// // Create a 2D tensor with dimensions [3, 1]
1617 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
1618 /// // Expand the tensor to a new shape [3, 4]
1619 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
1620 /// let expanded = tensor.expand([3, 4]);
1621 /// println!("{}", expanded);
1622 /// }
1623 /// ```
1624 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
1625 let shape = shape.into_shape(&self.shape());
1626 check!(TensorCheck::expand::<D, D2>(
1627 "expand",
1628 &self.shape(),
1629 &shape,
1630 ));
1631
1632 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
1633 }
1634}
1635
1636/// Iterator given by (Tensor::iter_dim).
1637pub struct DimIter<B, const D: usize, K>
1638where
1639 B: Backend,
1640 K: BasicOps<B>,
1641{
1642 start: usize,
1643 end: usize,
1644 dim: usize,
1645 ranges: [Range<usize>; D],
1646 tensor: Tensor<B, D, K>,
1647}
1648
1649impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
1650 type Item = Tensor<B, D, K>;
1651
1652 fn next(&mut self) -> Option<Self::Item> {
1653 if self.start >= self.end {
1654 return None;
1655 }
1656
1657 let mut ranges = self.ranges.clone();
1658 ranges[self.dim] = self.start..(self.start + 1);
1659
1660 let slice = self.tensor.clone().slice(ranges);
1661 self.start += 1;
1662
1663 Some(slice)
1664 }
1665}
1666
1667impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
1668 fn next_back(&mut self) -> Option<Self::Item> {
1669 if self.start >= self.end {
1670 return None;
1671 }
1672
1673 let mut ranges = self.ranges.clone();
1674 ranges[self.dim] = (self.end - 1)..self.end;
1675
1676 let slice = self.tensor.clone().slice(ranges);
1677 self.end = self.end.saturating_sub(1);
1678
1679 Some(slice)
1680 }
1681}
1682
1683impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
1684 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
1685 let dims = tensor.dims();
1686 let ranges = dims
1687 .iter()
1688 .map(|&dim| 0..dim)
1689 .collect::<Vec<Range<usize>>>();
1690 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
1691 Self {
1692 end: dims[dim],
1693 ranges,
1694 start: 0,
1695 dim,
1696 tensor,
1697 }
1698 }
1699}
1700
1701impl<B, const D: usize, K> Tensor<B, D, K>
1702where
1703 B: Backend,
1704 K: BasicOps<B>,
1705 <K as BasicOps<B>>::Elem: Debug,
1706{
1707 #[inline]
1708 fn push_newline_indent(acc: &mut String, indent: usize) {
1709 acc.push('\n');
1710 for _ in 0..indent {
1711 acc.push(' ');
1712 }
1713 }
1714 fn fmt_inner_tensor(
1715 &self,
1716 acc: &mut String,
1717 depth: usize,
1718 multi_index: &mut [usize],
1719 range: (usize, usize),
1720 precision: Option<usize>,
1721 ) {
1722 let (start, end) = range;
1723 for i in start..end {
1724 if i > 0 {
1725 acc.push_str(", ");
1726 }
1727 multi_index[depth] = i;
1728 let range: [Range<usize>; D] =
1729 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
1730
1731 let data =
1732 burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
1733
1734 if let Some(data) = data {
1735 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
1736 match (precision, K::name()) {
1737 (Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
1738 (_, "Bool") => acc.push_str(&format!("{}", elem.to_bool())),
1739 _ => acc.push_str(&format!("{:?}", elem)),
1740 }
1741 } else {
1742 acc.push_str("<Tensor data not available>");
1743 }
1744 }
1745 }
1746
1747 fn fmt_outer_tensor(
1748 &self,
1749 acc: &mut String,
1750 depth: usize,
1751 multi_index: &mut [usize],
1752 print_options: &PrintOptions,
1753 summarize: bool,
1754 range: (usize, usize),
1755 ) {
1756 let (start, end) = range;
1757 for i in start..end {
1758 if i > start {
1759 acc.push(',');
1760 Self::push_newline_indent(acc, depth + 1);
1761 }
1762 acc.push('[');
1763 multi_index[depth] = i;
1764 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
1765 acc.push(']');
1766 }
1767 }
1768
1769 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
1770 ///
1771 /// This function is designed to work with tensors of any dimensionality.
1772 /// It traverses the tensor dimensions recursively, converting the elements
1773 /// to strings and appending them to the accumulator string with the
1774 /// appropriate formatting.
1775 ///
1776 /// # Arguments
1777 ///
1778 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
1779 /// * `depth` - The current depth of the tensor dimensions being processed.
1780 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
1781 fn display_recursive(
1782 &self,
1783 acc: &mut String,
1784 depth: usize,
1785 multi_index: &mut [usize],
1786 print_options: &PrintOptions,
1787 summarize: bool,
1788 ) {
1789 let edge_items = print_options.edge_items;
1790
1791 if depth == 0 {
1792 acc.push('[');
1793 }
1794
1795 if depth == self.dims().len() - 1 {
1796 // if we are at the innermost dimension, just push its elements into the accumulator
1797 if summarize && self.dims()[depth] > 2 * edge_items {
1798 // print the starting `edge_items` elements
1799 self.fmt_inner_tensor(
1800 acc,
1801 depth,
1802 multi_index,
1803 (0, edge_items),
1804 print_options.precision,
1805 );
1806 acc.push_str(", ...");
1807 // print the last `edge_items` elements
1808 self.fmt_inner_tensor(
1809 acc,
1810 depth,
1811 multi_index,
1812 (self.dims()[depth] - edge_items, self.dims()[depth]),
1813 print_options.precision,
1814 );
1815 } else {
1816 // print all the elements
1817 self.fmt_inner_tensor(
1818 acc,
1819 depth,
1820 multi_index,
1821 (0, self.dims()[depth]),
1822 print_options.precision,
1823 );
1824 }
1825 } else {
1826 // otherwise, iterate through the current dimension and recursively display the inner tensors
1827 if summarize && self.dims()[depth] > 2 * edge_items {
1828 self.fmt_outer_tensor(
1829 acc,
1830 depth,
1831 multi_index,
1832 print_options,
1833 summarize,
1834 (0, edge_items),
1835 );
1836
1837 acc.push(',');
1838 Self::push_newline_indent(acc, depth + 1);
1839 acc.push_str("...");
1840 Self::push_newline_indent(acc, depth + 1);
1841
1842 self.fmt_outer_tensor(
1843 acc,
1844 depth,
1845 multi_index,
1846 print_options,
1847 summarize,
1848 (self.dims()[depth] - edge_items, self.dims()[depth]),
1849 );
1850 } else {
1851 self.fmt_outer_tensor(
1852 acc,
1853 depth,
1854 multi_index,
1855 print_options,
1856 summarize,
1857 (0, self.dims()[depth]),
1858 );
1859 }
1860 }
1861
1862 if depth == 0 {
1863 acc.push(']');
1864 }
1865 }
1866}
1867
1868#[derive(Clone, Debug)]
1869/// Options for Tensor pretty printing
1870pub struct PrintOptions {
1871 /// number of elements to start summarizing tensor
1872 pub threshold: usize,
1873
1874 /// number of starting elements and ending elements to display
1875 pub edge_items: usize,
1876
1877 /// Precision for floating point numbers
1878 pub precision: Option<usize>,
1879}
1880
1881static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
1882
1883impl PrintOptions {
1884 /// Print options with default values
1885 pub const fn const_default() -> Self {
1886 Self {
1887 threshold: 1000,
1888 edge_items: 3,
1889 precision: None,
1890 }
1891 }
1892}
1893
1894impl Default for PrintOptions {
1895 fn default() -> Self {
1896 Self::const_default()
1897 }
1898}
1899
1900/// Set print options
1901pub fn set_print_options(options: PrintOptions) {
1902 let mut print_opts = PRINT_OPTS.write().unwrap();
1903 *print_opts = options;
1904}
1905
1906/// Pretty print tensors
1907impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
1908where
1909 B: Backend,
1910 B::IntElem: core::fmt::Display,
1911 K: BasicOps<B>,
1912 <K as BasicOps<B>>::Elem: Debug,
1913{
1914 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1915 writeln!(f, "Tensor {{")?;
1916
1917 {
1918 // Do not lock the mutex for the whole function
1919 let mut po = { PRINT_OPTS.read().unwrap().clone() };
1920
1921 // Override the precision if it is set from the formatter
1922 // This will be possible when the tensor is printed using the `{:.*}` syntax
1923 if let Some(precision) = f.precision() {
1924 po.precision = Some(precision);
1925 }
1926
1927 let mut acc = String::new();
1928 let mut multi_index = vec![0; D];
1929 let summarize = self.shape().num_elements() > po.threshold;
1930
1931 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
1932
1933 writeln!(f, " data:")?;
1934 write!(f, "{acc}")?;
1935 writeln!(f, ",")?;
1936 }
1937
1938 writeln!(f, " shape: {:?},", self.dims())?;
1939 writeln!(f, " device: {:?},", self.device())?;
1940 writeln!(f, " backend: {:?},", B::name(&self.device()))?;
1941 writeln!(f, " kind: {:?},", K::name())?;
1942
1943 let dtype = self.primitive.dtype();
1944
1945 writeln!(f, " dtype: {:?},", dtype.name())?;
1946 write!(f, "}}")
1947 }
1948}
1949
1950/// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g.
1951/// ```rust
1952/// use burn_tensor::backend::Backend;
1953/// use burn_tensor::{Tensor, T};
1954///
1955/// fn example<B: Backend>() {
1956/// let device = Default::default();
1957/// let tensor = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
1958/// let transposed = tensor^T;
1959/// }
1960/// ```
1961pub struct T;
1962
1963impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
1964 type Output = Self;
1965 fn bitxor(self, _: T) -> Self::Output {
1966 self.transpose()
1967 }
1968}
1969
1970/// Trait that list all operations that can be applied on all tensors.
1971///
1972/// # Warnings
1973///
1974/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
1975pub trait BasicOps<B: Backend>: TensorKind<B> {
1976 /// The type of the tensor elements.
1977 type Elem: Element;
1978
1979 /// Creates an empty tensor with the given shape.
1980 ///
1981 /// # Arguments
1982 ///
1983 /// * `shape` - The shape of the tensor.
1984 /// * `device` - The device on which the tensor will be allocated.
1985 ///
1986 /// # Returns
1987 ///
1988 /// The empty tensor.
1989 ///
1990 /// # Remarks
1991 ///
1992 /// This is a low-level function used internally by the library to call different backend functions
1993 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1994 /// or use this function directly.
1995 ///
1996 /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
1997 /// which is more high-level and designed for public use.
1998 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive;
1999
2000 /// Reshapes the tensor.
2001 ///
2002 /// # Arguments
2003 ///
2004 /// * `tensor` - The tensor.
2005 /// * `shape` - The new shape of the tensor.
2006 ///
2007 /// # Returns
2008 ///
2009 /// The reshaped tensor.
2010 ///
2011 /// # Remarks
2012 ///
2013 /// This is a low-level function used internally by the library to call different backend functions
2014 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2015 /// or use this function directly.
2016 ///
2017 /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
2018 /// which is more high-level and designed for public use.
2019 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2020
2021 /// Transposes a tensor.
2022 ///
2023 /// # Arguments
2024 ///
2025 /// * `tensor` - The tensor to transpose.
2026 ///
2027 /// # Returns
2028 ///
2029 /// The transposed tensor.
2030 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
2031
2032 /// Swaps two dimensions of a tensor.
2033 ///
2034 /// # Arguments
2035 ///
2036 /// * `tensor` - The tensor to swap the dimensions of.
2037 /// * `dim1` - The first dimension to swap.
2038 /// * `dim2` - The second dimension to swap.
2039 ///
2040 /// # Returns
2041 ///
2042 /// The tensor with the dimensions swapped.
2043 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
2044
2045 /// Permutes the dimensions of a tensor.
2046 ///
2047 /// # Arguments
2048 ///
2049 /// * `tensor` - The tensor to permute the dimensions of.
2050 /// * `axes` - The new order of the dimensions.
2051 ///
2052 /// # Returns
2053 ///
2054 /// The tensor with the dimensions permuted.
2055 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2056
2057 /// Flips the tensor along the given axes.
2058 ///
2059 /// # Arguments
2060 ///
2061 /// * `tensor` - The tensor to flip.
2062 /// * `axes` - The axes to flip the tensor along.
2063 ///
2064 /// # Returns
2065 ///
2066 /// The tensor with the axes flipped.
2067 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2068
2069 /// Select tensor elements corresponding for the given ranges.
2070 ///
2071 /// # Arguments
2072 ///
2073 /// * `tensor` - The tensor.
2074 /// * `ranges` - The ranges of the elements to select.
2075 ///
2076 /// # Returns
2077 ///
2078 /// The selected elements.
2079 ///
2080 /// # Remarks
2081 ///
2082 /// This is a low-level function used internally by the library to call different backend functions
2083 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2084 /// or use this function directly.
2085 ///
2086 /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2087 /// which is more high-level and designed for public use.
2088 fn slice(tensor: Self::Primitive, range: &[Range<usize>]) -> Self::Primitive;
2089
2090 /// Assigns the given value to the tensor elements corresponding for the given ranges.
2091 ///
2092 /// # Arguments
2093 ///
2094 /// * `tensor` - The tensor.
2095 /// * `ranges` - The ranges of the elements to select.
2096 /// * `value` - The value to assign.
2097 ///
2098 /// # Returns
2099 ///
2100 /// The tensor with the assigned values.
2101 ///
2102 /// # Remarks
2103 ///
2104 /// This is a low-level function used internally by the library to call different backend functions
2105 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2106 /// or use this function directly.
2107 ///
2108 /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2109 /// which is more high-level and designed for public use.
2110 fn slice_assign(
2111 tensor: Self::Primitive,
2112 ranges: &[Range<usize>],
2113 value: Self::Primitive,
2114 ) -> Self::Primitive;
2115
2116 /// Returns the device on which the tensor is allocated.
2117 ///
2118 /// # Arguments
2119 ///
2120 /// * `tensor` - The tensor.
2121 ///
2122 /// # Returns
2123 ///
2124 /// The device on which the tensor is allocated.
2125 ///
2126 /// # Remarks
2127 ///
2128 /// This is a low-level function used internally by the library to call different backend functions
2129 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2130 /// or use this function directly.
2131 ///
2132 /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2133 /// which is more high-level and designed for public use.
2134 fn device(tensor: &Self::Primitive) -> B::Device;
2135
2136 /// Moves the tensor to the given device.
2137 ///
2138 /// # Arguments
2139 ///
2140 /// * `tensor` - The tensor.
2141 /// * `device` - The device on which the tensor will be moved.
2142 ///
2143 /// # Returns
2144 ///
2145 /// The tensor on the given device.
2146 ///
2147 /// # Remarks
2148 ///
2149 /// This is a low-level function used internally by the library to call different backend functions
2150 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2151 /// or use this function directly.
2152 ///
2153 /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2154 /// which is more high-level and designed for public use.
2155 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2156
2157 /// Extracts the data from the tensor asynchronously.
2158 ///
2159 /// # Arguments
2160 ///
2161 /// * `tensor` - The tensor.
2162 ///
2163 /// # Returns
2164 ///
2165 /// The data of the tensor.
2166 ///
2167 /// # Remarks
2168 ///
2169 /// This is a low-level function used internally by the library to call different backend functions
2170 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2171 /// or use this function directly.
2172 ///
2173 /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
2174 /// which is more high-level and designed for public use.
2175 fn into_data_async(
2176 tensor: Self::Primitive,
2177 ) -> impl Future<Output = TensorData> + 'static + Send;
2178
2179 /// Read the data from the tensor using a transaction.
2180 ///
2181 /// # Remarks
2182 ///
2183 /// This is a low-level function used internally by the library to call different backend functions
2184 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2185 /// or use this function directly.
2186 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
2187
2188 /// Creates a tensor from the given data.
2189 ///
2190 /// # Arguments
2191 ///
2192 /// * `data` - The data of the tensor.
2193 /// * `device` - The device on which the tensor will be allocated.
2194 ///
2195 /// # Returns
2196 ///
2197 /// The tensor.
2198 ///
2199 /// # Remarks
2200 ///
2201 /// This is a low-level function used internally by the library to call different backend functions
2202 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2203 /// or use this function directly.
2204 ///
2205 /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
2206 /// which is more high-level and designed for public use.
2207 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
2208 /// Creates a tensor from the given data enforcing the given data type.
2209 ///
2210 /// # Remarks
2211 ///
2212 /// This is a low-level function used internally by the library to call different backend functions
2213 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2214 /// or use this function directly.
2215 ///
2216 /// For creating a tensor from data, users should prefer the [Tensor::from_data_dtype](Tensor::from_data_dtype)
2217 /// function, which is more high-level and designed for public use.
2218 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;
2219
2220 /// Repeat the tensor along the given dimension.
2221 ///
2222 /// # Arguments
2223 ///
2224 /// * `tensor` - The tensor.
2225 /// * `dim` - The dimension along which the tensor will be repeated.
2226 /// * `times` - The number of times the tensor will be repeated.
2227 ///
2228 /// # Returns
2229 ///
2230 /// The repeated tensor.
2231 ///
2232 /// # Remarks
2233 ///
2234 /// This is a low-level function used internally by the library to call different backend functions
2235 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2236 /// or use this function directly.
2237 ///
2238 /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
2239 /// which is more high-level and designed for public use.
2240 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
2241
2242 /// Concatenates the given tensors along the given dimension.
2243 ///
2244 /// # Arguments
2245 ///
2246 /// * `vectors` - The tensors to concatenate.
2247 /// * `dim` - The dimension along which the tensors will be concatenated.
2248 ///
2249 /// # Returns
2250 ///
2251 /// The concatenated tensor.
2252 ///
2253 /// # Remarks
2254 ///
2255 /// This is a low-level function used internally by the library to call different backend functions
2256 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2257 /// or use this function directly.
2258 ///
2259 /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
2260 /// which is more high-level and designed for public use.
2261 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
2262
2263 /// Attempts to split the tensor along the given dimension into chunks.
2264 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2265 ///
2266 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2267 /// Otherwise all chunks will be of equal size except for the last one.
2268 ///
2269 /// # Panics
2270 ///
2271 /// If the dimension is greater than the number of dimensions of the tensor.
2272 ///
2273 /// # Returns
2274 /// A vector of tensors.
2275 ///
2276 /// # Remarks
2277 ///
2278 /// This is a low-level function used internally by the library to call different backend functions
2279 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2280 /// or use this function directly.
2281 ///
2282 /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
2283 /// which is more high-level and designed for public use.
2284 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;
2285
2286 /// Splits the tensor into chunks of a specified size along a given dimension.
2287 /// Each chunk is a view of the original tensor.
2288 ///
2289 /// # Panics
2290 ///
2291 /// If the dimension to split along is greater than the number of dimensions of the tensor.
2292 ///
2293 /// # Returns
2294 ///
2295 /// A vector of tensors.
2296 ///
2297 /// # Remarks
2298 /// This is a low-level function used internally by the library to call different backend functions
2299 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2300 /// or use this function directly.
2301 ///
2302 /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
2303 /// which is more high-level and designed for public use.
2304 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
2305
2306 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2307 /// Each chunk is a view of the original tensor.
2308 ///
2309 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2310 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2311 ///
2312 /// # Panics
2313 ///
2314 /// If the dimension to split along is greater than the number of dimensions of the tensor or
2315 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2316 ///
2317 /// # Returns
2318 ///
2319 /// A vector of tensors.
2320 ///
2321 /// # Remarks
2322 /// This is a low-level function used internally by the library to call different backend functions
2323 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2324 /// or use this function directly.
2325 ///
2326 /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
2327 /// which is more high-level and designed for public use.
2328 fn split_with_sizes(
2329 tensor: Self::Primitive,
2330 split_sizes: Vec<usize>,
2331 dim: usize,
2332 ) -> Vec<Self::Primitive>;
2333
2334 /// Equates the given tensors.
2335 ///
2336 /// # Arguments
2337 ///
2338 /// * `lhs` - The left hand side tensor.
2339 /// * `rhs` - The right hand side tensor.
2340 ///
2341 /// # Returns
2342 ///
2343 /// The tensor of booleans indicating whether the corresponding elements are equal.
2344 ///
2345 /// # Remarks
2346 ///
2347 /// This is a low-level function used internally by the library to call different backend functions
2348 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2349 /// or use this function directly.
2350 ///
2351 /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
2352 /// which is more high-level and designed for public use.
2353 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2354
2355 /// Applies element-wise non-equality comparison between the given tensors.
2356 ///
2357 /// # Arguments
2358 ///
2359 /// * `lhs` - The left hand side tensor.
2360 /// * `rhs` - The right hand side tensor.
2361 ///
2362 /// # Returns
2363 ///
2364 /// The tensor of booleans indicating whether the corresponding elements are equal.
2365 ///
2366 /// # Remarks
2367 ///
2368 /// This is a low-level function used internally by the library to call different backend functions
2369 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2370 /// or use this function directly.
2371 ///
2372 /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
2373 /// function, which is more high-level and designed for public use.
2374 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2375
2376 /// Returns the name of the element type.
2377 fn elem_type_name() -> &'static str {
2378 core::any::type_name::<Self::Elem>()
2379 }
2380
2381 /// Returns the tensor data type.
2382 fn dtype(tensor: &Self::Primitive) -> DType {
2383 tensor.dtype()
2384 }
2385
2386 /// Tests if any element in the `tensor` evaluates to True.
2387 ///
2388 /// # Arguments
2389 ///
2390 /// * `tensor` - The tensor to test.
2391 ///
2392 /// # Returns
2393 ///
2394 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
2395 ///
2396 /// # Remarks
2397 ///
2398 /// This is a low-level function used internally by the library to call different backend functions
2399 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2400 /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
2401 /// which is more high-level and designed for public use.
2402 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2403
2404 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
2405 ///
2406 /// # Arguments
2407 ///
2408 /// * tensor - The tensor to test.
2409 /// * dim - The axis along which to test.
2410 ///
2411 /// # Returns
2412 ///
2413 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
2414 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
2415 ///
2416 /// # Remarks
2417 ///
2418 /// This is a low-level function used internally by the library to call different backend functions
2419 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2420 /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
2421 /// which is more high-level and designed for public use.
2422 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2423
2424 /// Tests if all elements in the `tensor` evaluate to True.
2425 ///
2426 /// # Arguments
2427 ///
2428 /// * `tensor` - The tensor to test.
2429 ///
2430 /// # Returns
2431 ///
2432 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
2433 ///
2434 /// # Remarks
2435 ///
2436 /// This is a low-level function used internally by the library to call different backend functions
2437 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2438 /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
2439 /// which is more high-level and designed for public use.
2440 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2441
2442 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2443 ///
2444 /// # Arguments
2445 ///
2446 /// * `tensor` - The tensor to test.
2447 ///
2448 /// # Returns
2449 ///
2450 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
2451 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
2452 ///
2453 /// # Remarks
2454 ///
2455 /// This is a low-level function used internally by the library to call different backend functions
2456 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2457 /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
2458 /// which is more high-level and designed for public use.
2459 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2460
2461 /// Broadcasts the given tensor to the specified shape.
2462 ///
2463 /// # Arguments
2464 ///
2465 /// * `tensor` - The tensor to broadcast.
2466 /// * `shape` - The shape to broadcast to.
2467 ///
2468 /// # Returns
2469 ///
2470 /// The broadcasted tensor.
2471 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2472}
2473
2474impl<B: Backend> BasicOps<B> for Float {
2475 type Elem = B::FloatElem;
2476
2477 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2478 TensorPrimitive::Float(B::float_empty(shape, device))
2479 }
2480
2481 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2482 tr.register_float(tensor);
2483 }
2484
2485 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2486 match tensor {
2487 TensorPrimitive::Float(tensor) => {
2488 TensorPrimitive::Float(B::float_reshape(tensor, shape))
2489 }
2490 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
2491 }
2492 }
2493
2494 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2495 match tensor {
2496 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
2497 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
2498 }
2499 }
2500
2501 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2502 match tensor {
2503 TensorPrimitive::Float(tensor) => {
2504 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
2505 }
2506 TensorPrimitive::QFloat(tensor) => {
2507 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
2508 }
2509 }
2510 }
2511
2512 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2513 match tensor {
2514 TensorPrimitive::Float(tensor) => {
2515 TensorPrimitive::Float(B::float_slice(tensor, ranges))
2516 }
2517 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)),
2518 }
2519 }
2520
2521 fn slice_assign(
2522 tensor: Self::Primitive,
2523 ranges: &[Range<usize>],
2524 value: Self::Primitive,
2525 ) -> Self::Primitive {
2526 match (tensor, value) {
2527 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => {
2528 TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value))
2529 }
2530 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => {
2531 TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value))
2532 }
2533 _ => panic!("Primitive type mismatch for tensor and value"),
2534 }
2535 }
2536
2537 fn device(tensor: &Self::Primitive) -> Device<B> {
2538 match tensor {
2539 TensorPrimitive::Float(tensor) => B::float_device(tensor),
2540 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
2541 }
2542 }
2543
2544 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2545 match tensor {
2546 TensorPrimitive::Float(tensor) => {
2547 TensorPrimitive::Float(B::float_to_device(tensor, device))
2548 }
2549 TensorPrimitive::QFloat(tensor) => {
2550 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
2551 }
2552 }
2553 }
2554
2555 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2556 match tensor {
2557 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
2558 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
2559 }
2560 }
2561
2562 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2563 match data.dtype {
2564 DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
2565 _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),
2566 }
2567 }
2568
2569 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2570 match dtype {
2571 DType::QFloat(_strategy) => {
2572 TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))
2573 }
2574 _ if dtype.is_float() => {
2575 TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))
2576 }
2577 _ => panic!("Expected float dtype, got {dtype:?}"),
2578 }
2579 }
2580
2581 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2582 match tensor {
2583 TensorPrimitive::Float(tensor) => {
2584 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
2585 }
2586 TensorPrimitive::QFloat(tensor) => {
2587 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
2588 }
2589 }
2590 }
2591
2592 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2593 match vectors.first().unwrap() {
2594 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
2595 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
2596 dim,
2597 )),
2598 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
2599 vectors
2600 .into_iter()
2601 .map(|tensor| {
2602 if let TensorPrimitive::QFloat(t) = tensor {
2603 t
2604 } else {
2605 panic!("Concatenation only works with vector of QFloat")
2606 }
2607 })
2608 .collect(),
2609 dim,
2610 )),
2611 }
2612 }
2613
2614 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2615 B::float_equal(lhs.tensor(), rhs.tensor())
2616 }
2617
2618 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2619 B::float_not_equal(lhs.tensor(), rhs.tensor())
2620 }
2621
2622 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2623 B::float_any(tensor.tensor())
2624 }
2625
2626 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2627 B::float_any_dim(tensor.tensor(), dim)
2628 }
2629
2630 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2631 B::float_all(tensor.tensor())
2632 }
2633
2634 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2635 B::float_all_dim(tensor.tensor(), dim)
2636 }
2637
2638 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2639 match tensor {
2640 TensorPrimitive::Float(tensor) => {
2641 TensorPrimitive::Float(B::float_permute(tensor, axes))
2642 }
2643 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
2644 }
2645 }
2646
2647 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2648 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
2649 }
2650
2651 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2652 match tensor {
2653 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
2654 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
2655 }
2656 }
2657
2658 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2659 match tensor {
2660 TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim)
2661 .into_iter()
2662 .map(TensorPrimitive::Float)
2663 .collect(),
2664 TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim)
2665 .into_iter()
2666 .map(TensorPrimitive::QFloat)
2667 .collect(),
2668 }
2669 }
2670
2671 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2672 match tensor {
2673 TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
2674 .into_iter()
2675 .map(TensorPrimitive::Float)
2676 .collect(),
2677 TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
2678 .into_iter()
2679 .map(TensorPrimitive::QFloat)
2680 .collect(),
2681 }
2682 }
2683
2684 fn split_with_sizes(
2685 tensor: Self::Primitive,
2686 split_sizes: Vec<usize>,
2687 dim: usize,
2688 ) -> Vec<Self::Primitive> {
2689 match tensor {
2690 TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
2691 .into_iter()
2692 .map(TensorPrimitive::Float)
2693 .collect(),
2694 TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
2695 .into_iter()
2696 .map(TensorPrimitive::QFloat)
2697 .collect(),
2698 }
2699 }
2700}
2701
2702impl<B: Backend> BasicOps<B> for Int {
2703 type Elem = B::IntElem;
2704
2705 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2706 B::int_empty(shape, device)
2707 }
2708
2709 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2710 tr.register_int(tensor);
2711 }
2712
2713 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2714 B::int_reshape(tensor, shape)
2715 }
2716
2717 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2718 B::int_transpose(tensor)
2719 }
2720
2721 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2722 B::int_swap_dims(tensor, dim1, dim2)
2723 }
2724
2725 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2726 B::int_slice(tensor, ranges)
2727 }
2728
2729 fn slice_assign(
2730 tensor: Self::Primitive,
2731 ranges: &[Range<usize>],
2732 value: Self::Primitive,
2733 ) -> Self::Primitive {
2734 B::int_slice_assign(tensor, ranges, value)
2735 }
2736
2737 fn device(tensor: &Self::Primitive) -> Device<B> {
2738 B::int_device(tensor)
2739 }
2740
2741 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2742 B::int_to_device(tensor, device)
2743 }
2744
2745 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2746 B::int_into_data(tensor).await
2747 }
2748
2749 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2750 B::int_from_data(data.convert::<B::IntElem>(), device)
2751 }
2752
2753 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2754 if !dtype.is_int() {
2755 panic!("Expected int dtype, got {dtype:?}")
2756 }
2757
2758 B::int_from_data(data.convert_dtype(dtype), device)
2759 }
2760
2761 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2762 B::int_repeat_dim(tensor, dim, times)
2763 }
2764
2765 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2766 B::int_equal(lhs, rhs)
2767 }
2768
2769 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2770 B::int_not_equal(lhs, rhs)
2771 }
2772
2773 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2774 B::int_cat(vectors, dim)
2775 }
2776
2777 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2778 B::int_any(tensor)
2779 }
2780
2781 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2782 B::int_any_dim(tensor, dim)
2783 }
2784
2785 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2786 B::int_all(tensor)
2787 }
2788
2789 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2790 B::int_all_dim(tensor, dim)
2791 }
2792
2793 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2794 B::int_permute(tensor, axes)
2795 }
2796
2797 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2798 B::int_expand(tensor, shape)
2799 }
2800
2801 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2802 B::int_flip(tensor, axes)
2803 }
2804
2805 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2806 B::int_chunk(tensor, chunks, dim)
2807 }
2808
2809 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2810 B::int_split(tensor, split_size, dim)
2811 }
2812
2813 fn split_with_sizes(
2814 tensor: Self::Primitive,
2815 split_sizes: Vec<usize>,
2816 dim: usize,
2817 ) -> Vec<Self::Primitive> {
2818 B::int_split_with_sizes(tensor, split_sizes, dim)
2819 }
2820}
2821
2822impl<B: Backend> BasicOps<B> for Bool {
2823 type Elem = B::BoolElem;
2824
2825 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2826 B::bool_empty(shape, device)
2827 }
2828
2829 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2830 tr.register_bool(tensor);
2831 }
2832
2833 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2834 B::bool_reshape(tensor, shape)
2835 }
2836
2837 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2838 B::bool_transpose(tensor)
2839 }
2840
2841 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2842 B::bool_swap_dims(tensor, dim1, dim2)
2843 }
2844
2845 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2846 B::bool_slice(tensor, ranges)
2847 }
2848
2849 fn slice_assign(
2850 tensor: Self::Primitive,
2851 ranges: &[Range<usize>],
2852 value: Self::Primitive,
2853 ) -> Self::Primitive {
2854 B::bool_slice_assign(tensor, ranges, value)
2855 }
2856
2857 fn device(tensor: &Self::Primitive) -> Device<B> {
2858 B::bool_device(tensor)
2859 }
2860
2861 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2862 B::bool_to_device(tensor, device)
2863 }
2864
2865 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2866 B::bool_into_data(tensor).await
2867 }
2868
2869 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2870 B::bool_from_data(data.convert::<B::BoolElem>(), device)
2871 }
2872
2873 fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive {
2874 // Backends only use one bool representation dtype
2875 if dtype != B::BoolElem::dtype() {
2876 panic!("Expected bool dtype, got {dtype:?}")
2877 }
2878 B::bool_from_data(data.convert_dtype(dtype), device)
2879 }
2880
2881 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2882 B::bool_repeat_dim(tensor, dim, times)
2883 }
2884
2885 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2886 B::bool_equal(lhs, rhs)
2887 }
2888
2889 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2890 B::bool_not_equal(lhs, rhs)
2891 }
2892
2893 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2894 B::bool_cat(vectors, dim)
2895 }
2896
2897 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2898 B::bool_any(tensor)
2899 }
2900
2901 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2902 B::bool_any_dim(tensor, dim)
2903 }
2904
2905 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2906 B::bool_all(tensor)
2907 }
2908
2909 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2910 B::bool_all_dim(tensor, dim)
2911 }
2912
2913 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2914 B::bool_permute(tensor, axes)
2915 }
2916
2917 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2918 B::bool_expand(tensor, shape)
2919 }
2920
2921 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2922 B::bool_flip(tensor, axes)
2923 }
2924
2925 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2926 B::bool_chunk(tensor, chunks, dim)
2927 }
2928
2929 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2930 B::bool_split(tensor, split_size, dim)
2931 }
2932
2933 fn split_with_sizes(
2934 tensor: Self::Primitive,
2935 split_sizes: Vec<usize>,
2936 dim: usize,
2937 ) -> Vec<Self::Primitive> {
2938 B::bool_split_with_sizes(tensor, split_sizes, dim)
2939 }
2940}
2941
2942/// Trait used for movedim arguments
2943pub trait MovedimArgs {
2944 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
2945 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
2946}
2947
2948impl MovedimArgs for Vec<i32> {
2949 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2950 let set = self
2951 .iter()
2952 .map(|&dim| {
2953 if dim < 0 {
2954 (D as i32 + dim) as usize
2955 } else {
2956 dim as usize
2957 }
2958 })
2959 .collect::<Vec<usize>>();
2960 check!(TensorCheck::movedim_args_vec::<D>(&set));
2961
2962 set
2963 }
2964}
2965
2966impl MovedimArgs for Vec<usize> {
2967 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2968 check!(TensorCheck::movedim_args_vec::<D>(&self));
2969 self
2970 }
2971}
2972
2973impl MovedimArgs for usize {
2974 #[allow(clippy::vec_init_then_push)]
2975 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2976 check!(TensorCheck::movedim_args_usize::<D>(self));
2977
2978 let mut set = Vec::with_capacity(1);
2979 set.push(self);
2980
2981 set
2982 }
2983}
2984
2985impl MovedimArgs for i32 {
2986 #[allow(clippy::vec_init_then_push)]
2987 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2988 check!(TensorCheck::movedim_args_i32::<D>(self));
2989
2990 let dim = if self < 0 {
2991 (D as i32 + self) as usize
2992 } else {
2993 self as usize
2994 };
2995
2996 let mut set = Vec::with_capacity(1);
2997 set.push(dim);
2998
2999 set
3000 }
3001}
3002
3003/// Trait used for slice arguments
3004pub trait RangesArg<const D2: usize> {
3005 /// Converts into a set of ranges to `[Range<usize>; D2]` for the `tensor.slice()` function
3006 fn into_ranges(self, shape: Shape) -> [Range<usize>; D2];
3007}
3008
3009impl<const D2: usize, T: Into<Slice>> RangesArg<D2> for [T; D2] {
3010 fn into_ranges(self, shape: Shape) -> [Range<usize>; D2] {
3011 // clamp the ranges to the shape dimensions
3012 let ranges = self
3013 .into_iter()
3014 .enumerate()
3015 .map(|(i, range)| range.into().into_range(shape.dims[i]))
3016 .collect::<Vec<_>>();
3017 ranges.try_into().unwrap()
3018 }
3019}
3020
3021impl<T: Into<Slice>> RangesArg<1> for T {
3022 fn into_ranges(self, shape: Shape) -> [Range<usize>; 1] {
3023 [self.into().into_range(shape.dims[0])]
3024 }
3025}
3026
3027/// Trait used for reshape arguments.
3028pub trait ReshapeArgs<const D2: usize> {
3029 /// Converts to a shape.
3030 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3031 self,
3032 tensor: &Tensor<B, D, K>,
3033 ) -> Shape;
3034}
3035
3036impl<const D2: usize> ReshapeArgs<D2> for Shape {
3037 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3038 self,
3039 tensor: &Tensor<B, D, K>,
3040 ) -> Shape {
3041 check!(TensorCheck::reshape_args_usize::<D, D2>(
3042 &tensor.shape(),
3043 &self
3044 ));
3045
3046 self
3047 }
3048}
3049impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3050 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3051 self,
3052 tensor: &Tensor<B, D, K>,
3053 ) -> Shape {
3054 let shape = Shape::from(self);
3055
3056 check!(TensorCheck::reshape_args_usize::<D, D2>(
3057 &tensor.shape(),
3058 &shape
3059 ));
3060
3061 shape
3062 }
3063}
3064
3065impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3066 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3067 self,
3068 tensor: &Tensor<B, D, K>,
3069 ) -> Shape {
3070 // Validate the reshape arguments
3071 check!(TensorCheck::reshape_args_i32(&self));
3072
3073 // Temporary shape
3074 let mut new_shape: [i32; D2] = [1; D2];
3075
3076 // We need to find the index of the 0 dimension and
3077 // replace it with the actual dimension value.
3078 for (i, &s) in self.iter().enumerate() {
3079 if s != 0 {
3080 new_shape[i] = s;
3081 } else {
3082 new_shape[i] = tensor.dims()[i] as i32;
3083 }
3084 }
3085
3086 // Find the index of the inferred dimension (-1)
3087 let infer_index = new_shape.iter().position(|x| x == &-1);
3088
3089 // Handle the case where the dimension is inferred (via -1)
3090 if let Some(index) = infer_index {
3091 // Handle the case where the dimension is inferred
3092 let mut product = 1;
3093 for (i, &s) in new_shape.iter().enumerate() {
3094 if i != index {
3095 product *= s;
3096 }
3097 }
3098 let product_current = tensor.shape().num_elements() as i32;
3099
3100 new_shape[index] = product_current / product;
3101
3102 // Check if the reshape is valid
3103 if product_current % product != 0 {
3104 panic!(
3105 "Cannot reshape tensor of shape {:?} to shape {:?}",
3106 tensor.shape(),
3107 new_shape
3108 );
3109 }
3110 };
3111
3112 // Convert each element to usize
3113 let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3114
3115 Shape::from(new_shape)
3116 }
3117}
3118
3119/// Trait used for broadcast arguments.
3120pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3121 /// Converts to a shape.
3122 fn into_shape(self, shape: &Shape) -> Shape;
3123}
3124
3125impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3126 fn into_shape(self, _shape: &Shape) -> Shape {
3127 self
3128 }
3129}
3130impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3131 fn into_shape(self, _shape: &Shape) -> Shape {
3132 Shape::from(self)
3133 }
3134}
3135
3136impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3137 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3138 fn into_shape(self, shape: &Shape) -> Shape {
3139 if self.len() < shape.num_dims() {
3140 panic!("Broadcast arguments must be greater than the number of dimensions");
3141 }
3142
3143 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3144 let new_shape: Vec<_> = self
3145 .iter()
3146 .rev()
3147 .map(|x| {
3148 let primitive = x.to_i64();
3149 if primitive < -1 || primitive == 0 {
3150 panic!("Broadcast arguments must be positive or -1");
3151 }
3152 primitive
3153 })
3154 .zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3155 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3156 .collect::<Vec<_>>()
3157 .into_iter()
3158 .rev()
3159 .collect();
3160
3161 if new_shape.iter().any(|&x| x == 0) {
3162 panic!("Cannot substitute -1 for a non-existing dimension");
3163 }
3164
3165 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3166
3167 Shape::from(new_shape)
3168 }
3169}
3170
3171impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3172where
3173 B: Backend,
3174 K: BasicOps<B>,
3175 K::Elem: Debug + Copy + Serialize,
3176{
3177 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3178 let data = self.to_data();
3179 data.serialize(serializer)
3180 }
3181}
3182
3183impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3184where
3185 B: Backend,
3186 K: BasicOps<B>,
3187 K::Elem: Debug + Copy + Deserialize<'de>,
3188{
3189 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3190 let tensor = Tensor::from_data(
3191 TensorData::deserialize(deserializer)?,
3192 &<B::Device as Default>::default(),
3193 );
3194 Ok(tensor)
3195 }
3196}
3197
3198#[cfg(test)]
3199mod tests {
3200 use crate::Shape;
3201 use crate::s;
3202
3203 use super::*;
3204
3205 #[test]
3206 fn slice_range_single_dim_leading() {
3207 let shape = Shape::new([8, 4]);
3208
3209 // Half-open range
3210 assert_eq!([0..5], (0..5).into_ranges(shape.clone()));
3211 assert_eq!([0..5], [0..5].into_ranges(shape.clone()));
3212 assert_eq!([5..7], [-3..-1].into_ranges(shape.clone()));
3213
3214 // Inclusive range
3215 assert_eq!([0..5], (0..=4).into_ranges(shape.clone()));
3216 assert_eq!([0..5], [0..=4].into_ranges(shape.clone()));
3217 assert_eq!([6..8], [-2..=-1].into_ranges(shape.clone()));
3218
3219 // Unbounded start
3220 assert_eq!([0..3], (..3).into_ranges(shape.clone()));
3221 assert_eq!([0..3], [..3].into_ranges(shape.clone()));
3222 assert_eq!([0..3], [..-5].into_ranges(shape.clone()));
3223
3224 // Unbounded end
3225 assert_eq!([5..8], (5..).into_ranges(shape.clone()));
3226 assert_eq!([5..8], [5..].into_ranges(shape.clone()));
3227 assert_eq!([5..8], [-3..].into_ranges(shape.clone()));
3228
3229 // Full range
3230 assert_eq!([0..8], [..].into_ranges(shape));
3231 }
3232
3233 #[test]
3234 fn slice_range_multi_dim() {
3235 let shape = Shape::new([8, 4]);
3236
3237 // Multiple ways to provide ranges
3238 assert_eq!([0..5, 0..4], [0..5, 0..4].into_ranges(shape.clone()));
3239 assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape.clone()));
3240 assert_eq!([0..8, 0..4], [0..=7, 0..=3].into_ranges(shape.clone()));
3241
3242 assert_eq!([0..5, 0..3], [0..5, 0..3].into_ranges(shape.clone()));
3243
3244 assert_eq!([0..8, 0..4], [0.., 0..].into_ranges(shape));
3245 }
3246
3247 #[test]
3248 fn slice_range_multi_dim_index() {
3249 let shape = Shape::new([8, 4]);
3250
3251 // Indices (single integer) should also convert to correct range
3252 assert_eq!([0..1, 2..3], [0, 2].into_ranges(shape.clone()));
3253 assert_eq!([7..8, 3..4], [-1, -1].into_ranges(shape.clone()));
3254 assert_eq!([7..8], (-1).into_ranges(shape.clone()));
3255 assert_eq!([7..8], 7.into_ranges(shape));
3256 }
3257
3258 #[test]
3259 fn slice_range_multi_dim_heterogeneous() {
3260 // Slice macro `s![]` can be used to provide different range types
3261 let shape = Shape::new([8, 4, 2]);
3262 let slice = s![0..5, .., -1];
3263 assert_eq!([0..5, 0..4, 1..2], slice.into_ranges(shape));
3264
3265 let shape = Shape::new([8, 4, 2, 3]);
3266 let slice = s![..=4, 0..=3, .., -2..];
3267 assert_eq!([0..5, 0..4, 0..2, 1..3], slice.into_ranges(shape));
3268
3269 let shape = Shape::new([3, 4]);
3270 let slice = s![1..-1, ..];
3271 assert_eq!([1..2, 0..4], slice.into_ranges(shape));
3272 }
3273}