use crate::util::to_vec::ToVec;
use crate::{Constructors, FloatDataType, NdArray, NumericDataType, RawDataType, Tensor, TensorDataType};
use num::{Float, NumCast};
use rand::distributions::{Distribution, Uniform};
use rand::thread_rng;
use rand_distr::Normal;
pub trait RandomConstructors<T: RawDataType>: Constructors<T> {
fn randn(shape: impl ToVec<usize>) -> Self
where
T: FloatDataType
{
let mut rng = thread_rng();
let shape = shape.to_vec();
let n = shape.iter().product();
let normal = Normal::new(0.0, 1.0).unwrap();
let random_numbers: Vec<T> = (0..n)
.map(|_| <T as NumCast>::from(normal.sample(&mut rng)).unwrap())
.collect();
unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
}
fn rand(shape: impl ToVec<usize>) -> Self
where
T: FloatDataType
{
let mut rng = thread_rng();
let shape = shape.to_vec();
let n = shape.iter().product();
let uniform = Uniform::new(0.0, 1.0);
let random_numbers = (0..n)
.map(|_| <T as NumCast>::from(uniform.sample(&mut rng)).unwrap())
.collect();
unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
}
fn uniform(shape: impl ToVec<usize>, low: T, high: T) -> Self
where
T: FloatDataType
{
let mut rng = thread_rng();
let shape = shape.to_vec();
let n = shape.iter().product();
let uniform = Uniform::new(low, high);
let random_numbers = (0..n)
.map(|_| <T as NumCast>::from(uniform.sample(&mut rng)).unwrap())
.collect();
unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
}
fn randint(shape: impl ToVec<usize>, low: T, high: T) -> Self
where
T: NumericDataType
{
assert!(low < high, "randint: low must be less than high");
let mut rng = thread_rng();
let shape = shape.to_vec();
let n = shape.iter().product();
let uniform = Uniform::new(low.to_float(), high.to_float());
let random_numbers = (0..n)
.map(|_| <T as NumCast>::from(uniform.sample(&mut rng).round()).unwrap())
.collect();
unsafe { Self::from_contiguous_owned_buffer(shape, random_numbers) }
}
}
impl<'a, T: RawDataType> RandomConstructors<T> for NdArray<'a, T> {}
impl<'a, T: TensorDataType> RandomConstructors<T> for Tensor<'a, T> {}