#[cfg(test)]
#[path = "../../../tests/unit/algorithms/gsom/node_test.rs"]
mod node_test;
use super::*;
use crate::utils::Float;
use std::collections::VecDeque;
use std::fmt::Formatter;
pub struct Node<I: Input, S: Storage<Item = I>> {
pub weights: Vec<Float>,
pub error: Float,
pub total_hits: usize,
pub last_hits: VecDeque<usize>,
pub coordinate: Coordinate,
pub storage: S,
hit_memory_size: usize,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Ord, PartialOrd)]
pub struct Coordinate(pub i32, pub i32);
impl<I: Input, S: Storage<Item = I>> Node<I, S> {
pub fn new(coordinate: Coordinate, weights: &[Float], error: Float, hit_memory_size: usize, storage: S) -> Self {
Self {
weights: weights.to_vec(),
error,
total_hits: 0,
last_hits: VecDeque::with_capacity(hit_memory_size + 1),
coordinate,
storage,
hit_memory_size,
}
}
pub fn adjust(&mut self, target: &[Float], learning_rate: Float) {
debug_assert!(self.weights.len() == target.len());
for (idx, value) in target.iter().enumerate() {
self.weights[idx] += learning_rate * (*value - self.weights[idx]);
}
}
pub fn distance(&self, weights: &[Float]) -> Float {
self.storage.distance(self.weights.as_slice(), weights)
}
pub fn new_hit(&mut self, time: usize) {
self.total_hits += 1;
if self.last_hits.front().map_or(true, |last_time| *last_time != time) {
self.last_hits.push_front(time);
self.last_hits.truncate(self.hit_memory_size);
}
}
pub fn get_last_hits(&self, current_time: usize) -> usize {
self.last_hits
.iter()
.filter(|&hit| {
if current_time > self.hit_memory_size {
(current_time - self.hit_memory_size) < *hit
} else {
true
}
})
.count()
}
pub fn is_boundary<F: StorageFactory<I, S>>(&self, network: &Network<I, S, F>) -> bool {
self.neighbours(network, 1).filter(|(_, (x, y))| x.abs() + y.abs() < 2).any(|(node, _)| node.is_none())
}
pub fn neighbours<'a, F: StorageFactory<I, S>>(
&self,
network: &'a Network<I, S, F>,
radius: usize,
) -> impl Iterator<Item = (Option<Coordinate>, (i32, i32))> + 'a {
let radius = radius as i32;
let Coordinate(node_x, node_y) = self.coordinate;
(-radius..=radius).flat_map(move |x| {
(-radius..=radius)
.filter(move |&y| !(x == 0 && y == 0))
.map(move |y| (network.find(&Coordinate(node_x + x, node_y + y)).map(|node| node.coordinate), (x, y)))
})
}
pub fn unified_distance<F: StorageFactory<I, S>>(&self, network: &Network<I, S, F>, radius: usize) -> Float {
let (sum, count) = self
.neighbours(network, radius)
.filter_map(|(coord, _)| coord.and_then(|coord| network.find(&coord)))
.fold((0., 0), |(sum, count), node| {
let distance = self.storage.distance(self.weights.as_slice(), node.weights.as_slice());
(sum + distance, count + 1)
});
if count > 0 {
sum / count as Float
} else {
0.
}
}
pub fn node_distance(&self) -> Option<Float> {
self.storage.iter().next().map(|item| self.storage.distance(self.weights.as_slice(), item.weights()))
}
pub fn mse(&self) -> Float {
let (count, sum) = self
.storage
.iter()
.take(1)
.fold((0, 0.), |(items, acc), data| {
let err = data
.weights()
.iter()
.zip(self.weights.iter())
.map(|(&w1, &w2)| (w1 - w2) * (w1 - w2))
.sum::<Float>()
/ self.weights.len() as Float;
(items + 1, acc + err)
});
if count > 0 {
sum / count as Float
} else {
sum
}
}
}
impl Display for Coordinate {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("({},{})", self.0, self.1))
}
}
impl From<(i32, i32)> for Coordinate {
fn from(value: (i32, i32)) -> Self {
Coordinate(value.0, value.1)
}
}