use crate::core::error::{PureCvError, Result};
use crate::core::types::Scalar;
use crate::core::Matrix;
use num_traits::{FromPrimitive, ToPrimitive};
use std::cell::RefCell;
#[derive(Clone)]
struct Xoshiro256 {
s: [u64; 4],
}
impl Xoshiro256 {
fn from_seed(seed: u64) -> Self {
let mut sm = seed;
let mut s = [0u64; 4];
for val in &mut s {
*val = Self::splitmix64(&mut sm);
}
if s == [0; 4] {
s[0] = 0x853c49e6748fea9b;
}
Self { s }
}
#[inline]
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9e3779b97f4a7c15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
z ^ (z >> 31)
}
#[inline]
fn next_u64(&mut self) -> u64 {
let result = (self.s[1].wrapping_mul(5)).rotate_left(7).wrapping_mul(9);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
#[inline]
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 * (1.0 / (1u64 << 53) as f64)
}
#[inline]
fn next_gaussian_pair(&mut self) -> (f64, f64) {
loop {
let u1 = self.next_f64();
let u2 = self.next_f64();
if u1 > f64::EPSILON {
let r = (-2.0 * u1.ln()).sqrt();
let theta = std::f64::consts::TAU * u2;
return (r * theta.cos(), r * theta.sin());
}
}
}
}
thread_local! {
static THREAD_RNG: RefCell<Xoshiro256> = RefCell::new(Xoshiro256::from_seed(0));
}
pub fn set_rng_seed(seed: u64) {
THREAD_RNG.with(|rng| {
*rng.borrow_mut() = Xoshiro256::from_seed(seed);
});
}
pub fn randu<T>(dst: &mut Matrix<T>, low: Scalar<f64>, high: Scalar<f64>) -> Result<()>
where
T: Default + Clone + FromPrimitive + ToPrimitive + Send + Sync,
{
if dst.data.is_empty() {
return Err(PureCvError::InvalidDimensions(
"destination matrix is empty".into(),
));
}
let channels = dst.channels;
THREAD_RNG.with(|rng| {
let mut rng = rng.borrow_mut();
for chunk in dst.data.chunks_exact_mut(channels) {
for (i, elem) in chunk.iter_mut().enumerate() {
let ch = i % channels;
let lo = low.v[ch];
let hi = high.v[ch];
let val = lo + rng.next_f64() * (hi - lo);
*elem = T::from_f64(val).unwrap_or_default();
}
}
});
Ok(())
}
pub fn randn<T>(dst: &mut Matrix<T>, mean: Scalar<f64>, std_dev: Scalar<f64>) -> Result<()>
where
T: Default + Clone + FromPrimitive + ToPrimitive + Send + Sync,
{
if dst.data.is_empty() {
return Err(PureCvError::InvalidDimensions(
"destination matrix is empty".into(),
));
}
let channels = dst.channels;
let total = dst.data.len();
THREAD_RNG.with(|rng| {
let mut rng = rng.borrow_mut();
let mut idx = 0;
while idx < total {
let (g0, g1) = rng.next_gaussian_pair();
let ch0 = idx % channels;
let val0 = mean.v[ch0] + g0 * std_dev.v[ch0];
dst.data[idx] = T::from_f64(val0).unwrap_or_default();
idx += 1;
if idx < total {
let ch1 = idx % channels;
let val1 = mean.v[ch1] + g1 * std_dev.v[ch1];
dst.data[idx] = T::from_f64(val1).unwrap_or_default();
idx += 1;
}
}
});
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_rng_seed_determinism() {
set_rng_seed(42);
let mut m1 = Matrix::<f64>::new(10, 10, 1);
randu(&mut m1, Scalar::all(0.0), Scalar::all(1.0)).unwrap();
set_rng_seed(42);
let mut m2 = Matrix::<f64>::new(10, 10, 1);
randu(&mut m2, Scalar::all(0.0), Scalar::all(1.0)).unwrap();
assert_eq!(m1.data, m2.data);
}
#[test]
fn test_randu_bounds() {
set_rng_seed(123);
let mut mat = Matrix::<f64>::new(50, 50, 1);
randu(&mut mat, Scalar::all(-5.0), Scalar::all(5.0)).unwrap();
for &v in &mat.data {
assert!(v >= -5.0 && v < 5.0, "value {} out of range [-5, 5)", v);
}
}
#[test]
fn test_randu_per_channel_bounds() {
set_rng_seed(99);
let mut mat = Matrix::<f64>::new(20, 20, 3);
let low = Scalar::new(0.0, 10.0, 100.0, 0.0);
let high = Scalar::new(1.0, 20.0, 200.0, 0.0);
randu(&mut mat, low, high).unwrap();
for chunk in mat.data.chunks_exact(3) {
assert!(
chunk[0] >= 0.0 && chunk[0] < 1.0,
"ch0: {} not in [0,1)",
chunk[0]
);
assert!(
chunk[1] >= 10.0 && chunk[1] < 20.0,
"ch1: {} not in [10,20)",
chunk[1]
);
assert!(
chunk[2] >= 100.0 && chunk[2] < 200.0,
"ch2: {} not in [100,200)",
chunk[2]
);
}
}
#[test]
fn test_randu_u8() {
set_rng_seed(0);
let mut mat = Matrix::<u8>::new(100, 100, 1);
randu(&mut mat, Scalar::all(0.0), Scalar::all(256.0)).unwrap();
let min = *mat.data.iter().min().unwrap();
let max = *mat.data.iter().max().unwrap();
assert!(max > min, "randu produced no variation");
}
#[test]
fn test_randn_statistics() {
set_rng_seed(7);
let n = 100_000;
let mut mat = Matrix::<f64>::new(1, n, 1);
randn(&mut mat, Scalar::all(50.0), Scalar::all(10.0)).unwrap();
let sum: f64 = mat.data.iter().sum();
let mean_val = sum / n as f64;
let var: f64 = mat.data.iter().map(|v| (v - mean_val).powi(2)).sum::<f64>() / n as f64;
let std_val = var.sqrt();
assert!(
(mean_val - 50.0).abs() < 1.0,
"mean {} too far from 50",
mean_val
);
assert!(
(std_val - 10.0).abs() < 1.0,
"std {} too far from 10",
std_val
);
}
#[test]
fn test_randn_determinism() {
set_rng_seed(77);
let mut m1 = Matrix::<f64>::new(10, 10, 1);
randn(&mut m1, Scalar::all(0.0), Scalar::all(1.0)).unwrap();
set_rng_seed(77);
let mut m2 = Matrix::<f64>::new(10, 10, 1);
randn(&mut m2, Scalar::all(0.0), Scalar::all(1.0)).unwrap();
assert_eq!(m1.data, m2.data);
}
#[test]
fn test_randu_empty_matrix_error() {
let mut mat = Matrix::<f64>::new(0, 0, 1);
let result = randu(&mut mat, Scalar::all(0.0), Scalar::all(1.0));
assert!(result.is_err());
}
#[test]
fn test_randn_empty_matrix_error() {
let mut mat = Matrix::<f64>::new(0, 0, 1);
let result = randn(&mut mat, Scalar::all(0.0), Scalar::all(1.0));
assert!(result.is_err());
}
}