#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-json", derive(serde::Serialize, serde::Deserialize))]
pub struct ClusterFeature {
pub n: u64,
pub linear_sum: Vec<f64>,
pub squared_sum: Vec<f64>,
}
impl ClusterFeature {
pub fn new(n_features: usize) -> Self {
Self {
n: 0,
linear_sum: vec![0.0; n_features],
squared_sum: vec![0.0; n_features],
}
}
pub fn absorb(&mut self, point: &[f64]) {
debug_assert_eq!(
point.len(),
self.linear_sum.len(),
"point dimensionality mismatch: expected {}, got {}",
self.linear_sum.len(),
point.len(),
);
self.n += 1;
for (i, &v) in point.iter().enumerate() {
self.linear_sum[i] += v;
self.squared_sum[i] += v * v;
}
}
pub fn center(&self) -> Vec<f64> {
if self.n == 0 {
return vec![0.0; self.linear_sum.len()];
}
let n = self.n as f64;
self.linear_sum.iter().map(|&ls| ls / n).collect()
}
pub fn radius(&self) -> f64 {
if self.n < 2 {
return 0.0;
}
let n = self.n as f64;
let d = self.linear_sum.len() as f64;
let sum_var: f64 = self
.linear_sum
.iter()
.zip(self.squared_sum.iter())
.map(|(&ls, &ss)| {
let mean = ls / n;
ss / n - mean * mean
})
.sum();
let avg_var = sum_var / d;
if avg_var <= 0.0 {
return 0.0;
}
avg_var.sqrt()
}
pub fn merge(&mut self, other: &ClusterFeature) {
debug_assert_eq!(
self.linear_sum.len(),
other.linear_sum.len(),
"cannot merge CFs with different dimensionality",
);
self.n += other.n;
for (i, &v) in other.linear_sum.iter().enumerate() {
self.linear_sum[i] += v;
}
for (i, &v) in other.squared_sum.iter().enumerate() {
self.squared_sum[i] += v;
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-json", derive(serde::Serialize, serde::Deserialize))]
pub struct CluStreamConfig {
pub max_micro_clusters: usize,
pub max_radius_factor: f64,
pub n_features: usize,
}
#[derive(Debug, Clone)]
pub struct CluStreamConfigBuilder {
max_micro_clusters: usize,
max_radius_factor: f64,
n_features: usize,
}
impl CluStreamConfig {
pub fn builder(max_micro_clusters: usize) -> CluStreamConfigBuilder {
CluStreamConfigBuilder {
max_micro_clusters,
max_radius_factor: 2.0,
n_features: 0,
}
}
}
impl CluStreamConfigBuilder {
pub fn max_radius_factor(mut self, f: f64) -> Self {
self.max_radius_factor = f;
self
}
pub fn n_features(mut self, d: usize) -> Self {
self.n_features = d;
self
}
pub fn build(self) -> Result<CluStreamConfig, irithyll_core::error::ConfigError> {
if self.max_micro_clusters < 2 {
return Err(irithyll_core::error::ConfigError::out_of_range(
"max_micro_clusters",
"must be >= 2",
self.max_micro_clusters,
));
}
Ok(CluStreamConfig {
max_micro_clusters: self.max_micro_clusters,
max_radius_factor: self.max_radius_factor,
n_features: self.n_features,
})
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-json", derive(serde::Serialize, serde::Deserialize))]
pub struct CluStream {
config: CluStreamConfig,
micro_clusters: Vec<ClusterFeature>,
n_features: usize,
n_samples: u64,
}
impl CluStream {
pub fn new(config: CluStreamConfig) -> Self {
let n_features = config.n_features;
Self {
config,
micro_clusters: Vec::new(),
n_features,
n_samples: 0,
}
}
pub fn train_one(&mut self, features: &[f64]) {
if self.n_features == 0 {
self.n_features = features.len();
}
self.n_samples += 1;
if self.micro_clusters.is_empty() {
let mut cf = ClusterFeature::new(self.n_features);
cf.absorb(features);
self.micro_clusters.push(cf);
return;
}
let (nearest_idx, nearest_dist) = self.nearest_mc(features);
let mc = &self.micro_clusters[nearest_idx];
let r = mc.radius();
let threshold = self.config.max_radius_factor * r;
if nearest_dist < threshold || mc.n < 3 {
self.micro_clusters[nearest_idx].absorb(features);
return;
}
if self.micro_clusters.len() < self.config.max_micro_clusters {
let mut cf = ClusterFeature::new(self.n_features);
cf.absorb(features);
self.micro_clusters.push(cf);
return;
}
let (i, j) = self.closest_mc_pair();
let (lo, hi) = if i < j { (i, j) } else { (j, i) };
let removed = self.micro_clusters.remove(hi);
self.micro_clusters[lo].merge(&removed);
let mut cf = ClusterFeature::new(self.n_features);
cf.absorb(features);
self.micro_clusters.push(cf);
}
pub fn predict(&self, features: &[f64]) -> usize {
assert!(
!self.micro_clusters.is_empty(),
"cannot predict with no micro-clusters -- call train_one first"
);
let (idx, _) = self.nearest_mc(features);
idx
}
pub fn micro_clusters(&self) -> &[ClusterFeature] {
&self.micro_clusters
}
pub fn n_micro_clusters(&self) -> usize {
self.micro_clusters.len()
}
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn reset(&mut self) {
self.micro_clusters.clear();
self.n_samples = 0;
if self.config.n_features == 0 {
self.n_features = 0;
}
}
pub fn macro_clusters(&self, k: usize) -> Vec<Vec<usize>> {
let n = self.micro_clusters.len();
if n == 0 {
return Vec::new();
}
let effective_k = k.min(n);
let centers: Vec<Vec<f64>> = self.micro_clusters.iter().map(|mc| mc.center()).collect();
let weights: Vec<u64> = self.micro_clusters.iter().map(|mc| mc.n).collect();
let assignments = weighted_kmeans(¢ers, &weights, effective_k, 100);
let mut groups: Vec<Vec<usize>> = vec![Vec::new(); effective_k];
for (mc_idx, &cluster_id) in assignments.iter().enumerate() {
groups[cluster_id].push(mc_idx);
}
groups.retain(|g| !g.is_empty());
groups
}
fn nearest_mc(&self, point: &[f64]) -> (usize, f64) {
let mut best_idx = 0;
let mut best_dist = f64::MAX;
for (i, mc) in self.micro_clusters.iter().enumerate() {
let d = euclidean_distance(point, &mc.center());
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
(best_idx, best_dist)
}
fn closest_mc_pair(&self) -> (usize, usize) {
let n = self.micro_clusters.len();
debug_assert!(n >= 2, "need at least 2 micro-clusters to find a pair");
let mut best_i = 0;
let mut best_j = 1;
let mut best_dist = f64::MAX;
let centers: Vec<Vec<f64>> = self.micro_clusters.iter().map(|mc| mc.center()).collect();
for i in 0..n {
for j in (i + 1)..n {
let d = euclidean_distance(¢ers[i], ¢ers[j]);
if d < best_dist {
best_dist = d;
best_i = i;
best_j = j;
}
}
}
(best_i, best_j)
}
}
fn weighted_kmeans(centers: &[Vec<f64>], weights: &[u64], k: usize, max_iter: usize) -> Vec<usize> {
let n = centers.len();
if n == 0 || k == 0 {
return vec![0; n];
}
let effective_k = k.min(n);
let d = centers[0].len();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| weights[b].cmp(&weights[a]));
let mut centroids: Vec<Vec<f64>> = order[..effective_k]
.iter()
.map(|&i| centers[i].clone())
.collect();
let mut assignments = vec![0usize; n];
for _ in 0..max_iter {
let mut changed = false;
for i in 0..n {
let mut best_c = 0;
let mut best_dist = f64::MAX;
for (c, centroid) in centroids.iter().enumerate() {
let dist = euclidean_distance(¢ers[i], centroid);
if dist < best_dist {
best_dist = dist;
best_c = c;
}
}
if assignments[i] != best_c {
assignments[i] = best_c;
changed = true;
}
}
if !changed {
break;
}
let mut new_centroids = vec![vec![0.0; d]; effective_k];
let mut total_weight = vec![0.0_f64; effective_k];
for i in 0..n {
let c = assignments[i];
let w = weights[i] as f64;
total_weight[c] += w;
for (j, &v) in centers[i].iter().enumerate() {
new_centroids[c][j] += w * v;
}
}
for c in 0..effective_k {
if total_weight[c] > 0.0 {
for val in new_centroids[c].iter_mut().take(d) {
*val /= total_weight[c];
}
}
if total_weight[c] == 0.0 {
new_centroids[c] = centroids[c].clone();
}
}
centroids = new_centroids;
}
assignments
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len(), "dimension mismatch in euclidean_distance");
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let diff = x - y;
diff * diff
})
.sum::<f64>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
fn xorshift64(state: &mut u64) -> f64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
(x as f64) / (u64::MAX as f64)
}
#[test]
fn single_point_creates_micro_cluster() {
let config = CluStreamConfig::builder(5).build().unwrap();
let mut cs = CluStream::new(config);
cs.train_one(&[1.0, 2.0, 3.0]);
assert_eq!(cs.n_micro_clusters(), 1);
assert_eq!(cs.micro_clusters()[0].n, 1);
assert_eq!(cs.n_samples_seen(), 1);
}
#[test]
fn cluster_feature_absorb() {
let mut cf = ClusterFeature::new(2);
cf.absorb(&[1.0, 4.0]);
cf.absorb(&[3.0, 6.0]);
cf.absorb(&[2.0, 5.0]);
assert_eq!(cf.n, 3);
let center = cf.center();
assert!(approx_eq(center[0], 2.0));
assert!(approx_eq(center[1], 5.0));
}
#[test]
fn cluster_feature_merge() {
let mut cf1 = ClusterFeature::new(2);
cf1.absorb(&[1.0, 2.0]);
cf1.absorb(&[3.0, 4.0]);
let mut cf2 = ClusterFeature::new(2);
cf2.absorb(&[5.0, 6.0]);
cf1.merge(&cf2);
assert_eq!(cf1.n, 3);
assert!(approx_eq(cf1.linear_sum[0], 9.0));
assert!(approx_eq(cf1.linear_sum[1], 12.0));
assert!(approx_eq(cf1.squared_sum[0], 35.0));
assert!(approx_eq(cf1.squared_sum[1], 56.0));
let center = cf1.center();
assert!(approx_eq(center[0], 3.0));
assert!(approx_eq(center[1], 4.0));
}
#[test]
fn nearby_points_absorbed() {
let config = CluStreamConfig::builder(5)
.max_radius_factor(4.0)
.build()
.unwrap();
let mut cs = CluStream::new(config);
cs.train_one(&[1.0, 1.0]);
cs.train_one(&[1.01, 1.01]);
cs.train_one(&[0.99, 0.99]);
cs.train_one(&[1.02, 1.02]);
assert_eq!(cs.n_micro_clusters(), 1);
assert_eq!(cs.micro_clusters()[0].n, 4);
}
#[test]
fn distant_point_creates_new_mc() {
let config = CluStreamConfig::builder(5)
.max_radius_factor(2.0)
.build()
.unwrap();
let mut cs = CluStream::new(config);
cs.train_one(&[0.0, 0.0]);
cs.train_one(&[0.01, 0.01]);
cs.train_one(&[0.02, 0.02]);
cs.train_one(&[100.0, 100.0]);
assert_eq!(cs.n_micro_clusters(), 2);
}
#[test]
fn max_micro_clusters_triggers_merge() {
let config = CluStreamConfig::builder(3)
.max_radius_factor(2.0)
.build()
.unwrap();
let mut cs = CluStream::new(config);
for _ in 0..3 {
cs.train_one(&[0.0, 0.0]);
}
for _ in 0..3 {
cs.train_one(&[50.0, 50.0]);
}
for _ in 0..3 {
cs.train_one(&[100.0, 100.0]);
}
assert_eq!(cs.n_micro_clusters(), 3);
cs.train_one(&[200.0, 200.0]);
assert_eq!(cs.n_micro_clusters(), 3);
}
#[test]
fn macro_clusters_separates_groups() {
let config = CluStreamConfig::builder(20)
.max_radius_factor(2.0)
.build()
.unwrap();
let mut cs = CluStream::new(config);
let mut rng_state: u64 = 12345;
for _ in 0..50 {
let x = xorshift64(&mut rng_state) * 2.0 - 1.0; let y = xorshift64(&mut rng_state) * 2.0 - 1.0;
cs.train_one(&[x, y]);
}
for _ in 0..50 {
let x = 100.0 + xorshift64(&mut rng_state) * 2.0 - 1.0;
let y = 100.0 + xorshift64(&mut rng_state) * 2.0 - 1.0;
cs.train_one(&[x, y]);
}
let groups = cs.macro_clusters(2);
assert_eq!(groups.len(), 2);
for group in &groups {
let centers: Vec<Vec<f64>> = group
.iter()
.map(|&idx| cs.micro_clusters()[idx].center())
.collect();
let all_low = centers.iter().all(|c| c[0] < 50.0);
let all_high = centers.iter().all(|c| c[0] >= 50.0);
assert!(
all_low || all_high,
"macro-cluster group mixes centers from both regions"
);
}
}
#[test]
fn predict_nearest_micro_cluster() {
let config = CluStreamConfig::builder(10)
.max_radius_factor(2.0)
.build()
.unwrap();
let mut cs = CluStream::new(config);
for _ in 0..5 {
cs.train_one(&[0.0, 0.0]);
}
for _ in 0..5 {
cs.train_one(&[100.0, 100.0]);
}
let idx_near_origin = cs.predict(&[0.1, 0.1]);
let center_origin = cs.micro_clusters()[idx_near_origin].center();
assert!(
center_origin[0] < 50.0,
"expected prediction near origin, got center {:?}",
center_origin
);
let idx_near_far = cs.predict(&[99.9, 99.9]);
let center_far = cs.micro_clusters()[idx_near_far].center();
assert!(
center_far[0] >= 50.0,
"expected prediction near (100,100), got center {:?}",
center_far
);
}
#[test]
fn reset_clears_state() {
let config = CluStreamConfig::builder(5).build().unwrap();
let mut cs = CluStream::new(config);
cs.train_one(&[1.0, 2.0]);
cs.train_one(&[3.0, 4.0]);
assert_eq!(cs.n_micro_clusters(), 1);
assert_eq!(cs.n_samples_seen(), 2);
cs.reset();
assert_eq!(cs.n_micro_clusters(), 0);
assert_eq!(cs.n_samples_seen(), 0);
}
#[test]
fn config_builder_validates() {
use irithyll_core::error::ConfigError;
let result = CluStreamConfig::builder(1).build();
assert!(result.is_err());
assert!(
matches!(&result.unwrap_err(), ConfigError::OutOfRange { param, .. } if *param == "max_micro_clusters"),
"expected OutOfRange for max_micro_clusters"
);
let result = CluStreamConfig::builder(0).build();
assert!(result.is_err());
let result = CluStreamConfig::builder(2).build();
assert!(result.is_ok());
}
}