use std::f32::consts::PI;
use ndarray::{Array, ArrayBase, ArrayViewMutD, AsArray, Dimension, ViewRepr, Zip};
use rayon::prelude::*;
use crate::constants::RNG_SEED;
use crate::prelude::*;
use crate::simulation::rng::Pcg;
#[inline]
pub fn poisson_noise<'a, T, A, D>(
data: A,
scale: f64,
seed: Option<u64>,
threads: Option<usize>,
) -> Array<T, D>
where
A: AsArray<'a, T, D>,
D: Dimension,
T: 'a + AsNumeric,
{
let data: ArrayBase<ViewRepr<&'a T>, D> = data.into();
let seed = seed.unwrap_or(RNG_SEED);
let mut prng = Pcg::new(seed);
let mut noise_data: Array<T, D> = Array::from_elem(data.dim(), T::default());
par!(threads,
seq_exp: Zip::from(data.view()).and(noise_data.view_mut())
.for_each(|a, b| {
let a = a.to_f64();
let s = if a < 0.0 { -1.0 } else { 1.0 };
let l = a.abs() * scale;
*b = T::from_f64(get_poisson(&mut prng, l as f32) * s);
}),
par_exp: Zip::from(data.view()).and(noise_data.view_mut())
.into_par_iter()
.for_each_with(prng.fork(), |g, (a, b)| {
let a = a.to_f64();
let s = if a < 0.0 { -1.0 } else { 1.0 };
let l = a.abs() * scale;
*b = T::from_f64(get_poisson(g, l as f32) * s);
}));
noise_data
}
#[inline]
pub fn poisson_noise_mut<T>(
mut data: ArrayViewMutD<T>,
scale: f64,
seed: Option<u64>,
threads: Option<usize>,
) where
T: AsNumeric,
{
let seed = seed.unwrap_or(RNG_SEED);
let mut prng = Pcg::new(seed);
par!(threads,
seq_exp: data.iter_mut().for_each(|v| {
let a = v.to_f64();
let s = if a < 0.0 { -1.0 } else { 1.0 };
let l = a.abs() * scale;
*v = T::from_f64(get_poisson(&mut prng, l as f32) * s);
}),
par_exp: data.into_par_iter().for_each_with(prng.fork(), |g, v| {
let a = v.to_f64();
let s = if a < 0.0 { -1.0 } else { 1.0 };
let l = a.abs() * scale;
*v = T::from_f64(get_poisson(g, l as f32) * s);
}))
}
fn get_poisson(prng: &mut Pcg, lambda: f32) -> f64 {
if lambda >= 30.0 {
let u1 = prng.next_f32();
let u2 = prng.next_f32();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
let sample = (lambda + lambda.sqrt() * z).round().max(0.0);
return sample as f64;
}
let thres = (-lambda).exp();
let mut prod: f32 = 1.0;
let mut count: u64 = 0;
loop {
prod *= prng.next_f32();
if prod < thres {
return count as f64;
}
count += 1;
}
}