redstone_ml/common/
random.rs1use crate::util::to_vec::ToVec;
2use crate::{Constructors, FloatDataType, NdArray, NumericDataType, RawDataType, Tensor, TensorDataType};
3use num::{Float, NumCast};
4use rand::distributions::{Distribution, Uniform};
5use rand::thread_rng;
6use rand_distr::Normal;
7
8pub trait RandomConstructors<T: RawDataType>: Constructors<T> {
9 fn randn(shape: impl ToVec<usize>) -> Self
20 where
21 T: FloatDataType
22 {
23 let mut rng = thread_rng();
24 let shape = shape.to_vec();
25 let n = shape.iter().product();
26
27 let normal = Normal::new(0.0, 1.0).unwrap();
28
29 let random_numbers: Vec<T> = (0..n)
30 .map(|_| <T as NumCast>::from(normal.sample(&mut rng)).unwrap())
31 .collect();
32
33 unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
34 }
35
36 fn rand(shape: impl ToVec<usize>) -> Self
47 where
48 T: FloatDataType
49 {
50 let mut rng = thread_rng();
51 let shape = shape.to_vec();
52 let n = shape.iter().product();
53
54 let uniform = Uniform::new(0.0, 1.0);
55 let random_numbers = (0..n)
56 .map(|_| <T as NumCast>::from(uniform.sample(&mut rng)).unwrap())
57 .collect();
58
59 unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
60 }
61
62 fn uniform(shape: impl ToVec<usize>, low: T, high: T) -> Self
73 where
74 T: FloatDataType
75 {
76 let mut rng = thread_rng();
77 let shape = shape.to_vec();
78 let n = shape.iter().product();
79
80 let uniform = Uniform::new(low, high);
81 let random_numbers = (0..n)
82 .map(|_| <T as NumCast>::from(uniform.sample(&mut rng)).unwrap())
83 .collect();
84
85 unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
86 }
87
88 fn randint(shape: impl ToVec<usize>, low: T, high: T) -> Self
99 where
100 T: NumericDataType
101 {
102 assert!(low < high, "randint: low must be less than high");
103
104 let mut rng = thread_rng();
105 let shape = shape.to_vec();
106 let n = shape.iter().product();
107
108 let uniform = Uniform::new(low.to_float(), high.to_float());
109 let random_numbers = (0..n)
110 .map(|_| <T as NumCast>::from(uniform.sample(&mut rng).round()).unwrap())
111 .collect();
112
113 unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
114 }
115}
116
117impl<'a, T: RawDataType> RandomConstructors<T> for NdArray<'a, T> {}
118impl<'a, T: TensorDataType> RandomConstructors<T> for Tensor<'a, T> {}