1use rand::Rng;
18use rand::distributions::{Distribution, Standard};
19use rand_distr::{Normal, StandardNormal, Uniform};
20
21use axonml_core::dtype::{Float, Numeric, Scalar};
22
23use crate::tensor::Tensor;
24
25#[must_use]
40pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
41 let numel: usize = shape.iter().product();
42 let data = vec![T::zeroed(); numel];
43 Tensor::from_vec(data, shape).unwrap()
44}
45
46#[must_use]
51pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
52 full(shape, T::one())
53}
54
55pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
61 let numel: usize = shape.iter().product();
62 let data = vec![value; numel];
63 Tensor::from_vec(data, shape).unwrap()
64}
65
66#[must_use]
68pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
69 zeros(other.shape())
70}
71
72#[must_use]
74pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
75 ones(other.shape())
76}
77
78pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
80 full(other.shape(), value)
81}
82
83#[must_use]
92pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
93 let mut data = vec![T::zero(); n * n];
94 for i in 0..n {
95 data[i * n + i] = T::one();
96 }
97 Tensor::from_vec(data, &[n, n]).unwrap()
98}
99
100pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
105 let n = diag.len();
106 let mut data = vec![T::zero(); n * n];
107 for (i, &val) in diag.iter().enumerate() {
108 data[i * n + i] = val;
109 }
110 Tensor::from_vec(data, &[n, n]).unwrap()
111}
112
113#[must_use]
122pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
123where
124 Standard: Distribution<T>,
125{
126 let numel: usize = shape.iter().product();
127 let mut rng = rand::thread_rng();
128 let data: Vec<T> = (0..numel).map(|_| rng.r#gen()).collect();
129 Tensor::from_vec(data, shape).unwrap()
130}
131
132#[must_use]
137pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
138where
139 StandardNormal: Distribution<T>,
140{
141 let numel: usize = shape.iter().product();
142 let mut rng = rand::thread_rng();
143 let normal = StandardNormal;
144 let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
145 Tensor::from_vec(data, shape).unwrap()
146}
147
148pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
155where
156 T: rand::distributions::uniform::SampleUniform,
157{
158 let numel: usize = shape.iter().product();
159 let mut rng = rand::thread_rng();
160 let dist = Uniform::new(low, high);
161 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
162 Tensor::from_vec(data, shape).unwrap()
163}
164
165pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
172where
173 T: rand::distributions::uniform::SampleUniform,
174 StandardNormal: Distribution<T>,
175{
176 let numel: usize = shape.iter().product();
177 let mut rng = rand::thread_rng();
178 let dist = Normal::new(mean, std).unwrap();
179 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
180 Tensor::from_vec(data, shape).unwrap()
181}
182
183#[must_use]
190pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
191where
192 T: num_traits::NumCast,
193{
194 let numel: usize = shape.iter().product();
195 let mut rng = rand::thread_rng();
196 let dist = Uniform::new(low, high);
197 let data: Vec<T> = (0..numel)
198 .map(|_| T::from(dist.sample(&mut rng)).unwrap())
199 .collect();
200 Tensor::from_vec(data, shape).unwrap()
201}
202
203pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
214where
215 T: num_traits::NumCast + PartialOrd,
216{
217 let mut data = Vec::new();
218 let mut current = start;
219
220 if step > T::zero() {
221 while current < end {
222 data.push(current);
223 current = current + step;
224 }
225 } else if step < T::zero() {
226 while current > end {
227 data.push(current);
228 current = current + step;
229 }
230 }
231
232 let len = data.len();
233 Tensor::from_vec(data, &[len]).unwrap()
234}
235
236pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
243 if num == 0 {
244 return Tensor::from_vec(vec![], &[0]).unwrap();
245 }
246
247 if num == 1 {
248 return Tensor::from_vec(vec![start], &[1]).unwrap();
249 }
250
251 let step = (end - start) / T::from(num - 1).unwrap();
252 let data: Vec<T> = (0..num)
253 .map(|i| start + step * T::from(i).unwrap())
254 .collect();
255
256 Tensor::from_vec(data, &[num]).unwrap()
257}
258
259pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
267 if num == 0 {
268 return Tensor::from_vec(vec![], &[0]).unwrap();
269 }
270
271 let lin = linspace(start, end, num);
272 let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
273
274 Tensor::from_vec(data, &[num]).unwrap()
275}
276
277#[must_use]
289pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
290 zeros(shape)
291}
292
293#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_zeros() {
303 let t = zeros::<f32>(&[2, 3]);
304 assert_eq!(t.shape(), &[2, 3]);
305 assert_eq!(t.numel(), 6);
306 for val in t.to_vec() {
307 assert_eq!(val, 0.0);
308 }
309 }
310
311 #[test]
312 fn test_ones() {
313 let t = ones::<f32>(&[2, 3]);
314 for val in t.to_vec() {
315 assert_eq!(val, 1.0);
316 }
317 }
318
319 #[test]
320 fn test_full() {
321 let t = full::<f32>(&[2, 3], 42.0);
322 for val in t.to_vec() {
323 assert_eq!(val, 42.0);
324 }
325 }
326
327 #[test]
328 fn test_eye() {
329 let t = eye::<f32>(3);
330 assert_eq!(t.shape(), &[3, 3]);
331 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
332 assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
333 assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
334 assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
335 }
336
337 #[test]
338 fn test_rand() {
339 let t = rand::<f32>(&[100]);
340 for val in t.to_vec() {
341 assert!((0.0..1.0).contains(&val));
342 }
343 }
344
345 #[test]
346 fn test_arange() {
347 let t = arange::<f32>(0.0, 5.0, 1.0);
348 assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
349
350 let t = arange::<f32>(0.0, 1.0, 0.2);
351 assert_eq!(t.numel(), 5);
352 }
353
354 #[test]
355 fn test_linspace() {
356 let t = linspace::<f32>(0.0, 1.0, 5);
357 let data = t.to_vec();
358 assert_eq!(data.len(), 5);
359 assert!((data[0] - 0.0).abs() < 1e-6);
360 assert!((data[4] - 1.0).abs() < 1e-6);
361 }
362
363 #[test]
364 fn test_zeros_like() {
365 let a = ones::<f32>(&[2, 3]);
366 let b = zeros_like(&a);
367 assert_eq!(b.shape(), &[2, 3]);
368 for val in b.to_vec() {
369 assert_eq!(val, 0.0);
370 }
371 }
372}