use ndarray::{
Array3, ArrayBase, ArrayView1, ArrayViewMut3, AsArray, Axis, Ix1, Ix3, ViewRepr, Zip,
};
use rand::SeedableRng;
use rand::prelude::*;
use rand::rngs::StdRng;
use rand_distr::{Distribution, Poisson};
use rayon::prelude::*;
use crate::error::ImgalError;
use crate::traits::numeric::AsNumeric;
pub fn poisson_noise_1d<'a, T, A>(data: A, scale: f64, seed: Option<u64>) -> Vec<f64>
where
A: AsArray<'a, T, Ix1>,
T: 'a + AsNumeric,
{
let view: ArrayBase<ViewRepr<&'a T>, Ix1> = data.into();
let s = seed.unwrap_or(0);
let mut rng = StdRng::seed_from_u64(s);
let mut n_data = vec![0.0; view.len()];
n_data.iter_mut().zip(view.iter()).for_each(|(n, &d)| {
if d.to_f64() > 0.0 {
let l: f64 = d.to_f64() * scale;
let p = Poisson::new(l).unwrap();
*n = p.sample(&mut rng);
} else {
*n = 0.0;
}
});
n_data
}
pub fn poisson_noise_1d_mut(data: &mut [f64], scale: f64, seed: Option<u64>) {
let s = seed.unwrap_or(0);
let mut rng = StdRng::seed_from_u64(s);
data.iter_mut().for_each(|x| {
if *x > 0.0 {
let l = *x * scale;
let p = Poisson::new(l).unwrap();
*x = p.sample(&mut rng);
} else {
*x = 0.0;
}
});
}
pub fn poisson_noise_3d<'a, T, A>(
data: A,
scale: f64,
seed: Option<u64>,
axis: Option<usize>,
) -> Result<Array3<f64>, ImgalError>
where
A: AsArray<'a, T, Ix3>,
T: 'a + AsNumeric,
{
let a = axis.unwrap_or(2);
if a >= 3 {
return Err(ImgalError::InvalidAxis {
axis_idx: a,
dim_len: 3,
});
}
let view: ArrayBase<ViewRepr<&'a T>, Ix3> = data.into();
let shape = view.dim();
let mut n_data = Array3::<f64>::zeros(shape);
let src_lanes = view.lanes(Axis(a));
let dst_lanes = n_data.lanes_mut(Axis(a));
if let Some(s) = seed {
Zip::from(src_lanes)
.and(dst_lanes)
.par_for_each(|s_ln, d_ln| {
let mut rng = StdRng::seed_from_u64(s);
Zip::from(s_ln).and(d_ln).for_each(|s, d| {
if (*s).to_f64() > 0.0 {
let l = (*s).to_f64() * scale;
let p = Poisson::new(l).unwrap();
*d = p.sample(&mut rng);
} else {
*d = 0.0;
}
});
});
} else {
Zip::from(src_lanes)
.and(dst_lanes)
.par_for_each(|s_ln, d_ln| {
let mut rng = rand::rng();
Zip::from(s_ln).and(d_ln).for_each(|s, d| {
if (*s).to_f64() > 0.0 {
let l = (*s).to_f64() * scale;
let p = Poisson::new(l).unwrap();
*d = p.sample(&mut rng);
} else {
*d = 0.0
}
});
});
}
Ok(n_data)
}
pub fn poisson_noise_3d_mut(
mut data: ArrayViewMut3<f64>,
scale: f64,
seed: Option<u64>,
axis: Option<usize>,
) {
let a = axis.unwrap_or(2);
let lanes = data.lanes_mut(Axis(a));
if let Some(s) = seed {
lanes.into_iter().par_bridge().for_each(|mut ln| {
if let Some(l) = ln.as_slice_mut() {
poisson_noise_1d_mut(l, scale, Some(s));
} else {
let mut l = ln.to_vec();
poisson_noise_1d_mut(&mut l, scale, Some(s));
let l = ArrayView1::from(&l);
ln.assign(&l);
}
});
} else {
lanes.into_iter().par_bridge().for_each(|mut ln| {
let mut rng = rand::rng();
let s = rng.next_u64();
if let Some(l) = ln.as_slice_mut() {
poisson_noise_1d_mut(l, scale, Some(s));
} else {
let mut l = ln.to_vec();
poisson_noise_1d_mut(&mut l, scale, Some(s));
let l = ArrayView1::from(&l);
ln.assign(&l);
}
});
}
}