redstone_ml/common/
random.rs

1use crate::util::to_vec::ToVec;
2use crate::{Constructors, FloatDataType, NdArray, NumericDataType, RawDataType, Tensor, TensorDataType};
3use num::{Float, NumCast};
4use rand::distributions::{Distribution, Uniform};
5use rand::thread_rng;
6use rand_distr::Normal;
7
8pub trait RandomConstructors<T: RawDataType>: Constructors<T> {
9    /// Samples an `NdArray` with the specified shape
10    /// from a standard normal distribution (0 mean, unit standard deviation).
11    ///
12    /// # Examples
13    /// ```
14    /// # use redstone_ml::*;
15    ///
16    /// let ndarray = NdArray::<f64>::randn([2, 3]);
17    /// println!("{:?}", ndarray);
18    /// ```
19    fn randn(shape: impl ToVec<usize>) -> Self
20    where
21        T: FloatDataType
22    {
23        let mut rng = thread_rng();
24        let shape = shape.to_vec();
25        let n = shape.iter().product();
26
27        let normal = Normal::new(0.0, 1.0).unwrap();
28
29        let random_numbers: Vec<T> = (0..n)
30            .map(|_| <T as NumCast>::from(normal.sample(&mut rng)).unwrap())
31            .collect();
32
33        unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
34    }
35
36    /// Samples an `NdArray` with the specified shape
37    /// with values uniformly distributed in [0, 1).
38    ///
39    /// # Examples
40    /// ```
41    /// # use redstone_ml::*;
42    ///
43    /// let ndarray = NdArray::<f64>::rand([2, 3]);
44    /// println!("{:?}", ndarray);
45    /// ```
46    fn rand(shape: impl ToVec<usize>) -> Self
47    where
48        T: FloatDataType
49    {
50        let mut rng = thread_rng();
51        let shape = shape.to_vec();
52        let n = shape.iter().product();
53
54        let uniform = Uniform::new(0.0, 1.0);
55        let random_numbers = (0..n)
56            .map(|_| <T as NumCast>::from(uniform.sample(&mut rng)).unwrap())
57            .collect();
58
59        unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
60    }
61
62    /// Samples an `NdArray` with the specified shape
63    /// with values uniformly distributed in [`low`, `high`).
64    ///
65    /// # Examples
66    /// ```
67    /// # use redstone_ml::*;
68    ///
69    /// let ndarray = NdArray::<f64>::uniform([2, 3], -5.0, 3.0);
70    /// println!("{:?}", ndarray);
71    /// ```
72    fn uniform(shape: impl ToVec<usize>, low: T, high: T) -> Self
73    where
74        T: FloatDataType
75    {
76        let mut rng = thread_rng();
77        let shape = shape.to_vec();
78        let n = shape.iter().product();
79
80        let uniform = Uniform::new(low, high);
81        let random_numbers = (0..n)
82            .map(|_| <T as NumCast>::from(uniform.sample(&mut rng)).unwrap())
83            .collect();
84
85        unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
86    }
87
88    /// Samples an `NdArray` with the specified shape
89    /// with integer values uniformly distributed between `low` (inclusive) and `high` (exclusive).
90    ///
91    /// # Examples
92    /// ```
93    /// # use redstone_ml::*;
94    ///
95    /// let ndarray = NdArray::<isize>::randint([2, 3], -5, 3);
96    /// println!("{:?}", ndarray);
97    /// ```
98    fn randint(shape: impl ToVec<usize>, low: T, high: T) -> Self
99    where
100        T: NumericDataType
101    {
102        assert!(low < high, "randint: low must be less than high");
103
104        let mut rng = thread_rng();
105        let shape = shape.to_vec();
106        let n = shape.iter().product();
107
108        let uniform = Uniform::new(low.to_float(), high.to_float());
109        let random_numbers = (0..n)
110            .map(|_| <T as NumCast>::from(uniform.sample(&mut rng).round()).unwrap())
111            .collect();
112
113        unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
114    }
115}
116
117impl<'a, T: RawDataType> RandomConstructors<T> for NdArray<'a, T> {}
118impl<'a, T: TensorDataType> RandomConstructors<T> for Tensor<'a, T> {}