Skip to main content

burn_tensor/tensor/api/
int.rs

1use burn_backend::{Scalar, get_device_settings};
2
3use crate::{
4    Cast, Float, Int, Shape, Tensor, TensorCreationOptions, TensorData, TensorPrimitive,
5    backend::Backend, cartesian_grid,
6};
7
8use core::ops::Range;
9
10impl<B> Tensor<B, 1, Int>
11where
12    B: Backend,
13{
14    /// Returns a new integer tensor on the specified device.
15    ///
16    /// # Arguments
17    ///
18    /// * `range` - The range of values to generate.
19    /// * `device` - The device to create the tensor on.
20    pub fn arange(range: Range<i64>, options: impl Into<TensorCreationOptions<B>>) -> Self {
21        let opt = options.into();
22        let dtype = opt.resolve_dtype::<Int>();
23        Tensor::new(B::int_arange(range, &opt.device, dtype.into()))
24    }
25
26    /// Returns a new integer tensor on the specified device.
27    ///
28    /// # Arguments
29    ///
30    /// * `range` - The range of values to generate.
31    /// * `step` - The step between each value.
32    pub fn arange_step(
33        range: Range<i64>,
34        step: usize,
35        options: impl Into<TensorCreationOptions<B>>,
36    ) -> Self {
37        let opt = options.into();
38        let dtype = opt.resolve_dtype::<Int>();
39        Tensor::new(B::int_arange_step(range, step, &opt.device, dtype.into()))
40    }
41}
42
43impl<const D: usize, B> Tensor<B, D, Int>
44where
45    B: Backend,
46{
47    /// Create a tensor from integers (i32), placing it on a given device.
48    ///
49    /// # Example
50    ///
51    /// ```rust
52    /// use burn_tensor::backend::Backend;
53    /// use burn_tensor::{Tensor, Int};
54    ///
55    /// fn example<B: Backend>() {
56    ///     let device = B::Device::default();
57    ///     let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);
58    ///     let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);
59    /// }
60    /// ```
61    pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {
62        Self::from_data(ints.into().convert::<i32>(), device)
63    }
64
65    /// Returns a new tensor with the same shape and device as the current tensor and the data
66    /// cast to Float.
67    ///
68    /// # Example
69    ///
70    /// ```rust
71    /// use burn_tensor::backend::Backend;
72    /// use burn_tensor::{Int, Tensor};
73    ///
74    /// fn example<B: Backend>() {
75    ///     let device = Default::default();
76    ///     let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
77    ///     let float_tensor = int_tensor.float();
78    /// }
79    /// ```
80    pub fn float(self) -> Tensor<B, D, Float> {
81        let out_dtype = get_device_settings::<B>(&self.device()).float_dtype;
82        Tensor::new(TensorPrimitive::Float(B::int_into_float(
83            self.primitive,
84            out_dtype,
85        )))
86    }
87
88    /// Generates a cartesian grid for the given tensor shape on the specified device.
89    /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.
90    ///
91    /// # Arguments
92    ///
93    /// * `shape` - The shape specifying the dimensions of the tensor.
94    /// * `device` - The device to create the tensor on.
95    ///
96    /// # Panics
97    ///
98    /// Panics if `D2` is not equal to `D+1`.
99    ///
100    /// # Examples
101    ///
102    /// ```rust
103    ///    use burn_tensor::Int;
104    ///    use burn_tensor::{backend::Backend, Shape, Tensor};
105    ///    fn example<B: Backend>() {
106    ///        let device = Default::default();
107    ///        let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);
108    ///        println!("{}", result);
109    ///    }
110    /// ```
111    pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(
112        shape: S,
113        device: &B::Device,
114    ) -> Tensor<B, D2, Int> {
115        cartesian_grid::<B, S, D, D2>(shape, device)
116    }
117
118    /// Applies the bitwise logical and operation with each bit representing the integer.
119    pub fn bitwise_and(self, other: Self) -> Self {
120        Self::new(B::bitwise_and(self.primitive, other.primitive))
121    }
122
123    /// Applies the bitwise logical or operation with another tensor.
124    pub fn bitwise_or(self, other: Self) -> Self {
125        Self::new(B::bitwise_or(self.primitive, other.primitive))
126    }
127
128    /// Applies the bitwise logical xor operation with another tensor.
129    pub fn bitwise_xor(self, other: Self) -> Self {
130        Self::new(B::bitwise_xor(self.primitive, other.primitive))
131    }
132
133    /// Applies the bitwise logical not operation.
134    pub fn bitwise_not(self) -> Self {
135        Self::new(B::bitwise_not(self.primitive))
136    }
137
138    /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.
139    pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {
140        let other = Scalar::new(other, &self.dtype());
141        Self::new(B::bitwise_and_scalar(self.primitive, other))
142    }
143
144    /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.
145    pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {
146        let other = Scalar::new(other, &self.dtype());
147        Self::new(B::bitwise_or_scalar(self.primitive, other))
148    }
149
150    /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.
151    pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {
152        let other = Scalar::new(other, &self.dtype());
153        Self::new(B::bitwise_xor_scalar(self.primitive, other))
154    }
155
156    /// Applies the bitwise left shift operation with the integers in the tensor.
157    pub fn bitwise_left_shift(self, other: Self) -> Self {
158        Self::new(B::bitwise_left_shift(self.primitive, other.primitive))
159    }
160
161    /// Applies the bitwise right shift operation with the integers in the tensor.
162    pub fn bitwise_right_shift(self, other: Self) -> Self {
163        Self::new(B::bitwise_right_shift(self.primitive, other.primitive))
164    }
165
166    /// Applies the bitwise left shift operation with the scalar.
167    pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {
168        let other = Scalar::new(other, &self.dtype());
169        Self::new(B::bitwise_left_shift_scalar(self.primitive, other))
170    }
171
172    /// Applies the bitwise right shift operation with the scalar.
173    pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {
174        let other = Scalar::new(other, &self.dtype());
175        Self::new(B::bitwise_right_shift_scalar(self.primitive, other))
176    }
177
178    /// Converts a tensor to the specified data type.
179    ///
180    /// Supports both within-kind casting (e.g., `IntDType::I64`) and cross-kind casting
181    /// (e.g., `FloatDType::F32` to produce a float tensor).
182    ///
183    /// This is a no-op when casting to the current dtype within the same kind.
184    ///
185    /// # Example
186    ///
187    /// ```rust
188    /// use burn_tensor::backend::Backend;
189    /// use burn_tensor::{Tensor, Int, IntDType, FloatDType};
190    ///
191    /// fn example<B: Backend>() {
192    ///     let device = Default::default();
193    ///     let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);
194    ///
195    ///     // Within-kind cast (int to int)
196    ///     let i64_tensor = int_tensor.clone().cast(IntDType::I64);
197    ///
198    ///     // Cross-kind cast (int to float)
199    ///     let float_tensor = int_tensor.cast(FloatDType::F32);
200    /// }
201    /// ```
202    #[must_use]
203    pub fn cast<T: Cast<B, Int>>(self, dtype: T) -> Tensor<B, D, T::OutputKind> {
204        Tensor::new(T::cast(self.primitive, dtype))
205    }
206}