hpt_traits/ops/
creation.rs

1use hpt_common::{error::base::TensorError, shape::shape::Shape};
2use hpt_types::arch_simd as simd;
3use hpt_types::type_promote::FloatOutBinary;
4use hpt_types::{dtype::TypeCommon, into_scalar::Cast, type_promote::NormalOut};
5#[cfg(target_feature = "avx2")]
6type BoolVector = simd::_256bit::boolx32;
7#[cfg(any(
8    all(not(target_feature = "avx2"), target_feature = "sse"),
9    target_arch = "arm",
10    target_arch = "aarch64",
11    target_feature = "neon"
12))]
13type BoolVector = simd::_128bit::boolx16;
14
15/// A trait defines a set of functions to create tensors.
16pub trait TensorCreator
17where
18    Self: Sized,
19{
20    /// the output type of the tensor
21    type Output;
22    /// the meta type of the tensor
23    type Meta;
24    /// Creates a new uninitialized tensor with the specified shape. The tensor's values will be whatever was in memory at the time of allocation.
25    ///
26    /// ## Parameters:
27    /// `shape`: The desired shape for the tensor.
28    ///
29    /// ## Example:
30    /// ```rust
31    /// let a = Tensor::<f32>::empty(&[2, 3])?; // Shape: [2, 3]
32    /// ```
33    #[track_caller]
34    fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
35
36    /// Creates a new tensor of the specified shape, filled with zeros.
37    ///
38    /// ## Parameters:
39    /// `shape`: The desired shape for the tensor.
40    ///
41    /// ## Example:
42    /// ```rust
43    /// let a = Tensor::<f32>::zeros(&[2, 3])?; // Shape: [2, 3]
44    /// ```
45    #[track_caller]
46    fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
47
48    /// Creates a new tensor of the specified shape, filled with ones.
49    ///
50    /// ## Parameters:
51    /// `shape`: The desired shape for the tensor.
52    ///
53    /// ## Example:
54    /// ```rust
55    /// let a = Tensor::<f32>::ones(&[2, 3])?; // Shape: [2, 3]
56    /// ```
57    #[track_caller]
58    fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
59    where
60        u8: Cast<Self::Meta>;
61
62    /// Creates a new uninitialized tensor with the same shape as the input tensor.
63    ///
64    /// ## Example:
65    /// ```rust
66    /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0]).reshape(&[3, 1])?;
67    /// let b = a.empty_like()?; // Shape: [3, 1]
68    /// ```
69    #[track_caller]
70    fn empty_like(&self) -> Result<Self::Output, TensorError>;
71
72    /// Creates a new zeroed tensor with the same shape as the input tensor.
73    ///
74    /// ## Example:
75    /// ```rust
76    /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0]).reshape(&[3, 1])?;
77    /// let b = a.zeros_like()?; // Shape: [3, 1]
78    /// ```
79    #[track_caller]
80    fn zeros_like(&self) -> Result<Self::Output, TensorError>;
81
82    /// Creates a new tensor with all ones with the same shape as the input tensor.
83    ///
84    /// ## Example:
85    /// ```rust
86    /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0]).reshape(&[3, 1])?;
87    /// let b = a.ones_like()?; // Shape: [3, 1]
88    /// ```
89    #[track_caller]
90    fn ones_like(&self) -> Result<Self::Output, TensorError>
91    where
92        u8: Cast<Self::Meta>;
93
94    /// Creates a new tensor of the specified shape, filled with a specified value.
95    ///
96    /// ## Parameters:
97    /// `val`: The value to fill the tensor with.
98    ///
99    /// `shape`: The desired shape for the tensor.
100    ///
101    /// ## Example:
102    /// ```rust
103    /// let a = Tensor::<f32>::full(5.0, &[2, 3])?;
104    /// // [[5, 5, 5],
105    /// //  [5, 5, 5]]
106    /// ```
107    #[track_caller]
108    fn full<S: Into<Shape>>(val: Self::Meta, shape: S) -> Result<Self::Output, TensorError>;
109
110    /// Creates a new tensor filled with a specified value with the same shape as the input tensor.
111    ///
112    /// ## Parameters:
113    /// `val`: The value to fill the tensor with.
114    ///
115    /// ## Example:
116    /// ```rust
117    /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2])?;
118    /// let b = a.full_like(7.0)?;
119    /// // [[7, 7],
120    /// //  [7, 7]]
121    /// ```
122    #[track_caller]
123    fn full_like(&self, val: Self::Meta) -> Result<Self::Output, TensorError>;
124
125    /// Creates a 1-D tensor with evenly spaced values within a given interval `[start, end)`.
126    ///
127    /// ## Parameters:
128    /// `start`: Start of interval (inclusive)
129    ///
130    /// `end`: End of interval (exclusive)
131    ///
132    /// ## Example:
133    /// ```rust
134    /// let a = Tensor::<f32>::arange(0, 5)?; // [0, 1, 2, 3, 4]
135    /// let b = Tensor::<f32>::arange(1.5, 5.5)?; // [1.5, 2.5, 3.5, 4.5]
136    /// ```
137    #[track_caller]
138    fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
139    where
140        usize: Cast<Self::Meta>,
141        U: Cast<i64> + Cast<Self::Meta> + Copy;
142
143    /// Creates a 1-D tensor with evenly spaced values within a given interval `[start, end)` with a specified step size.
144    ///
145    /// ## Parameters:
146    /// `start`: Start of interval (inclusive)
147    ///
148    /// `end`: End of interval (exclusive)
149    ///
150    /// `step`: Size of spacing between values
151    ///
152    /// ## Example:
153    /// ```rust
154    /// let a = Tensor::<f32>::arange_step(0.0, 5.0, 2.0)?; // [0, 2, 4]
155    /// let b = Tensor::<f32>::arange_step(5.0, 0.0, -1.5)?; // [5, 3.5, 2, 0.5]
156    /// ```
157    #[track_caller]
158    fn arange_step(
159        start: Self::Meta,
160        end: Self::Meta,
161        step: Self::Meta,
162    ) -> Result<Self::Output, TensorError>
163    where
164        Self::Meta: Cast<f64> + Cast<f64>,
165        f64: Cast<Self::Meta>,
166        usize: Cast<Self::Meta>;
167
168    /// Creates a 2-D tensor with ones on the k-th diagonal and zeros elsewhere.
169    ///
170    /// ## Parameters:
171    /// `n`: Number of rows
172    ///
173    /// `m`: Number of columns
174    ///
175    /// `k`: Index of the diagonal (0 represents the main diagonal, positive values are above the main diagonal)
176    ///
177    /// ## Example:
178    /// ```rust
179    /// let a = Tensor::<f32>::eye(3, 4, 0)?;
180    /// // [[1, 0, 0, 0],
181    /// //  [0, 1, 0, 0],
182    /// //  [0, 0, 1, 0]]
183    /// let b = Tensor::<f32>::eye(3, 4, 1)?;
184    /// // [[0, 1, 0, 0],
185    /// //  [0, 0, 1, 0],
186    /// //  [0, 0, 0, 1]]
187    /// ```
188    #[track_caller]
189    fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError>;
190
191    /// Creates a 1-D tensor of `num` evenly spaced values between `start` and `end`.
192    ///
193    /// ## Parameters:
194    /// `start`: The starting value of the sequence
195    ///
196    /// `end`: The end value of the sequence
197    ///
198    /// `num`: Number of samples to generate
199    ///
200    /// `include_end`: Whether to include the end value in the sequence
201    ///
202    /// ## Example:
203    /// ```rust
204    /// let a = Tensor::<f32>::linspace(0.0, 1.0, 5, true)?;
205    /// // [0.0, 0.25, 0.5, 0.75, 1.0]
206    /// let b = Tensor::<f32>::linspace(0.0, 1.0, 5, false)?;
207    /// // [0.0, 0.2, 0.4, 0.6, 0.8]
208    /// let c = Tensor::<f32>::linspace(0, 10, 6, true)?;
209    /// // [0.0, 2.0, 4.0, 6.0, 8.0, 10.0]
210    /// ```
211    #[track_caller]
212    fn linspace<U>(
213        start: U,
214        end: U,
215        num: usize,
216        include_end: bool,
217    ) -> Result<Self::Output, TensorError>
218    where
219        U: Cast<f64> + Cast<Self::Meta> + Copy,
220        usize: Cast<Self::Meta>,
221        f64: Cast<Self::Meta>;
222
223    /// Creates a 1-D tensor with `num` numbers logarithmically spaced between `base^start` and `base^end`.
224    ///
225    /// ## Parameters:
226    /// `start`: The starting value of the sequence (power of base)
227    ///
228    /// `end`: The end value of the sequence (power of base)
229    ///
230    /// `num`: Number of samples to generate
231    ///
232    /// `include_end`: Whether to include the end value in the sequence
233    ///
234    /// `base`: The base of the log space (default is 10.0)
235    ///
236    /// ## Example:
237    /// ```rust
238    /// let a = Tensor::<f32>::logspace(0.0, 3.0, 4, true, 10.0)?;
239    /// // [1.0, 10.0, 100.0, 1000.0]
240    /// let b = Tensor::<f32>::logspace(0.0, 3.0, 4, true, 2.0)?;
241    /// // [1.0, 2.0, 4.0, 8.0]
242    /// let c = Tensor::<f32>::logspace(0.0, 2.0, 4, false, 10.0)?;
243    /// // [1.0, 3.1623, 10.0, 31.6228]
244    /// ```
245    #[track_caller]
246    fn logspace<V: Cast<Self::Meta>>(
247        start: V,
248        end: V,
249        num: usize,
250        include_end: bool,
251        base: V,
252    ) -> Result<Self::Output, TensorError>
253    where
254        Self::Meta: Cast<f64> + num::Float + FloatOutBinary<Self::Meta, Output = Self::Meta>,
255        usize: Cast<Self::Meta>,
256        f64: Cast<Self::Meta>;
257
258    /// Creates a 1-D tensor with `n` numbers geometrically spaced between `start` and `end`.
259    ///
260    /// ## Parameters:
261    /// `start`: The starting value of the sequence
262    ///
263    /// `end`: The end value of the sequence
264    ///
265    /// `num`: Number of samples to generate
266    ///
267    /// `include_end`: Whether to include the end value in the sequence
268    ///
269    /// ## Example:
270    /// ```rust
271    /// let a = Tensor::<f32>::geomspace(1.0, 1000.0, 4, true)?;
272    /// // [1.0, 10.0, 100.0, 1000.0]
273    /// let b = Tensor::<f32>::geomspace(1.0, 100.0, 3, false)?;
274    /// // [1.0, 4.6416, 21.5443]
275    /// let c = Tensor::<f32>::geomspace(1.0, 32.0, 5, true)?;
276    /// // [1.0, 2.3784, 5.6569, 13.4543, 32.0000]
277    /// ```
278    #[track_caller]
279    fn geomspace<V: Cast<Self::Meta>>(
280        start: V,
281        end: V,
282        n: usize,
283        include_end: bool,
284    ) -> Result<Self::Output, TensorError>
285    where
286        f64: Cast<Self::Meta>,
287        usize: Cast<Self::Meta>,
288        Self::Meta: Cast<f64> + FloatOutBinary<Self::Meta, Output = Self::Meta>;
289
290    /// Creates a tensor with ones at and below (or above) the k-th diagonal.
291    ///
292    /// ## Parameters:
293    /// `n`: Number of rows
294    ///
295    /// `m`: Number of columns
296    ///
297    /// `k`: The diagonal above or below which to fill with ones (0 represents the main diagonal)
298    ///
299    /// `low_triangle`: If true, fill with ones below and on the k-th diagonal; if false, fill with ones above the k-th diagonal
300    ///
301    /// ## Example:
302    /// ```rust
303    /// let a = Tensor::<f32>::tri(3, 3, 0, true)?;
304    /// // [[1, 0, 0],
305    /// //  [1, 1, 0],
306    /// //  [1, 1, 1]]
307    /// let b = Tensor::<f32>::tri(3, 3, 0, false)?;
308    /// // [[1, 1, 1],
309    /// //  [0, 1, 1],
310    /// //  [0, 0, 1]]
311    /// let c = Tensor::<f32>::tri(3, 4, 1, true)?;
312    /// // [[1, 1, 0, 0],
313    /// //  [1, 1, 1, 0],
314    /// //  [1, 1, 1, 1]]
315    /// ```
316    #[track_caller]
317    fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
318    where
319        u8: Cast<Self::Meta>;
320
321    /// Returns a copy of the tensor with elements above the k-th diagonal zeroed.
322    ///
323    /// ## Parameters:
324    /// `k`: Diagonal above which to zero elements. k=0 is the main diagonal, k>0 is above and k<0 is below the main diagonal
325    ///
326    /// ## Example:
327    /// ```rust
328    /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).reshape(&[3, 3])?;
329    /// let b = a.tril(0)?;
330    /// // [[1, 0, 0],
331    /// //  [4, 5, 0],
332    /// //  [7, 8, 9]]
333    /// let c = a.tril(-1)?;
334    /// // [[0, 0, 0],
335    /// //  [4, 0, 0],
336    /// //  [7, 8, 0]]
337    /// ```
338    #[track_caller]
339    fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
340    where
341        Self::Meta: NormalOut<bool, Output = Self::Meta> + Cast<Self::Meta> + TypeCommon,
342        <Self::Meta as TypeCommon>::Vec:
343            NormalOut<BoolVector, Output = <Self::Meta as TypeCommon>::Vec>;
344
345    /// Returns a copy of the tensor with elements below the k-th diagonal zeroed.
346    ///
347    /// ## Parameters:
348    /// `k`: Diagonal below which to zero elements. k=0 is the main diagonal, k>0 is above and k<0 is below the main diagonal
349    ///
350    /// ## Example:
351    /// ```rust
352    /// let a = Tensor::<f32>::new(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).reshape(&[3, 3])?;
353    /// let b = a.triu(0)?;
354    /// // [[1, 2, 3],
355    /// //  [0, 5, 6],
356    /// //  [0, 0, 9]]
357    /// let c = a.triu(1)?;
358    /// // [[0, 2, 3],
359    /// //  [0, 0, 6],
360    /// //  [0, 0, 0]]
361    /// ```
362    #[track_caller]
363    fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
364    where
365        Self::Meta: NormalOut<bool, Output = Self::Meta> + Cast<Self::Meta> + TypeCommon,
366        <Self::Meta as TypeCommon>::Vec:
367            NormalOut<BoolVector, Output = <Self::Meta as TypeCommon>::Vec>;
368
369    /// Creates a 2-D identity tensor (1's on the main diagonal and 0's elsewhere).
370    ///
371    /// ## Parameters:
372    /// `n`: Number of rows and columns
373    ///
374    /// ## Example:
375    /// ```rust
376    /// let a = Tensor::<f32>::identity(3)?;
377    /// // [[1, 0, 0],
378    /// //  [0, 1, 0],
379    /// //  [0, 0, 1]]
380    /// ```
381    #[track_caller]
382    fn identity(n: usize) -> Result<Self::Output, TensorError>
383    where
384        u8: Cast<Self::Meta>;
385}