Skip to main content

axonml_tensor/
creation.rs

1//! Tensor factory functions.
2//!
3//! 372 lines. `zeros`, `ones`, `full`, `zeros_like`, `ones_like`, `full_like`,
4//! `eye`, `diag`, `rand` (uniform [0,1)), `randn` (standard normal), `uniform`
5//! (range), `normal` (parameterised), `randint`, `arange` (start/end/step),
6//! `linspace` (inclusive endpoints), `logspace`. All return CPU `Tensor<T>`;
7//! use `.to_device()` afterward for GPU.
8//!
9//! # File
10//! `crates/axonml-tensor/src/creation.rs`
11//!
12//! # Author
13//! Andrew Jewell Sr. — AutomataNexus LLC
14//! ORCID: 0009-0005-2158-7060
15//!
16//! # Updated
17//! April 14, 2026 11:15 PM EST
18//!
19//! # Disclaimer
20//! Use at own risk. This software is provided "as is", without warranty of any
21//! kind, express or implied. The author and AutomataNexus shall not be held
22//! liable for any damages arising from the use of this software.
23
24use rand::Rng;
25use rand::distributions::{Distribution, Standard};
26use rand_distr::{Normal, StandardNormal, Uniform};
27
28use axonml_core::dtype::{Float, Numeric, Scalar};
29
30use crate::tensor::Tensor;
31
32// =============================================================================
33// Zero and One Initialization
34// =============================================================================
35
36/// Creates a tensor filled with zeros.
37///
38/// # Arguments
39/// * `shape` - Shape of the tensor
40///
41/// # Example
42/// ```rust,ignore
43/// use axonml_tensor::zeros;
44/// let t = zeros::<f32>(&[2, 3]);
45/// ```
46#[must_use]
47pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
48    let numel: usize = shape.iter().product();
49    let data = vec![T::zeroed(); numel];
50    Tensor::from_vec(data, shape).expect("tensor creation failed")
51}
52
53/// Creates a tensor filled with ones.
54///
55/// # Arguments
56/// * `shape` - Shape of the tensor
57#[must_use]
58pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
59    full(shape, T::one())
60}
61
62/// Creates a tensor filled with a specific value.
63///
64/// # Arguments
65/// * `shape` - Shape of the tensor
66/// * `value` - Fill value
67pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
68    let numel: usize = shape.iter().product();
69    let data = vec![value; numel];
70    Tensor::from_vec(data, shape).expect("tensor creation failed")
71}
72
73/// Creates a tensor with the same shape as another, filled with zeros.
74#[must_use]
75pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
76    zeros(other.shape())
77}
78
79/// Creates a tensor with the same shape as another, filled with ones.
80#[must_use]
81pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
82    ones(other.shape())
83}
84
85/// Creates a tensor with the same shape as another, filled with a value.
86pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
87    full(other.shape(), value)
88}
89
90// =============================================================================
91// Identity and Diagonal
92// =============================================================================
93
94/// Creates a 2D identity matrix.
95///
96/// # Arguments
97/// * `n` - Size of the matrix (n x n)
98#[must_use]
99pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
100    let mut data = vec![T::zero(); n * n];
101    for i in 0..n {
102        data[i * n + i] = T::one();
103    }
104    Tensor::from_vec(data, &[n, n]).expect("tensor creation failed")
105}
106
107/// Creates a 2D tensor with the given diagonal values.
108///
109/// # Arguments
110/// * `diag` - Values for the diagonal
111pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
112    let n = diag.len();
113    let mut data = vec![T::zero(); n * n];
114    for (i, &val) in diag.iter().enumerate() {
115        data[i * n + i] = val;
116    }
117    Tensor::from_vec(data, &[n, n]).expect("tensor creation failed")
118}
119
120// =============================================================================
121// Random Initialization
122// =============================================================================
123
124/// Creates a tensor with uniformly distributed random values in [0, 1).
125///
126/// # Arguments
127/// * `shape` - Shape of the tensor
128#[must_use]
129pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
130where
131    Standard: Distribution<T>,
132{
133    let numel: usize = shape.iter().product();
134    let mut rng = rand::thread_rng();
135    let data: Vec<T> = (0..numel).map(|_| rng.r#gen()).collect();
136    Tensor::from_vec(data, shape).expect("tensor creation failed")
137}
138
139/// Creates a tensor with normally distributed random values (mean=0, std=1).
140///
141/// # Arguments
142/// * `shape` - Shape of the tensor
143#[must_use]
144pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
145where
146    StandardNormal: Distribution<T>,
147{
148    let numel: usize = shape.iter().product();
149    let mut rng = rand::thread_rng();
150    let normal = StandardNormal;
151    let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
152    Tensor::from_vec(data, shape).expect("tensor creation failed")
153}
154
155/// Creates a tensor with uniformly distributed random values in [low, high).
156///
157/// # Arguments
158/// * `shape` - Shape of the tensor
159/// * `low` - Lower bound (inclusive)
160/// * `high` - Upper bound (exclusive)
161pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
162where
163    T: rand::distributions::uniform::SampleUniform,
164{
165    let numel: usize = shape.iter().product();
166    let mut rng = rand::thread_rng();
167    let dist = Uniform::new(low, high);
168    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
169    Tensor::from_vec(data, shape).expect("tensor creation failed")
170}
171
172/// Creates a tensor with normally distributed random values.
173///
174/// # Arguments
175/// * `shape` - Shape of the tensor
176/// * `mean` - Mean of the distribution
177/// * `std` - Standard deviation of the distribution
178pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
179where
180    T: rand::distributions::uniform::SampleUniform,
181    StandardNormal: Distribution<T>,
182{
183    let numel: usize = shape.iter().product();
184    let mut rng = rand::thread_rng();
185    let dist = Normal::new(mean, std).unwrap();
186    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
187    Tensor::from_vec(data, shape).expect("tensor creation failed")
188}
189
190/// Creates a tensor with random integers in [low, high).
191///
192/// # Arguments
193/// * `shape` - Shape of the tensor
194/// * `low` - Lower bound (inclusive)
195/// * `high` - Upper bound (exclusive)
196#[must_use]
197pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
198where
199    T: num_traits::NumCast,
200{
201    let numel: usize = shape.iter().product();
202    let mut rng = rand::thread_rng();
203    let dist = Uniform::new(low, high);
204    let data: Vec<T> = (0..numel)
205        .map(|_| T::from(dist.sample(&mut rng)).unwrap())
206        .collect();
207    Tensor::from_vec(data, shape).expect("tensor creation failed")
208}
209
210// =============================================================================
211// Range Functions
212// =============================================================================
213
214/// Creates a 1D tensor with values from start to end (exclusive) with step.
215///
216/// # Arguments
217/// * `start` - Start value
218/// * `end` - End value (exclusive)
219/// * `step` - Step size
220pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
221where
222    T: num_traits::NumCast + PartialOrd,
223{
224    let mut data = Vec::new();
225    let mut current = start;
226
227    if step > T::zero() {
228        while current < end {
229            data.push(current);
230            current = current + step;
231        }
232    } else if step < T::zero() {
233        while current > end {
234            data.push(current);
235            current = current + step;
236        }
237    }
238
239    let len = data.len();
240    Tensor::from_vec(data, &[len]).expect("tensor creation failed")
241}
242
243/// Creates a 1D tensor with `num` evenly spaced values from start to end.
244///
245/// # Arguments
246/// * `start` - Start value
247/// * `end` - End value (inclusive)
248/// * `num` - Number of values
249pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
250    if num == 0 {
251        return Tensor::from_vec(vec![], &[0]).expect("tensor creation failed");
252    }
253
254    if num == 1 {
255        return Tensor::from_vec(vec![start], &[1]).expect("tensor creation failed");
256    }
257
258    let step = (end - start) / T::from(num - 1).unwrap();
259    let data: Vec<T> = (0..num)
260        .map(|i| start + step * T::from(i).unwrap())
261        .collect();
262
263    Tensor::from_vec(data, &[num]).expect("tensor creation failed")
264}
265
266/// Creates a 1D tensor with `num` logarithmically spaced values.
267///
268/// # Arguments
269/// * `start` - Start exponent (base^start)
270/// * `end` - End exponent (base^end)
271/// * `num` - Number of values
272/// * `base` - Base of the logarithm
273pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
274    if num == 0 {
275        return Tensor::from_vec(vec![], &[0]).expect("tensor creation failed");
276    }
277
278    let lin = linspace(start, end, num);
279    let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
280
281    Tensor::from_vec(data, &[num]).expect("tensor creation failed")
282}
283
284// =============================================================================
285// Empty Tensor
286// =============================================================================
287
288/// Creates an uninitialized tensor (values are undefined).
289///
290/// # Safety
291/// The tensor contents are uninitialized. Reading before writing is undefined.
292///
293/// # Arguments
294/// * `shape` - Shape of the tensor
295#[must_use]
296pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
297    zeros(shape)
298}
299
300// =============================================================================
301// Tests
302// =============================================================================
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_zeros() {
310        let t = zeros::<f32>(&[2, 3]);
311        assert_eq!(t.shape(), &[2, 3]);
312        assert_eq!(t.numel(), 6);
313        for val in t.to_vec() {
314            assert_eq!(val, 0.0);
315        }
316    }
317
318    #[test]
319    fn test_ones() {
320        let t = ones::<f32>(&[2, 3]);
321        for val in t.to_vec() {
322            assert_eq!(val, 1.0);
323        }
324    }
325
326    #[test]
327    fn test_full() {
328        let t = full::<f32>(&[2, 3], 42.0);
329        for val in t.to_vec() {
330            assert_eq!(val, 42.0);
331        }
332    }
333
334    #[test]
335    fn test_eye() {
336        let t = eye::<f32>(3);
337        assert_eq!(t.shape(), &[3, 3]);
338        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
339        assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
340        assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
341        assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
342    }
343
344    #[test]
345    fn test_rand() {
346        let t = rand::<f32>(&[100]);
347        for val in t.to_vec() {
348            assert!((0.0..1.0).contains(&val));
349        }
350    }
351
352    #[test]
353    fn test_arange() {
354        let t = arange::<f32>(0.0, 5.0, 1.0);
355        assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
356
357        let t = arange::<f32>(0.0, 1.0, 0.2);
358        assert_eq!(t.numel(), 5);
359    }
360
361    #[test]
362    fn test_linspace() {
363        let t = linspace::<f32>(0.0, 1.0, 5);
364        let data = t.to_vec();
365        assert_eq!(data.len(), 5);
366        assert!((data[0] - 0.0).abs() < 1e-6);
367        assert!((data[4] - 1.0).abs() < 1e-6);
368    }
369
370    #[test]
371    fn test_zeros_like() {
372        let a = ones::<f32>(&[2, 3]);
373        let b = zeros_like(&a);
374        assert_eq!(b.shape(), &[2, 3]);
375        for val in b.to_vec() {
376            assert_eq!(val, 0.0);
377        }
378    }
379}