Skip to main content

axonml_tensor/
creation.rs

1//! Tensor Creation Functions
2//!
3//! # File
4//! `crates/axonml-tensor/src/creation.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use rand::Rng;
18use rand::distributions::{Distribution, Standard};
19use rand_distr::{Normal, StandardNormal, Uniform};
20
21use axonml_core::dtype::{Float, Numeric, Scalar};
22
23use crate::tensor::Tensor;
24
25// =============================================================================
26// Zero and One Initialization
27// =============================================================================
28
29/// Creates a tensor filled with zeros.
30///
31/// # Arguments
32/// * `shape` - Shape of the tensor
33///
34/// # Example
35/// ```rust,ignore
36/// use axonml_tensor::zeros;
37/// let t = zeros::<f32>(&[2, 3]);
38/// ```
39#[must_use]
40pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
41    let numel: usize = shape.iter().product();
42    let data = vec![T::zeroed(); numel];
43    Tensor::from_vec(data, shape).unwrap()
44}
45
46/// Creates a tensor filled with ones.
47///
48/// # Arguments
49/// * `shape` - Shape of the tensor
50#[must_use]
51pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
52    full(shape, T::one())
53}
54
55/// Creates a tensor filled with a specific value.
56///
57/// # Arguments
58/// * `shape` - Shape of the tensor
59/// * `value` - Fill value
60pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
61    let numel: usize = shape.iter().product();
62    let data = vec![value; numel];
63    Tensor::from_vec(data, shape).unwrap()
64}
65
66/// Creates a tensor with the same shape as another, filled with zeros.
67#[must_use]
68pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
69    zeros(other.shape())
70}
71
72/// Creates a tensor with the same shape as another, filled with ones.
73#[must_use]
74pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
75    ones(other.shape())
76}
77
78/// Creates a tensor with the same shape as another, filled with a value.
79pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
80    full(other.shape(), value)
81}
82
83// =============================================================================
84// Identity and Diagonal
85// =============================================================================
86
87/// Creates a 2D identity matrix.
88///
89/// # Arguments
90/// * `n` - Size of the matrix (n x n)
91#[must_use]
92pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
93    let mut data = vec![T::zero(); n * n];
94    for i in 0..n {
95        data[i * n + i] = T::one();
96    }
97    Tensor::from_vec(data, &[n, n]).unwrap()
98}
99
100/// Creates a 2D tensor with the given diagonal values.
101///
102/// # Arguments
103/// * `diag` - Values for the diagonal
104pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
105    let n = diag.len();
106    let mut data = vec![T::zero(); n * n];
107    for (i, &val) in diag.iter().enumerate() {
108        data[i * n + i] = val;
109    }
110    Tensor::from_vec(data, &[n, n]).unwrap()
111}
112
113// =============================================================================
114// Random Initialization
115// =============================================================================
116
117/// Creates a tensor with uniformly distributed random values in [0, 1).
118///
119/// # Arguments
120/// * `shape` - Shape of the tensor
121#[must_use]
122pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
123where
124    Standard: Distribution<T>,
125{
126    let numel: usize = shape.iter().product();
127    let mut rng = rand::thread_rng();
128    let data: Vec<T> = (0..numel).map(|_| rng.r#gen()).collect();
129    Tensor::from_vec(data, shape).unwrap()
130}
131
132/// Creates a tensor with normally distributed random values (mean=0, std=1).
133///
134/// # Arguments
135/// * `shape` - Shape of the tensor
136#[must_use]
137pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
138where
139    StandardNormal: Distribution<T>,
140{
141    let numel: usize = shape.iter().product();
142    let mut rng = rand::thread_rng();
143    let normal = StandardNormal;
144    let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
145    Tensor::from_vec(data, shape).unwrap()
146}
147
148/// Creates a tensor with uniformly distributed random values in [low, high).
149///
150/// # Arguments
151/// * `shape` - Shape of the tensor
152/// * `low` - Lower bound (inclusive)
153/// * `high` - Upper bound (exclusive)
154pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
155where
156    T: rand::distributions::uniform::SampleUniform,
157{
158    let numel: usize = shape.iter().product();
159    let mut rng = rand::thread_rng();
160    let dist = Uniform::new(low, high);
161    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
162    Tensor::from_vec(data, shape).unwrap()
163}
164
165/// Creates a tensor with normally distributed random values.
166///
167/// # Arguments
168/// * `shape` - Shape of the tensor
169/// * `mean` - Mean of the distribution
170/// * `std` - Standard deviation of the distribution
171pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
172where
173    T: rand::distributions::uniform::SampleUniform,
174    StandardNormal: Distribution<T>,
175{
176    let numel: usize = shape.iter().product();
177    let mut rng = rand::thread_rng();
178    let dist = Normal::new(mean, std).unwrap();
179    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
180    Tensor::from_vec(data, shape).unwrap()
181}
182
183/// Creates a tensor with random integers in [low, high).
184///
185/// # Arguments
186/// * `shape` - Shape of the tensor
187/// * `low` - Lower bound (inclusive)
188/// * `high` - Upper bound (exclusive)
189#[must_use]
190pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
191where
192    T: num_traits::NumCast,
193{
194    let numel: usize = shape.iter().product();
195    let mut rng = rand::thread_rng();
196    let dist = Uniform::new(low, high);
197    let data: Vec<T> = (0..numel)
198        .map(|_| T::from(dist.sample(&mut rng)).unwrap())
199        .collect();
200    Tensor::from_vec(data, shape).unwrap()
201}
202
203// =============================================================================
204// Range Functions
205// =============================================================================
206
207/// Creates a 1D tensor with values from start to end (exclusive) with step.
208///
209/// # Arguments
210/// * `start` - Start value
211/// * `end` - End value (exclusive)
212/// * `step` - Step size
213pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
214where
215    T: num_traits::NumCast + PartialOrd,
216{
217    let mut data = Vec::new();
218    let mut current = start;
219
220    if step > T::zero() {
221        while current < end {
222            data.push(current);
223            current = current + step;
224        }
225    } else if step < T::zero() {
226        while current > end {
227            data.push(current);
228            current = current + step;
229        }
230    }
231
232    let len = data.len();
233    Tensor::from_vec(data, &[len]).unwrap()
234}
235
236/// Creates a 1D tensor with `num` evenly spaced values from start to end.
237///
238/// # Arguments
239/// * `start` - Start value
240/// * `end` - End value (inclusive)
241/// * `num` - Number of values
242pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
243    if num == 0 {
244        return Tensor::from_vec(vec![], &[0]).unwrap();
245    }
246
247    if num == 1 {
248        return Tensor::from_vec(vec![start], &[1]).unwrap();
249    }
250
251    let step = (end - start) / T::from(num - 1).unwrap();
252    let data: Vec<T> = (0..num)
253        .map(|i| start + step * T::from(i).unwrap())
254        .collect();
255
256    Tensor::from_vec(data, &[num]).unwrap()
257}
258
259/// Creates a 1D tensor with `num` logarithmically spaced values.
260///
261/// # Arguments
262/// * `start` - Start exponent (base^start)
263/// * `end` - End exponent (base^end)
264/// * `num` - Number of values
265/// * `base` - Base of the logarithm
266pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
267    if num == 0 {
268        return Tensor::from_vec(vec![], &[0]).unwrap();
269    }
270
271    let lin = linspace(start, end, num);
272    let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
273
274    Tensor::from_vec(data, &[num]).unwrap()
275}
276
277// =============================================================================
278// Empty Tensor
279// =============================================================================
280
281/// Creates an uninitialized tensor (values are undefined).
282///
283/// # Safety
284/// The tensor contents are uninitialized. Reading before writing is undefined.
285///
286/// # Arguments
287/// * `shape` - Shape of the tensor
288#[must_use]
289pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
290    zeros(shape)
291}
292
293// =============================================================================
294// Tests
295// =============================================================================
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_zeros() {
303        let t = zeros::<f32>(&[2, 3]);
304        assert_eq!(t.shape(), &[2, 3]);
305        assert_eq!(t.numel(), 6);
306        for val in t.to_vec() {
307            assert_eq!(val, 0.0);
308        }
309    }
310
311    #[test]
312    fn test_ones() {
313        let t = ones::<f32>(&[2, 3]);
314        for val in t.to_vec() {
315            assert_eq!(val, 1.0);
316        }
317    }
318
319    #[test]
320    fn test_full() {
321        let t = full::<f32>(&[2, 3], 42.0);
322        for val in t.to_vec() {
323            assert_eq!(val, 42.0);
324        }
325    }
326
327    #[test]
328    fn test_eye() {
329        let t = eye::<f32>(3);
330        assert_eq!(t.shape(), &[3, 3]);
331        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
332        assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
333        assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
334        assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
335    }
336
337    #[test]
338    fn test_rand() {
339        let t = rand::<f32>(&[100]);
340        for val in t.to_vec() {
341            assert!((0.0..1.0).contains(&val));
342        }
343    }
344
345    #[test]
346    fn test_arange() {
347        let t = arange::<f32>(0.0, 5.0, 1.0);
348        assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
349
350        let t = arange::<f32>(0.0, 1.0, 0.2);
351        assert_eq!(t.numel(), 5);
352    }
353
354    #[test]
355    fn test_linspace() {
356        let t = linspace::<f32>(0.0, 1.0, 5);
357        let data = t.to_vec();
358        assert_eq!(data.len(), 5);
359        assert!((data[0] - 0.0).abs() < 1e-6);
360        assert!((data[4] - 1.0).abs() < 1e-6);
361    }
362
363    #[test]
364    fn test_zeros_like() {
365        let a = ones::<f32>(&[2, 3]);
366        let b = zeros_like(&a);
367        assert_eq!(b.shape(), &[2, 3]);
368        for val in b.to_vec() {
369            assert_eq!(val, 0.0);
370        }
371    }
372}