use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip};
use scirs2_core::numeric::{Float, FromPrimitive};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::vq::euclidean_distance;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveOnlineConfig {
pub initial_learning_rate: f64,
pub min_learning_rate: f64,
pub learning_rate_decay: f64,
pub forgetting_factor: f64,
pub cluster_creation_threshold: f64,
pub max_clusters: usize,
pub min_cluster_size: usize,
pub merge_threshold: f64,
pub concept_drift_window: usize,
pub drift_detection_threshold: f64,
}
impl Default for AdaptiveOnlineConfig {
fn default() -> Self {
Self {
initial_learning_rate: 0.1,
min_learning_rate: 0.001,
learning_rate_decay: 0.999,
forgetting_factor: 0.95,
cluster_creation_threshold: 2.0,
max_clusters: 50,
min_cluster_size: 10,
merge_threshold: 0.5,
concept_drift_window: 1000,
drift_detection_threshold: 0.3,
}
}
}
pub struct AdaptiveOnlineClustering<F: Float> {
config: AdaptiveOnlineConfig,
clusters: Vec<OnlineCluster<F>>,
learning_rate: f64,
samples_processed: usize,
recent_distances: VecDeque<f64>,
drift_detector: ConceptDriftDetector,
}
#[derive(Debug, Clone)]
struct OnlineCluster<F: Float> {
centroid: Array1<F>,
weight: f64,
last_update: usize,
variance: f64,
age: usize,
recent_assignments: VecDeque<usize>,
}
#[derive(Debug, Clone)]
struct ConceptDriftDetector {
recent_errors: VecDeque<f64>,
baseline_error: f64,
window_size: usize,
}
impl<F: Float + FromPrimitive + Debug> AdaptiveOnlineClustering<F> {
pub fn new(config: AdaptiveOnlineConfig) -> Self {
Self {
config: config.clone(),
clusters: Vec::new(),
learning_rate: config.initial_learning_rate,
samples_processed: 0,
recent_distances: VecDeque::with_capacity(config.concept_drift_window),
drift_detector: ConceptDriftDetector {
recent_errors: VecDeque::with_capacity(config.concept_drift_window),
baseline_error: 1.0,
window_size: config.concept_drift_window,
},
}
}
pub fn partial_fit(&mut self, point: ArrayView1<F>) -> Result<usize> {
self.samples_processed += 1;
let (nearest_cluster_idx, nearest_distance) = self.find_nearest_cluster(point);
let assigned_cluster = if let Some(cluster_idx) = nearest_cluster_idx {
let distance_threshold = F::from(self.config.cluster_creation_threshold)
.expect("Failed to convert to float");
if nearest_distance <= distance_threshold {
self.update_cluster(cluster_idx, point)?;
cluster_idx
} else if self.clusters.len() < self.config.max_clusters {
self.create_new_cluster(point)?
} else {
self.update_cluster(cluster_idx, point)?;
cluster_idx
}
} else {
self.create_new_cluster(point)?
};
self.learning_rate = (self.learning_rate * self.config.learning_rate_decay)
.max(self.config.min_learning_rate);
self.recent_distances
.push_back(nearest_distance.to_f64().unwrap_or(0.0));
if self.recent_distances.len() > self.config.concept_drift_window {
self.recent_distances.pop_front();
}
if self.samples_processed.is_multiple_of(100) {
self.detect_concept_drift()?;
}
if self.samples_processed.is_multiple_of(1000) {
self.merge_similar_clusters()?;
self.remove_old_clusters()?;
}
Ok(assigned_cluster)
}
fn find_nearest_cluster(&self, point: ArrayView1<F>) -> (Option<usize>, F) {
if self.clusters.is_empty() {
return (None, F::infinity());
}
let mut min_distance = F::infinity();
let mut nearest_idx = 0;
for (i, cluster) in self.clusters.iter().enumerate() {
let distance = euclidean_distance(point, cluster.centroid.view());
if distance < min_distance {
min_distance = distance;
nearest_idx = i;
}
}
(Some(nearest_idx), min_distance)
}
fn update_cluster(&mut self, clusteridx: usize, point: ArrayView1<F>) -> Result<()> {
let cluster = &mut self.clusters[clusteridx];
cluster.weight = cluster.weight * self.config.forgetting_factor + 1.0;
let learning_rate =
F::from(self.learning_rate / cluster.weight).expect("Failed to convert to float");
Zip::from(&mut cluster.centroid)
.and(point)
.for_each(|centroid_val, &point_val| {
let diff = point_val - *centroid_val;
*centroid_val = *centroid_val + learning_rate * diff;
});
let distance = euclidean_distance(point, cluster.centroid.view());
let distance_squared = distance * distance;
cluster.variance = cluster.variance * 0.9 + distance_squared.to_f64().unwrap_or(0.0) * 0.1;
cluster.last_update = self.samples_processed;
cluster.age += 1;
cluster.recent_assignments.push_back(self.samples_processed);
if cluster.recent_assignments.len() > 100 {
cluster.recent_assignments.pop_front();
}
Ok(())
}
fn create_new_cluster(&mut self, point: ArrayView1<F>) -> Result<usize> {
let new_cluster = OnlineCluster {
centroid: point.to_owned(),
weight: 1.0,
last_update: self.samples_processed,
variance: 0.0,
age: 0,
recent_assignments: VecDeque::new(),
};
self.clusters.push(new_cluster);
Ok(self.clusters.len() - 1)
}
fn detect_concept_drift(&mut self) -> Result<()> {
if self.recent_distances.len() < self.config.concept_drift_window / 2 {
return Ok(());
}
let recent_mean: f64 =
self.recent_distances.iter().sum::<f64>() / self.recent_distances.len() as f64;
self.drift_detector.recent_errors.push_back(recent_mean);
if self.drift_detector.recent_errors.len() > self.drift_detector.window_size {
self.drift_detector.recent_errors.pop_front();
}
let current_error: f64 = self.drift_detector.recent_errors.iter().sum::<f64>()
/ self.drift_detector.recent_errors.len() as f64;
if current_error
> self.drift_detector.baseline_error * (1.0 + self.config.drift_detection_threshold)
{
self.learning_rate = (self.learning_rate * 2.0).min(0.5);
self.drift_detector.baseline_error = current_error;
} else {
self.drift_detector.baseline_error =
self.drift_detector.baseline_error * 0.99 + current_error * 0.01;
}
Ok(())
}
fn merge_similar_clusters(&mut self) -> Result<()> {
let mut to_merge = Vec::new();
let merge_threshold =
F::from(self.config.merge_threshold).expect("Failed to convert to float");
for i in 0..self.clusters.len() {
for j in (i + 1)..self.clusters.len() {
let distance = euclidean_distance(
self.clusters[i].centroid.view(),
self.clusters[j].centroid.view(),
);
if distance <= merge_threshold {
to_merge.push((i, j));
}
}
}
for (i, j) in to_merge.into_iter().rev() {
self.merge_clusters(i, j)?;
}
Ok(())
}
fn merge_clusters(&mut self, i: usize, j: usize) -> Result<()> {
if i >= self.clusters.len() || j >= self.clusters.len() || i == j {
return Ok(());
}
let (cluster_i, cluster_j) = if i < j {
let (left, right) = self.clusters.split_at_mut(j);
(&mut left[i], &mut right[0])
} else {
let (left, right) = self.clusters.split_at_mut(i);
(&mut right[0], &mut left[j])
};
let total_weight = cluster_i.weight + cluster_j.weight;
let weight_i =
F::from(cluster_i.weight / total_weight).expect("Failed to convert to float");
let weight_j =
F::from(cluster_j.weight / total_weight).expect("Failed to convert to float");
Zip::from(&mut cluster_i.centroid)
.and(&cluster_j.centroid)
.for_each(|cent_i, ¢_j| {
*cent_i = *cent_i * weight_i + cent_j * weight_j;
});
cluster_i.weight = total_weight;
cluster_i.variance = (cluster_i.variance + cluster_j.variance) / 2.0;
cluster_i.age = cluster_i.age.max(cluster_j.age);
cluster_i.last_update = cluster_i.last_update.max(cluster_j.last_update);
let remove_idx = if i < j { j } else { i };
self.clusters.remove(remove_idx);
Ok(())
}
fn remove_old_clusters(&mut self) -> Result<()> {
let current_time = self.samples_processed;
let max_age = 10000;
self.clusters.retain(|cluster| {
let age_ok = cluster.age < max_age;
let recent_activity = current_time - cluster.last_update < 5000;
let sufficient_size = cluster.weight >= self.config.min_cluster_size as f64;
age_ok && (recent_activity || sufficient_size)
});
Ok(())
}
pub fn predict(&self, point: ArrayView1<F>) -> Result<usize> {
let (nearest_cluster_idx_, _distance) = self.find_nearest_cluster(point);
nearest_cluster_idx_.ok_or_else(|| {
ClusteringError::InvalidInput("No clusters available for prediction".to_string())
})
}
pub fn cluster_centers(&self) -> Array2<F> {
if self.clusters.is_empty() {
return Array2::zeros((0, 0));
}
let n_features = self.clusters[0].centroid.len();
let mut centers = Array2::zeros((self.clusters.len(), n_features));
for (i, cluster) in self.clusters.iter().enumerate() {
centers.row_mut(i).assign(&cluster.centroid);
}
centers
}
pub fn cluster_info(&self) -> Vec<(f64, f64, usize)> {
self.clusters
.iter()
.map(|cluster| (cluster.weight, cluster.variance, cluster.age))
.collect()
}
pub fn n_clusters(&self) -> usize {
self.clusters.len()
}
}
pub fn adaptive_online_clustering<F: Float + FromPrimitive + Debug>(
data: ArrayView2<F>,
config: Option<AdaptiveOnlineConfig>,
) -> Result<(Array2<F>, Array1<usize>)> {
let config = config.unwrap_or_default();
let mut clusterer = AdaptiveOnlineClustering::new(config);
let n_samples = data.nrows();
let mut labels = Array1::zeros(n_samples);
for (i, point) in data.rows().into_iter().enumerate() {
labels[i] = clusterer.partial_fit(point)?;
}
let centers = clusterer.cluster_centers();
Ok((centers, labels))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_adaptive_online_config_default() {
let config = AdaptiveOnlineConfig::default();
assert_eq!(config.initial_learning_rate, 0.1);
assert_eq!(config.max_clusters, 50);
assert_eq!(config.concept_drift_window, 1000);
}
#[test]
fn test_adaptive_online_clustering_creation() {
let config = AdaptiveOnlineConfig::default();
let clusterer = AdaptiveOnlineClustering::<f64>::new(config);
assert_eq!(clusterer.n_clusters(), 0);
}
#[test]
fn test_adaptive_online_clustering_simple() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0])
.expect("Operation failed");
let config = AdaptiveOnlineConfig {
cluster_creation_threshold: 2.0,
max_clusters: 10,
..Default::default()
};
let result = adaptive_online_clustering(data.view(), Some(config));
assert!(result.is_ok());
let (centers, labels) = result.expect("Operation failed");
assert_eq!(labels.len(), 4);
assert!(centers.nrows() <= 4); }
#[test]
fn test_online_cluster_creation() {
let config = AdaptiveOnlineConfig::default();
let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
let point = Array1::from_vec(vec![1.0, 2.0]);
let cluster_id = clusterer
.partial_fit(point.view())
.expect("Operation failed");
assert_eq!(cluster_id, 0);
assert_eq!(clusterer.n_clusters(), 1);
}
#[test]
fn test_concept_drift_detection() {
let config = AdaptiveOnlineConfig {
concept_drift_window: 10,
drift_detection_threshold: 0.1,
..Default::default()
};
let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
for i in 0..5 {
let point = Array1::from_vec(vec![i as f64, i as f64]);
clusterer
.partial_fit(point.view())
.expect("Operation failed");
}
assert!(clusterer.detect_concept_drift().is_ok());
}
#[test]
fn test_cluster_merging() {
let config = AdaptiveOnlineConfig {
merge_threshold: 1.0,
cluster_creation_threshold: 0.5,
..Default::default()
};
let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
let point1 = Array1::from_vec(vec![0.0, 0.0]);
let point2 = Array1::from_vec(vec![0.3, 0.3]);
clusterer
.partial_fit(point1.view())
.expect("Operation failed");
clusterer
.partial_fit(point2.view())
.expect("Operation failed");
let initial_clusters = clusterer.n_clusters();
clusterer
.merge_similar_clusters()
.expect("Operation failed");
assert!(clusterer.n_clusters() <= initial_clusters);
}
}