Skip to main content

axonml_tensor/
creation.rs

1//! Tensor Creation Functions
2//!
3//! Provides convenient functions for creating tensors with various initializations
4//! including zeros, ones, random values, ranges, and more.
5//!
6//! # Key Features
7//! - Factory functions for common tensor initializations
8//! - Random tensor generation with various distributions
9//! - Range and linspace functions
10//!
11
12//! @version 0.1.0
13//! @author `AutomataNexus` Development Team
14
15use rand::distributions::{Distribution, Standard};
16use rand::Rng;
17use rand_distr::{Normal, StandardNormal, Uniform};
18
19use axonml_core::dtype::{Float, Numeric, Scalar};
20
21use crate::tensor::Tensor;
22
23// =============================================================================
24// Zero and One Initialization
25// =============================================================================
26
27/// Creates a tensor filled with zeros.
28///
29/// # Arguments
30/// * `shape` - Shape of the tensor
31///
32/// # Example
33/// ```rust,ignore
34/// use axonml_tensor::zeros;
35/// let t = zeros::<f32>(&[2, 3]);
36/// ```
37#[must_use]
38pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
39    let numel: usize = shape.iter().product();
40    let data = vec![T::zeroed(); numel];
41    Tensor::from_vec(data, shape).unwrap()
42}
43
44/// Creates a tensor filled with ones.
45///
46/// # Arguments
47/// * `shape` - Shape of the tensor
48#[must_use]
49pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
50    full(shape, T::one())
51}
52
53/// Creates a tensor filled with a specific value.
54///
55/// # Arguments
56/// * `shape` - Shape of the tensor
57/// * `value` - Fill value
58pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
59    let numel: usize = shape.iter().product();
60    let data = vec![value; numel];
61    Tensor::from_vec(data, shape).unwrap()
62}
63
64/// Creates a tensor with the same shape as another, filled with zeros.
65#[must_use]
66pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
67    zeros(other.shape())
68}
69
70/// Creates a tensor with the same shape as another, filled with ones.
71#[must_use]
72pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
73    ones(other.shape())
74}
75
76/// Creates a tensor with the same shape as another, filled with a value.
77pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
78    full(other.shape(), value)
79}
80
81// =============================================================================
82// Identity and Diagonal
83// =============================================================================
84
85/// Creates a 2D identity matrix.
86///
87/// # Arguments
88/// * `n` - Size of the matrix (n x n)
89#[must_use]
90pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
91    let mut data = vec![T::zero(); n * n];
92    for i in 0..n {
93        data[i * n + i] = T::one();
94    }
95    Tensor::from_vec(data, &[n, n]).unwrap()
96}
97
98/// Creates a 2D tensor with the given diagonal values.
99///
100/// # Arguments
101/// * `diag` - Values for the diagonal
102pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
103    let n = diag.len();
104    let mut data = vec![T::zero(); n * n];
105    for (i, &val) in diag.iter().enumerate() {
106        data[i * n + i] = val;
107    }
108    Tensor::from_vec(data, &[n, n]).unwrap()
109}
110
111// =============================================================================
112// Random Initialization
113// =============================================================================
114
115/// Creates a tensor with uniformly distributed random values in [0, 1).
116///
117/// # Arguments
118/// * `shape` - Shape of the tensor
119#[must_use]
120pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
121where
122    Standard: Distribution<T>,
123{
124    let numel: usize = shape.iter().product();
125    let mut rng = rand::thread_rng();
126    let data: Vec<T> = (0..numel).map(|_| rng.gen()).collect();
127    Tensor::from_vec(data, shape).unwrap()
128}
129
130/// Creates a tensor with normally distributed random values (mean=0, std=1).
131///
132/// # Arguments
133/// * `shape` - Shape of the tensor
134#[must_use]
135pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
136where
137    StandardNormal: Distribution<T>,
138{
139    let numel: usize = shape.iter().product();
140    let mut rng = rand::thread_rng();
141    let normal = StandardNormal;
142    let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
143    Tensor::from_vec(data, shape).unwrap()
144}
145
146/// Creates a tensor with uniformly distributed random values in [low, high).
147///
148/// # Arguments
149/// * `shape` - Shape of the tensor
150/// * `low` - Lower bound (inclusive)
151/// * `high` - Upper bound (exclusive)
152pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
153where
154    T: rand::distributions::uniform::SampleUniform,
155{
156    let numel: usize = shape.iter().product();
157    let mut rng = rand::thread_rng();
158    let dist = Uniform::new(low, high);
159    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
160    Tensor::from_vec(data, shape).unwrap()
161}
162
163/// Creates a tensor with normally distributed random values.
164///
165/// # Arguments
166/// * `shape` - Shape of the tensor
167/// * `mean` - Mean of the distribution
168/// * `std` - Standard deviation of the distribution
169pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
170where
171    T: rand::distributions::uniform::SampleUniform,
172    StandardNormal: Distribution<T>,
173{
174    let numel: usize = shape.iter().product();
175    let mut rng = rand::thread_rng();
176    let dist = Normal::new(mean, std).unwrap();
177    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
178    Tensor::from_vec(data, shape).unwrap()
179}
180
181/// Creates a tensor with random integers in [low, high).
182///
183/// # Arguments
184/// * `shape` - Shape of the tensor
185/// * `low` - Lower bound (inclusive)
186/// * `high` - Upper bound (exclusive)
187#[must_use]
188pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
189where
190    T: num_traits::NumCast,
191{
192    let numel: usize = shape.iter().product();
193    let mut rng = rand::thread_rng();
194    let dist = Uniform::new(low, high);
195    let data: Vec<T> = (0..numel)
196        .map(|_| T::from(dist.sample(&mut rng)).unwrap())
197        .collect();
198    Tensor::from_vec(data, shape).unwrap()
199}
200
201// =============================================================================
202// Range Functions
203// =============================================================================
204
205/// Creates a 1D tensor with values from start to end (exclusive) with step.
206///
207/// # Arguments
208/// * `start` - Start value
209/// * `end` - End value (exclusive)
210/// * `step` - Step size
211pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
212where
213    T: num_traits::NumCast + PartialOrd,
214{
215    let mut data = Vec::new();
216    let mut current = start;
217
218    if step > T::zero() {
219        while current < end {
220            data.push(current);
221            current = current + step;
222        }
223    } else if step < T::zero() {
224        while current > end {
225            data.push(current);
226            current = current + step;
227        }
228    }
229
230    let len = data.len();
231    Tensor::from_vec(data, &[len]).unwrap()
232}
233
234/// Creates a 1D tensor with `num` evenly spaced values from start to end.
235///
236/// # Arguments
237/// * `start` - Start value
238/// * `end` - End value (inclusive)
239/// * `num` - Number of values
240pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
241    if num == 0 {
242        return Tensor::from_vec(vec![], &[0]).unwrap();
243    }
244
245    if num == 1 {
246        return Tensor::from_vec(vec![start], &[1]).unwrap();
247    }
248
249    let step = (end - start) / T::from(num - 1).unwrap();
250    let data: Vec<T> = (0..num)
251        .map(|i| start + step * T::from(i).unwrap())
252        .collect();
253
254    Tensor::from_vec(data, &[num]).unwrap()
255}
256
257/// Creates a 1D tensor with `num` logarithmically spaced values.
258///
259/// # Arguments
260/// * `start` - Start exponent (base^start)
261/// * `end` - End exponent (base^end)
262/// * `num` - Number of values
263/// * `base` - Base of the logarithm
264pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
265    if num == 0 {
266        return Tensor::from_vec(vec![], &[0]).unwrap();
267    }
268
269    let lin = linspace(start, end, num);
270    let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
271
272    Tensor::from_vec(data, &[num]).unwrap()
273}
274
275// =============================================================================
276// Empty Tensor
277// =============================================================================
278
279/// Creates an uninitialized tensor (values are undefined).
280///
281/// # Safety
282/// The tensor contents are uninitialized. Reading before writing is undefined.
283///
284/// # Arguments
285/// * `shape` - Shape of the tensor
286#[must_use]
287pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
288    zeros(shape)
289}
290
291// =============================================================================
292// Tests
293// =============================================================================
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_zeros() {
301        let t = zeros::<f32>(&[2, 3]);
302        assert_eq!(t.shape(), &[2, 3]);
303        assert_eq!(t.numel(), 6);
304        for val in t.to_vec() {
305            assert_eq!(val, 0.0);
306        }
307    }
308
309    #[test]
310    fn test_ones() {
311        let t = ones::<f32>(&[2, 3]);
312        for val in t.to_vec() {
313            assert_eq!(val, 1.0);
314        }
315    }
316
317    #[test]
318    fn test_full() {
319        let t = full::<f32>(&[2, 3], 42.0);
320        for val in t.to_vec() {
321            assert_eq!(val, 42.0);
322        }
323    }
324
325    #[test]
326    fn test_eye() {
327        let t = eye::<f32>(3);
328        assert_eq!(t.shape(), &[3, 3]);
329        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
330        assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
331        assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
332        assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
333    }
334
335    #[test]
336    fn test_rand() {
337        let t = rand::<f32>(&[100]);
338        for val in t.to_vec() {
339            assert!((0.0..1.0).contains(&val));
340        }
341    }
342
343    #[test]
344    fn test_arange() {
345        let t = arange::<f32>(0.0, 5.0, 1.0);
346        assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
347
348        let t = arange::<f32>(0.0, 1.0, 0.2);
349        assert_eq!(t.numel(), 5);
350    }
351
352    #[test]
353    fn test_linspace() {
354        let t = linspace::<f32>(0.0, 1.0, 5);
355        let data = t.to_vec();
356        assert_eq!(data.len(), 5);
357        assert!((data[0] - 0.0).abs() < 1e-6);
358        assert!((data[4] - 1.0).abs() < 1e-6);
359    }
360
361    #[test]
362    fn test_zeros_like() {
363        let a = ones::<f32>(&[2, 3]);
364        let b = zeros_like(&a);
365        assert_eq!(b.shape(), &[2, 3]);
366        for val in b.to_vec() {
367            assert_eq!(val, 0.0);
368        }
369    }
370}