burn_tensor/tensor/api/int.rs
1use crate::check;
2use crate::check::TensorCheck;
3use crate::{
4 backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive,
5};
6
7use core::ops::Range;
8
9impl<B> Tensor<B, 1, Int>
10where
11 B: Backend,
12{
13 /// Returns a new integer tensor on the specified device.
14 ///
15 /// # Arguments
16 ///
17 /// * `range` - The range of values to generate.
18 /// * `device` - The device to create the tensor on.
19 pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
20 Tensor::new(B::int_arange(range, device))
21 }
22
23 /// Returns a new integer tensor on the specified device.
24 ///
25 /// # Arguments
26 ///
27 /// * `range` - The range of values to generate.
28 /// * `step` - The step between each value.
29 pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
30 Tensor::new(B::int_arange_step(range, step, device))
31 }
32
33 /// Create a one hot tensor from an index tensor.
34 ///
35 /// # Arguments
36 ///
37 /// * `num_classes` - The number of classes to use in encoding.
38 ///
39 /// # Example
40 ///
41 /// ```rust
42 /// use burn_tensor::backend::Backend;
43 /// use burn_tensor::{Tensor, Int};
44 ///
45 /// fn example<B: Backend>() {
46 /// let device = B::Device::default();
47 /// let indices: Tensor<B, 1, Int> = Tensor::from_ints([0, 1, 2, 3], &device);
48 /// let one_hot = indices.one_hot(4);
49 /// println!("{}", one_hot.to_data());
50 /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
51 /// }
52 /// ```
53 pub fn one_hot(self, num_classes: usize) -> Tensor<B, 2, Int> {
54 check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
55 let [num_samples] = self.dims();
56 let indices = self.unsqueeze_dim(1);
57 let values = indices.ones_like();
58 Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values)
59 }
60}
61
62impl<const D: usize, B> Tensor<B, D, Int>
63where
64 B: Backend,
65{
66 /// Create a tensor from integers (i32), placing it on a given device.
67 ///
68 /// # Example
69 ///
70 /// ```rust
71 /// use burn_tensor::backend::Backend;
72 /// use burn_tensor::{Tensor, Int};
73 ///
74 /// fn example<B: Backend>() {
75 /// let device = B::Device::default();
76 /// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
77 /// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
78 /// }
79 /// ```
80 pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
81 Self::from_data(ints.into().convert::<i32>(), device)
82 }
83
84 /// Returns a new tensor with the same shape and device as the current tensor and the data
85 /// cast to Float.
86 ///
87 /// # Example
88 ///
89 /// ```rust
90 /// use burn_tensor::backend::Backend;
91 /// use burn_tensor::{Int, Tensor};
92 ///
93 /// fn example<B: Backend>() {
94 /// let device = Default::default();
95 /// let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
96 /// let float_tensor = int_tensor.float();
97 /// }
98 /// ```
99 pub fn float(self) -> Tensor<B, D, Float> {
100 Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))
101 }
102
103 /// Generates a cartesian grid for the given tensor shape on the specified device.
104 /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
105 ///
106 /// # Arguments
107 ///
108 /// * `shape` - The shape specifying the dimensions of the tensor.
109 /// * `device` - The device to create the tensor on.
110 ///
111 /// # Panics
112 ///
113 /// Panics if `D2` is not equal to `D+1`.
114 ///
115 /// # Examples
116 ///
117 /// ```rust
118 /// use burn_tensor::Int;
119 /// use burn_tensor::{backend::Backend, Shape, Tensor};
120 /// fn example<B: Backend>() {
121 /// let device = Default::default();
122 /// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
123 /// println!("{}", result);
124 /// }
125 /// ```
126 pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
127 shape: S,
128 device: &B::Device,
129 ) -> Tensor<B, D2, Int> {
130 cartesian_grid::<B, S, D, D2>(shape, device)
131 }
132}