1use rand::Rng;
25use rand::distributions::{Distribution, Standard};
26use rand_distr::{Normal, StandardNormal, Uniform};
27
28use axonml_core::dtype::{Float, Numeric, Scalar};
29
30use crate::tensor::Tensor;
31
32#[must_use]
47pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
48 let numel: usize = shape.iter().product();
49 let data = vec![T::zeroed(); numel];
50 Tensor::from_vec(data, shape).expect("tensor creation failed")
51}
52
53#[must_use]
58pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
59 full(shape, T::one())
60}
61
62pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
68 let numel: usize = shape.iter().product();
69 let data = vec![value; numel];
70 Tensor::from_vec(data, shape).expect("tensor creation failed")
71}
72
73#[must_use]
75pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
76 zeros(other.shape())
77}
78
79#[must_use]
81pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
82 ones(other.shape())
83}
84
85pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
87 full(other.shape(), value)
88}
89
90#[must_use]
99pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
100 let mut data = vec![T::zero(); n * n];
101 for i in 0..n {
102 data[i * n + i] = T::one();
103 }
104 Tensor::from_vec(data, &[n, n]).expect("tensor creation failed")
105}
106
107pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
112 let n = diag.len();
113 let mut data = vec![T::zero(); n * n];
114 for (i, &val) in diag.iter().enumerate() {
115 data[i * n + i] = val;
116 }
117 Tensor::from_vec(data, &[n, n]).expect("tensor creation failed")
118}
119
120#[must_use]
129pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
130where
131 Standard: Distribution<T>,
132{
133 let numel: usize = shape.iter().product();
134 let mut rng = rand::thread_rng();
135 let data: Vec<T> = (0..numel).map(|_| rng.r#gen()).collect();
136 Tensor::from_vec(data, shape).expect("tensor creation failed")
137}
138
139#[must_use]
144pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
145where
146 StandardNormal: Distribution<T>,
147{
148 let numel: usize = shape.iter().product();
149 let mut rng = rand::thread_rng();
150 let normal = StandardNormal;
151 let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
152 Tensor::from_vec(data, shape).expect("tensor creation failed")
153}
154
155pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
162where
163 T: rand::distributions::uniform::SampleUniform,
164{
165 let numel: usize = shape.iter().product();
166 let mut rng = rand::thread_rng();
167 let dist = Uniform::new(low, high);
168 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
169 Tensor::from_vec(data, shape).expect("tensor creation failed")
170}
171
172pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
179where
180 T: rand::distributions::uniform::SampleUniform,
181 StandardNormal: Distribution<T>,
182{
183 let numel: usize = shape.iter().product();
184 let mut rng = rand::thread_rng();
185 let dist = Normal::new(mean, std).unwrap();
186 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
187 Tensor::from_vec(data, shape).expect("tensor creation failed")
188}
189
190#[must_use]
197pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
198where
199 T: num_traits::NumCast,
200{
201 let numel: usize = shape.iter().product();
202 let mut rng = rand::thread_rng();
203 let dist = Uniform::new(low, high);
204 let data: Vec<T> = (0..numel)
205 .map(|_| T::from(dist.sample(&mut rng)).unwrap())
206 .collect();
207 Tensor::from_vec(data, shape).expect("tensor creation failed")
208}
209
210pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
221where
222 T: num_traits::NumCast + PartialOrd,
223{
224 let mut data = Vec::new();
225 let mut current = start;
226
227 if step > T::zero() {
228 while current < end {
229 data.push(current);
230 current = current + step;
231 }
232 } else if step < T::zero() {
233 while current > end {
234 data.push(current);
235 current = current + step;
236 }
237 }
238
239 let len = data.len();
240 Tensor::from_vec(data, &[len]).expect("tensor creation failed")
241}
242
243pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
250 if num == 0 {
251 return Tensor::from_vec(vec![], &[0]).expect("tensor creation failed");
252 }
253
254 if num == 1 {
255 return Tensor::from_vec(vec![start], &[1]).expect("tensor creation failed");
256 }
257
258 let step = (end - start) / T::from(num - 1).unwrap();
259 let data: Vec<T> = (0..num)
260 .map(|i| start + step * T::from(i).unwrap())
261 .collect();
262
263 Tensor::from_vec(data, &[num]).expect("tensor creation failed")
264}
265
266pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
274 if num == 0 {
275 return Tensor::from_vec(vec![], &[0]).expect("tensor creation failed");
276 }
277
278 let lin = linspace(start, end, num);
279 let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
280
281 Tensor::from_vec(data, &[num]).expect("tensor creation failed")
282}
283
284#[must_use]
296pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
297 zeros(shape)
298}
299
300#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_zeros() {
310 let t = zeros::<f32>(&[2, 3]);
311 assert_eq!(t.shape(), &[2, 3]);
312 assert_eq!(t.numel(), 6);
313 for val in t.to_vec() {
314 assert_eq!(val, 0.0);
315 }
316 }
317
318 #[test]
319 fn test_ones() {
320 let t = ones::<f32>(&[2, 3]);
321 for val in t.to_vec() {
322 assert_eq!(val, 1.0);
323 }
324 }
325
326 #[test]
327 fn test_full() {
328 let t = full::<f32>(&[2, 3], 42.0);
329 for val in t.to_vec() {
330 assert_eq!(val, 42.0);
331 }
332 }
333
334 #[test]
335 fn test_eye() {
336 let t = eye::<f32>(3);
337 assert_eq!(t.shape(), &[3, 3]);
338 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
339 assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
340 assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
341 assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
342 }
343
344 #[test]
345 fn test_rand() {
346 let t = rand::<f32>(&[100]);
347 for val in t.to_vec() {
348 assert!((0.0..1.0).contains(&val));
349 }
350 }
351
352 #[test]
353 fn test_arange() {
354 let t = arange::<f32>(0.0, 5.0, 1.0);
355 assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
356
357 let t = arange::<f32>(0.0, 1.0, 0.2);
358 assert_eq!(t.numel(), 5);
359 }
360
361 #[test]
362 fn test_linspace() {
363 let t = linspace::<f32>(0.0, 1.0, 5);
364 let data = t.to_vec();
365 assert_eq!(data.len(), 5);
366 assert!((data[0] - 0.0).abs() < 1e-6);
367 assert!((data[4] - 1.0).abs() < 1e-6);
368 }
369
370 #[test]
371 fn test_zeros_like() {
372 let a = ones::<f32>(&[2, 3]);
373 let b = zeros_like(&a);
374 assert_eq!(b.shape(), &[2, 3]);
375 for val in b.to_vec() {
376 assert_eq!(val, 0.0);
377 }
378 }
379}