burn_tensor/tensor/api/
int.rs

1use crate::{
2    Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend,
3    cartesian_grid,
4};
5
6use core::ops::Range;
7
8impl<B> Tensor<B, 1, Int>
9where
10    B: Backend,
11{
12    /// Returns a new integer tensor on the specified device.
13    ///
14    /// # Arguments
15    ///
16    /// * `range` - The range of values to generate.
17    /// * `device` - The device to create the tensor on.
18    pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
19        Tensor::new(B::int_arange(range, device))
20    }
21
22    /// Returns a new integer tensor on the specified device.
23    ///
24    /// # Arguments
25    ///
26    /// * `range` - The range of values to generate.
27    /// * `step` - The step between each value.
28    pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
29        Tensor::new(B::int_arange_step(range, step, device))
30    }
31}
32
33impl<const D: usize, B> Tensor<B, D, Int>
34where
35    B: Backend,
36{
37    /// Create a tensor from integers (i32), placing it on a given device.
38    ///
39    /// # Example
40    ///
41    /// ```rust
42    /// use burn_tensor::backend::Backend;
43    /// use burn_tensor::{Tensor, Int};
44    ///
45    /// fn example<B: Backend>() {
46    ///     let device = B::Device::default();
47    ///     let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
48    ///     let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
49    /// }
50    /// ```
51    pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
52        Self::from_data(ints.into().convert::<i32>(), device)
53    }
54
55    /// Returns a new tensor with the same shape and device as the current tensor and the data
56    /// cast to Float.
57    ///
58    /// # Example
59    ///
60    /// ```rust
61    /// use burn_tensor::backend::Backend;
62    /// use burn_tensor::{Int, Tensor};
63    ///
64    /// fn example<B: Backend>() {
65    ///     let device = Default::default();
66    ///     let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
67    ///     let float_tensor = int_tensor.float();
68    /// }
69    /// ```
70    pub fn float(self) -> Tensor<B, D, Float> {
71        Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))
72    }
73
74    /// Generates a cartesian grid for the given tensor shape on the specified device.
75    /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
76    ///
77    /// # Arguments
78    ///
79    /// * `shape` - The shape specifying the dimensions of the tensor.
80    /// * `device` - The device to create the tensor on.
81    ///
82    /// # Panics
83    ///
84    /// Panics if `D2` is not equal to `D+1`.
85    ///
86    /// # Examples
87    ///
88    /// ```rust
89    ///    use burn_tensor::Int;
90    ///    use burn_tensor::{backend::Backend, Shape, Tensor};
91    ///    fn example<B: Backend>() {
92    ///        let device = Default::default();
93    ///        let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
94    ///        println!("{}", result);
95    ///    }
96    /// ```
97    pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
98        shape: S,
99        device: &B::Device,
100    ) -> Tensor<B, D2, Int> {
101        cartesian_grid::<B, S, D, D2>(shape, device)
102    }
103
104    /// Applies the bitwise logical and operation with each bit representing the integer.
105    pub fn bitwise_and(self, other: Self) -> Self {
106        Self::new(B::bitwise_and(self.primitive, other.primitive))
107    }
108
109    /// Applies the bitwise logical or operation with another tensor.
110    pub fn bitwise_or(self, other: Self) -> Self {
111        Self::new(B::bitwise_or(self.primitive, other.primitive))
112    }
113
114    /// Applies the bitwise logical xor operation with another tensor.
115    pub fn bitwise_xor(self, other: Self) -> Self {
116        Self::new(B::bitwise_xor(self.primitive, other.primitive))
117    }
118
119    /// Applies the bitwise logical not operation.
120    pub fn bitwise_not(self) -> Self {
121        Self::new(B::bitwise_not(self.primitive))
122    }
123
124    /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.
125    pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {
126        Self::new(B::bitwise_and_scalar(self.primitive, other))
127    }
128
129    /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.
130    pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {
131        Self::new(B::bitwise_or_scalar(self.primitive, other))
132    }
133
134    /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.
135    pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {
136        Self::new(B::bitwise_xor_scalar(self.primitive, other))
137    }
138
139    /// Applies the bitwise left shift operation with the integers in the tensor.
140    pub fn bitwise_left_shift(self, other: Self) -> Self {
141        Self::new(B::bitwise_left_shift(self.primitive, other.primitive))
142    }
143
144    /// Applies the bitwise right shift operation with the integers in the tensor.
145    pub fn bitwise_right_shift(self, other: Self) -> Self {
146        Self::new(B::bitwise_right_shift(self.primitive, other.primitive))
147    }
148
149    /// Applies the bitwise left shift operation with the scalar.
150    pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {
151        Self::new(B::bitwise_left_shift_scalar(self.primitive, other))
152    }
153
154    /// Applies the bitwise right shift operation with the scalar.
155    pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {
156        Self::new(B::bitwise_right_shift_scalar(self.primitive, other))
157    }
158
159    /// Converts a tensor to the specified integer data type.
160    ///
161    /// This is always a no-op when casting to the current dtype.
162    ///
163    /// # Warning
164    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors
165    /// have the same integer data type for operations multiple input tensors (e.g., binary ops).
166    pub fn cast<F: Into<IntDType>>(self, dtype: F) -> Tensor<B, D, Int> {
167        let dtype = dtype.into();
168        let self_dtype: IntDType = self.dtype().into();
169        if dtype == self_dtype {
170            // no-op.
171            return self;
172        }
173        Tensor::new(B::int_cast(self.primitive, dtype))
174    }
175}