1use std::sync::{Arc, RwLock};
2
3use crate::SamplingMethod;
4use linfa::Float;
5use ndarray::{Array, Array2, ArrayBase, Data, Ix2};
6use ndarray_rand::{RandomExt, rand::Rng, rand::SeedableRng, rand_distr::Uniform};
7use rand_xoshiro::Xoshiro256Plus;
8
9#[cfg(feature = "serializable")]
10use serde::{Deserialize, Serialize};
11
12type RngRef<R> = Arc<RwLock<R>>;
13#[derive(Clone, Debug)]
15#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
16pub struct Random<F: Float, R: Rng> {
17 xlimits: Array2<F>,
20 rng: RngRef<R>,
22}
23
24impl<F: Float> Random<F, Xoshiro256Plus> {
25 pub fn new(xlimits: &ArrayBase<impl Data<Elem = F>, Ix2>) -> Self {
34 Self::new_with_rng(xlimits, Xoshiro256Plus::from_entropy())
35 }
36}
37
38impl<F: Float, R: Rng> Random<F, R> {
39 pub fn new_with_rng(xlimits: &ArrayBase<impl Data<Elem = F>, Ix2>, rng: R) -> Self {
44 if xlimits.ncols() != 2 {
45 panic!("xlimits must have 2 columns (lower, upper)");
46 }
47 Random {
48 xlimits: xlimits.to_owned(),
49 rng: Arc::new(RwLock::new(rng)),
50 }
51 }
52
53 pub fn with_rng<R2: Rng>(self, rng: R2) -> Random<F, R2> {
55 Random {
56 xlimits: self.xlimits,
57 rng: Arc::new(RwLock::new(rng)),
58 }
59 }
60}
61
62impl<F: Float, R: Rng> SamplingMethod<F> for Random<F, R> {
63 fn sampling_space(&self) -> &Array2<F> {
64 &self.xlimits
65 }
66
67 fn normalized_sample(&self, ns: usize) -> Array2<F> {
68 let mut rng = self.rng.write().unwrap();
69 let nx = self.xlimits.nrows();
70 Array::random_using((ns, nx), Uniform::new(0., 1.), &mut *rng).mapv(|v| F::cast(v))
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use approx::assert_abs_diff_eq;
78 use ndarray::{arr2, array};
79
80 #[test]
81 fn test_random() {
82 let xlimits = arr2(&[[5., 10.], [0., 1.]]);
83 let expected = array![
84 [5.4287779764773045, 0.31041139572710486],
85 [5.31284890781607, 0.306461322653673],
86 [5.0002147942961885, 0.3030653113049855],
87 [5.438048037018622, 0.2270337387265695],
88 [9.31397733563812, 0.5232539513550647],
89 [6.0549173955055435, 0.8198009346946455],
90 [8.303444344933911, 0.8588635290560207],
91 [5.721154177502889, 0.3516459308028457],
92 [5.457086177138239, 0.11691074717669259]
93 ];
94 let actual = Random::new(&xlimits)
95 .with_rng(Xoshiro256Plus::seed_from_u64(42))
96 .sample(9);
97 assert_abs_diff_eq!(expected, actual, epsilon = 1e-6);
98 }
99}