1use rand::distributions::{Distribution, Standard};
16use rand::Rng;
17use rand_distr::{Normal, StandardNormal, Uniform};
18
19use axonml_core::dtype::{Float, Numeric, Scalar};
20
21use crate::tensor::Tensor;
22
23#[must_use] pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
38 let numel: usize = shape.iter().product();
39 let data = vec![T::zeroed(); numel];
40 Tensor::from_vec(data, shape).unwrap()
41}
42
43#[must_use] pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
48 full(shape, T::one())
49}
50
51pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
57 let numel: usize = shape.iter().product();
58 let data = vec![value; numel];
59 Tensor::from_vec(data, shape).unwrap()
60}
61
62#[must_use] pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
64 zeros(other.shape())
65}
66
67#[must_use] pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
69 ones(other.shape())
70}
71
72pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
74 full(other.shape(), value)
75}
76
77#[must_use] pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
86 let mut data = vec![T::zero(); n * n];
87 for i in 0..n {
88 data[i * n + i] = T::one();
89 }
90 Tensor::from_vec(data, &[n, n]).unwrap()
91}
92
93pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
98 let n = diag.len();
99 let mut data = vec![T::zero(); n * n];
100 for (i, &val) in diag.iter().enumerate() {
101 data[i * n + i] = val;
102 }
103 Tensor::from_vec(data, &[n, n]).unwrap()
104}
105
106#[must_use] pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
115where
116 Standard: Distribution<T>,
117{
118 let numel: usize = shape.iter().product();
119 let mut rng = rand::thread_rng();
120 let data: Vec<T> = (0..numel).map(|_| rng.gen()).collect();
121 Tensor::from_vec(data, shape).unwrap()
122}
123
124#[must_use] pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
129where
130 StandardNormal: Distribution<T>,
131{
132 let numel: usize = shape.iter().product();
133 let mut rng = rand::thread_rng();
134 let normal = StandardNormal;
135 let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
136 Tensor::from_vec(data, shape).unwrap()
137}
138
139pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
146where
147 T: rand::distributions::uniform::SampleUniform,
148{
149 let numel: usize = shape.iter().product();
150 let mut rng = rand::thread_rng();
151 let dist = Uniform::new(low, high);
152 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
153 Tensor::from_vec(data, shape).unwrap()
154}
155
156pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
163where
164 T: rand::distributions::uniform::SampleUniform,
165 StandardNormal: Distribution<T>,
166{
167 let numel: usize = shape.iter().product();
168 let mut rng = rand::thread_rng();
169 let dist = Normal::new(mean, std).unwrap();
170 let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
171 Tensor::from_vec(data, shape).unwrap()
172}
173
174#[must_use] pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
181where
182 T: num_traits::NumCast,
183{
184 let numel: usize = shape.iter().product();
185 let mut rng = rand::thread_rng();
186 let dist = Uniform::new(low, high);
187 let data: Vec<T> = (0..numel)
188 .map(|_| T::from(dist.sample(&mut rng)).unwrap())
189 .collect();
190 Tensor::from_vec(data, shape).unwrap()
191}
192
193pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
204where
205 T: num_traits::NumCast + PartialOrd,
206{
207 let mut data = Vec::new();
208 let mut current = start;
209
210 if step > T::zero() {
211 while current < end {
212 data.push(current);
213 current = current + step;
214 }
215 } else if step < T::zero() {
216 while current > end {
217 data.push(current);
218 current = current + step;
219 }
220 }
221
222 let len = data.len();
223 Tensor::from_vec(data, &[len]).unwrap()
224}
225
226pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
233 if num == 0 {
234 return Tensor::from_vec(vec![], &[0]).unwrap();
235 }
236
237 if num == 1 {
238 return Tensor::from_vec(vec![start], &[1]).unwrap();
239 }
240
241 let step = (end - start) / T::from(num - 1).unwrap();
242 let data: Vec<T> = (0..num)
243 .map(|i| start + step * T::from(i).unwrap())
244 .collect();
245
246 Tensor::from_vec(data, &[num]).unwrap()
247}
248
249pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
257 if num == 0 {
258 return Tensor::from_vec(vec![], &[0]).unwrap();
259 }
260
261 let lin = linspace(start, end, num);
262 let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
263
264 Tensor::from_vec(data, &[num]).unwrap()
265}
266
267#[must_use] pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
279 zeros(shape)
280}
281
282#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_zeros() {
292 let t = zeros::<f32>(&[2, 3]);
293 assert_eq!(t.shape(), &[2, 3]);
294 assert_eq!(t.numel(), 6);
295 for val in t.to_vec() {
296 assert_eq!(val, 0.0);
297 }
298 }
299
300 #[test]
301 fn test_ones() {
302 let t = ones::<f32>(&[2, 3]);
303 for val in t.to_vec() {
304 assert_eq!(val, 1.0);
305 }
306 }
307
308 #[test]
309 fn test_full() {
310 let t = full::<f32>(&[2, 3], 42.0);
311 for val in t.to_vec() {
312 assert_eq!(val, 42.0);
313 }
314 }
315
316 #[test]
317 fn test_eye() {
318 let t = eye::<f32>(3);
319 assert_eq!(t.shape(), &[3, 3]);
320 assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
321 assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
322 assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
323 assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
324 }
325
326 #[test]
327 fn test_rand() {
328 let t = rand::<f32>(&[100]);
329 for val in t.to_vec() {
330 assert!((0.0..1.0).contains(&val));
331 }
332 }
333
334 #[test]
335 fn test_arange() {
336 let t = arange::<f32>(0.0, 5.0, 1.0);
337 assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
338
339 let t = arange::<f32>(0.0, 1.0, 0.2);
340 assert_eq!(t.numel(), 5);
341 }
342
343 #[test]
344 fn test_linspace() {
345 let t = linspace::<f32>(0.0, 1.0, 5);
346 let data = t.to_vec();
347 assert_eq!(data.len(), 5);
348 assert!((data[0] - 0.0).abs() < 1e-6);
349 assert!((data[4] - 1.0).abs() < 1e-6);
350 }
351
352 #[test]
353 fn test_zeros_like() {
354 let a = ones::<f32>(&[2, 3]);
355 let b = zeros_like(&a);
356 assert_eq!(b.shape(), &[2, 3]);
357 for val in b.to_vec() {
358 assert_eq!(val, 0.0);
359 }
360 }
361}