fn euclidean_distance_squared(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b).map(|(ai, bi)| (ai - bi).powi(2)).sum()
}
#[derive(Debug, Clone)]
pub struct StreamingKMeansConfigBuilder {
k: usize,
forgetting_factor: f64,
seed: Option<u64>,
}
impl StreamingKMeansConfigBuilder {
pub fn forgetting_factor(mut self, f: f64) -> Self {
self.forgetting_factor = f;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = Some(s);
self
}
pub fn build(self) -> Result<StreamingKMeansConfig, irithyll_core::error::ConfigError> {
use irithyll_core::error::ConfigError;
if self.k < 1 {
return Err(ConfigError::out_of_range("k", "must be >= 1", self.k));
}
if self.forgetting_factor <= 0.0 || self.forgetting_factor > 1.0 {
return Err(ConfigError::out_of_range(
"forgetting_factor",
"must be in (0, 1]",
self.forgetting_factor,
));
}
Ok(StreamingKMeansConfig {
k: self.k,
forgetting_factor: self.forgetting_factor,
seed: self.seed,
})
}
}
#[derive(Debug, Clone)]
pub struct StreamingKMeansConfig {
pub k: usize,
pub forgetting_factor: f64,
pub seed: Option<u64>,
}
impl StreamingKMeansConfig {
pub fn builder(k: usize) -> StreamingKMeansConfigBuilder {
StreamingKMeansConfigBuilder {
k,
forgetting_factor: 1.0,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct StreamingKMeans {
config: StreamingKMeansConfig,
centroids: Vec<Vec<f64>>,
counts: Vec<f64>, n_samples: u64,
initialized: bool, }
impl StreamingKMeans {
pub fn new(config: StreamingKMeansConfig) -> Self {
Self {
centroids: Vec::with_capacity(config.k),
counts: Vec::with_capacity(config.k),
n_samples: 0,
initialized: false,
config,
}
}
pub fn train_one(&mut self, features: &[f64]) {
self.n_samples += 1;
if !self.initialized {
let is_duplicate = self.centroids.iter().any(|c| {
c.len() == features.len()
&& c.iter()
.zip(features)
.all(|(ci, fi)| (ci - fi).abs() < f64::EPSILON)
});
if !is_duplicate {
self.centroids.push(features.to_vec());
self.counts.push(1.0);
if self.centroids.len() == self.config.k {
self.initialized = true;
}
return;
}
if self.centroids.is_empty() {
return;
}
}
let nearest = self.nearest_centroid(features);
self.counts[nearest] += 1.0;
let eta = 1.0 / self.counts[nearest];
let centroid = &mut self.centroids[nearest];
for (ci, fi) in centroid.iter_mut().zip(features) {
*ci += eta * (fi - *ci);
}
if self.config.forgetting_factor < 1.0 {
for count in &mut self.counts {
*count *= self.config.forgetting_factor;
}
}
}
pub fn predict(&self, features: &[f64]) -> usize {
assert!(
!self.centroids.is_empty(),
"cannot predict: no centroids initialized"
);
self.nearest_centroid(features)
}
pub fn predict_distance(&self, features: &[f64]) -> (usize, f64) {
assert!(
!self.centroids.is_empty(),
"cannot predict: no centroids initialized"
);
let mut best_idx = 0;
let mut best_dist = f64::MAX;
for (i, c) in self.centroids.iter().enumerate() {
let d = euclidean_distance_squared(features, c);
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
(best_idx, best_dist)
}
pub fn centroids(&self) -> &[Vec<f64>] {
&self.centroids
}
pub fn cluster_counts(&self) -> Vec<u64> {
self.counts.iter().map(|c| c.round() as u64).collect()
}
pub fn n_clusters(&self) -> usize {
self.centroids.len()
}
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn inertia(&self, data: &[&[f64]]) -> f64 {
assert!(
!self.centroids.is_empty(),
"cannot compute inertia: no centroids initialized"
);
data.iter()
.map(|sample| {
self.centroids
.iter()
.map(|c| euclidean_distance_squared(sample, c))
.fold(f64::MAX, f64::min)
})
.sum()
}
pub fn reset(&mut self) {
self.centroids.clear();
self.counts.clear();
self.n_samples = 0;
self.initialized = false;
}
fn nearest_centroid(&self, features: &[f64]) -> usize {
let mut best_idx = 0;
let mut best_dist = f64::MAX;
for (i, c) in self.centroids.iter().enumerate() {
let d = euclidean_distance_squared(features, c);
if d < best_dist {
best_dist = d;
best_idx = i;
}
}
best_idx
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn xorshift64(state: &mut u64) -> f64 {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
(x as f64) / (u64::MAX as f64)
}
#[test]
fn initialization_from_first_k_samples() {
let config = StreamingKMeansConfig::builder(3).build().unwrap();
let mut km = StreamingKMeans::new(config);
let points = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
for p in &points {
km.train_one(p);
}
assert!(km.is_initialized());
assert_eq!(km.n_clusters(), 3);
for (i, p) in points.iter().enumerate() {
assert!(
(km.centroids()[i][0] - p[0]).abs() < EPS,
"centroid {} x mismatch: expected {}, got {}",
i,
p[0],
km.centroids()[i][0]
);
assert!(
(km.centroids()[i][1] - p[1]).abs() < EPS,
"centroid {} y mismatch: expected {}, got {}",
i,
p[1],
km.centroids()[i][1]
);
}
}
#[test]
fn predict_nearest_centroid() {
let config = StreamingKMeansConfig::builder(3).build().unwrap();
let mut km = StreamingKMeans::new(config);
km.train_one(&[0.0, 0.0]);
km.train_one(&[10.0, 0.0]);
km.train_one(&[0.0, 10.0]);
assert!(km.is_initialized());
assert_eq!(km.predict(&[0.1, 0.1]), 0);
assert_eq!(km.predict(&[9.9, 0.1]), 1);
assert_eq!(km.predict(&[0.1, 9.9]), 2);
}
#[test]
fn centroids_converge_on_clusters() {
let config = StreamingKMeansConfig::builder(2).build().unwrap();
let mut km = StreamingKMeans::new(config);
let mut state = 42u64;
let mut samples = Vec::with_capacity(200);
for _ in 0..100 {
let x = (xorshift64(&mut state) - 0.5) * 1.0;
let y = (xorshift64(&mut state) - 0.5) * 1.0;
samples.push([x, y]);
}
for _ in 0..100 {
let x = 10.0 + (xorshift64(&mut state) - 0.5) * 1.0;
let y = 10.0 + (xorshift64(&mut state) - 0.5) * 1.0;
samples.push([x, y]);
}
for i in 0..100 {
km.train_one(&samples[i]);
km.train_one(&samples[100 + i]);
}
assert!(km.is_initialized());
let c0 = &km.centroids()[0];
let c1 = &km.centroids()[1];
let d0_origin = euclidean_distance_squared(c0, &[0.0, 0.0]);
let d0_far = euclidean_distance_squared(c0, &[10.0, 10.0]);
let tolerance = 4.0;
if d0_origin < d0_far {
assert!(
d0_origin < tolerance,
"centroid 0 too far from [0,0]: squared distance = {}",
d0_origin
);
let d1_far = euclidean_distance_squared(c1, &[10.0, 10.0]);
assert!(
d1_far < tolerance,
"centroid 1 too far from [10,10]: squared distance = {}",
d1_far
);
} else {
assert!(
d0_far < tolerance,
"centroid 0 too far from [10,10]: squared distance = {}",
d0_far
);
let d1_origin = euclidean_distance_squared(c1, &[0.0, 0.0]);
assert!(
d1_origin < tolerance,
"centroid 1 too far from [0,0]: squared distance = {}",
d1_origin
);
}
}
#[test]
fn forgetting_factor_adapts_to_drift() {
let config_forget = StreamingKMeansConfig::builder(1)
.forgetting_factor(0.9)
.build()
.unwrap();
let config_no_forget = StreamingKMeansConfig::builder(1).build().unwrap();
let mut km_forget = StreamingKMeans::new(config_forget);
let mut km_no_forget = StreamingKMeans::new(config_no_forget);
for _ in 0..50 {
km_forget.train_one(&[0.0, 0.0]);
km_no_forget.train_one(&[0.0, 0.0]);
}
for _ in 0..50 {
km_forget.train_one(&[10.0, 10.0]);
km_no_forget.train_one(&[10.0, 10.0]);
}
let dist_forget = euclidean_distance_squared(&km_forget.centroids()[0], &[10.0, 10.0]);
let dist_no_forget =
euclidean_distance_squared(&km_no_forget.centroids()[0], &[10.0, 10.0]);
assert!(
dist_forget < dist_no_forget,
"forgetting model should be closer to [10,10]: forget dist^2 = {}, no-forget dist^2 = {}",
dist_forget,
dist_no_forget
);
}
#[test]
fn predict_distance_returns_correct_distance() {
let config = StreamingKMeansConfig::builder(2).build().unwrap();
let mut km = StreamingKMeans::new(config);
km.train_one(&[0.0, 0.0]);
km.train_one(&[10.0, 0.0]);
let (idx, dist) = km.predict_distance(&[3.0, 4.0]);
assert_eq!(idx, 0);
assert!((dist - 25.0).abs() < EPS, "expected 25.0, got {}", dist);
}
#[test]
fn cluster_counts_track_assignments() {
let config = StreamingKMeansConfig::builder(2).build().unwrap();
let mut km = StreamingKMeans::new(config);
km.train_one(&[0.0, 0.0]);
km.train_one(&[10.0, 10.0]);
for _ in 0..5 {
km.train_one(&[0.1, 0.1]);
}
for _ in 0..3 {
km.train_one(&[9.9, 9.9]);
}
let counts = km.cluster_counts();
assert_eq!(counts[0], 6, "cluster 0 count mismatch: {:?}", counts);
assert_eq!(counts[1], 4, "cluster 1 count mismatch: {:?}", counts);
assert_eq!(km.n_samples_seen(), 10);
}
#[test]
fn inertia_decreases_with_training() {
let config = StreamingKMeansConfig::builder(2).build().unwrap();
let mut km = StreamingKMeans::new(config);
let mut state = 123u64;
let mut test_data_vecs = Vec::new();
for _ in 0..50 {
let x = (xorshift64(&mut state) - 0.5) * 1.0;
let y = (xorshift64(&mut state) - 0.5) * 1.0;
test_data_vecs.push(vec![x, y]);
}
for _ in 0..50 {
let x = 10.0 + (xorshift64(&mut state) - 0.5) * 1.0;
let y = 10.0 + (xorshift64(&mut state) - 0.5) * 1.0;
test_data_vecs.push(vec![x, y]);
}
let test_data: Vec<&[f64]> = test_data_vecs.iter().map(|v| v.as_slice()).collect();
km.train_one(test_data[0]);
km.train_one(test_data[50]);
let inertia_before = km.inertia(&test_data);
for sample in &test_data {
km.train_one(sample);
}
let inertia_after = km.inertia(&test_data);
assert!(
inertia_after <= inertia_before,
"inertia should decrease with training: before = {}, after = {}",
inertia_before,
inertia_after
);
}
#[test]
fn reset_clears_all_state() {
let config = StreamingKMeansConfig::builder(2).build().unwrap();
let mut km = StreamingKMeans::new(config);
km.train_one(&[1.0, 2.0]);
km.train_one(&[3.0, 4.0]);
km.train_one(&[5.0, 6.0]);
km.reset();
assert_eq!(km.n_samples_seen(), 0);
assert_eq!(km.n_clusters(), 0);
assert!(!km.is_initialized());
assert!(km.centroids().is_empty());
assert!(km.cluster_counts().is_empty());
}
#[test]
fn config_builder_validates() {
use irithyll_core::error::ConfigError;
let result = StreamingKMeansConfig::builder(0).build();
assert!(result.is_err(), "k=0 should fail validation");
assert!(
matches!(&result.unwrap_err(), ConfigError::OutOfRange { param, .. } if *param == "k"),
"expected OutOfRange for k"
);
let result = StreamingKMeansConfig::builder(3)
.forgetting_factor(0.0)
.build();
assert!(result.is_err(), "forgetting_factor=0.0 should fail");
let result = StreamingKMeansConfig::builder(3)
.forgetting_factor(1.5)
.build();
assert!(result.is_err(), "forgetting_factor=1.5 should fail");
let result = StreamingKMeansConfig::builder(3)
.forgetting_factor(-0.1)
.build();
assert!(result.is_err(), "forgetting_factor=-0.1 should fail");
let result = StreamingKMeansConfig::builder(5)
.forgetting_factor(0.95)
.seed(42)
.build();
assert!(result.is_ok(), "valid config should build successfully");
let config = result.unwrap();
assert_eq!(config.k, 5);
assert!((config.forgetting_factor - 0.95).abs() < EPS);
assert_eq!(config.seed, Some(42));
}
}