burn_tensor/tensor/api/int.rs
1use burn_backend::Scalar;
2
3use crate::{
4 Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend,
5 cartesian_grid,
6};
7
8use core::ops::Range;
9
10impl<B> Tensor<B, 1, Int>
11where
12 B: Backend,
13{
14 /// Returns a new integer tensor on the specified device.
15 ///
16 /// # Arguments
17 ///
18 /// * `range` - The range of values to generate.
19 /// * `device` - The device to create the tensor on.
20 pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
21 Tensor::new(B::int_arange(range, device))
22 }
23
24 /// Returns a new integer tensor on the specified device.
25 ///
26 /// # Arguments
27 ///
28 /// * `range` - The range of values to generate.
29 /// * `step` - The step between each value.
30 pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
31 Tensor::new(B::int_arange_step(range, step, device))
32 }
33}
34
35impl<const D: usize, B> Tensor<B, D, Int>
36where
37 B: Backend,
38{
39 /// Create a tensor from integers (i32), placing it on a given device.
40 ///
41 /// # Example
42 ///
43 /// ```rust
44 /// use burn_tensor::backend::Backend;
45 /// use burn_tensor::{Tensor, Int};
46 ///
47 /// fn example<B: Backend>() {
48 /// let device = B::Device::default();
49 /// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
50 /// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
51 /// }
52 /// ```
53 pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
54 Self::from_data(ints.into().convert::<i32>(), device)
55 }
56
57 /// Returns a new tensor with the same shape and device as the current tensor and the data
58 /// cast to Float.
59 ///
60 /// # Example
61 ///
62 /// ```rust
63 /// use burn_tensor::backend::Backend;
64 /// use burn_tensor::{Int, Tensor};
65 ///
66 /// fn example<B: Backend>() {
67 /// let device = Default::default();
68 /// let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
69 /// let float_tensor = int_tensor.float();
70 /// }
71 /// ```
72 pub fn float(self) -> Tensor<B, D, Float> {
73 Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))
74 }
75
76 /// Generates a cartesian grid for the given tensor shape on the specified device.
77 /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
78 ///
79 /// # Arguments
80 ///
81 /// * `shape` - The shape specifying the dimensions of the tensor.
82 /// * `device` - The device to create the tensor on.
83 ///
84 /// # Panics
85 ///
86 /// Panics if `D2` is not equal to `D+1`.
87 ///
88 /// # Examples
89 ///
90 /// ```rust
91 /// use burn_tensor::Int;
92 /// use burn_tensor::{backend::Backend, Shape, Tensor};
93 /// fn example<B: Backend>() {
94 /// let device = Default::default();
95 /// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
96 /// println!("{}", result);
97 /// }
98 /// ```
99 pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
100 shape: S,
101 device: &B::Device,
102 ) -> Tensor<B, D2, Int> {
103 cartesian_grid::<B, S, D, D2>(shape, device)
104 }
105
106 /// Applies the bitwise logical and operation with each bit representing the integer.
107 pub fn bitwise_and(self, other: Self) -> Self {
108 Self::new(B::bitwise_and(self.primitive, other.primitive))
109 }
110
111 /// Applies the bitwise logical or operation with another tensor.
112 pub fn bitwise_or(self, other: Self) -> Self {
113 Self::new(B::bitwise_or(self.primitive, other.primitive))
114 }
115
116 /// Applies the bitwise logical xor operation with another tensor.
117 pub fn bitwise_xor(self, other: Self) -> Self {
118 Self::new(B::bitwise_xor(self.primitive, other.primitive))
119 }
120
121 /// Applies the bitwise logical not operation.
122 pub fn bitwise_not(self) -> Self {
123 Self::new(B::bitwise_not(self.primitive))
124 }
125
126 /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.
127 pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {
128 let other = Scalar::new(other, &self.dtype());
129 Self::new(B::bitwise_and_scalar(self.primitive, other))
130 }
131
132 /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.
133 pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {
134 let other = Scalar::new(other, &self.dtype());
135 Self::new(B::bitwise_or_scalar(self.primitive, other))
136 }
137
138 /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.
139 pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {
140 let other = Scalar::new(other, &self.dtype());
141 Self::new(B::bitwise_xor_scalar(self.primitive, other))
142 }
143
144 /// Applies the bitwise left shift operation with the integers in the tensor.
145 pub fn bitwise_left_shift(self, other: Self) -> Self {
146 Self::new(B::bitwise_left_shift(self.primitive, other.primitive))
147 }
148
149 /// Applies the bitwise right shift operation with the integers in the tensor.
150 pub fn bitwise_right_shift(self, other: Self) -> Self {
151 Self::new(B::bitwise_right_shift(self.primitive, other.primitive))
152 }
153
154 /// Applies the bitwise left shift operation with the scalar.
155 pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {
156 let other = Scalar::new(other, &self.dtype());
157 Self::new(B::bitwise_left_shift_scalar(self.primitive, other))
158 }
159
160 /// Applies the bitwise right shift operation with the scalar.
161 pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {
162 let other = Scalar::new(other, &self.dtype());
163 Self::new(B::bitwise_right_shift_scalar(self.primitive, other))
164 }
165
166 /// Converts a tensor to the specified integer data type.
167 ///
168 /// This is always a no-op when casting to the current dtype.
169 ///
170 /// # Warning
171 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
172 /// have the same integer data type for operations multiple input tensors (e.g., binary ops).
173 pub fn cast<F: Into<IntDType>>(self, dtype: F) -> Tensor<B, D, Int> {
174 let dtype = dtype.into();
175 let self_dtype: IntDType = self.dtype().into();
176 if dtype == self_dtype {
177 // no-op.
178 return self;
179 }
180 Tensor::new(B::int_cast(self.primitive, dtype))
181 }
182}