use std::collections::BTreeSet;
use enum_dispatch::enum_dispatch;
use log::{debug, error, trace, warn};
use ndarray::prelude::*;
use rand::Rng;
use rand_chacha::ChaCha8Rng;
use thiserror::Error;
#[derive(Error, Debug, PartialEq)]
pub enum ParamsError {
#[error("Unsupported method")]
UnsupportedMethod(String),
#[error("Paramiters not initialized")]
ParametersNotInitialized(String),
#[error("Invalid cim for parameter")]
InvalidCIM(String),
}
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub enum StateType {
Discrete(usize),
}
#[enum_dispatch(Params)]
pub trait ParamsTrait {
fn reset_params(&mut self);
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType;
fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError>;
fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError>;
fn get_reserved_space_as_parent(&self) -> usize;
fn state_to_index(&self, state: &StateType) -> usize;
fn validate_params(&self) -> Result<(), ParamsError>;
fn get_label(&self) -> &String;
}
#[derive(Clone)]
#[enum_dispatch]
pub enum Params {
DiscreteStatesContinousTime(DiscreteStatesContinousTimeParams),
}
#[derive(Clone)]
pub struct DiscreteStatesContinousTimeParams {
label: String,
domain: BTreeSet<String>,
cim: Option<Array3<f64>>,
transitions: Option<Array3<usize>>,
residence_time: Option<Array2<f64>>,
}
impl DiscreteStatesContinousTimeParams {
pub fn new(label: String, domain: BTreeSet<String>) -> DiscreteStatesContinousTimeParams {
debug!("Creation of node {}", label);
DiscreteStatesContinousTimeParams {
label,
domain,
cim: Option::None,
transitions: Option::None,
residence_time: Option::None,
}
}
pub fn get_cim(&self) -> &Option<Array3<f64>> {
debug!("Getting cim from node {}", self.label);
&self.cim
}
pub fn set_cim(&mut self, cim: Array3<f64>) -> Result<(), ParamsError> {
debug!("Setting cim for node {}", self.label);
self.cim = Some(cim);
match self.validate_params() {
Ok(()) => Ok(()),
Err(e) => {
warn!("Validation cim faild for node {}", self.label);
self.cim = None;
Err(e)
}
}
}
pub fn set_cim_unchecked(&mut self, cim: Array3<f64>) {
debug!("Setting cim (unchecked) for node {}", self.label);
self.cim = Some(cim);
}
pub fn get_transitions(&self) -> &Option<Array3<usize>> {
debug!("Get transitions from node {}", self.label);
&self.transitions
}
pub fn set_transitions(&mut self, transitions: Array3<usize>) {
debug!("Set transitions for node {}", self.label);
self.transitions = Some(transitions);
}
pub fn get_residence_time(&self) -> &Option<Array2<f64>> {
debug!("Get residence time from node {}", self.label);
&self.residence_time
}
pub fn set_residence_time(&mut self, residence_time: Array2<f64>) {
debug!("Set residence time for node {}", self.label);
self.residence_time = Some(residence_time);
}
}
impl ParamsTrait for DiscreteStatesContinousTimeParams {
fn reset_params(&mut self) {
debug!(
"Setting cim, transitions and residence_time to None for node {}",
self.label
);
self.cim = Option::None;
self.transitions = Option::None;
self.residence_time = Option::None;
}
fn get_random_state_uniform(&self, rng: &mut ChaCha8Rng) -> StateType {
let state = StateType::Discrete(rng.gen_range(0..(self.domain.len())));
trace!(
"Generate random state uniform. Node: {} - State: {:?}",
self.get_label(),
&state
);
return state;
}
fn get_random_residence_time(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<f64, ParamsError> {
match &self.cim {
Option::Some(cim) => {
let lambda = cim[[u, state, state]] * -1.0;
let x: f64 = rng.gen_range(0.0..=1.0);
let ret = -x.ln() / lambda;
trace!(
"Generate random residence time. Node: {} - Time: {}",
self.get_label(),
ret
);
Ok(ret)
}
Option::None => {
warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)))
}
}
}
fn get_random_state(
&self,
state: usize,
u: usize,
rng: &mut ChaCha8Rng,
) -> Result<StateType, ParamsError> {
match &self.cim {
Option::Some(cim) => {
let lambda = cim[[u, state, state]] * -1.0;
let urand: f64 = rng.gen_range(0.0..=1.0);
let next_state = cim.slice(s![u, state, ..]).map(|x| x / lambda).iter().fold(
(0, 0.0),
|mut acc, ele| {
if &acc.1 + ele < urand && ele > &0.0 {
acc.0 += 1;
}
if ele > &0.0 {
acc.1 += ele;
}
acc
},
);
let next_state = if next_state.0 < state {
next_state.0
} else {
next_state.0 + 1
};
let next_state = StateType::Discrete(next_state);
trace!(
"Generate random state. Node: {} - State: {:?}",
self.get_label(),
next_state
);
Ok(next_state)
}
Option::None => {
warn!("Cim not initialized for node {}", self.get_label());
Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)))
}
}
}
fn get_reserved_space_as_parent(&self) -> usize {
self.domain.len()
}
fn state_to_index(&self, state: &StateType) -> usize {
match state {
StateType::Discrete(val) => val.clone() as usize,
}
}
fn validate_params(&self) -> Result<(), ParamsError> {
let domain_size = self.domain.len();
if let None = self.cim {
warn!("Cim not initialized for node {}", self.get_label());
return Err(ParamsError::ParametersNotInitialized(String::from(
"CIM not initialized",
)));
}
let cim = self.cim.as_ref().unwrap();
if cim.shape()[1] != domain_size || cim.shape()[2] != domain_size {
let message = format!(
"Incompatible shape {:?} with domain {:?}",
cim.shape(),
domain_size
);
warn!("{}", message);
return Err(ParamsError::InvalidCIM(message));
}
if cim
.axis_iter(Axis(0))
.any(|x| x.diag().iter().any(|x| x >= &0.0))
{
warn!(
"The diagonal of each cim for node {} must be non-positive",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from(
"The diagonal of each cim must be non-positive",
)));
}
if cim
.sum_axis(Axis(2))
.iter()
.any(|x| f64::abs(x.clone()) > f64::EPSILON.sqrt())
{
warn!(
"The sum of each row of the cim for node {} must be 0",
self.get_label()
);
return Err(ParamsError::InvalidCIM(String::from(
"The sum of each row must be 0",
)));
}
return Ok(());
}
fn get_label(&self) -> &String {
&self.label
}
}