axonml_tensor/
creation.rs

1//! Tensor Creation Functions
2//!
3//! Provides convenient functions for creating tensors with various initializations
4//! including zeros, ones, random values, ranges, and more.
5//!
6//! # Key Features
7//! - Factory functions for common tensor initializations
8//! - Random tensor generation with various distributions
9//! - Range and linspace functions
10//!
11
12//! @version 0.1.0
13//! @author `AutomataNexus` Development Team
14
15use rand::distributions::{Distribution, Standard};
16use rand::Rng;
17use rand_distr::{Normal, StandardNormal, Uniform};
18
19use axonml_core::dtype::{Float, Numeric, Scalar};
20
21use crate::tensor::Tensor;
22
23// =============================================================================
24// Zero and One Initialization
25// =============================================================================
26
27/// Creates a tensor filled with zeros.
28///
29/// # Arguments
30/// * `shape` - Shape of the tensor
31///
32/// # Example
33/// ```rust,ignore
34/// use axonml_tensor::zeros;
35/// let t = zeros::<f32>(&[2, 3]);
36/// ```
37#[must_use] pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
38    let numel: usize = shape.iter().product();
39    let data = vec![T::zeroed(); numel];
40    Tensor::from_vec(data, shape).unwrap()
41}
42
43/// Creates a tensor filled with ones.
44///
45/// # Arguments
46/// * `shape` - Shape of the tensor
47#[must_use] pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
48    full(shape, T::one())
49}
50
51/// Creates a tensor filled with a specific value.
52///
53/// # Arguments
54/// * `shape` - Shape of the tensor
55/// * `value` - Fill value
56pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
57    let numel: usize = shape.iter().product();
58    let data = vec![value; numel];
59    Tensor::from_vec(data, shape).unwrap()
60}
61
62/// Creates a tensor with the same shape as another, filled with zeros.
63#[must_use] pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
64    zeros(other.shape())
65}
66
67/// Creates a tensor with the same shape as another, filled with ones.
68#[must_use] pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
69    ones(other.shape())
70}
71
72/// Creates a tensor with the same shape as another, filled with a value.
73pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
74    full(other.shape(), value)
75}
76
77// =============================================================================
78// Identity and Diagonal
79// =============================================================================
80
81/// Creates a 2D identity matrix.
82///
83/// # Arguments
84/// * `n` - Size of the matrix (n x n)
85#[must_use] pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
86    let mut data = vec![T::zero(); n * n];
87    for i in 0..n {
88        data[i * n + i] = T::one();
89    }
90    Tensor::from_vec(data, &[n, n]).unwrap()
91}
92
93/// Creates a 2D tensor with the given diagonal values.
94///
95/// # Arguments
96/// * `diag` - Values for the diagonal
97pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
98    let n = diag.len();
99    let mut data = vec![T::zero(); n * n];
100    for (i, &val) in diag.iter().enumerate() {
101        data[i * n + i] = val;
102    }
103    Tensor::from_vec(data, &[n, n]).unwrap()
104}
105
106// =============================================================================
107// Random Initialization
108// =============================================================================
109
110/// Creates a tensor with uniformly distributed random values in [0, 1).
111///
112/// # Arguments
113/// * `shape` - Shape of the tensor
114#[must_use] pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
115where
116    Standard: Distribution<T>,
117{
118    let numel: usize = shape.iter().product();
119    let mut rng = rand::thread_rng();
120    let data: Vec<T> = (0..numel).map(|_| rng.gen()).collect();
121    Tensor::from_vec(data, shape).unwrap()
122}
123
124/// Creates a tensor with normally distributed random values (mean=0, std=1).
125///
126/// # Arguments
127/// * `shape` - Shape of the tensor
128#[must_use] pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
129where
130    StandardNormal: Distribution<T>,
131{
132    let numel: usize = shape.iter().product();
133    let mut rng = rand::thread_rng();
134    let normal = StandardNormal;
135    let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
136    Tensor::from_vec(data, shape).unwrap()
137}
138
139/// Creates a tensor with uniformly distributed random values in [low, high).
140///
141/// # Arguments
142/// * `shape` - Shape of the tensor
143/// * `low` - Lower bound (inclusive)
144/// * `high` - Upper bound (exclusive)
145pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
146where
147    T: rand::distributions::uniform::SampleUniform,
148{
149    let numel: usize = shape.iter().product();
150    let mut rng = rand::thread_rng();
151    let dist = Uniform::new(low, high);
152    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
153    Tensor::from_vec(data, shape).unwrap()
154}
155
156/// Creates a tensor with normally distributed random values.
157///
158/// # Arguments
159/// * `shape` - Shape of the tensor
160/// * `mean` - Mean of the distribution
161/// * `std` - Standard deviation of the distribution
162pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
163where
164    T: rand::distributions::uniform::SampleUniform,
165    StandardNormal: Distribution<T>,
166{
167    let numel: usize = shape.iter().product();
168    let mut rng = rand::thread_rng();
169    let dist = Normal::new(mean, std).unwrap();
170    let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
171    Tensor::from_vec(data, shape).unwrap()
172}
173
174/// Creates a tensor with random integers in [low, high).
175///
176/// # Arguments
177/// * `shape` - Shape of the tensor
178/// * `low` - Lower bound (inclusive)
179/// * `high` - Upper bound (exclusive)
180#[must_use] pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
181where
182    T: num_traits::NumCast,
183{
184    let numel: usize = shape.iter().product();
185    let mut rng = rand::thread_rng();
186    let dist = Uniform::new(low, high);
187    let data: Vec<T> = (0..numel)
188        .map(|_| T::from(dist.sample(&mut rng)).unwrap())
189        .collect();
190    Tensor::from_vec(data, shape).unwrap()
191}
192
193// =============================================================================
194// Range Functions
195// =============================================================================
196
197/// Creates a 1D tensor with values from start to end (exclusive) with step.
198///
199/// # Arguments
200/// * `start` - Start value
201/// * `end` - End value (exclusive)
202/// * `step` - Step size
203pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
204where
205    T: num_traits::NumCast + PartialOrd,
206{
207    let mut data = Vec::new();
208    let mut current = start;
209
210    if step > T::zero() {
211        while current < end {
212            data.push(current);
213            current = current + step;
214        }
215    } else if step < T::zero() {
216        while current > end {
217            data.push(current);
218            current = current + step;
219        }
220    }
221
222    let len = data.len();
223    Tensor::from_vec(data, &[len]).unwrap()
224}
225
226/// Creates a 1D tensor with `num` evenly spaced values from start to end.
227///
228/// # Arguments
229/// * `start` - Start value
230/// * `end` - End value (inclusive)
231/// * `num` - Number of values
232pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
233    if num == 0 {
234        return Tensor::from_vec(vec![], &[0]).unwrap();
235    }
236
237    if num == 1 {
238        return Tensor::from_vec(vec![start], &[1]).unwrap();
239    }
240
241    let step = (end - start) / T::from(num - 1).unwrap();
242    let data: Vec<T> = (0..num)
243        .map(|i| start + step * T::from(i).unwrap())
244        .collect();
245
246    Tensor::from_vec(data, &[num]).unwrap()
247}
248
249/// Creates a 1D tensor with `num` logarithmically spaced values.
250///
251/// # Arguments
252/// * `start` - Start exponent (base^start)
253/// * `end` - End exponent (base^end)
254/// * `num` - Number of values
255/// * `base` - Base of the logarithm
256pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
257    if num == 0 {
258        return Tensor::from_vec(vec![], &[0]).unwrap();
259    }
260
261    let lin = linspace(start, end, num);
262    let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
263
264    Tensor::from_vec(data, &[num]).unwrap()
265}
266
267// =============================================================================
268// Empty Tensor
269// =============================================================================
270
271/// Creates an uninitialized tensor (values are undefined).
272///
273/// # Safety
274/// The tensor contents are uninitialized. Reading before writing is undefined.
275///
276/// # Arguments
277/// * `shape` - Shape of the tensor
278#[must_use] pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
279    zeros(shape)
280}
281
282// =============================================================================
283// Tests
284// =============================================================================
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_zeros() {
292        let t = zeros::<f32>(&[2, 3]);
293        assert_eq!(t.shape(), &[2, 3]);
294        assert_eq!(t.numel(), 6);
295        for val in t.to_vec() {
296            assert_eq!(val, 0.0);
297        }
298    }
299
300    #[test]
301    fn test_ones() {
302        let t = ones::<f32>(&[2, 3]);
303        for val in t.to_vec() {
304            assert_eq!(val, 1.0);
305        }
306    }
307
308    #[test]
309    fn test_full() {
310        let t = full::<f32>(&[2, 3], 42.0);
311        for val in t.to_vec() {
312            assert_eq!(val, 42.0);
313        }
314    }
315
316    #[test]
317    fn test_eye() {
318        let t = eye::<f32>(3);
319        assert_eq!(t.shape(), &[3, 3]);
320        assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
321        assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
322        assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
323        assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
324    }
325
326    #[test]
327    fn test_rand() {
328        let t = rand::<f32>(&[100]);
329        for val in t.to_vec() {
330            assert!((0.0..1.0).contains(&val));
331        }
332    }
333
334    #[test]
335    fn test_arange() {
336        let t = arange::<f32>(0.0, 5.0, 1.0);
337        assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
338
339        let t = arange::<f32>(0.0, 1.0, 0.2);
340        assert_eq!(t.numel(), 5);
341    }
342
343    #[test]
344    fn test_linspace() {
345        let t = linspace::<f32>(0.0, 1.0, 5);
346        let data = t.to_vec();
347        assert_eq!(data.len(), 5);
348        assert!((data[0] - 0.0).abs() < 1e-6);
349        assert!((data[4] - 1.0).abs() < 1e-6);
350    }
351
352    #[test]
353    fn test_zeros_like() {
354        let a = ones::<f32>(&[2, 3]);
355        let b = zeros_like(&a);
356        assert_eq!(b.shape(), &[2, 3]);
357        for val in b.to_vec() {
358            assert_eq!(val, 0.0);
359        }
360    }
361}