hpt_traits/ops/creation.rs
1use hpt_common::{error::base::TensorError, shape::shape::Shape};
2use hpt_types::arch_simd as simd;
3use hpt_types::type_promote::FloatOutBinary;
4use hpt_types::{dtype::TypeCommon, into_scalar::Cast, type_promote::NormalOut};
5#[cfg(target_feature = "avx2")]
6type BoolVector = simd::_256bit::boolx32;
7#[cfg(any(
8 all(not(target_feature = "avx2"), target_feature = "sse"),
9 target_arch = "arm",
10 target_arch = "aarch64",
11 target_feature = "neon"
12))]
13type BoolVector = simd::_128bit::boolx16;
14
15/// A trait defines a set of functions to create tensors.
16pub trait TensorCreator
17where
18 Self: Sized,
19{
20 /// the output type of the tensor
21 type Output;
22 /// the meta type of the tensor
23 type Meta;
24 /// Creates a new uninitialized tensor with the specified shape. The tensor's values will be whatever was in memory at the time of allocation.
25 ///
26 /// ## Parameters:
27 /// `shape`: The desired shape for the tensor.
28 ///
29 /// ## Example:
30 /// ```rust
31 /// let a = Tensor::<f32>::empty(&[2, 3])?; // Shape: [2, 3]
32 /// ```
33 #[track_caller]
34 fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
35
36 /// Creates a new tensor of the specified shape, filled with zeros.
37 ///
38 /// ## Parameters:
39 /// `shape`: The desired shape for the tensor.
40 ///
41 /// ## Example:
42 /// ```rust
43 /// let a = Tensor::<f32>::zeros(&[2, 3])?; // Shape: [2, 3]
44 /// ```
45 #[track_caller]
46 fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
47
48 /// Creates a new tensor of the specified shape, filled with ones.
49 ///
50 /// ## Parameters:
51 /// `shape`: The desired shape for the tensor.
52 ///
53 /// ## Example:
54 /// ```rust
55 /// let a = Tensor::<f32>::ones(&[2, 3])?; // Shape: [2, 3]
56 /// ```
57 #[track_caller]
58 fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
59 where
60 u8: Cast<Self::Meta>;
61
62 /// Creates a new uninitialized tensor with the same shape as the input tensor.
63 ///
64 /// ## Example:
65 /// ```rust
66 /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0]).reshape(&[3, 1])?;
67 /// let b = a.empty_like()?; // Shape: [3, 1]
68 /// ```
69 #[track_caller]
70 fn empty_like(&self) -> Result<Self::Output, TensorError>;
71
72 /// Creates a new zeroed tensor with the same shape as the input tensor.
73 ///
74 /// ## Example:
75 /// ```rust
76 /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0]).reshape(&[3, 1])?;
77 /// let b = a.zeros_like()?; // Shape: [3, 1]
78 /// ```
79 #[track_caller]
80 fn zeros_like(&self) -> Result<Self::Output, TensorError>;
81
82 /// Creates a new tensor with all ones with the same shape as the input tensor.
83 ///
84 /// ## Example:
85 /// ```rust
86 /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0]).reshape(&[3, 1])?;
87 /// let b = a.ones_like()?; // Shape: [3, 1]
88 /// ```
89 #[track_caller]
90 fn ones_like(&self) -> Result<Self::Output, TensorError>
91 where
92 u8: Cast<Self::Meta>;
93
94 /// Creates a new tensor of the specified shape, filled with a specified value.
95 ///
96 /// ## Parameters:
97 /// `val`: The value to fill the tensor with.
98 ///
99 /// `shape`: The desired shape for the tensor.
100 ///
101 /// ## Example:
102 /// ```rust
103 /// let a = Tensor::<f32>::full(5.0, &[2, 3])?;
104 /// // [[5, 5, 5],
105 /// // [5, 5, 5]]
106 /// ```
107 #[track_caller]
108 fn full<S: Into<Shape>>(val: Self::Meta, shape: S) -> Result<Self::Output, TensorError>;
109
110 /// Creates a new tensor filled with a specified value with the same shape as the input tensor.
111 ///
112 /// ## Parameters:
113 /// `val`: The value to fill the tensor with.
114 ///
115 /// ## Example:
116 /// ```rust
117 /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2])?;
118 /// let b = a.full_like(7.0)?;
119 /// // [[7, 7],
120 /// // [7, 7]]
121 /// ```
122 #[track_caller]
123 fn full_like(&self, val: Self::Meta) -> Result<Self::Output, TensorError>;
124
125 /// Creates a 1-D tensor with evenly spaced values within a given interval `[start, end)`.
126 ///
127 /// ## Parameters:
128 /// `start`: Start of interval (inclusive)
129 ///
130 /// `end`: End of interval (exclusive)
131 ///
132 /// ## Example:
133 /// ```rust
134 /// let a = Tensor::<f32>::arange(0, 5)?; // [0, 1, 2, 3, 4]
135 /// let b = Tensor::<f32>::arange(1.5, 5.5)?; // [1.5, 2.5, 3.5, 4.5]
136 /// ```
137 #[track_caller]
138 fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
139 where
140 usize: Cast<Self::Meta>,
141 U: Cast<i64> + Cast<Self::Meta> + Copy;
142
143 /// Creates a 1-D tensor with evenly spaced values within a given interval `[start, end)` with a specified step size.
144 ///
145 /// ## Parameters:
146 /// `start`: Start of interval (inclusive)
147 ///
148 /// `end`: End of interval (exclusive)
149 ///
150 /// `step`: Size of spacing between values
151 ///
152 /// ## Example:
153 /// ```rust
154 /// let a = Tensor::<f32>::arange_step(0.0, 5.0, 2.0)?; // [0, 2, 4]
155 /// let b = Tensor::<f32>::arange_step(5.0, 0.0, -1.5)?; // [5, 3.5, 2, 0.5]
156 /// ```
157 #[track_caller]
158 fn arange_step(
159 start: Self::Meta,
160 end: Self::Meta,
161 step: Self::Meta,
162 ) -> Result<Self::Output, TensorError>
163 where
164 Self::Meta: Cast<f64> + Cast<f64>,
165 f64: Cast<Self::Meta>,
166 usize: Cast<Self::Meta>;
167
168 /// Creates a 2-D tensor with ones on the k-th diagonal and zeros elsewhere.
169 ///
170 /// ## Parameters:
171 /// `n`: Number of rows
172 ///
173 /// `m`: Number of columns
174 ///
175 /// `k`: Index of the diagonal (0 represents the main diagonal, positive values are above the main diagonal)
176 ///
177 /// ## Example:
178 /// ```rust
179 /// let a = Tensor::<f32>::eye(3, 4, 0)?;
180 /// // [[1, 0, 0, 0],
181 /// // [0, 1, 0, 0],
182 /// // [0, 0, 1, 0]]
183 /// let b = Tensor::<f32>::eye(3, 4, 1)?;
184 /// // [[0, 1, 0, 0],
185 /// // [0, 0, 1, 0],
186 /// // [0, 0, 0, 1]]
187 /// ```
188 #[track_caller]
189 fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError>;
190
191 /// Creates a 1-D tensor of `num` evenly spaced values between `start` and `end`.
192 ///
193 /// ## Parameters:
194 /// `start`: The starting value of the sequence
195 ///
196 /// `end`: The end value of the sequence
197 ///
198 /// `num`: Number of samples to generate
199 ///
200 /// `include_end`: Whether to include the end value in the sequence
201 ///
202 /// ## Example:
203 /// ```rust
204 /// let a = Tensor::<f32>::linspace(0.0, 1.0, 5, true)?;
205 /// // [0.0, 0.25, 0.5, 0.75, 1.0]
206 /// let b = Tensor::<f32>::linspace(0.0, 1.0, 5, false)?;
207 /// // [0.0, 0.2, 0.4, 0.6, 0.8]
208 /// let c = Tensor::<f32>::linspace(0, 10, 6, true)?;
209 /// // [0.0, 2.0, 4.0, 6.0, 8.0, 10.0]
210 /// ```
211 #[track_caller]
212 fn linspace<U>(
213 start: U,
214 end: U,
215 num: usize,
216 include_end: bool,
217 ) -> Result<Self::Output, TensorError>
218 where
219 U: Cast<f64> + Cast<Self::Meta> + Copy,
220 usize: Cast<Self::Meta>,
221 f64: Cast<Self::Meta>;
222
223 /// Creates a 1-D tensor with `num` numbers logarithmically spaced between `base^start` and `base^end`.
224 ///
225 /// ## Parameters:
226 /// `start`: The starting value of the sequence (power of base)
227 ///
228 /// `end`: The end value of the sequence (power of base)
229 ///
230 /// `num`: Number of samples to generate
231 ///
232 /// `include_end`: Whether to include the end value in the sequence
233 ///
234 /// `base`: The base of the log space (default is 10.0)
235 ///
236 /// ## Example:
237 /// ```rust
238 /// let a = Tensor::<f32>::logspace(0.0, 3.0, 4, true, 10.0)?;
239 /// // [1.0, 10.0, 100.0, 1000.0]
240 /// let b = Tensor::<f32>::logspace(0.0, 3.0, 4, true, 2.0)?;
241 /// // [1.0, 2.0, 4.0, 8.0]
242 /// let c = Tensor::<f32>::logspace(0.0, 2.0, 4, false, 10.0)?;
243 /// // [1.0, 3.1623, 10.0, 31.6228]
244 /// ```
245 #[track_caller]
246 fn logspace<V: Cast<Self::Meta>>(
247 start: V,
248 end: V,
249 num: usize,
250 include_end: bool,
251 base: V,
252 ) -> Result<Self::Output, TensorError>
253 where
254 Self::Meta: Cast<f64> + num::Float + FloatOutBinary<Self::Meta, Output = Self::Meta>,
255 usize: Cast<Self::Meta>,
256 f64: Cast<Self::Meta>;
257
258 /// Creates a 1-D tensor with `n` numbers geometrically spaced between `start` and `end`.
259 ///
260 /// ## Parameters:
261 /// `start`: The starting value of the sequence
262 ///
263 /// `end`: The end value of the sequence
264 ///
265 /// `num`: Number of samples to generate
266 ///
267 /// `include_end`: Whether to include the end value in the sequence
268 ///
269 /// ## Example:
270 /// ```rust
271 /// let a = Tensor::<f32>::geomspace(1.0, 1000.0, 4, true)?;
272 /// // [1.0, 10.0, 100.0, 1000.0]
273 /// let b = Tensor::<f32>::geomspace(1.0, 100.0, 3, false)?;
274 /// // [1.0, 4.6416, 21.5443]
275 /// let c = Tensor::<f32>::geomspace(1.0, 32.0, 5, true)?;
276 /// // [1.0, 2.3784, 5.6569, 13.4543, 32.0000]
277 /// ```
278 #[track_caller]
279 fn geomspace<V: Cast<Self::Meta>>(
280 start: V,
281 end: V,
282 n: usize,
283 include_end: bool,
284 ) -> Result<Self::Output, TensorError>
285 where
286 f64: Cast<Self::Meta>,
287 usize: Cast<Self::Meta>,
288 Self::Meta: Cast<f64> + FloatOutBinary<Self::Meta, Output = Self::Meta>;
289
290 /// Creates a tensor with ones at and below (or above) the k-th diagonal.
291 ///
292 /// ## Parameters:
293 /// `n`: Number of rows
294 ///
295 /// `m`: Number of columns
296 ///
297 /// `k`: The diagonal above or below which to fill with ones (0 represents the main diagonal)
298 ///
299 /// `low_triangle`: If true, fill with ones below and on the k-th diagonal; if false, fill with ones above the k-th diagonal
300 ///
301 /// ## Example:
302 /// ```rust
303 /// let a = Tensor::<f32>::tri(3, 3, 0, true)?;
304 /// // [[1, 0, 0],
305 /// // [1, 1, 0],
306 /// // [1, 1, 1]]
307 /// let b = Tensor::<f32>::tri(3, 3, 0, false)?;
308 /// // [[1, 1, 1],
309 /// // [0, 1, 1],
310 /// // [0, 0, 1]]
311 /// let c = Tensor::<f32>::tri(3, 4, 1, true)?;
312 /// // [[1, 1, 0, 0],
313 /// // [1, 1, 1, 0],
314 /// // [1, 1, 1, 1]]
315 /// ```
316 #[track_caller]
317 fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
318 where
319 u8: Cast<Self::Meta>;
320
321 /// Returns a copy of the tensor with elements above the k-th diagonal zeroed.
322 ///
323 /// ## Parameters:
324 /// `k`: Diagonal above which to zero elements. k=0 is the main diagonal, k>0 is above and k<0 is below the main diagonal
325 ///
326 /// ## Example:
327 /// ```rust
328 /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).reshape(&[3, 3])?;
329 /// let b = a.tril(0)?;
330 /// // [[1, 0, 0],
331 /// // [4, 5, 0],
332 /// // [7, 8, 9]]
333 /// let c = a.tril(-1)?;
334 /// // [[0, 0, 0],
335 /// // [4, 0, 0],
336 /// // [7, 8, 0]]
337 /// ```
338 #[track_caller]
339 fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
340 where
341 Self::Meta: NormalOut<bool, Output = Self::Meta> + Cast<Self::Meta> + TypeCommon,
342 <Self::Meta as TypeCommon>::Vec:
343 NormalOut<BoolVector, Output = <Self::Meta as TypeCommon>::Vec>;
344
345 /// Returns a copy of the tensor with elements below the k-th diagonal zeroed.
346 ///
347 /// ## Parameters:
348 /// `k`: Diagonal below which to zero elements. k=0 is the main diagonal, k>0 is above and k<0 is below the main diagonal
349 ///
350 /// ## Example:
351 /// ```rust
352 /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).reshape(&[3, 3])?;
353 /// let b = a.triu(0)?;
354 /// // [[1, 2, 3],
355 /// // [0, 5, 6],
356 /// // [0, 0, 9]]
357 /// let c = a.triu(1)?;
358 /// // [[0, 2, 3],
359 /// // [0, 0, 6],
360 /// // [0, 0, 0]]
361 /// ```
362 #[track_caller]
363 fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
364 where
365 Self::Meta: NormalOut<bool, Output = Self::Meta> + Cast<Self::Meta> + TypeCommon,
366 <Self::Meta as TypeCommon>::Vec:
367 NormalOut<BoolVector, Output = <Self::Meta as TypeCommon>::Vec>;
368
369 /// Creates a 2-D identity tensor (1's on the main diagonal and 0's elsewhere).
370 ///
371 /// ## Parameters:
372 /// `n`: Number of rows and columns
373 ///
374 /// ## Example:
375 /// ```rust
376 /// let a = Tensor::<f32>::identity(3)?;
377 /// // [[1, 0, 0],
378 /// // [0, 1, 0],
379 /// // [0, 0, 1]]
380 /// ```
381 #[track_caller]
382 fn identity(n: usize) -> Result<Self::Output, TensorError>
383 where
384 u8: Cast<Self::Meta>;
385}