1use rand::distributions::{Distribution, Standard};
16use rand::Rng;
17use rand_distr::{Normal, StandardNormal, Uniform};
18
19use axonml_core::dtype::{Float, Numeric, Scalar};
20
21use crate::tensor::Tensor;
22
23#[must_use]
38pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
39 let numel: usize = shape.iter().product();
40 let data = vec![T::zeroed(); numel];
41 Tensor::from_vec(data, shape).unwrap()
42}
43
44#[must_use]
49pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
50 full(shape, T::one())
51}
52
53pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
59 let numel: usize = shape.iter().product();
60 let data = vec![value; numel];
61 Tensor::from_vec(data, shape).unwrap()
62}
63
64#[must_use]
66pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
67 zeros(other.shape())
68}
69
70#[must_use]
72pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
73 ones(other.shape())
74}
75
76pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
78 full(other.shape(), value)
79}
80
81#[must_use]
90pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
91 let mut data = vec![T::zero(); n * n];
92 for i in 0..n {
93 data[i * n + i] = T::one();
94 }
95 Tensor::from_vec(data, &[n, n]).unwrap()
96}
97
98pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
103 let n = diag.len();
104 let mut data = vec![T::zero(); n * n];
105 for (i, &val) in diag.iter().enumerate() {
106 data[i * n + i] = val;
107 }
108 Tensor::from_vec(data, &[n, n]).unwrap()
109}
110
111#[must_use]
120pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
121where
122 Standard: Distribution<T>,
123{
124 let numel: usize = shape.iter().product();
125 let mut rng = rand::thread_rng();
126 let data: Vec<T> = (0..numel).map(|_| rng.gen()).collect();
127 Tensor::from_vec(data, shape).unwrap()
128}
129
130#[must_use]
135pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
136where
137 StandardNormal: Distribution<T>,
138{
139 let numel: usize = shape.iter().product();
140 let mut rng = rand::thread_rng();
141 let normal = StandardNormal;
142 let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
143 Tensor::from_vec(data, shape).unwrap()
144}
145
146pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
153where
154 T: rand::distributions::uniform::SampleUniform,
155{
156 let numel: usize = shape.iter().product();
157 let mut rng = rand::thread_rng();
158 let dist = Uniform::new(low, high);
159 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
160 Tensor::from_vec(data, shape).unwrap()
161}
162
163pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
170where
171 T: rand::distributions::uniform::SampleUniform,
172 StandardNormal: Distribution<T>,
173{
174 let numel: usize = shape.iter().product();
175 let mut rng = rand::thread_rng();
176 let dist = Normal::new(mean, std).unwrap();
177 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
178 Tensor::from_vec(data, shape).unwrap()
179}
180
181#[must_use]
188pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
189where
190 T: num_traits::NumCast,
191{
192 let numel: usize = shape.iter().product();
193 let mut rng = rand::thread_rng();
194 let dist = Uniform::new(low, high);
195 let data: Vec<T> = (0..numel)
196 .map(|_| T::from(dist.sample(&mut rng)).unwrap())
197 .collect();
198 Tensor::from_vec(data, shape).unwrap()
199}
200
201pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
212where
213 T: num_traits::NumCast + PartialOrd,
214{
215 let mut data = Vec::new();
216 let mut current = start;
217
218 if step > T::zero() {
219 while current < end {
220 data.push(current);
221 current = current + step;
222 }
223 } else if step < T::zero() {
224 while current > end {
225 data.push(current);
226 current = current + step;
227 }
228 }
229
230 let len = data.len();
231 Tensor::from_vec(data, &[len]).unwrap()
232}
233
234pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
241 if num == 0 {
242 return Tensor::from_vec(vec![], &[0]).unwrap();
243 }
244
245 if num == 1 {
246 return Tensor::from_vec(vec![start], &[1]).unwrap();
247 }
248
249 let step = (end - start) / T::from(num - 1).unwrap();
250 let data: Vec<T> = (0..num)
251 .map(|i| start + step * T::from(i).unwrap())
252 .collect();
253
254 Tensor::from_vec(data, &[num]).unwrap()
255}
256
257pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
265 if num == 0 {
266 return Tensor::from_vec(vec![], &[0]).unwrap();
267 }
268
269 let lin = linspace(start, end, num);
270 let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
271
272 Tensor::from_vec(data, &[num]).unwrap()
273}
274
275#[must_use]
287pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
288 zeros(shape)
289}
290
291#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_zeros() {
301 let t = zeros::<f32>(&[2, 3]);
302 assert_eq!(t.shape(), &[2, 3]);
303 assert_eq!(t.numel(), 6);
304 for val in t.to_vec() {
305 assert_eq!(val, 0.0);
306 }
307 }
308
309 #[test]
310 fn test_ones() {
311 let t = ones::<f32>(&[2, 3]);
312 for val in t.to_vec() {
313 assert_eq!(val, 1.0);
314 }
315 }
316
317 #[test]
318 fn test_full() {
319 let t = full::<f32>(&[2, 3], 42.0);
320 for val in t.to_vec() {
321 assert_eq!(val, 42.0);
322 }
323 }
324
325 #[test]
326 fn test_eye() {
327 let t = eye::<f32>(3);
328 assert_eq!(t.shape(), &[3, 3]);
329 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
330 assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
331 assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
332 assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
333 }
334
335 #[test]
336 fn test_rand() {
337 let t = rand::<f32>(&[100]);
338 for val in t.to_vec() {
339 assert!((0.0..1.0).contains(&val));
340 }
341 }
342
343 #[test]
344 fn test_arange() {
345 let t = arange::<f32>(0.0, 5.0, 1.0);
346 assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
347
348 let t = arange::<f32>(0.0, 1.0, 0.2);
349 assert_eq!(t.numel(), 5);
350 }
351
352 #[test]
353 fn test_linspace() {
354 let t = linspace::<f32>(0.0, 1.0, 5);
355 let data = t.to_vec();
356 assert_eq!(data.len(), 5);
357 assert!((data[0] - 0.0).abs() < 1e-6);
358 assert!((data[4] - 1.0).abs() < 1e-6);
359 }
360
361 #[test]
362 fn test_zeros_like() {
363 let a = ones::<f32>(&[2, 3]);
364 let b = zeros_like(&a);
365 assert_eq!(b.shape(), &[2, 3]);
366 for val in b.to_vec() {
367 assert_eq!(val, 0.0);
368 }
369 }
370}