use std::collections::BTreeSet;
use log::info;
use ndarray::prelude::*;
use crate::params::{DiscreteStatesContinousTimeParams, Params, ParamsTrait, StateType};
use crate::process;
use super::ctmp::CtmpProcess;
use super::{NetworkProcess, NetworkProcessState};
pub struct CtbnNetwork {
adj_matrix: Option<Array2<u16>>,
nodes: Vec<Params>,
}
impl CtbnNetwork {
pub fn new() -> CtbnNetwork {
CtbnNetwork {
adj_matrix: None,
nodes: Vec::new(),
}
}
pub fn amalgamation(&self) -> CtmpProcess {
info!("Network Amalgamation Started");
let variables_domain =
Array1::from_iter(self.nodes.iter().map(|x| x.get_reserved_space_as_parent()));
let state_space = variables_domain.product();
let variables_set = BTreeSet::from_iter(self.get_node_indices());
let mut amalgamated_cim: Array3<f64> = Array::zeros((1, state_space, state_space));
for idx_current_state in 0..state_space {
let current_state = CtbnNetwork::idx_to_state(&variables_domain, idx_current_state);
let current_state_statetype: NetworkProcessState = current_state
.iter()
.map(|x| StateType::Discrete(*x))
.collect();
for idx_node in 0..self.nodes.len() {
let p = match self.get_node(idx_node) {
Params::DiscreteStatesContinousTime(p) => p,
};
for next_node_state in 0..variables_domain[idx_node] {
let mut next_state = current_state.clone();
next_state[idx_node] = next_node_state;
let next_state_statetype: NetworkProcessState =
next_state.iter().map(|x| StateType::Discrete(*x)).collect();
let idx_next_state = self.get_param_index_from_custom_parent_set(
&next_state_statetype,
&variables_set,
);
amalgamated_cim[[0, idx_current_state, idx_next_state]] +=
p.get_cim().as_ref().unwrap()[[
self.get_param_index_network(idx_node, ¤t_state_statetype),
current_state[idx_node],
next_node_state,
]];
}
}
}
let mut amalgamated_param = DiscreteStatesContinousTimeParams::new(
"ctmp".to_string(),
BTreeSet::from_iter((0..state_space).map(|x| x.to_string())),
);
amalgamated_param.set_cim(amalgamated_cim).unwrap();
let mut ctmp = CtmpProcess::new();
ctmp.add_node(Params::DiscreteStatesContinousTime(amalgamated_param))
.unwrap();
return ctmp;
}
pub fn idx_to_state(variables_domain: &Array1<usize>, state: usize) -> Array1<usize> {
let mut state = state;
let mut array_state = Array1::zeros(variables_domain.shape()[0]);
for (idx, var) in variables_domain.indexed_iter() {
array_state[idx] = state % var;
state = state / var;
}
return array_state;
}
pub fn get_adj_matrix(&self) -> Option<&Array2<u16>> {
self.adj_matrix.as_ref()
}
}
impl process::NetworkProcess for CtbnNetwork {
fn initialize_adj_matrix(&mut self) {
self.adj_matrix = Some(Array2::<u16>::zeros(
(self.nodes.len(), self.nodes.len()).f(),
));
}
fn add_node(&mut self, mut n: Params) -> Result<usize, process::NetworkError> {
n.reset_params();
self.adj_matrix = Option::None;
self.nodes.push(n);
Ok(self.nodes.len() - 1)
}
fn add_edge(&mut self, parent: usize, child: usize) {
if let None = self.adj_matrix {
self.initialize_adj_matrix();
}
if let Some(network) = &mut self.adj_matrix {
network[[parent, child]] = 1;
self.nodes[child].reset_params();
}
}
fn get_node_indices(&self) -> std::ops::Range<usize> {
0..self.nodes.len()
}
fn get_number_of_nodes(&self) -> usize {
self.nodes.len()
}
fn get_node(&self, node_idx: usize) -> &Params {
&self.nodes[node_idx]
}
fn get_node_mut(&mut self, node_idx: usize) -> &mut Params {
&mut self.nodes[node_idx]
}
fn get_param_index_network(&self, node: usize, current_state: &NetworkProcessState) -> usize {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.fold((0, 1), |mut acc, x| {
if x.1 > &0 {
acc.0 += self.nodes[x.0].state_to_index(¤t_state[x.0]) * acc.1;
acc.1 *= self.nodes[x.0].get_reserved_space_as_parent();
}
acc
})
.0
}
fn get_param_index_from_custom_parent_set(
&self,
current_state: &NetworkProcessState,
parent_set: &BTreeSet<usize>,
) -> usize {
parent_set
.iter()
.fold((0, 1), |mut acc, x| {
acc.0 += self.nodes[*x].state_to_index(¤t_state[*x]) * acc.1;
acc.1 *= self.nodes[*x].get_reserved_space_as_parent();
acc
})
.0
}
fn get_parent_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.column(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
fn get_children_set(&self, node: usize) -> BTreeSet<usize> {
self.adj_matrix
.as_ref()
.unwrap()
.row(node)
.iter()
.enumerate()
.filter_map(|(idx, x)| if x > &0 { Some(idx) } else { None })
.collect()
}
}