burn_tensor/tensor/api/int.rs
1use crate::{
2 Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend,
3 cartesian_grid,
4};
5
6use core::ops::Range;
7
8impl<B> Tensor<B, 1, Int>
9where
10 B: Backend,
11{
12 /// Returns a new integer tensor on the specified device.
13 ///
14 /// # Arguments
15 ///
16 /// * `range` - The range of values to generate.
17 /// * `device` - The device to create the tensor on.
18 pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
19 Tensor::new(B::int_arange(range, device))
20 }
21
22 /// Returns a new integer tensor on the specified device.
23 ///
24 /// # Arguments
25 ///
26 /// * `range` - The range of values to generate.
27 /// * `step` - The step between each value.
28 pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
29 Tensor::new(B::int_arange_step(range, step, device))
30 }
31}
32
33impl<const D: usize, B> Tensor<B, D, Int>
34where
35 B: Backend,
36{
37 /// Create a tensor from integers (i32), placing it on a given device.
38 ///
39 /// # Example
40 ///
41 /// ```rust
42 /// use burn_tensor::backend::Backend;
43 /// use burn_tensor::{Tensor, Int};
44 ///
45 /// fn example<B: Backend>() {
46 /// let device = B::Device::default();
47 /// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
48 /// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
49 /// }
50 /// ```
51 pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
52 Self::from_data(ints.into().convert::<i32>(), device)
53 }
54
55 /// Returns a new tensor with the same shape and device as the current tensor and the data
56 /// cast to Float.
57 ///
58 /// # Example
59 ///
60 /// ```rust
61 /// use burn_tensor::backend::Backend;
62 /// use burn_tensor::{Int, Tensor};
63 ///
64 /// fn example<B: Backend>() {
65 /// let device = Default::default();
66 /// let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
67 /// let float_tensor = int_tensor.float();
68 /// }
69 /// ```
70 pub fn float(self) -> Tensor<B, D, Float> {
71 Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))
72 }
73
74 /// Generates a cartesian grid for the given tensor shape on the specified device.
75 /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
76 ///
77 /// # Arguments
78 ///
79 /// * `shape` - The shape specifying the dimensions of the tensor.
80 /// * `device` - The device to create the tensor on.
81 ///
82 /// # Panics
83 ///
84 /// Panics if `D2` is not equal to `D+1`.
85 ///
86 /// # Examples
87 ///
88 /// ```rust
89 /// use burn_tensor::Int;
90 /// use burn_tensor::{backend::Backend, Shape, Tensor};
91 /// fn example<B: Backend>() {
92 /// let device = Default::default();
93 /// let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
94 /// println!("{}", result);
95 /// }
96 /// ```
97 pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
98 shape: S,
99 device: &B::Device,
100 ) -> Tensor<B, D2, Int> {
101 cartesian_grid::<B, S, D, D2>(shape, device)
102 }
103
104 /// Applies the bitwise logical and operation with each bit representing the integer.
105 pub fn bitwise_and(self, other: Self) -> Self {
106 Self::new(B::bitwise_and(self.primitive, other.primitive))
107 }
108
109 /// Applies the bitwise logical or operation with another tensor.
110 pub fn bitwise_or(self, other: Self) -> Self {
111 Self::new(B::bitwise_or(self.primitive, other.primitive))
112 }
113
114 /// Applies the bitwise logical xor operation with another tensor.
115 pub fn bitwise_xor(self, other: Self) -> Self {
116 Self::new(B::bitwise_xor(self.primitive, other.primitive))
117 }
118
119 /// Applies the bitwise logical not operation.
120 pub fn bitwise_not(self) -> Self {
121 Self::new(B::bitwise_not(self.primitive))
122 }
123
124 /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.
125 pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {
126 Self::new(B::bitwise_and_scalar(self.primitive, other))
127 }
128
129 /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.
130 pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {
131 Self::new(B::bitwise_or_scalar(self.primitive, other))
132 }
133
134 /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.
135 pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {
136 Self::new(B::bitwise_xor_scalar(self.primitive, other))
137 }
138
139 /// Applies the bitwise left shift operation with the integers in the tensor.
140 pub fn bitwise_left_shift(self, other: Self) -> Self {
141 Self::new(B::bitwise_left_shift(self.primitive, other.primitive))
142 }
143
144 /// Applies the bitwise right shift operation with the integers in the tensor.
145 pub fn bitwise_right_shift(self, other: Self) -> Self {
146 Self::new(B::bitwise_right_shift(self.primitive, other.primitive))
147 }
148
149 /// Applies the bitwise left shift operation with the scalar.
150 pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {
151 Self::new(B::bitwise_left_shift_scalar(self.primitive, other))
152 }
153
154 /// Applies the bitwise right shift operation with the scalar.
155 pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {
156 Self::new(B::bitwise_right_shift_scalar(self.primitive, other))
157 }
158
159 /// Converts a tensor to the specified integer data type.
160 ///
161 /// This is always a no-op when casting to the current dtype.
162 ///
163 /// # Warning
164 /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
165 /// have the same integer data type for operations multiple input tensors (e.g., binary ops).
166 pub fn cast<F: Into<IntDType>>(self, dtype: F) -> Tensor<B, D, Int> {
167 let dtype = dtype.into();
168 let self_dtype: IntDType = self.dtype().into();
169 if dtype == self_dtype {
170 // no-op.
171 return self;
172 }
173 Tensor::new(B::int_cast(self.primitive, dtype))
174 }
175}