#[cfg(test)]
#[path = "../../../tests/unit/algorithms/gsom/state_test.rs"]
mod state_test;
use super::*;
use crate::algorithms::gsom::Coordinate;
use std::fmt::Write;
use std::ops::Range;
pub struct NetworkState {
pub shape: (Range<i32>, Range<i32>, usize),
pub mean_distance: Float,
pub nodes: Vec<NodeState>,
}
pub struct NodeState {
pub coordinate: (i32, i32),
pub unified_distance: Float,
pub node_distance: Option<Float>,
pub weights: Vec<Float>,
pub total_hits: usize,
pub last_hits: usize,
pub dump: String,
}
pub fn get_network_state<I, S, F>(network: &Network<I, S, F>) -> NetworkState
where
I: Input,
S: Storage<Item = I>,
F: StorageFactory<I, S>,
{
let ((x_min, x_max), (y_min, y_max)) = get_network_shape(network);
let mean_distance = network.mean_distance();
let nodes = network
.get_nodes()
.map(|node| {
let mut dump = String::new();
write!(dump, "{}", node.storage).unwrap();
NodeState {
coordinate: (node.coordinate.0, node.coordinate.1),
unified_distance: node.unified_distance(network, 1),
node_distance: node.node_distance(),
weights: node.weights.clone(),
total_hits: node.total_hits,
last_hits: node.get_last_hits(network.get_current_time()),
dump,
}
})
.collect::<Vec<_>>();
let dim = nodes.first().map_or(0, |node| node.weights.len());
NetworkState { shape: (x_min..x_max, y_min..y_max, dim), nodes, mean_distance }
}
pub fn get_network_shape<I, S, F>(network: &Network<I, S, F>) -> ((i32, i32), (i32, i32))
where
I: Input,
S: Storage<Item = I>,
F: StorageFactory<I, S>,
{
network.get_coordinates().fold(
((i32::MAX, i32::MIN), (i32::MAX, i32::MIN)),
|((x_min, x_max), (y_min, y_max)), Coordinate(x, y)| {
((x_min.min(x), x_max.max(x)), (y_min.min(y), y_max.max(y)))
},
)
}