use std::collections::HashSet;
use crate::math::{point::Point, FloatNumber};
#[derive(Debug, PartialEq, Clone)]
pub struct Cluster<T, const N: usize>
where
T: FloatNumber,
{
members: HashSet<usize>,
centroid: Point<T, N>,
}
impl<T, const N: usize> Cluster<T, N>
where
T: FloatNumber,
{
#[must_use]
pub fn new() -> Self {
Self {
members: HashSet::new(),
centroid: [T::zero(); N],
}
}
#[must_use]
pub fn len(&self) -> usize {
self.members.len()
}
pub fn members(&self) -> impl Iterator<Item = &usize> {
self.members.iter()
}
pub fn add_member(&mut self, index: usize, point: &Point<T, N>) -> bool {
if !self.members.insert(index) {
return false;
}
let size = T::from_usize(self.members.len());
for (i, &value) in point.iter().enumerate() {
self.centroid[i] *= (size - T::one()) / size;
self.centroid[i] += value / size;
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let actual: Cluster<f32, 2> = Cluster::new();
assert_eq!(
actual,
Cluster {
members: HashSet::new(),
centroid: [0.0, 0.0],
}
);
}
#[test]
fn test_add_member() {
let mut cluster: Cluster<f32, 2> = Cluster::new();
let point = [1.0, 2.0];
assert!(cluster.add_member(0, &point));
assert_eq!(cluster.len(), 1);
assert_eq!(
cluster,
Cluster {
members: HashSet::from([0]),
centroid: [1.0, 2.0],
}
);
let point = [2.0, 4.0];
assert!(cluster.add_member(1, &point));
assert_eq!(cluster.len(), 2);
assert_eq!(
cluster,
Cluster {
members: HashSet::from([0, 1]),
centroid: [1.5, 3.0],
}
);
let point = [3.0, 6.0];
assert!(cluster.add_member(2, &point));
assert_eq!(cluster.len(), 3);
assert_eq!(
cluster,
Cluster {
members: HashSet::from([0, 1, 2]),
centroid: [2.0, 4.0],
}
);
assert!(!cluster.add_member(2, &point));
assert_eq!(cluster.len(), 3);
assert_eq!(
cluster,
Cluster {
members: HashSet::from([0, 1, 2]),
centroid: [2.0, 4.0],
}
);
}
}