use rand::prelude::*;
use rand::distributions::Uniform;
use rand::distributions::Poisson;
use ndarray::stack;
use ndarray::array;
use ndarray::prelude::*;
use ndarray_parallel::prelude::*;
use rayon::prelude::*;
pub fn poisson_process(tmax: f64, lambda: f64) -> Array1<f64> {
assert!(lambda > 0.0);
let mut rng = thread_rng();
let num_events = Poisson::new(tmax*lambda).sample(&mut rng) as usize;
let mut events = Array1::<f64>::zeros((num_events,));
events.par_mapv_inplace(|_| {
let mut rng = thread_rng();
let u = Uniform::new(0.0, tmax);
u.sample(&mut rng)
});
events
}
pub fn variable_poisson<F>(tmax: f64, lambda: &F, max_lambda: f64) -> Array2<f64>
where F: Fn(f64) -> f64 + Send + Sync
{
let mut rng = thread_rng();
let num_events = Poisson::new(tmax*max_lambda).sample(&mut rng);
let lambda = std::sync::Arc::from(lambda);
let events: Vec<Array2<f64>> = (0..num_events)
.into_par_iter().filter_map(|_| {
let mut rng = thread_rng();
let timestamp = rng.gen::<f64>()*tmax;
let lambda_val = rng.gen::<f64>()*max_lambda;
if lambda_val < lambda(timestamp) {
Some(array![[timestamp, lambda_val]])
} else {
None
}
}).collect();
if events.len() > 0 {
let events_ref: Vec<ArrayView2<f64>> = events.iter().map(|v| v.view()).collect();
stack(Axis(0), events_ref.as_slice()).unwrap()
} else {
Array2::<f64>::zeros((0,2))
}
}
pub fn hawkes_exponential(tmax: f64, decay: f64, lambda0: f64, alpha: f64) -> Array2<f64>
{
let mut rng = thread_rng(); let mut result = Vec::<Array2<f64>>::new();
let mut s = -1.0/lambda0*rng.gen::<f64>().ln();
let mut cur_lambda = lambda0 + alpha;
result.push(array![[s, cur_lambda, alpha]]);
let mut lbda_max = cur_lambda;
while s < tmax {
let u: f64 = rng.gen();
let ds = -1.0/lbda_max*u.ln();
cur_lambda = lambda0 + (cur_lambda-lambda0)*(-decay*ds).exp();
s += ds; if s > tmax {
break;
}
let d: f64 = rng.gen();
if d < cur_lambda/lbda_max {
cur_lambda = cur_lambda + alpha; result.push(array![[s, cur_lambda, alpha]]); }
lbda_max = cur_lambda;
}
if result.len() > 0 {
let events: Vec<ArrayView2<f64>> = result.iter().map(|v| v.view()).collect();
stack(Axis(0), &events).unwrap()
} else {
Array2::<f64>::zeros((0,3))
}
}