burn_tensor/tensor/api/base.rs
1#![allow(clippy::single_range_in_vec_init)]
2
3use alloc::vec::Vec;
4
5use alloc::format;
6use alloc::string::String;
7use alloc::vec;
8
9use burn_common::stub::RwLock;
10use core::any::TypeId;
11use core::future::Future;
12use core::iter::repeat;
13use core::{fmt::Debug, ops::Range};
14use serde::{Deserialize, Deserializer};
15
16use serde::{Serialize, Serializer};
17
18use crate::check::TensorCheck;
19use crate::tensor::api::narrow::narrow;
20use crate::{
21 backend::Backend, check, ops::Device, Bool, Float, Int, Shape, TensorData, TensorKind,
22};
23use crate::{DType, Element, TensorPrimitive};
24
25use super::{TensorMetadata, Transaction};
26
27/// A tensor with a given backend, shape and data type.
28///
29/// # Indexing
30/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types
31/// or [`select`](Tensor::select) for numeric types.
32///
33/// ## Example
34///
35/// ```rust
36/// use burn_tensor::backend::Backend;
37/// use burn_tensor::Tensor;
38/// use burn_tensor::Int;
39///
40/// fn example<B: Backend>() {
41/// let device = Default::default();
42///
43/// let tensor = Tensor::<B, 2>::from_data(
44/// [
45/// [3.0, 4.9, 2.0],
46/// [2.0, 1.9, 3.0],
47/// [6.0, 1.5, 7.0],
48/// [3.0, 4.9, 9.0],
49/// ],
50/// &device,
51/// );
52///
53/// // Slice the tensor to get the second and third rows:
54/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
55/// // The resulting tensor will have dimensions [2, 3].
56/// let slice = tensor.clone().slice([1..3]);
57/// println!("{slice}");
58///
59/// // Slice the tensor to get the first two rows and the first 2 columns:
60/// // [[3.0, 4.9], [2.0, 1.9]]
61/// // The resulting tensor will have dimensions [2, 2].
62/// let slice = tensor.clone().slice([0..2, 0..2]);
63/// println!("{slice}");
64///
65/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
66/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
67/// // The resulting tensor will have dimensions [4, 2]
68/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
69/// let indexed = tensor.select(1, indices);
70/// println!("{indexed}");
71/// }
72/// ```
73#[derive(new, Clone, Debug)]
74pub struct Tensor<B, const D: usize, K = Float>
75where
76 B: Backend,
77 K: TensorKind<B>,
78{
79 pub(crate) primitive: K::Primitive,
80}
81
82impl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>
83where
84 B: Backend,
85 K: BasicOps<B>,
86 T: Into<TensorData>,
87{
88 fn from(value: T) -> Self {
89 Tensor::from_data(value.into(), &Default::default())
90 }
91}
92
93impl<B, const D: usize, K> Tensor<B, D, K>
94where
95 B: Backend,
96 K: BasicOps<B>,
97{
98 /// Converts the tensor into a primitive tensor.
99 pub fn into_primitive(self) -> K::Primitive {
100 self.primitive
101 }
102
103 /// Converts from a primitive tensor into a tensor.
104 pub fn from_primitive(tensor: K::Primitive) -> Self {
105 Self::new(tensor)
106 }
107
108 /// Returns the tensor primitive data type.
109 ///
110 /// # Note
111 /// Some element types are encoded in different primitive types depending on the backend
112 /// (e.g., bool could be encoded as `u8` or `u32`).
113 pub fn dtype(&self) -> DType {
114 self.primitive.dtype()
115 }
116
117 /// Create an empty tensor of the given shape.
118 ///
119 /// # Arguments
120 ///
121 /// - shape: The shape of the tensor.
122 /// - device: The device where the tensor will be created.
123 ///
124 /// # Example
125 /// ```rust
126 /// use burn_tensor::backend::Backend;
127 /// use burn_tensor::Tensor;
128 ///
129 /// fn example<B: Backend>() {
130 /// let device = Default::default();
131 /// // Create an empty tensor with dimensions [2, 3, 4].
132 /// let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);
133 /// }
134 /// ```
135 pub fn empty<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
136 let shape = shape.into();
137 check!(TensorCheck::creation_ops::<D>("Empty", &shape.dims));
138 Self::new(K::empty(shape, device))
139 }
140
141 /// Returns the dimensions of the current tensor.
142 ///
143 /// # Example
144 /// ```rust
145 /// use burn_tensor::backend::Backend;
146 /// use burn_tensor::Tensor;
147 ///
148 /// fn example<B: Backend>() {
149 /// let device = Default::default();
150 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
151 /// let dims = tensor.dims(); // [2, 3, 4]
152 /// println!("{dims:?}");
153 /// }
154 /// ```
155 pub fn dims(&self) -> [usize; D] {
156 Self::shape(self).dims()
157 }
158
159 /// Returns the shape of the current tensor.
160 ///
161 /// # Example
162 /// ```rust
163 /// use burn_tensor::backend::Backend;
164 /// use burn_tensor::Tensor;
165 ///
166 /// fn example<B: Backend>() {
167 /// let device = Default::default();
168 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
169 /// // Shape { dims: [2, 3, 4] }
170 /// let shape = tensor.shape();
171 /// }
172 /// ```
173 pub fn shape(&self) -> Shape {
174 self.primitive.shape()
175 }
176
177 /// Reshape the tensor to have the given shape.
178 ///
179 /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`
180 /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].
181 ///
182 /// A `0` in the shape instructs to keep the current dimension from the original tensor,
183 /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].
184 /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`
185 /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor
186 /// with [1, 3, 4] dimensions to [1, 12].
187 ///
188 /// # Arguments
189 /// - `shape`: The new shape of the tensor.
190 ///
191 /// # Panics
192 /// - If the tensor contains more than one `-1` in the shape.
193 /// - If the tensor contains values that are not positive (other than -1).
194 /// - If the shape does not match the number of elements of the original shape.
195 ///
196 /// # Example
197 ///
198 /// ```rust
199 /// use burn_tensor::backend::Backend;
200 /// use burn_tensor::Tensor;
201 ///
202 /// fn example<B: Backend>() {
203 /// let device = Default::default();
204 /// // Create a tensor with dimensions [2, 3, 4]
205 /// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
206 /// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
207 /// let reshaped = tensor.reshape([2, -1]);
208 /// println!("{reshaped}");
209 /// }
210 /// ```
211 pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
212 // Convert reshape args to shape
213 let shape = shape.into_shape(&self);
214 Tensor::new(K::reshape(self.primitive, shape))
215 }
216
217 /// Transpose the tensor.
218 ///
219 /// # Arguments
220 ///
221 /// * `tensor` - The tensor to transpose.
222 ///
223 /// # Returns
224 ///
225 /// The transposed tensor.
226 ///
227 /// # Example
228 ///
229 /// ```rust
230 /// use burn_tensor::backend::Backend;
231 /// use burn_tensor::Tensor;
232 ///
233 /// fn example<B: Backend>() {
234 /// let device = Default::default();
235 /// // Create a 2D tensor of shape [2, 3]
236 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
237 ///
238 /// // Transpose the tensor:
239 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
240 /// // The resulting tensor will have dimensions [3, 2].
241 /// let transposed = tensor.transpose();
242 /// println!("{transposed}");
243 /// }
244 /// ```
245 pub fn transpose(self) -> Tensor<B, D, K> {
246 Tensor::new(K::transpose(self.primitive))
247 }
248
249 /// Swaps two dimensions of a tensor.
250 ///
251 /// # Arguments
252 ///
253 /// * `tensor` - The tensor to swap the dimensions of.
254 /// * `dim1` - The first dimension to swap.
255 /// * `dim2` - The second dimension to swap.
256 ///
257 /// # Returns
258 ///
259 /// The tensor with the dimensions swapped.
260 ///
261 /// # Example
262 ///
263 /// ```rust
264 /// use burn_tensor::backend::Backend;
265 /// use burn_tensor::Tensor;
266 ///
267 /// fn example<B: Backend>() {
268 /// let device = Default::default();
269 /// // Create a 2D tensor of shape [2, 3]
270 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
271 ///
272 /// // Swap the dimensions 0 and 1 (equivalent to `tensor.transpose()`):
273 /// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
274 /// // The resulting tensor will have dimensions [3, 2].
275 /// let swapped = tensor.swap_dims(0, 1);
276 /// println!("{swapped}");
277 /// }
278 /// ```
279 pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
280 check!(TensorCheck::swap_dims::<D>(dim1, dim2));
281 Tensor::new(K::swap_dims(self.primitive, dim1, dim2))
282 }
283
284 /// Permute the dimensions of the tensor.
285 ///
286 /// # Arguments
287 ///
288 /// * `axes` - The new order of the dimensions. The length of the axes
289 /// must be equal to the number of dimensions of the tensor.
290 /// The values must be unique and in the range of the number of dimensions.
291 /// The values can be negative, in which case they are used as an offset from the end.
292 ///
293 /// # Returns
294 ///
295 /// The tensor with the dimensions permuted.
296 ///
297 /// # Example
298 ///
299 /// ```rust
300 /// use burn_tensor::backend::Backend;
301 /// use burn_tensor::Tensor;
302 ///
303 /// fn example<B: Backend>() {
304 /// let device = Default::default();
305 /// // Create a 2D tensor of shape [3, 2]
306 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);
307 ///
308 /// // Permute the dimensions 1 and 0:
309 /// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
310 /// // The resulting tensor will have dimensions [3, 2].
311 /// let permuted = tensor.permute([1, 0]);
312 /// println!("{permuted}");
313 /// }
314 /// ```
315 pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
316 // Convert the axes to usize and handle negative values without using vector
317 let mut transformed_axes: [usize; D] = [0; D];
318 for (i, &x) in axes.iter().enumerate() {
319 transformed_axes[i] = if x < 0 {
320 (D as isize + x) as usize
321 } else {
322 x as usize
323 };
324 }
325
326 // Check if the axes are valid after the transformation
327 check!(TensorCheck::permute(transformed_axes));
328
329 Tensor::new(K::permute(self.primitive, &transformed_axes))
330 }
331
332 /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.
333 ///
334 /// Other dimensions of input that are not explicitly moved remain in their original order and appear
335 /// at the positions not specified in destination.
336 ///
337 /// # Arguments
338 ///
339 /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.
340 /// The values can be negative, in which case they are used as an offset from the end.
341 ///
342 /// * `dst` - Destination positions for each of the original dims. These must also be unique.
343 ///
344 /// # Panics
345 ///
346 /// - If the source and destination dimensions are not of the same length.
347 /// - If the source and destination vectors contain duplicate values.
348 /// - If the source and destination vectors contain values that are out of bounds.
349 ///
350 /// # Returns
351 ///
352 /// The tensor with the dimensions moved.
353 ///
354 /// # Example
355 ///
356 /// ```rust
357 /// use burn_tensor::backend::Backend;
358 /// use burn_tensor::Tensor;
359 ///
360 /// fn example<B: Backend>() {
361 /// let device = Default::default();
362 /// // Create a 3D tensor of shape [3, 2, 1]
363 /// let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);
364 ///
365 /// // Move the dimensions 0 and 1:
366 /// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
367 /// // The resulting tensor will have dimensions [2, 3, 1].
368 /// let moved = tensor.movedim(1, 0);
369 /// println!("{moved}");
370 /// }
371 /// ```
372 // This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
373 // for it
374 pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {
375 let source_dims = src.into_dim_vec::<D>();
376 let destination_dims = dst.into_dim_vec::<D>();
377
378 check!(TensorCheck::movedim_args_length(
379 &source_dims,
380 &destination_dims
381 ));
382
383 let mut m = [-1; D];
384 for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {
385 m[d] = s as isize;
386 }
387 let mut axes: [isize; D] = [0; D];
388 let mut source_i = 0;
389 for (dest_i, item) in axes.iter_mut().enumerate().take(D) {
390 *item = if m[dest_i] != -1 {
391 m[dest_i]
392 } else {
393 while source_dims.contains(&source_i) {
394 source_i += 1;
395 }
396 let result = source_i as isize;
397 source_i += 1;
398 result
399 };
400 }
401
402 self.permute(axes)
403 }
404
405 /// Reverse the order of elements in the tensor along the given dimensions.
406 ///
407 /// # Arguments
408 ///
409 /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
410 /// The values can be negative, in which case they are used as an offset from the end.
411 ///
412 /// # Returns
413 ///
414 /// The tensor with the axes flipped.
415 ///
416 /// # Example
417 ///
418 /// ```rust
419 /// use burn_tensor::backend::Backend;
420 /// use burn_tensor::Tensor;
421 ///
422 /// fn example<B: Backend>() {
423 /// let device = Default::default();
424 /// // Create a 2D tensor with dimensions [4, 3]
425 /// let tensor = Tensor::<B, 2>::from_data(
426 /// [
427 /// [3.0, 4.9, 2.0],
428 /// [2.0, 1.9, 3.0],
429 /// [4.0, 5.9, 8.0],
430 /// [1.4, 5.8, 6.0],
431 /// ],
432 /// &device,
433 /// );
434 ///
435 /// // Flip the elements in dimensions 0 and 1:
436 /// // [[6.0, 5.8, 1.4],
437 /// // [8.0, 5.9, 4.0],
438 /// // [3.0, 1.9, 2.0],
439 /// // [2.0, 4.9, 3.0]]
440 /// // The resulting tensor will have dimensions [4, 3].
441 /// let flipped = tensor.flip([0, 1]);
442 /// println!("{flipped}");
443 /// }
444 /// ```
445 pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
446 // Convert the axes to usize and handle negative values without using vector
447 let mut transformed_axes: [usize; N] = [0; N];
448 for (i, &x) in axes.iter().enumerate() {
449 transformed_axes[i] = if x < 0 {
450 (D as isize + x) as usize
451 } else {
452 x as usize
453 };
454 }
455
456 // Check if the axes are valid
457 check!(TensorCheck::flip(D, &transformed_axes));
458
459 Tensor::new(K::flip(self.primitive, &transformed_axes))
460 }
461
462 /// Flatten the tensor along a given range of dimensions.
463 ///
464 /// This function collapses the specified range of dimensions into a single dimension,
465 /// effectively flattening the tensor in that range.
466 ///
467 /// # Arguments
468 ///
469 /// - `start_dim`: The starting dimension of the range to be flattened.
470 /// - `end_dim`: The ending dimension of the range to be flattened (inclusive).
471 ///
472 /// # Type Parameters
473 ///
474 /// - `D2`: The resulting number of dimensions in the flattened tensor.
475 ///
476 /// # Returns
477 ///
478 /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.
479 ///
480 /// # Example
481 ///
482 /// ```rust
483 ///
484 /// use burn_tensor::backend::Backend;
485 /// use burn_tensor::{Tensor, Shape};
486 ///
487 /// fn example<B: Backend>() {
488 /// let device = Default::default();
489 /// // Create a 3D tensor with dimensions [2, 3, 4]
490 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);
491 ///
492 /// // Flatten the tensor from dimensions 1 to 2 (inclusive).
493 /// // The resulting tensor will have dimensions [2, 12]
494 /// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
495 /// println!("{flattened}");
496 /// }
497 /// ```
498 pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
499 check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));
500
501 let current_dims = self.shape().dims;
502 let mut new_dims: [usize; D2] = [0; D2];
503 let mut flatten_dims = 1;
504
505 for i in current_dims[start_dim..=end_dim].iter() {
506 flatten_dims *= i;
507 }
508
509 new_dims[..start_dim].copy_from_slice(¤t_dims[..start_dim]);
510 new_dims[start_dim] = flatten_dims;
511 new_dims[start_dim + 1..].copy_from_slice(¤t_dims[end_dim + 1..]);
512
513 Tensor::new(K::reshape(self.primitive, new_dims.into()))
514 }
515
516 /// Squeeze the tensor along the given dimension, removing the specified dimension
517 /// of size one, and effectively reducing the rank of the tensor by one.
518 ///
519 /// # Arguments
520 ///
521 /// - `dim`: The dimension to be squeezed.
522 ///
523 /// # Type Parameters
524 ///
525 /// - 'D2': The resulting number of dimensions in the squeezed tensor.
526 ///
527 /// # Returns
528 ///
529 /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.
530 ///
531 /// # Example
532 ///
533 /// ```rust
534 ///
535 /// use burn_tensor::backend::Backend;
536 /// use burn_tensor::{Tensor, Shape};
537 ///
538 /// fn example<B: Backend>() {
539 /// let device = Default::default();
540 /// // Create a 3D tensor with dimensions [3, 1, 3]
541 /// let tensor = Tensor::<B, 3>::from_data(
542 /// [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],
543 /// &device,
544 /// );
545 ///
546 /// // Squeeze the dimension 1.
547 /// // The resulting tensor will have dimensions [3, 3].
548 /// let squeezed = tensor.squeeze::<2>(1);
549 /// println!("{squeezed}");
550 /// }
551 /// ```
552 pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
553 check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
554
555 let current_dims = self.shape().dims;
556 let mut new_dims: [usize; D2] = [0; D2];
557
558 new_dims[..dim].copy_from_slice(¤t_dims[..dim]);
559 new_dims[dim..].copy_from_slice(¤t_dims[dim + 1..]);
560
561 Tensor::new(K::reshape(self.primitive, new_dims.into()))
562 }
563
564 /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and
565 /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions
566 /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1
567 /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries
568 /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted
569 /// from the back.
570 ///
571 /// # Arguments
572 ///
573 /// - `dims`: The dimension(s) to be squeezed.
574 ///
575 /// # Type Parameters
576 ///
577 /// - 'D2': The resulting number of dimensions in the squeezed tensor.
578 ///
579 /// # Returns
580 ///
581 /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.
582 ///
583 /// # Example
584 ///
585 /// ```rust
586 ///
587 /// use burn_tensor::backend::Backend;
588 /// use burn_tensor::{Tensor, Shape};
589 ///
590 /// fn example<B: Backend>() {
591 /// let device = Default::default();
592 /// // Create a 4D tensor with dimensions [2, 1, 4, 1]
593 /// let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);
594 ///
595 /// // Squeeze the dimensions 1 and 3.
596 /// // The resulting tensor will have dimensions [2, 4].
597 /// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
598 /// println!("{squeezed}");
599 /// }
600 /// ```
601 pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
602 let current_dims = self.shape().dims;
603 let mut dim_indices: Vec<usize>;
604
605 // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries
606 if dims.is_empty() {
607 dim_indices = current_dims
608 .iter()
609 .enumerate()
610 .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })
611 .collect();
612 } else {
613 // If negative dims, count from the back
614 dim_indices = dims
615 .iter()
616 .map(|&d| {
617 if d < 0 {
618 (current_dims.len() as isize + d) as usize
619 } else {
620 d as usize
621 }
622 })
623 .collect();
624 }
625
626 // Sort indices and remove duplicates
627 dim_indices.sort_unstable();
628 dim_indices.dedup();
629
630 // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions
631 check!(TensorCheck::squeeze_dims_input::<D2>(
632 &dim_indices,
633 ¤t_dims
634 ));
635
636 // Calculate new dimensions
637 let mut new_dims = Vec::new();
638 for (index, &dim_size) in current_dims.iter().enumerate() {
639 // Exclude the dimension if it's explicitly marked for squeezing
640 if dim_indices.contains(&index) {
641 check!(TensorCheck::squeeze::<D2>(index, ¤t_dims));
642 continue;
643 }
644 new_dims.push(dim_size);
645 }
646
647 // Check that after squeezing, we still respect the D2 size
648 check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));
649
650 Tensor::new(K::reshape(self.primitive, new_dims.into()))
651 }
652
653 /// Unsqueeze the current tensor. Create new dimensions to fit the given size.
654 ///
655 /// If the output size is higher than the current tensor.
656 ///
657 /// # Example
658 ///
659 /// ```rust
660 /// use burn_tensor::backend::Backend;
661 /// use burn_tensor::{Tensor, Shape};
662 ///
663 /// fn example<B: Backend>() {
664 /// let device = Default::default();
665 /// // Create a 2D tensor with dimensions [3, 3]
666 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
667 /// // Unsqueeze the tensor up to 4 dimensions.
668 /// // The resulting tensor will have dimensions [1, 1, 3, 3].
669 /// let unsqueezed = tensor.unsqueeze::<4>();
670 /// println!("{unsqueezed}");
671 /// }
672 /// ```
673 pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
674 check!(TensorCheck::unsqueeze::<D, D2>());
675
676 let mut dims = [1; D2];
677 let num_ones = D2 - D;
678 let shape = self.shape();
679
680 dims[num_ones..(D + num_ones)].copy_from_slice(&shape.dims[..D]);
681
682 let shape = Shape::new(dims);
683 self.reshape(shape)
684 }
685
686 /// Creates a new tensor with a dimension of size one inserted at the specified position.
687 ///
688 /// # Example
689 ///
690 /// ```rust
691 /// use burn_tensor::backend::Backend;
692 /// use burn_tensor::{Tensor, Shape};
693 ///
694 /// fn example<B: Backend>() {
695 /// let device = Default::default();
696 /// // Create a 2D tensor with dimensions [3, 3]
697 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);
698 /// // Unsqueeze the dimension 1.
699 /// // The resulting tensor will have dimensions [3, 1, 3].
700 /// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
701 /// println!("{unsqueezed}");
702 /// }
703 /// ```
704 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
705 check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));
706
707 let mut dims = [1; D2];
708 let shape = self.shape();
709
710 dims[0..dim].copy_from_slice(&shape.dims[0..dim]);
711
712 if dim < D {
713 dims[dim] = 1;
714 dims[(dim + 1)..].copy_from_slice(&shape.dims[dim..]);
715 } else {
716 dims[dim] = 1;
717 }
718
719 let shape = Shape::new(dims);
720 self.reshape(shape)
721 }
722
723 /// Creates a new tensor with added dimensions of size one inserted at the specified indices.
724 /// The indices can be negative, in which case they are counted from the last to the first dimension.
725 /// the axes can contain duplicates, in which case the number of dimensions inserted at the index
726 /// is the number of duplicates.
727 /// # Example
728 ///
729 /// ```rust
730 /// use burn_tensor::backend::Backend;
731 /// use burn_tensor::{Tensor, Shape};
732 ///
733 /// fn example<B: Backend>() {
734 /// let device = Default::default();
735 /// // Create a 3D tensor with dimensions [3, 4, 5]
736 /// let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);
737 /// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
738 /// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
739 /// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
740 /// println!("{unsqueezed}");
741 /// }
742 /// ```
743 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
744 let mut new_dims = [1; D2];
745 let old_dims = self.shape().dims;
746 //for checking if the dimension is in the acceptable range
747
748 //part 1: convert the negative indices to positive
749 let mut neg_offset = D2;
750 let mut dim_indices = axes
751 .iter()
752 .map(|d| {
753 // check if the dimension is in the acceptable range
754 check!(TensorCheck::unsqueeze_dims::<{ D2 }>(*d));
755 (if *d < 0 {
756 neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)
757 d + neg_offset as isize + 1
758 } else {
759 *d
760 }) as usize
761 })
762 .collect::<Vec<usize>>();
763
764 //sort the indices
765 dim_indices.sort_unstable();
766
767 //Now use this to copy the chunks of the dims
768 let mut prev_idx: usize = 0;
769 let mut current_left_b: usize = 0;
770 let mut current_right_b: usize = 0;
771 let mut offset: usize = 0;
772 dim_indices.iter().for_each(|d| {
773 //check if there is space for at least one dimension
774 if prev_idx < *d {
775 current_right_b = *d - offset;
776 //copy the chunks of the dims
777 if current_right_b < D {
778 new_dims[prev_idx..*d]
779 .copy_from_slice(&old_dims[current_left_b..current_right_b]);
780 } else {
781 new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
782 }
783 prev_idx = *d + 1;
784 //offset is equal to the number of extracted elements from the original shape
785 offset += current_right_b - current_left_b;
786 current_left_b = current_right_b;
787 } else {
788 //it's sorted so the only reason this would happen
789 //is if multiple indices are the same
790 prev_idx += 1;
791 }
792 });
793 //copy over anything past the index of the last new dimension
794 if current_left_b < D {
795 new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
796 }
797
798 //lastly, create the shape and reshape
799 let shape = Shape::new(new_dims);
800 self.reshape(shape)
801 }
802
803 /// Returns a tensor containing the elements selected from the given ranges.
804 ///
805 /// # Arguments
806 ///
807 /// * `ranges` - A type implementing the `RangesArg` trait, which can be:
808 /// - An array of `core::ops::Range<usize>`
809 /// - An array of `Option<(i64, i64)>`
810 /// - An array of `(i64, i64)` tuples
811 ///
812 /// # Behavior
813 ///
814 /// - Supports partial and full slicing in any number of dimensions.
815 /// - Missing ranges are treated as full slices if D > D2.
816 /// - Handles negative indices by wrapping around from the end of the dimension.
817 /// - Clamps ranges to the tensor's dimensions if they exceed the bounds.
818 /// - For `Option<(i64, i64)>` ranges, `None` selects the full range of that dimension.
819 ///
820 /// # Panics
821 ///
822 /// - If the number of ranges provided exceeds the tensor's dimensions.
823 /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1).
824 ///
825 /// # Examples
826 ///
827 /// ```rust
828 /// use burn_tensor::backend::Backend;
829 /// use burn_tensor::{Tensor, Shape};
830 ///
831 /// fn example<B: Backend>() {
832 /// let device = B::Device::default();
833 ///
834 /// // 1D slicing
835 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
836 /// let slice = tensor.slice([1..4]);
837 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
838 ///
839 /// // 2D slicing
840 /// let tensor = Tensor::<B, 2>::ones(Shape::new([3, 4]), &device);
841 /// let slice = tensor.slice([1..3, 0..2]);
842 /// assert_eq!(slice.dims(), [2, 2]);
843 ///
844 /// // Using negative indices
845 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..5, &device);
846 /// let slice = tensor.slice([(1, -1)]); // Equivalent to 1..4
847 /// assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![1i32, 2, 3]);
848 ///
849 /// // Using Option<(i64, i64)>
850 /// let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..12, &device).reshape([3, 4]);
851 /// let slice = tensor.slice([Some((1, -1)), None]); // Select rows 1 and 2, all columns
852 /// assert_eq!(slice.dims(), [2, 4]);
853 /// }
854 /// ```
855 ///
856 /// # Note
857 ///
858 /// This function uses the `RangesArg` trait for flexible range specification. The trait
859 /// handles the conversion of various range formats and applies clamping and negative
860 /// index handling internally.
861 pub fn slice<const D2: usize, R: RangesArg<D2>>(self, ranges: R) -> Self {
862 let ranges = ranges.into_ranges(self.shape());
863
864 check!(TensorCheck::slice::<D, D2>(&self.shape(), &ranges));
865 Self::new(K::slice(self.primitive, &ranges))
866 }
867
868 /// Returns a copy of the current tensor with the selected elements changed to the new ones at
869 /// the selected indices.
870 ///
871 /// # Panics
872 ///
873 /// - If a range exceeds the number of elements on a dimension.
874 /// - If the given values don't match the given ranges.
875 ///
876 /// # Example
877 ///
878 /// ```rust
879 /// use burn_tensor::backend::Backend;
880 /// use burn_tensor::Tensor;
881 ///
882 /// fn example<B: Backend>() {
883 /// let device = B::Device::default();
884 /// let tensor = Tensor::<B, 3>::ones([2, 3, 3], &device);
885 /// let values = Tensor::<B, 3>::zeros([1, 1, 1], &device);
886 /// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
887 /// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
888 /// }
889 /// ```
890 pub fn slice_assign<const D2: usize>(
891 self,
892 ranges: [core::ops::Range<usize>; D2],
893 values: Self,
894 ) -> Self {
895 check!(TensorCheck::slice_assign::<D, D2>(
896 &self.shape(),
897 &values.shape(),
898 &ranges
899 ));
900 Self::new(K::slice_assign(self.primitive, &ranges, values.primitive))
901 }
902
903 /// Returns the device of the current tensor.
904 pub fn device(&self) -> B::Device {
905 K::device(&self.primitive)
906 }
907
908 /// Returns a new tensor on the given device.
909 pub fn to_device(self, device: &B::Device) -> Self {
910 Self::new(K::to_device(self.primitive, device))
911 }
912
913 /// Converts the data of the current tensor.
914 ///
915 /// # Note
916 ///
917 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
918 /// tensors at once. This may improve laziness, especially if executed on a different
919 /// thread in native environments.
920 pub fn into_data(self) -> TensorData {
921 crate::try_read_sync(self.into_data_async()).expect(
922 "Failed to read tensor data synchronously.
923 This can happen on platforms that don't support blocking futures like WASM.
924 If possible, try using into_data_async instead.",
925 )
926 }
927
928 /// Converts the data of the current tensor.
929 ///
930 /// # Note
931 ///
932 /// For better performance, prefer using a [Transaction](Transaction) when reading multiple
933 /// tensors at once. This may improve laziness, especially if executed on a different
934 /// thread in native environments.
935 pub fn to_data(&self) -> TensorData {
936 self.clone().into_data()
937 }
938
939 /// Returns the data of the current tensor.
940 pub async fn into_data_async(self) -> TensorData {
941 K::into_data_async(self.primitive).await
942 }
943
944 /// Returns the data of the current tensor.
945 pub async fn to_data_async(&self) -> TensorData {
946 self.clone().into_data_async().await
947 }
948
949 /// Create a tensor from the given data on the given device.
950 pub fn from_data<T>(data: T, device: &B::Device) -> Self
951 where
952 T: Into<TensorData>,
953 {
954 let data = data.into();
955 check!(TensorCheck::creation_ops::<D>(
956 "From Data",
957 data.shape.as_slice()
958 ));
959 Self::new(K::from_data(data, device))
960 }
961
962 /// Repeat the tensor along the given dimension.
963 ///
964 ///
965 /// # Arguments
966 /// - `dim`: The dimension to repeat.
967 /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.
968 ///
969 /// # Returns
970 ///
971 /// A new tensor with the given dimension repeated `times` times.
972 ///
973 /// # Example
974 ///
975 /// ```rust
976 /// use burn_tensor::backend::Backend;
977 /// use burn_tensor::Tensor;
978 ///
979 /// fn example<B: Backend>() {
980 /// let device = Default::default();
981 /// // Create a 2D tensor with dimensions [3, 2]
982 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
983 ///
984 /// // Repeat the tensor along the dimension 0 twice.
985 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
986 /// // The resulting tensor will have dimensions [6, 2].
987 /// let repeated = tensor.repeat_dim(0, 2);
988 /// println!("{repeated}");
989 /// }
990 /// ```
991 pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
992 Self::new(K::repeat_dim(self.primitive, dim, times))
993 }
994
995 /// Repeat the tensor along the given dimensions.
996 /// # Arguments
997 /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
998 ///
999 /// # Returns
1000 ///
1001 /// A new tensor with the given dimensions repeated `times` times.
1002 ///
1003 /// # Example
1004 ///
1005 /// ```rust
1006 ///
1007 /// use burn_tensor::backend::Backend;
1008 /// use burn_tensor::Tensor;
1009 ///
1010 /// fn example<B: Backend>() {
1011 /// let device = Default::default();
1012 /// // Create a 2D tensor with dimensions [3, 2]
1013 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1014 ///
1015 /// // Repeat the tensor along the dimension 0 twice and the dimension 0 once.
1016 /// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
1017 /// // The resulting tensor will have dimensions [6, 2].
1018 /// let repeated = tensor.repeat(&[2, 1]);
1019 /// }
1020 /// ```
1021 pub fn repeat(self, sizes: &[usize]) -> Self {
1022 let mut tensor = self;
1023 for (dim, ×) in sizes.iter().enumerate() {
1024 if times > 1 {
1025 tensor = tensor.repeat_dim(dim, times);
1026 }
1027 }
1028 tensor
1029 }
1030
1031 /// Applies element-wise equal comparison and returns a boolean tensor.
1032 ///
1033 /// # Panics
1034 ///
1035 /// If the two tensors don't have the same shape.
1036 ///
1037 /// # Example
1038 ///
1039 /// ```rust
1040 /// use burn_tensor::backend::Backend;
1041 /// use burn_tensor::Tensor;
1042 ///
1043 /// fn example<B: Backend>() {
1044 /// let device = Default::default();
1045 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1046 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1047 /// // Compare the elements of the two 2D tensors with dimensions [3, 2].
1048 /// // [[false, true], [true, true], [true, true]]
1049 /// let equal = t1.equal(t2);
1050 /// println!("{equal}");
1051 /// }
1052 /// ```
1053 pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
1054 check!(TensorCheck::binary_ops_ew("Equal", &self, &other));
1055 Tensor::new(K::equal(self.primitive, other.primitive))
1056 }
1057
1058 /// Applies element-wise non-equality comparison and returns a boolean tensor.
1059 ///
1060 /// # Panics
1061 ///
1062 /// If the two tensors don't have the same shape.
1063 ///
1064 /// # Example
1065 ///
1066 /// ```rust
1067 /// use burn_tensor::backend::Backend;
1068 /// use burn_tensor::Tensor;
1069 ///
1070 /// fn example<B: Backend>() {
1071 /// let device = Default::default();
1072 /// let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1073 /// let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);
1074 /// // Compare the elements of the two 2D tensors for inequality.
1075 /// // [[true, false], [false, false], [false, false]]
1076 /// let not_equal = t1.not_equal(t2);
1077 /// println!("{not_equal}");
1078 /// }
1079 /// ```
1080 pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
1081 check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other));
1082 Tensor::new(K::not_equal(self.primitive, other.primitive))
1083 }
1084
1085 /// Concatenates all tensors into a new one along the given dimension.
1086 ///
1087 /// # Panics
1088 ///
1089 /// If all tensors don't have the same shape.
1090 ///
1091 /// # Example
1092 ///
1093 /// ```rust
1094 /// use burn_tensor::backend::Backend;
1095 /// use burn_tensor::Tensor;
1096 ///
1097 /// fn example<B: Backend>() {
1098 /// let device = Default::default();
1099 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1100 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1101 ///
1102 /// // Concatenate the two tensors with shape [2, 3] along the dimension 1.
1103 /// // [[3.0, 4.9, 2.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.4, 5.8, 6.0]]
1104 /// // The resulting tensor will have shape [2, 6].
1105 /// let concat = Tensor::cat(vec![t1, t2], 1);
1106 /// println!("{concat}");
1107 /// }
1108 /// ```
1109 pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
1110 check!(TensorCheck::cat(&tensors, dim));
1111
1112 Self::new(K::cat(
1113 tensors.into_iter().map(|vector| vector.primitive).collect(),
1114 dim,
1115 ))
1116 }
1117
1118 /// Concatenates all tensors into a new one along a new dimension.
1119 ///
1120 /// # Panics
1121 ///
1122 /// If all tensors don't have the same shape.
1123 /// Given dimension is not with range of 0..D2
1124 ///
1125 /// # Example
1126 ///
1127 /// ```rust
1128 /// use burn_tensor::backend::Backend;
1129 /// use burn_tensor::Tensor;
1130 ///
1131 /// fn example<B: Backend>() {
1132 /// let device = Default::default();
1133 /// let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1134 /// let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1135 /// let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);
1136 ///
1137 /// // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.
1138 /// // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],
1139 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],
1140 /// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
1141 /// // The resulting tensor will have shape [3, 2, 3].
1142 /// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
1143 /// println!("{stacked}");
1144 /// }
1145 /// ```
1146 pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
1147 check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));
1148 let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();
1149 Tensor::<B, D2, K>::cat(tensors, dim)
1150 }
1151
1152 /// Iterate over slices of tensors alongside a given dimension.
1153 ///
1154 /// # Panics
1155 ///
1156 /// Given dimension is less than tensor rank.
1157 ///
1158 /// # Returns
1159 ///
1160 /// A tensor iterator.
1161 ///
1162 /// # Example
1163 ///
1164 /// ```rust
1165 /// use burn_tensor::backend::Backend;
1166 /// use burn_tensor::Tensor;
1167 /// fn example<B: Backend>() {
1168 /// let device = Default::default();
1169 /// let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);
1170 /// // Given a 2D tensor with dimensions (2, 3), iterate over slices of tensors along the dimension 0.
1171 /// let iter = tensor.iter_dim(0);
1172 /// for (i,tensor) in iter.enumerate() {
1173 /// println!("Tensor {}: {}", i, tensor);
1174 /// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
1175 /// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
1176 /// }
1177 /// }
1178 /// ```
1179 pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {
1180 check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
1181 DimIter::new(self, dim)
1182 }
1183
1184 /// Returns a new tensor with the given dimension narrowed to the given range.
1185 ///
1186 /// # Panics
1187 ///
1188 /// - If the dimension is greater than the number of dimensions of the tensor.
1189 /// - If the given range exceeds the number of elements on the given dimension.
1190 ///
1191 /// # Returns
1192 ///
1193 /// A new tensor with the given dimension narrowed to the given range.
1194 ///
1195 /// # Example
1196 ///
1197 /// ```rust
1198 /// use burn_tensor::backend::Backend;
1199 /// use burn_tensor::Tensor;
1200 ///
1201 /// fn example<B: Backend>() {
1202 /// let device = Default::default();
1203 /// // Create a 2D tensor with dimensions [4, 3]
1204 /// let tensor = Tensor::<B, 2>::from_data(
1205 /// [
1206 /// [3.0, 4.9, 2.0],
1207 /// [2.0, 1.9, 3.0],
1208 /// [6.0, 1.5, 7.0],
1209 /// [3.0, 4.9, 9.0],
1210 /// ],
1211 /// &device,
1212 /// );
1213 /// // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.
1214 /// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
1215 /// // The resulting tensor will have dimensions [3, 3].
1216 /// let narrowed = tensor.narrow(0, 1, 3);
1217 /// println!("{narrowed}");
1218 /// }
1219 /// ```
1220 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
1221 check!(TensorCheck::dim_ops::<D>("narrow", dim));
1222 check!(TensorCheck::narrow(&self, dim, start, length));
1223 Self::new(narrow::<B, K>(self.primitive, dim, start, length))
1224 }
1225
1226 /// Attempts to split the tensor into a specified number of chunks along a given dimension.
1227 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
1228 ///
1229 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
1230 /// Otherwise all chunks will be of equal size except for the last one.
1231 ///
1232 /// # Panics
1233 ///
1234 /// If the dimension is greater than the number of dimensions of the tensor.
1235 ///
1236 /// # Returns
1237 /// A vector of tensors.
1238 ///
1239 /// # Example
1240 ///
1241 /// ```rust
1242 /// use burn_tensor::backend::Backend;
1243 /// use burn_tensor::Tensor;
1244 ///
1245 /// fn example<B: Backend>() {
1246 /// let device = Default::default();
1247 /// // Create a 2D tensor with dimensions [4, 3]
1248 /// let tensor = Tensor::<B, 2>::from_data(
1249 /// [
1250 /// [3.0, 4.9, 2.0],
1251 /// [2.0, 1.9, 3.0],
1252 /// [6.0, 1.5, 7.0],
1253 /// [3.0, 4.9, 9.0],
1254 /// ],
1255 /// &device,
1256 /// );
1257 /// // Split the tensor along the dimension 1 into 2 chunks.
1258 /// // The first chuck will have shape [4, 2]:
1259 /// // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]
1260 /// // The second chunk will have shape [4, 1]:
1261 /// // [[2.0], [3.0], [7.0], [9.0]]
1262 /// let chunks = tensor.chunk(2, 1);
1263 /// println!("{chunks:?}");
1264 /// }
1265 /// ```
1266 pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
1267 check!(TensorCheck::dim_ops::<D>("chunk", dim));
1268 K::chunk(self.primitive, chunks, dim)
1269 .into_iter()
1270 .map(Self::new)
1271 .collect()
1272 }
1273
1274 /// Splits the tensor into chunks of a specified size along a given dimension.
1275 /// Each chunk is a view of the original tensor.
1276 ///
1277 /// If the tensor size along the given dimension is not divisible by `split_size`,
1278 /// then the last chunk will be smaller.
1279 ///
1280 /// # Panics
1281 ///
1282 /// If the specified dimension to split along is greater than the number of dimensions of the tensor.
1283 ///
1284 /// # Returns
1285 ///
1286 /// A vector of tensors.
1287 ///
1288 /// # Example
1289 /// ```rust
1290 /// use burn_tensor::backend::Backend;
1291 /// use burn_tensor::Tensor;
1292 ///
1293 /// fn example<B: Backend>() {
1294 /// let device = Default::default();
1295 /// // Create a 1D tensor with 5 elements
1296 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1297 /// // Split the tensor into chunks of size 2 along dimension 0
1298 /// let chunks = tensor.split(2, 0);
1299 /// // The result is a vector of tensors:
1300 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]
1301 /// println!("{:?}", chunks);
1302 /// }
1303 /// ```
1304 pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {
1305 check!(TensorCheck::split::<D>(
1306 self.shape().dims.as_ref(),
1307 split_size,
1308 dim
1309 ));
1310 K::split(self.primitive, split_size, dim)
1311 .into_iter()
1312 .map(Self::new)
1313 .collect()
1314 }
1315
1316 /// Splits the tensor into chunks with the specified sizes along a given dimension.
1317 /// Each chunk is a view of the original tensor.
1318 ///
1319 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
1320 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
1321 ///
1322 /// # Panics
1323 ///
1324 /// If the specified dimension to split along is greater than the number of dimensions of the tensor or
1325 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
1326 ///
1327 /// # Returns
1328 ///
1329 /// A vector of tensors.
1330 ///
1331 /// # Example
1332 /// ```rust
1333 /// use burn_tensor::backend::Backend;
1334 /// use burn_tensor::Tensor;
1335 ///
1336 /// fn example<B: Backend>() {
1337 /// let device = Default::default();
1338 /// // Create a 1D tensor with 5 elements
1339 /// let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);
1340 /// // Split the tensor into chunks with sizes [2, 3] along dimension 0
1341 /// let chunks = tensor.split_with_sizes(vec![2, 3], 0);
1342 /// // The result is a vector of tensors:
1343 /// // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]
1344 /// println!("{:?}", chunks);
1345 /// }
1346 /// ```
1347 pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {
1348 check!(TensorCheck::split_with_sizes::<D>(
1349 self.shape().dims.as_ref(),
1350 &split_sizes,
1351 dim
1352 ));
1353 K::split_with_sizes(self.primitive, split_sizes, dim)
1354 .into_iter()
1355 .map(Self::new)
1356 .collect()
1357 }
1358
1359 /// Tests if any element in the `tensor` evaluates to True.
1360 ///
1361 /// # Arguments
1362 ///
1363 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1364 ///
1365 /// # Returns
1366 ///
1367 /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor
1368 /// evaluates to True, False otherwise.
1369 ///
1370 /// # Example
1371 ///
1372 /// ```rust
1373 /// use burn_tensor::backend::Backend;
1374 /// use burn_tensor::{Tensor, Bool};
1375 ///
1376 /// fn example<B: Backend>() {
1377 /// let device = Default::default();
1378 /// let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);
1379 /// let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);
1380 ///
1381 /// // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True.
1382 /// let any_tensor = tensor.any();
1383 /// println!("{}", any_tensor);
1384 /// // Tensor { data: [true], ... }
1385 ///
1386 /// // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True.
1387 /// let any_tensor_two = tensor_two.any();
1388 /// println!("{}", any_tensor_two);
1389 /// // Tensor { data: [false], ... }
1390 /// }
1391 /// ```
1392 pub fn any(self) -> Tensor<B, 1, Bool> {
1393 Tensor::new(K::any(self.primitive))
1394 }
1395
1396 /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.
1397 ///
1398 /// # Arguments
1399 ///
1400 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1401 /// * `dim` - The axis along which to test.
1402 ///
1403 /// # Returns
1404 ///
1405 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1406 /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input
1407 /// evaluates to True, False otherwise.
1408 ///
1409 /// # Example
1410 ///
1411 /// ```rust
1412 /// use burn_tensor::backend::Backend;
1413 /// use burn_tensor::{Tensor, Bool};
1414 ///
1415 /// fn example<B: Backend>() {
1416 /// let device = Default::default();
1417 /// let tensor =
1418 /// Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);
1419 /// // Check if any element in the tensor evaluates to True along the dimension 1.
1420 /// // [[true], [true]],
1421 /// let any_dim = tensor.clone().any_dim(1);
1422 /// println!("{any_dim}");
1423 /// }
1424 /// ```
1425 pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1426 Tensor::new(K::any_dim(self.primitive, dim))
1427 }
1428
1429 /// Tests if all elements in the `tensor` evaluate to True.
1430 ///
1431 /// # Arguments
1432 ///
1433 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1434 ///
1435 /// # Returns
1436 ///
1437 /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor
1438 /// evaluate to True, False otherwise.
1439 ///
1440 /// # Example
1441 ///
1442 /// ```rust
1443 /// use burn_tensor::backend::Backend;
1444 /// use burn_tensor::{Tensor, Bool};
1445 ///
1446 /// fn example<B: Backend>() {
1447 /// let device = Default::default();
1448 /// let tensor =
1449 /// Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);
1450 /// // Check if all elements in the tensor evaluate to True (which is not the case).
1451 /// // [false]
1452 /// let all = tensor.all();
1453 /// println!("{all}");
1454 /// }
1455 /// ```
1456 pub fn all(self) -> Tensor<B, 1, Bool> {
1457 Tensor::new(K::all(self.primitive))
1458 }
1459
1460 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
1461 ///
1462 /// # Arguments
1463 ///
1464 /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.
1465 /// * `dim` - The axis along which to test.
1466 ///
1467 /// # Returns
1468 ///
1469 /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis
1470 /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input
1471 /// evaluates to True, False otherwise.
1472 ///
1473 /// # Example
1474 ///
1475 /// ```rust
1476 /// use burn_tensor::backend::Backend;
1477 /// use burn_tensor::{Tensor, Bool};
1478 ///
1479 /// fn example<B: Backend>() {
1480 /// let device = Default::default();
1481 /// let tensor =
1482 /// Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);
1483 /// // Check if all elements in the tensor evaluate to True along the dimension 1.
1484 /// // [[true, true, false]]
1485 /// let all_dim = tensor.clone().all_dim(0);
1486 /// println!("{all_dim}");
1487 /// }
1488 /// ```
1489 pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
1490 Tensor::new(K::all_dim(self.primitive, dim))
1491 }
1492
1493 /// Convert the tensor into a scalar.
1494 ///
1495 /// # Panics
1496 ///
1497 /// If the tensor doesn't have one element.
1498 /// If the backend fails to read the tensor data synchronously.
1499 ///
1500 /// # Returns
1501 ///
1502 /// The scalar value of the tensor.
1503 ///
1504 /// # Example
1505 ///
1506 /// ```rust
1507 /// use burn_tensor::backend::Backend;
1508 /// use burn_tensor::Tensor;
1509 ///
1510 /// fn example<B: Backend>() {
1511 /// let device = Default::default();
1512 /// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
1513 /// // Convert the tensor with a single element into a scalar.
1514 /// let scalar = tensor.into_scalar();
1515 /// println!("{scalar}");
1516 /// }
1517 /// ```
1518 pub fn into_scalar(self) -> K::Elem {
1519 crate::try_read_sync(self.into_scalar_async()).expect(
1520 "Failed to read tensor data synchronously. This can happen on platforms
1521 that don't support blocking futures like WASM. Try into_scalar_async instead.",
1522 )
1523 }
1524
1525 /// Convert the tensor into a scalar.
1526 ///
1527 /// # Panics
1528 ///
1529 /// If the tensor doesn't have one element.
1530 pub async fn into_scalar_async(self) -> K::Elem {
1531 check!(TensorCheck::into_scalar::<D>(&self.shape()));
1532 let x = self.into_data_async().await.iter().next().unwrap();
1533 x
1534 }
1535
1536 /// Broadcast the tensor to the given shape.
1537 ///
1538 /// # Arguments
1539 ///
1540 /// * `shape` - The shape to broadcast the tensor to.
1541 /// Can contain -1 for dimensions that should be inferred.
1542 /// The number of elements in the shape must be greater or equal as
1543 /// the number of dimensions of the tensor.
1544 ///
1545 /// # Panics
1546 ///
1547 /// If the tensor cannot be broadcasted to the given shape.
1548 ///
1549 /// # Returns
1550 ///
1551 /// A new tensor with the given shape.
1552 ///
1553 /// # Example
1554 ///
1555 /// ```rust
1556 /// use burn_tensor::backend::Backend;
1557 /// use burn_tensor::Tensor;
1558 ///
1559 /// fn example<B: Backend>() {
1560 /// let device = Default::default();
1561 /// // Create a 2D tensor with dimensions [3, 1]
1562 /// let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);
1563 /// // Expand the tensor to a new shape [3, 4]
1564 /// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
1565 /// let expanded = tensor.expand([3, 4]);
1566 /// println!("{}", expanded);
1567 /// }
1568 /// ```
1569 pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
1570 let shape = shape.into_shape(&self.shape());
1571 check!(TensorCheck::expand::<D, D2>(
1572 "expand",
1573 &self.shape(),
1574 &shape,
1575 ));
1576
1577 Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))
1578 }
1579}
1580
1581/// Iterator given by (Tensor::iter_dim).
1582pub struct DimIter<B, const D: usize, K>
1583where
1584 B: Backend,
1585 K: BasicOps<B>,
1586{
1587 start: usize,
1588 end: usize,
1589 dim: usize,
1590 ranges: [Range<usize>; D],
1591 tensor: Tensor<B, D, K>,
1592}
1593
1594impl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {
1595 type Item = Tensor<B, D, K>;
1596
1597 fn next(&mut self) -> Option<Self::Item> {
1598 if self.start >= self.end {
1599 return None;
1600 }
1601
1602 let mut ranges = self.ranges.clone();
1603 ranges[self.dim] = self.start..(self.start + 1);
1604
1605 let slice = self.tensor.clone().slice(ranges);
1606 self.start += 1;
1607
1608 Some(slice)
1609 }
1610}
1611
1612impl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {
1613 fn next_back(&mut self) -> Option<Self::Item> {
1614 if self.start >= self.end {
1615 return None;
1616 }
1617
1618 let mut ranges = self.ranges.clone();
1619 ranges[self.dim] = (self.end - 1)..self.end;
1620
1621 let slice = self.tensor.clone().slice(ranges);
1622 self.end = self.end.saturating_sub(1);
1623
1624 Some(slice)
1625 }
1626}
1627
1628impl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {
1629 fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {
1630 let dims = tensor.dims();
1631 let ranges = dims
1632 .iter()
1633 .map(|&dim| 0..dim)
1634 .collect::<Vec<Range<usize>>>();
1635 let ranges: [Range<usize>; D] = ranges.try_into().unwrap();
1636 Self {
1637 end: dims[dim],
1638 ranges,
1639 start: 0,
1640 dim,
1641 tensor,
1642 }
1643 }
1644}
1645
1646impl<B, const D: usize, K> Tensor<B, D, K>
1647where
1648 B: Backend,
1649 K: BasicOps<B>,
1650 <K as BasicOps<B>>::Elem: Debug,
1651{
1652 #[inline]
1653 fn push_newline_indent(acc: &mut String, indent: usize) {
1654 acc.push('\n');
1655 for _ in 0..indent {
1656 acc.push(' ');
1657 }
1658 }
1659 fn fmt_inner_tensor(
1660 &self,
1661 acc: &mut String,
1662 depth: usize,
1663 multi_index: &mut [usize],
1664 range: (usize, usize),
1665 precision: Option<usize>,
1666 ) {
1667 let (start, end) = range;
1668 for i in start..end {
1669 if i > 0 {
1670 acc.push_str(", ");
1671 }
1672 multi_index[depth] = i;
1673 let range: [core::ops::Range<usize>; D] =
1674 core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
1675
1676 let data =
1677 burn_common::reader::try_read_sync(self.clone().slice(range).into_data_async());
1678
1679 if let Some(data) = data {
1680 let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();
1681 match (precision, K::name()) {
1682 (Some(p), "Float") => acc.push_str(&format!("{:.1$}", elem, p)),
1683 _ => acc.push_str(&format!("{:?}", elem)),
1684 }
1685 } else {
1686 acc.push_str("<Tensor data not available>");
1687 }
1688 }
1689 }
1690
1691 fn fmt_outer_tensor(
1692 &self,
1693 acc: &mut String,
1694 depth: usize,
1695 multi_index: &mut [usize],
1696 print_options: &PrintOptions,
1697 summarize: bool,
1698 range: (usize, usize),
1699 ) {
1700 let (start, end) = range;
1701 for i in start..end {
1702 if i > start {
1703 acc.push(',');
1704 Self::push_newline_indent(acc, depth + 1);
1705 }
1706 acc.push('[');
1707 multi_index[depth] = i;
1708 self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);
1709 acc.push(']');
1710 }
1711 }
1712
1713 /// Recursively formats the tensor data for display and appends it to the provided accumulator string.
1714 ///
1715 /// This function is designed to work with tensors of any dimensionality.
1716 /// It traverses the tensor dimensions recursively, converting the elements
1717 /// to strings and appending them to the accumulator string with the
1718 /// appropriate formatting.
1719 ///
1720 /// # Arguments
1721 ///
1722 /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.
1723 /// * `depth` - The current depth of the tensor dimensions being processed.
1724 /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.
1725 fn display_recursive(
1726 &self,
1727 acc: &mut String,
1728 depth: usize,
1729 multi_index: &mut [usize],
1730 print_options: &PrintOptions,
1731 summarize: bool,
1732 ) {
1733 let edge_items = print_options.edge_items;
1734
1735 if depth == 0 {
1736 acc.push('[');
1737 }
1738
1739 if depth == self.dims().len() - 1 {
1740 // if we are at the innermost dimension, just push its elements into the accumulator
1741 if summarize && self.dims()[depth] > 2 * edge_items {
1742 // print the starting `edge_items` elements
1743 self.fmt_inner_tensor(
1744 acc,
1745 depth,
1746 multi_index,
1747 (0, edge_items),
1748 print_options.precision,
1749 );
1750 acc.push_str(", ...");
1751 // print the last `edge_items` elements
1752 self.fmt_inner_tensor(
1753 acc,
1754 depth,
1755 multi_index,
1756 (self.dims()[depth] - edge_items, self.dims()[depth]),
1757 print_options.precision,
1758 );
1759 } else {
1760 // print all the elements
1761 self.fmt_inner_tensor(
1762 acc,
1763 depth,
1764 multi_index,
1765 (0, self.dims()[depth]),
1766 print_options.precision,
1767 );
1768 }
1769 } else {
1770 // otherwise, iterate through the current dimension and recursively display the inner tensors
1771 if summarize && self.dims()[depth] > 2 * edge_items {
1772 self.fmt_outer_tensor(
1773 acc,
1774 depth,
1775 multi_index,
1776 print_options,
1777 summarize,
1778 (0, edge_items),
1779 );
1780
1781 acc.push(',');
1782 Self::push_newline_indent(acc, depth + 1);
1783 acc.push_str("...");
1784 Self::push_newline_indent(acc, depth + 1);
1785
1786 self.fmt_outer_tensor(
1787 acc,
1788 depth,
1789 multi_index,
1790 print_options,
1791 summarize,
1792 (self.dims()[depth] - edge_items, self.dims()[depth]),
1793 );
1794 } else {
1795 self.fmt_outer_tensor(
1796 acc,
1797 depth,
1798 multi_index,
1799 print_options,
1800 summarize,
1801 (0, self.dims()[depth]),
1802 );
1803 }
1804 }
1805
1806 if depth == 0 {
1807 acc.push(']');
1808 }
1809 }
1810}
1811
1812#[derive(Clone, Debug)]
1813/// Options for Tensor pretty printing
1814pub struct PrintOptions {
1815 /// number of elements to start summarizing tensor
1816 pub threshold: usize,
1817
1818 /// number of starting elements and ending elements to display
1819 pub edge_items: usize,
1820
1821 /// Precision for floating point numbers
1822 pub precision: Option<usize>,
1823}
1824
1825static PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());
1826
1827impl PrintOptions {
1828 /// Print options with default values
1829 pub const fn const_default() -> Self {
1830 Self {
1831 threshold: 1000,
1832 edge_items: 3,
1833 precision: None,
1834 }
1835 }
1836}
1837
1838impl Default for PrintOptions {
1839 fn default() -> Self {
1840 Self::const_default()
1841 }
1842}
1843
1844/// Set print options
1845pub fn set_print_options(options: PrintOptions) {
1846 let mut print_opts = PRINT_OPTS.write().unwrap();
1847 *print_opts = options;
1848}
1849
1850/// Pretty print tensors
1851impl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>
1852where
1853 B: Backend,
1854 B::IntElem: core::fmt::Display,
1855 K: BasicOps<B>,
1856 <K as BasicOps<B>>::Elem: Debug,
1857{
1858 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1859 writeln!(f, "Tensor {{")?;
1860
1861 {
1862 // Do not lock the mutex for the whole function
1863 let mut po = { PRINT_OPTS.read().unwrap().clone() };
1864
1865 // Override the precision if it is set from the formatter
1866 // This will be possible when the tensor is printed using the `{:.*}` syntax
1867 if let Some(precision) = f.precision() {
1868 po.precision = Some(precision);
1869 }
1870
1871 let mut acc = String::new();
1872 let mut multi_index = vec![0; D];
1873 let summarize = self.shape().num_elements() > po.threshold;
1874
1875 self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);
1876
1877 writeln!(f, " data:")?;
1878 write!(f, "{acc}")?;
1879 writeln!(f, ",")?;
1880 }
1881
1882 writeln!(f, " shape: {:?},", self.dims())?;
1883 writeln!(f, " device: {:?},", self.device())?;
1884 writeln!(f, " backend: {:?},", B::name())?;
1885 writeln!(f, " kind: {:?},", K::name())?;
1886
1887 // Bool tensors might be encoded in a different type, which we abstract for the display
1888 let dtype = if TypeId::of::<K::Elem>() == TypeId::of::<bool>() {
1889 DType::Bool
1890 } else {
1891 self.primitive.dtype()
1892 };
1893
1894 writeln!(f, " dtype: {:?},", dtype.name())?;
1895 write!(f, "}}")
1896 }
1897}
1898
1899/// Transpose marker (zero-size type). Used to sugar the transpose of a tensor, e.g.
1900/// ```rust
1901/// use burn_tensor::backend::Backend;
1902/// use burn_tensor::{Tensor, T};
1903///
1904/// fn example<B: Backend>() {
1905/// let device = Default::default();
1906/// let tensor = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
1907/// let transposed = tensor^T;
1908/// }
1909/// ```
1910pub struct T;
1911
1912impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
1913 type Output = Self;
1914 fn bitxor(self, _: T) -> Self::Output {
1915 self.transpose()
1916 }
1917}
1918
1919/// Trait that list all operations that can be applied on all tensors.
1920///
1921/// # Warnings
1922///
1923/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
1924pub trait BasicOps<B: Backend>: TensorKind<B> {
1925 /// The type of the tensor elements.
1926 type Elem: Element;
1927
1928 /// Creates an empty tensor with the given shape.
1929 ///
1930 /// # Arguments
1931 ///
1932 /// * `shape` - The shape of the tensor.
1933 /// * `device` - The device on which the tensor will be allocated.
1934 ///
1935 /// # Returns
1936 ///
1937 /// The empty tensor.
1938 ///
1939 /// # Remarks
1940 ///
1941 /// This is a low-level function used internally by the library to call different backend functions
1942 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1943 /// or use this function directly.
1944 ///
1945 /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
1946 /// which is more high-level and designed for public use.
1947 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive;
1948
1949 /// Reshapes the tensor.
1950 ///
1951 /// # Arguments
1952 ///
1953 /// * `tensor` - The tensor.
1954 /// * `shape` - The new shape of the tensor.
1955 ///
1956 /// # Returns
1957 ///
1958 /// The reshaped tensor.
1959 ///
1960 /// # Remarks
1961 ///
1962 /// This is a low-level function used internally by the library to call different backend functions
1963 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
1964 /// or use this function directly.
1965 ///
1966 /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
1967 /// which is more high-level and designed for public use.
1968 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
1969
1970 /// Transposes a tensor.
1971 ///
1972 /// # Arguments
1973 ///
1974 /// * `tensor` - The tensor to transpose.
1975 ///
1976 /// # Returns
1977 ///
1978 /// The transposed tensor.
1979 fn transpose(tensor: Self::Primitive) -> Self::Primitive;
1980
1981 /// Swaps two dimensions of a tensor.
1982 ///
1983 /// # Arguments
1984 ///
1985 /// * `tensor` - The tensor to swap the dimensions of.
1986 /// * `dim1` - The first dimension to swap.
1987 /// * `dim2` - The second dimension to swap.
1988 ///
1989 /// # Returns
1990 ///
1991 /// The tensor with the dimensions swapped.
1992 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;
1993
1994 /// Permutes the dimensions of a tensor.
1995 ///
1996 /// # Arguments
1997 ///
1998 /// * `tensor` - The tensor to permute the dimensions of.
1999 /// * `axes` - The new order of the dimensions.
2000 ///
2001 /// # Returns
2002 ///
2003 /// The tensor with the dimensions permuted.
2004 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2005
2006 /// Flips the tensor along the given axes.
2007 ///
2008 /// # Arguments
2009 ///
2010 /// * `tensor` - The tensor to flip.
2011 /// * `axes` - The axes to flip the tensor along.
2012 ///
2013 /// # Returns
2014 ///
2015 /// The tensor with the axes flipped.
2016 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;
2017
2018 /// Select tensor elements corresponding for the given ranges.
2019 ///
2020 /// # Arguments
2021 ///
2022 /// * `tensor` - The tensor.
2023 /// * `ranges` - The ranges of the elements to select.
2024 ///
2025 /// # Returns
2026 ///
2027 /// The selected elements.
2028 ///
2029 /// # Remarks
2030 ///
2031 /// This is a low-level function used internally by the library to call different backend functions
2032 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2033 /// or use this function directly.
2034 ///
2035 /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
2036 /// which is more high-level and designed for public use.
2037 fn slice(tensor: Self::Primitive, range: &[Range<usize>]) -> Self::Primitive;
2038
2039 /// Assigns the given value to the tensor elements corresponding for the given ranges.
2040 ///
2041 /// # Arguments
2042 ///
2043 /// * `tensor` - The tensor.
2044 /// * `ranges` - The ranges of the elements to select.
2045 /// * `value` - The value to assign.
2046 ///
2047 /// # Returns
2048 ///
2049 /// The tensor with the assigned values.
2050 ///
2051 /// # Remarks
2052 ///
2053 /// This is a low-level function used internally by the library to call different backend functions
2054 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2055 /// or use this function directly.
2056 ///
2057 /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
2058 /// which is more high-level and designed for public use.
2059 fn slice_assign(
2060 tensor: Self::Primitive,
2061 ranges: &[Range<usize>],
2062 value: Self::Primitive,
2063 ) -> Self::Primitive;
2064
2065 /// Returns the device on which the tensor is allocated.
2066 ///
2067 /// # Arguments
2068 ///
2069 /// * `tensor` - The tensor.
2070 ///
2071 /// # Returns
2072 ///
2073 /// The device on which the tensor is allocated.
2074 ///
2075 /// # Remarks
2076 ///
2077 /// This is a low-level function used internally by the library to call different backend functions
2078 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2079 /// or use this function directly.
2080 ///
2081 /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
2082 /// which is more high-level and designed for public use.
2083 fn device(tensor: &Self::Primitive) -> B::Device;
2084
2085 /// Moves the tensor to the given device.
2086 ///
2087 /// # Arguments
2088 ///
2089 /// * `tensor` - The tensor.
2090 /// * `device` - The device on which the tensor will be moved.
2091 ///
2092 /// # Returns
2093 ///
2094 /// The tensor on the given device.
2095 ///
2096 /// # Remarks
2097 ///
2098 /// This is a low-level function used internally by the library to call different backend functions
2099 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2100 /// or use this function directly.
2101 ///
2102 /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
2103 /// which is more high-level and designed for public use.
2104 fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;
2105
2106 /// Extracts the data from the tensor asynchronously.
2107 ///
2108 /// # Arguments
2109 ///
2110 /// * `tensor` - The tensor.
2111 ///
2112 /// # Returns
2113 ///
2114 /// The data of the tensor.
2115 ///
2116 /// # Remarks
2117 ///
2118 /// This is a low-level function used internally by the library to call different backend functions
2119 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2120 /// or use this function directly.
2121 ///
2122 /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
2123 /// which is more high-level and designed for public use.
2124 fn into_data_async(
2125 tensor: Self::Primitive,
2126 ) -> impl Future<Output = TensorData> + 'static + Send;
2127
2128 /// Read the data from the tensor using a transaction.
2129 ///
2130 /// # Remarks
2131 ///
2132 /// This is a low-level function used internally by the library to call different backend functions
2133 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2134 /// or use this function directly.
2135 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive);
2136
2137 /// Creates a tensor from the given data.
2138 ///
2139 /// # Arguments
2140 ///
2141 /// * `data` - The data of the tensor.
2142 /// * `device` - The device on which the tensor will be allocated.
2143 ///
2144 /// # Returns
2145 ///
2146 /// The tensor.
2147 ///
2148 /// # Remarks
2149 ///
2150 /// This is a low-level function used internally by the library to call different backend functions
2151 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2152 /// or use this function directly.
2153 ///
2154 /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
2155 /// which is more high-level and designed for public use.
2156 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;
2157
2158 /// Repeat the tensor along the given dimension.
2159 ///
2160 /// # Arguments
2161 ///
2162 /// * `tensor` - The tensor.
2163 /// * `dim` - The dimension along which the tensor will be repeated.
2164 /// * `times` - The number of times the tensor will be repeated.
2165 ///
2166 /// # Returns
2167 ///
2168 /// The repeated tensor.
2169 ///
2170 /// # Remarks
2171 ///
2172 /// This is a low-level function used internally by the library to call different backend functions
2173 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2174 /// or use this function directly.
2175 ///
2176 /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
2177 /// which is more high-level and designed for public use.
2178 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;
2179
2180 /// Concatenates the given tensors along the given dimension.
2181 ///
2182 /// # Arguments
2183 ///
2184 /// * `vectors` - The tensors to concatenate.
2185 /// * `dim` - The dimension along which the tensors will be concatenated.
2186 ///
2187 /// # Returns
2188 ///
2189 /// The concatenated tensor.
2190 ///
2191 /// # Remarks
2192 ///
2193 /// This is a low-level function used internally by the library to call different backend functions
2194 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2195 /// or use this function directly.
2196 ///
2197 /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
2198 /// which is more high-level and designed for public use.
2199 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;
2200
2201 /// Attempts to split the tensor along the given dimension into chunks.
2202 /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
2203 ///
2204 /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
2205 /// Otherwise all chunks will be of equal size except for the last one.
2206 ///
2207 /// # Panics
2208 ///
2209 /// If the dimension is greater than the number of dimensions of the tensor.
2210 ///
2211 /// # Returns
2212 /// A vector of tensors.
2213 ///
2214 /// # Remarks
2215 ///
2216 /// This is a low-level function used internally by the library to call different backend functions
2217 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2218 /// or use this function directly.
2219 ///
2220 /// To chunk a tensor, users should prefer the [Tensor::chunk](Tensor::chunk) function,
2221 /// which is more high-level and designed for public use.
2222 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive>;
2223
2224 /// Splits the tensor into chunks of a specified size along a given dimension.
2225 /// Each chunk is a view of the original tensor.
2226 ///
2227 /// # Panics
2228 ///
2229 /// If the dimension to split along is greater than the number of dimensions of the tensor.
2230 ///
2231 /// # Returns
2232 ///
2233 /// A vector of tensors.
2234 ///
2235 /// # Remarks
2236 /// This is a low-level function used internally by the library to call different backend functions
2237 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2238 /// or use this function directly.
2239 ///
2240 /// To split a tensor, users should prefer the [Tensor::split](Tensor::split) function,
2241 /// which is more high-level and designed for public use.
2242 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive>;
2243
2244 /// Splits the tensor into chunks with the specified sizes along a given dimension.
2245 /// Each chunk is a view of the original tensor.
2246 ///
2247 /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes
2248 /// in `split_sizes` must equal the size of the tensor along the specified dimension.
2249 ///
2250 /// # Panics
2251 ///
2252 /// If the dimension to split along is greater than the number of dimensions of the tensor or
2253 /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.
2254 ///
2255 /// # Returns
2256 ///
2257 /// A vector of tensors.
2258 ///
2259 /// # Remarks
2260 /// This is a low-level function used internally by the library to call different backend functions
2261 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2262 /// or use this function directly.
2263 ///
2264 /// To split a tensor, users should prefer the [Tensor::split_with_sizes](Tensor::split_with_sizes) function,
2265 /// which is more high-level and designed for public use.
2266 fn split_with_sizes(
2267 tensor: Self::Primitive,
2268 split_sizes: Vec<usize>,
2269 dim: usize,
2270 ) -> Vec<Self::Primitive>;
2271
2272 /// Equates the given tensors.
2273 ///
2274 /// # Arguments
2275 ///
2276 /// * `lhs` - The left hand side tensor.
2277 /// * `rhs` - The right hand side tensor.
2278 ///
2279 /// # Returns
2280 ///
2281 /// The tensor of booleans indicating whether the corresponding elements are equal.
2282 ///
2283 /// # Remarks
2284 ///
2285 /// This is a low-level function used internally by the library to call different backend functions
2286 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2287 /// or use this function directly.
2288 ///
2289 /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
2290 /// which is more high-level and designed for public use.
2291 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2292
2293 /// Applies element-wise non-equality comparison between the given tensors.
2294 ///
2295 /// # Arguments
2296 ///
2297 /// * `lhs` - The left hand side tensor.
2298 /// * `rhs` - The right hand side tensor.
2299 ///
2300 /// # Returns
2301 ///
2302 /// The tensor of booleans indicating whether the corresponding elements are equal.
2303 ///
2304 /// # Remarks
2305 ///
2306 /// This is a low-level function used internally by the library to call different backend functions
2307 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2308 /// or use this function directly.
2309 ///
2310 /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal)
2311 /// function, which is more high-level and designed for public use.
2312 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2313
2314 /// Returns the name of the element type.
2315 fn elem_type_name() -> &'static str {
2316 core::any::type_name::<Self::Elem>()
2317 }
2318
2319 /// Returns the tensor data type.
2320 fn dtype(tensor: &Self::Primitive) -> DType {
2321 tensor.dtype()
2322 }
2323
2324 /// Tests if any element in the `tensor` evaluates to True.
2325 ///
2326 /// # Arguments
2327 ///
2328 /// * `tensor` - The tensor to test.
2329 ///
2330 /// # Returns
2331 ///
2332 /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.
2333 ///
2334 /// # Remarks
2335 ///
2336 /// This is a low-level function used internally by the library to call different backend functions
2337 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2338 /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function
2339 /// which is more high-level and designed for public use.
2340 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2341
2342 /// Tests if any element in the tensor evaluates to True along a given dimension dim.
2343 ///
2344 /// # Arguments
2345 ///
2346 /// * tensor - The tensor to test.
2347 /// * dim - The axis along which to test.
2348 ///
2349 /// # Returns
2350 ///
2351 /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.
2352 /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.
2353 ///
2354 /// # Remarks
2355 ///
2356 /// This is a low-level function used internally by the library to call different backend functions
2357 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2358 /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function,
2359 /// which is more high-level and designed for public use.
2360 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2361
2362 /// Tests if all elements in the `tensor` evaluate to True.
2363 ///
2364 /// # Arguments
2365 ///
2366 /// * `tensor` - The tensor to test.
2367 ///
2368 /// # Returns
2369 ///
2370 /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.
2371 ///
2372 /// # Remarks
2373 ///
2374 /// This is a low-level function used internally by the library to call different backend functions
2375 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2376 /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function,
2377 /// which is more high-level and designed for public use.
2378 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;
2379
2380 /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.
2381 ///
2382 /// # Arguments
2383 ///
2384 /// * `tensor` - The tensor to test.
2385 ///
2386 /// # Returns
2387 ///
2388 /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.
2389 /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.
2390 ///
2391 /// # Remarks
2392 ///
2393 /// This is a low-level function used internally by the library to call different backend functions
2394 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2395 /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function,
2396 /// which is more high-level and designed for public use.
2397 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;
2398
2399 /// Broadcasts the given tensor to the specified shape.
2400 ///
2401 /// # Arguments
2402 ///
2403 /// * `tensor` - The tensor to broadcast.
2404 /// * `shape` - The shape to broadcast to.
2405 ///
2406 /// # Returns
2407 ///
2408 /// The broadcasted tensor.
2409 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;
2410}
2411
2412impl<B: Backend> BasicOps<B> for Float {
2413 type Elem = B::FloatElem;
2414
2415 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2416 TensorPrimitive::Float(B::float_empty(shape, device))
2417 }
2418
2419 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2420 tr.register_float(tensor);
2421 }
2422
2423 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2424 match tensor {
2425 TensorPrimitive::Float(tensor) => {
2426 TensorPrimitive::Float(B::float_reshape(tensor, shape))
2427 }
2428 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),
2429 }
2430 }
2431
2432 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2433 match tensor {
2434 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),
2435 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),
2436 }
2437 }
2438
2439 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2440 match tensor {
2441 TensorPrimitive::Float(tensor) => {
2442 TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))
2443 }
2444 TensorPrimitive::QFloat(tensor) => {
2445 TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))
2446 }
2447 }
2448 }
2449
2450 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2451 match tensor {
2452 TensorPrimitive::Float(tensor) => {
2453 TensorPrimitive::Float(B::float_slice(tensor, ranges))
2454 }
2455 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, ranges)),
2456 }
2457 }
2458
2459 fn slice_assign(
2460 tensor: Self::Primitive,
2461 ranges: &[Range<usize>],
2462 value: Self::Primitive,
2463 ) -> Self::Primitive {
2464 match (tensor, value) {
2465 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(value)) => {
2466 TensorPrimitive::Float(B::float_slice_assign(tensor, ranges, value))
2467 }
2468 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(value)) => {
2469 TensorPrimitive::QFloat(B::q_slice_assign(tensor, ranges, value))
2470 }
2471 _ => panic!("Primitive type mismatch for tensor and value"),
2472 }
2473 }
2474
2475 fn device(tensor: &Self::Primitive) -> Device<B> {
2476 match tensor {
2477 TensorPrimitive::Float(tensor) => B::float_device(tensor),
2478 TensorPrimitive::QFloat(tensor) => B::q_device(tensor),
2479 }
2480 }
2481
2482 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2483 match tensor {
2484 TensorPrimitive::Float(tensor) => {
2485 TensorPrimitive::Float(B::float_to_device(tensor, device))
2486 }
2487 TensorPrimitive::QFloat(tensor) => {
2488 TensorPrimitive::QFloat(B::q_to_device(tensor, device))
2489 }
2490 }
2491 }
2492
2493 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2494 match tensor {
2495 TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,
2496 TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,
2497 }
2498 }
2499
2500 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2501 match data.dtype {
2502 DType::QFloat(_strategy) => TensorPrimitive::QFloat(B::q_from_data(data, device)),
2503 _ => TensorPrimitive::Float(B::float_from_data(data, device)),
2504 }
2505 }
2506
2507 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2508 match tensor {
2509 TensorPrimitive::Float(tensor) => {
2510 TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))
2511 }
2512 TensorPrimitive::QFloat(tensor) => {
2513 TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))
2514 }
2515 }
2516 }
2517
2518 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2519 match vectors.first().unwrap() {
2520 TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(
2521 vectors.into_iter().map(|tensor| tensor.tensor()).collect(),
2522 dim,
2523 )),
2524 TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(
2525 vectors
2526 .into_iter()
2527 .map(|tensor| {
2528 if let TensorPrimitive::QFloat(t) = tensor {
2529 t
2530 } else {
2531 panic!("Concatenation only works with vector of QFloat")
2532 }
2533 })
2534 .collect(),
2535 dim,
2536 )),
2537 }
2538 }
2539
2540 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2541 B::float_equal(lhs.tensor(), rhs.tensor())
2542 }
2543
2544 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2545 B::float_not_equal(lhs.tensor(), rhs.tensor())
2546 }
2547
2548 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2549 B::float_any(tensor.tensor())
2550 }
2551
2552 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2553 B::float_any_dim(tensor.tensor(), dim)
2554 }
2555
2556 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2557 B::float_all(tensor.tensor())
2558 }
2559
2560 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2561 B::float_all_dim(tensor.tensor(), dim)
2562 }
2563
2564 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2565 match tensor {
2566 TensorPrimitive::Float(tensor) => {
2567 TensorPrimitive::Float(B::float_permute(tensor, axes))
2568 }
2569 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),
2570 }
2571 }
2572
2573 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2574 TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))
2575 }
2576
2577 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2578 match tensor {
2579 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),
2580 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),
2581 }
2582 }
2583
2584 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2585 match tensor {
2586 TensorPrimitive::Float(tensor) => B::float_chunk(tensor, chunks, dim)
2587 .into_iter()
2588 .map(TensorPrimitive::Float)
2589 .collect(),
2590 TensorPrimitive::QFloat(tensor) => B::q_chunk(tensor, chunks, dim)
2591 .into_iter()
2592 .map(TensorPrimitive::QFloat)
2593 .collect(),
2594 }
2595 }
2596
2597 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2598 match tensor {
2599 TensorPrimitive::Float(tensor) => B::float_split(tensor, split_size, dim)
2600 .into_iter()
2601 .map(TensorPrimitive::Float)
2602 .collect(),
2603 TensorPrimitive::QFloat(tensor) => B::q_split(tensor, split_size, dim)
2604 .into_iter()
2605 .map(TensorPrimitive::QFloat)
2606 .collect(),
2607 }
2608 }
2609
2610 fn split_with_sizes(
2611 tensor: Self::Primitive,
2612 split_sizes: Vec<usize>,
2613 dim: usize,
2614 ) -> Vec<Self::Primitive> {
2615 match tensor {
2616 TensorPrimitive::Float(tensor) => B::float_split_with_sizes(tensor, split_sizes, dim)
2617 .into_iter()
2618 .map(TensorPrimitive::Float)
2619 .collect(),
2620 TensorPrimitive::QFloat(tensor) => B::q_split_with_sizes(tensor, split_sizes, dim)
2621 .into_iter()
2622 .map(TensorPrimitive::QFloat)
2623 .collect(),
2624 }
2625 }
2626}
2627
2628impl<B: Backend> BasicOps<B> for Int {
2629 type Elem = B::IntElem;
2630
2631 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2632 B::int_empty(shape, device)
2633 }
2634
2635 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2636 tr.register_int(tensor);
2637 }
2638
2639 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2640 B::int_reshape(tensor, shape)
2641 }
2642
2643 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2644 B::int_transpose(tensor)
2645 }
2646
2647 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2648 B::int_swap_dims(tensor, dim1, dim2)
2649 }
2650
2651 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2652 B::int_slice(tensor, ranges)
2653 }
2654
2655 fn slice_assign(
2656 tensor: Self::Primitive,
2657 ranges: &[Range<usize>],
2658 value: Self::Primitive,
2659 ) -> Self::Primitive {
2660 B::int_slice_assign(tensor, ranges, value)
2661 }
2662
2663 fn device(tensor: &Self::Primitive) -> Device<B> {
2664 B::int_device(tensor)
2665 }
2666
2667 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2668 B::int_to_device(tensor, device)
2669 }
2670
2671 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2672 B::int_into_data(tensor).await
2673 }
2674
2675 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2676 B::int_from_data(data, device)
2677 }
2678
2679 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2680 B::int_repeat_dim(tensor, dim, times)
2681 }
2682
2683 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2684 B::int_equal(lhs, rhs)
2685 }
2686
2687 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2688 B::int_not_equal(lhs, rhs)
2689 }
2690
2691 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2692 B::int_cat(vectors, dim)
2693 }
2694
2695 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2696 B::int_any(tensor)
2697 }
2698
2699 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2700 B::int_any_dim(tensor, dim)
2701 }
2702
2703 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2704 B::int_all(tensor)
2705 }
2706
2707 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2708 B::int_all_dim(tensor, dim)
2709 }
2710
2711 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2712 B::int_permute(tensor, axes)
2713 }
2714
2715 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2716 B::int_expand(tensor, shape)
2717 }
2718
2719 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2720 B::int_flip(tensor, axes)
2721 }
2722
2723 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2724 B::int_chunk(tensor, chunks, dim)
2725 }
2726
2727 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2728 B::int_split(tensor, split_size, dim)
2729 }
2730
2731 fn split_with_sizes(
2732 tensor: Self::Primitive,
2733 split_sizes: Vec<usize>,
2734 dim: usize,
2735 ) -> Vec<Self::Primitive> {
2736 B::int_split_with_sizes(tensor, split_sizes, dim)
2737 }
2738}
2739
2740impl<B: Backend> BasicOps<B> for Bool {
2741 type Elem = bool;
2742
2743 fn empty(shape: Shape, device: &B::Device) -> Self::Primitive {
2744 B::bool_empty(shape, device)
2745 }
2746
2747 fn register_transaction(tr: &mut Transaction<B>, tensor: Self::Primitive) {
2748 tr.register_bool(tensor);
2749 }
2750
2751 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2752 B::bool_reshape(tensor, shape)
2753 }
2754
2755 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
2756 B::bool_transpose(tensor)
2757 }
2758
2759 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
2760 B::bool_swap_dims(tensor, dim1, dim2)
2761 }
2762
2763 fn slice(tensor: Self::Primitive, ranges: &[Range<usize>]) -> Self::Primitive {
2764 B::bool_slice(tensor, ranges)
2765 }
2766
2767 fn slice_assign(
2768 tensor: Self::Primitive,
2769 ranges: &[Range<usize>],
2770 value: Self::Primitive,
2771 ) -> Self::Primitive {
2772 B::bool_slice_assign(tensor, ranges, value)
2773 }
2774
2775 fn device(tensor: &Self::Primitive) -> Device<B> {
2776 B::bool_device(tensor)
2777 }
2778
2779 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
2780 B::bool_to_device(tensor, device)
2781 }
2782
2783 async fn into_data_async(tensor: Self::Primitive) -> TensorData {
2784 B::bool_into_data(tensor).await
2785 }
2786
2787 fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive {
2788 B::bool_from_data(data, device)
2789 }
2790
2791 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
2792 B::bool_repeat_dim(tensor, dim, times)
2793 }
2794
2795 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2796 B::bool_equal(lhs, rhs)
2797 }
2798
2799 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
2800 B::bool_not_equal(lhs, rhs)
2801 }
2802
2803 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
2804 B::bool_cat(vectors, dim)
2805 }
2806
2807 fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2808 B::bool_any(tensor)
2809 }
2810
2811 fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2812 B::bool_any_dim(tensor, dim)
2813 }
2814
2815 fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {
2816 B::bool_all(tensor)
2817 }
2818
2819 fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {
2820 B::bool_all_dim(tensor, dim)
2821 }
2822
2823 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2824 B::bool_permute(tensor, axes)
2825 }
2826
2827 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
2828 B::bool_expand(tensor, shape)
2829 }
2830
2831 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
2832 B::bool_flip(tensor, axes)
2833 }
2834
2835 fn chunk(tensor: Self::Primitive, chunks: usize, dim: usize) -> Vec<Self::Primitive> {
2836 B::bool_chunk(tensor, chunks, dim)
2837 }
2838
2839 fn split(tensor: Self::Primitive, split_size: usize, dim: usize) -> Vec<Self::Primitive> {
2840 B::bool_split(tensor, split_size, dim)
2841 }
2842
2843 fn split_with_sizes(
2844 tensor: Self::Primitive,
2845 split_sizes: Vec<usize>,
2846 dim: usize,
2847 ) -> Vec<Self::Primitive> {
2848 B::bool_split_with_sizes(tensor, split_sizes, dim)
2849 }
2850}
2851
2852/// Trait used for movedim arguments
2853pub trait MovedimArgs {
2854 /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function
2855 fn into_dim_vec<const D: usize>(self) -> Vec<usize>;
2856}
2857
2858impl MovedimArgs for Vec<i32> {
2859 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2860 let set = self
2861 .iter()
2862 .map(|&dim| {
2863 if dim < 0 {
2864 (D as i32 + dim) as usize
2865 } else {
2866 dim as usize
2867 }
2868 })
2869 .collect::<Vec<usize>>();
2870 check!(TensorCheck::movedim_args_vec::<D>(&set));
2871
2872 set
2873 }
2874}
2875
2876impl MovedimArgs for Vec<usize> {
2877 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2878 check!(TensorCheck::movedim_args_vec::<D>(&self));
2879 self
2880 }
2881}
2882
2883impl MovedimArgs for usize {
2884 #[allow(clippy::vec_init_then_push)]
2885 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2886 check!(TensorCheck::movedim_args_usize::<D>(self));
2887
2888 let mut set = Vec::with_capacity(1);
2889 set.push(self);
2890
2891 set
2892 }
2893}
2894
2895impl MovedimArgs for i32 {
2896 #[allow(clippy::vec_init_then_push)]
2897 fn into_dim_vec<const D: usize>(self) -> Vec<usize> {
2898 check!(TensorCheck::movedim_args_i32::<D>(self));
2899
2900 let dim = if self < 0 {
2901 (D as i32 + self) as usize
2902 } else {
2903 self as usize
2904 };
2905
2906 let mut set = Vec::with_capacity(1);
2907 set.push(dim);
2908
2909 set
2910 }
2911}
2912
2913/// Trait used for slice arguments
2914pub trait RangesArg<const D2: usize> {
2915 /// Converts into a set of ranges to `[core::ops::Range<usize>; D2]` for the `tensor.slice()` function
2916 fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2];
2917
2918 /// Handles negative index values
2919 fn handle_negative_index(start: i64, end: i64, dim: usize) -> (usize, usize) {
2920 let start = if start < 0 {
2921 (dim as i64 + start) as usize
2922 } else {
2923 start as usize
2924 };
2925 let end = if end < 0 {
2926 (dim as i64 + end) as usize
2927 } else {
2928 end as usize
2929 };
2930 (start, end)
2931 }
2932
2933 /// Clamps the range to the shape dimensions
2934 fn clamp_range(start: usize, end: usize, dim: usize) -> (usize, usize) {
2935 let start = start.clamp(0, dim);
2936 let end = end.clamp(0, dim);
2937 (start, end)
2938 }
2939}
2940
2941impl<const D2: usize> RangesArg<D2> for [core::ops::Range<usize>; D2] {
2942 fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2] {
2943 // clamp the ranges to the shape dimensions
2944 let ranges = self
2945 .iter()
2946 .enumerate()
2947 .map(|(i, range)| {
2948 let (start, end) = Self::clamp_range(range.start, range.end, shape.dims[i]);
2949 start..end
2950 })
2951 .collect::<Vec<_>>();
2952 ranges.try_into().unwrap()
2953 }
2954}
2955
2956impl<const D2: usize> RangesArg<D2> for [Option<(i64, i64)>; D2] {
2957 fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2] {
2958 let ranges = self
2959 .iter()
2960 .enumerate()
2961 .map(|(i, range)| match range {
2962 Some((start, end)) => {
2963 let (start, end) = Self::handle_negative_index(*start, *end, shape.dims[i]);
2964 let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
2965 start..end
2966 }
2967 None => 0..shape.dims[i], // if None, use the full range
2968 })
2969 .collect::<Vec<_>>();
2970
2971 ranges.try_into().unwrap()
2972 }
2973}
2974
2975impl<const D2: usize> RangesArg<D2> for [(i64, i64); D2] {
2976 fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; D2] {
2977 let ranges = self
2978 .iter()
2979 .enumerate()
2980 .map(|(i, &(start, end))| {
2981 let (start, end) = Self::handle_negative_index(start, end, shape.dims[i]);
2982 let (start, end) = Self::clamp_range(start, end, shape.dims[i]);
2983 start..end
2984 })
2985 .collect::<Vec<_>>();
2986
2987 ranges.try_into().unwrap()
2988 }
2989}
2990
2991/// Trait used for reshape arguments.
2992pub trait ReshapeArgs<const D2: usize> {
2993 /// Converts to a shape.
2994 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
2995 self,
2996 tensor: &Tensor<B, D, K>,
2997 ) -> Shape;
2998}
2999
3000impl<const D2: usize> ReshapeArgs<D2> for Shape {
3001 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3002 self,
3003 tensor: &Tensor<B, D, K>,
3004 ) -> Shape {
3005 check!(TensorCheck::reshape_args_usize::<D, D2>(
3006 &tensor.shape(),
3007 &self
3008 ));
3009
3010 self
3011 }
3012}
3013impl<const D2: usize> ReshapeArgs<D2> for [usize; D2] {
3014 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3015 self,
3016 tensor: &Tensor<B, D, K>,
3017 ) -> Shape {
3018 let shape = Shape::from(self);
3019
3020 check!(TensorCheck::reshape_args_usize::<D, D2>(
3021 &tensor.shape(),
3022 &shape
3023 ));
3024
3025 shape
3026 }
3027}
3028
3029impl<const D2: usize> ReshapeArgs<D2> for [i32; D2] {
3030 fn into_shape<B: Backend, const D: usize, K: BasicOps<B>>(
3031 self,
3032 tensor: &Tensor<B, D, K>,
3033 ) -> Shape {
3034 // Validate the reshape arguments
3035 check!(TensorCheck::reshape_args_i32(&self));
3036
3037 // Temporary shape
3038 let mut new_shape: [i32; D2] = [1; D2];
3039
3040 // We need to find the index of the 0 dimension and
3041 // replace it with the actual dimension value.
3042 for (i, &s) in self.iter().enumerate() {
3043 if s != 0 {
3044 new_shape[i] = s;
3045 } else {
3046 new_shape[i] = tensor.dims()[i] as i32;
3047 }
3048 }
3049
3050 // Find the index of the inferred dimension (-1)
3051 let infer_index = new_shape.iter().position(|x| x == &-1);
3052
3053 // Handle the case where the dimension is inferred (via -1)
3054 if let Some(index) = infer_index {
3055 // Handle the case where the dimension is inferred
3056 let mut product = 1;
3057 for (i, &s) in new_shape.iter().enumerate() {
3058 if i != index {
3059 product *= s;
3060 }
3061 }
3062 let product_current = tensor.shape().num_elements() as i32;
3063
3064 new_shape[index] = product_current / product;
3065
3066 // Check if the reshape is valid
3067 if product_current % product != 0 {
3068 panic!(
3069 "Cannot reshape tensor of shape {:?} to shape {:?}",
3070 tensor.shape(),
3071 new_shape
3072 );
3073 }
3074 };
3075
3076 // Convert each element to usize
3077 let new_shape: [usize; D2] = new_shape.map(|x| x as usize);
3078
3079 Shape::from(new_shape)
3080 }
3081}
3082
3083/// Trait used for broadcast arguments.
3084pub trait BroadcastArgs<const D1: usize, const D2: usize> {
3085 /// Converts to a shape.
3086 fn into_shape(self, shape: &Shape) -> Shape;
3087}
3088
3089impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {
3090 fn into_shape(self, _shape: &Shape) -> Shape {
3091 self
3092 }
3093}
3094impl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for [usize; D2] {
3095 fn into_shape(self, _shape: &Shape) -> Shape {
3096 Shape::from(self)
3097 }
3098}
3099
3100impl<const D1: usize, const D2: usize, E: Element> BroadcastArgs<D1, D2> for [E; D2] {
3101 // Passing -1 as the size for a dimension means not changing the size of that dimension.
3102 fn into_shape(self, shape: &Shape) -> Shape {
3103 if self.len() < shape.num_dims() {
3104 panic!("Broadcast arguments must be greater than the number of dimensions");
3105 }
3106
3107 // Zip the two shapes in reverse order and replace -1 with the actual dimension value.
3108 let new_shape: Vec<_> = self
3109 .iter()
3110 .rev()
3111 .map(|x| {
3112 let primitive = x.to_i64();
3113 if primitive < -1 || primitive == 0 {
3114 panic!("Broadcast arguments must be positive or -1");
3115 }
3116 primitive
3117 })
3118 .zip(shape.dims.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s
3119 .map(|(x, &y)| if x == -1 { y } else { x as usize })
3120 .collect::<Vec<_>>()
3121 .into_iter()
3122 .rev()
3123 .collect();
3124
3125 if new_shape.iter().any(|&x| x == 0) {
3126 panic!("Cannot substitute -1 for a non-existing dimension");
3127 }
3128
3129 let new_shape: [usize; D2] = new_shape.try_into().unwrap();
3130
3131 Shape::from(new_shape)
3132 }
3133}
3134
3135impl<B, const D: usize, K> Serialize for Tensor<B, D, K>
3136where
3137 B: Backend,
3138 K: BasicOps<B>,
3139 K::Elem: Debug + Copy + Serialize,
3140{
3141 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
3142 let data = self.to_data();
3143 data.serialize(serializer)
3144 }
3145}
3146
3147impl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>
3148where
3149 B: Backend,
3150 K: BasicOps<B>,
3151 K::Elem: Debug + Copy + Deserialize<'de>,
3152{
3153 fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {
3154 let tensor = Tensor::from_data(
3155 TensorData::deserialize(deserializer)?,
3156 &<B::Device as Default>::default(),
3157 );
3158 Ok(tensor)
3159 }
3160}