use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct DBStreamConfigBuilder {
radius: f64,
decay_rate: f64,
min_weight: f64,
cleanup_interval: usize,
shared_density_threshold: f64,
}
impl DBStreamConfigBuilder {
pub fn decay_rate(mut self, d: f64) -> Self {
self.decay_rate = d;
self
}
pub fn min_weight(mut self, w: f64) -> Self {
self.min_weight = w;
self
}
pub fn cleanup_interval(mut self, n: usize) -> Self {
self.cleanup_interval = n;
self
}
pub fn shared_density_threshold(mut self, t: f64) -> Self {
self.shared_density_threshold = t;
self
}
pub fn build(self) -> Result<DBStreamConfig, irithyll_core::error::ConfigError> {
use irithyll_core::error::ConfigError;
if self.radius <= 0.0 {
return Err(ConfigError::out_of_range(
"radius",
"must be > 0",
self.radius,
));
}
if self.decay_rate <= 0.0 {
return Err(ConfigError::out_of_range(
"decay_rate",
"must be > 0",
self.decay_rate,
));
}
if self.min_weight < 0.0 {
return Err(ConfigError::out_of_range(
"min_weight",
"must be >= 0",
self.min_weight,
));
}
if self.cleanup_interval == 0 {
return Err(ConfigError::out_of_range(
"cleanup_interval",
"must be > 0",
self.cleanup_interval,
));
}
if self.shared_density_threshold < 0.0 || self.shared_density_threshold > 1.0 {
return Err(ConfigError::out_of_range(
"shared_density_threshold",
"must be in [0, 1]",
self.shared_density_threshold,
));
}
Ok(DBStreamConfig {
radius: self.radius,
decay_rate: self.decay_rate,
min_weight: self.min_weight,
cleanup_interval: self.cleanup_interval,
shared_density_threshold: self.shared_density_threshold,
})
}
}
#[derive(Debug, Clone)]
pub struct DBStreamConfig {
pub radius: f64,
pub decay_rate: f64,
pub min_weight: f64,
pub cleanup_interval: usize,
pub shared_density_threshold: f64,
}
impl DBStreamConfig {
pub fn builder(radius: f64) -> DBStreamConfigBuilder {
DBStreamConfigBuilder {
radius,
decay_rate: 0.001,
min_weight: 1.0,
cleanup_interval: 100,
shared_density_threshold: 0.3,
}
}
}
#[derive(Debug, Clone)]
pub struct MicroCluster {
pub center: Vec<f64>,
pub weight: f64,
pub creation_time: u64,
}
#[derive(Debug, Clone)]
pub struct DBStream {
config: DBStreamConfig,
micro_clusters: Vec<MicroCluster>,
shared_density: HashMap<(usize, usize), f64>,
n_samples: u64,
}
impl DBStream {
pub fn new(config: DBStreamConfig) -> Self {
Self {
config,
micro_clusters: Vec::new(),
shared_density: HashMap::new(),
n_samples: 0,
}
}
pub fn train_one(&mut self, features: &[f64]) {
self.n_samples += 1;
let mut in_range: Vec<(usize, f64)> = Vec::new();
for (i, mc) in self.micro_clusters.iter().enumerate() {
let d = euclidean_distance(&mc.center, features);
if d <= self.config.radius {
in_range.push((i, d));
}
}
if !in_range.is_empty() {
let nearest_idx = in_range
.iter()
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.unwrap()
.0;
let mc = &mut self.micro_clusters[nearest_idx];
let new_weight = mc.weight + 1.0;
for (c, f) in mc.center.iter_mut().zip(features.iter()) {
*c = (*c * mc.weight + f) / new_weight;
}
mc.weight = new_weight;
for i in 0..in_range.len() {
for j in (i + 1)..in_range.len() {
let a = in_range[i].0;
let b = in_range[j].0;
let key = make_pair_key(a, b);
*self.shared_density.entry(key).or_insert(0.0) += 1.0;
}
}
} else {
self.micro_clusters.push(MicroCluster {
center: features.to_vec(),
weight: 1.0,
creation_time: self.n_samples,
});
}
let decay_factor = 2.0_f64.powf(-self.config.decay_rate);
for mc in &mut self.micro_clusters {
mc.weight *= decay_factor;
}
for sd in self.shared_density.values_mut() {
*sd *= decay_factor;
}
if self.n_samples % self.config.cleanup_interval as u64 == 0 {
self.cleanup();
}
}
pub fn predict(&self, features: &[f64]) -> usize {
assert!(
!self.micro_clusters.is_empty(),
"cannot predict with no micro-clusters"
);
self.micro_clusters
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let da = euclidean_distance(&a.center, features);
let db = euclidean_distance(&b.center, features);
da.partial_cmp(&db).unwrap()
})
.unwrap()
.0
}
pub fn predict_or_noise(&self, features: &[f64], noise_radius: f64) -> Option<usize> {
let mut best_idx = None;
let mut best_dist = f64::INFINITY;
for (i, mc) in self.micro_clusters.iter().enumerate() {
let d = euclidean_distance(&mc.center, features);
if d < best_dist {
best_dist = d;
best_idx = Some(i);
}
}
if best_dist <= noise_radius {
best_idx
} else {
None
}
}
pub fn micro_clusters(&self) -> &[MicroCluster] {
&self.micro_clusters
}
pub fn n_micro_clusters(&self) -> usize {
self.micro_clusters.len()
}
pub fn macro_clusters(&self) -> Vec<Vec<usize>> {
let n = self.micro_clusters.len();
if n == 0 {
return Vec::new();
}
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for (&(i, j), &sd) in &self.shared_density {
if i >= n || j >= n {
continue;
}
let combined_weight = self.micro_clusters[i].weight + self.micro_clusters[j].weight;
if sd > self.config.shared_density_threshold * combined_weight {
adj[i].push(j);
adj[j].push(i);
}
}
let mut visited = vec![false; n];
let mut components: Vec<Vec<usize>> = Vec::new();
for start in 0..n {
if visited[start] {
continue;
}
let mut component = Vec::new();
let mut stack = vec![start];
while let Some(node) = stack.pop() {
if visited[node] {
continue;
}
visited[node] = true;
component.push(node);
for &neighbor in &adj[node] {
if !visited[neighbor] {
stack.push(neighbor);
}
}
}
component.sort_unstable();
components.push(component);
}
components
}
pub fn n_clusters(&self) -> usize {
self.macro_clusters().len()
}
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn reset(&mut self) {
self.micro_clusters.clear();
self.shared_density.clear();
self.n_samples = 0;
}
fn cleanup(&mut self) {
let mut keep_indices: Vec<usize> = Vec::new();
for (i, mc) in self.micro_clusters.iter().enumerate() {
if mc.weight >= self.config.min_weight {
keep_indices.push(i);
}
}
if keep_indices.len() == self.micro_clusters.len() {
return;
}
let mut index_map: HashMap<usize, usize> = HashMap::new();
for (new_idx, &old_idx) in keep_indices.iter().enumerate() {
index_map.insert(old_idx, new_idx);
}
let new_mcs: Vec<MicroCluster> = keep_indices
.iter()
.map(|&i| self.micro_clusters[i].clone())
.collect();
self.micro_clusters = new_mcs;
let mut new_sd: HashMap<(usize, usize), f64> = HashMap::new();
for (&(old_a, old_b), &val) in &self.shared_density {
if let (Some(&new_a), Some(&new_b)) = (index_map.get(&old_a), index_map.get(&old_b)) {
let key = make_pair_key(new_a, new_b);
new_sd.insert(key, val);
}
}
self.shared_density = new_sd;
}
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f64>()
.sqrt()
}
fn make_pair_key(a: usize, b: usize) -> (usize, usize) {
if a <= b {
(a, b)
} else {
(b, a)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-6;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
fn default_config(radius: f64) -> DBStreamConfig {
DBStreamConfig::builder(radius)
.decay_rate(0.001)
.min_weight(0.0) .cleanup_interval(1000)
.build()
.unwrap()
}
#[test]
fn single_point_creates_micro_cluster() {
let config = default_config(1.0);
let mut db = DBStream::new(config);
db.train_one(&[5.0, 5.0]);
assert_eq!(db.n_micro_clusters(), 1);
let mc = &db.micro_clusters()[0];
assert!(approx_eq(mc.center[0], 5.0));
assert!(approx_eq(mc.center[1], 5.0));
assert_eq!(db.n_samples_seen(), 1);
}
#[test]
fn nearby_points_merge() {
let config = default_config(1.0);
let mut db = DBStream::new(config);
db.train_one(&[0.0, 0.0]);
db.train_one(&[0.1, 0.1]);
db.train_one(&[0.2, 0.2]);
assert_eq!(db.n_micro_clusters(), 1);
assert_eq!(db.n_samples_seen(), 3);
}
#[test]
fn distant_points_separate() {
let config = default_config(1.0);
let mut db = DBStream::new(config);
db.train_one(&[0.0, 0.0]);
db.train_one(&[10.0, 10.0]);
assert_eq!(db.n_micro_clusters(), 2);
}
#[test]
fn decay_reduces_weights() {
let config = DBStreamConfig::builder(1.0)
.decay_rate(0.1) .min_weight(0.0)
.cleanup_interval(10_000)
.build()
.unwrap();
let mut db = DBStream::new(config);
db.train_one(&[0.0, 0.0]);
let initial_weight = db.micro_clusters()[0].weight;
for i in 1..20 {
db.train_one(&[100.0 * i as f64, 100.0 * i as f64]);
}
let final_weight = db.micro_clusters()[0].weight;
assert!(
final_weight < initial_weight,
"expected weight to decay: initial={}, final={}",
initial_weight,
final_weight
);
}
#[test]
fn cleanup_removes_light_clusters() {
let config = DBStreamConfig::builder(1.0)
.decay_rate(0.5) .min_weight(0.1)
.cleanup_interval(5)
.build()
.unwrap();
let mut db = DBStream::new(config);
db.train_one(&[0.0, 0.0]);
let initial_count = db.n_micro_clusters();
assert_eq!(initial_count, 1);
for i in 1..=20 {
db.train_one(&[1000.0 * i as f64, 1000.0 * i as f64]);
}
let has_origin = db
.micro_clusters()
.iter()
.any(|mc| approx_eq(mc.center[0], 0.0) && approx_eq(mc.center[1], 0.0));
assert!(
!has_origin,
"expected the origin MC to be removed after decay and cleanup"
);
}
#[test]
fn macro_clusters_merge_shared_density() {
let config = DBStreamConfig::builder(1.0)
.decay_rate(0.0001) .min_weight(0.0)
.cleanup_interval(10_000)
.shared_density_threshold(0.1)
.build()
.unwrap();
let mut db = DBStream::new(config);
for _ in 0..10 {
db.train_one(&[0.0, 0.0]);
db.train_one(&[0.5, 0.5]);
}
for _ in 0..10 {
db.train_one(&[10.0, 10.0]);
db.train_one(&[10.5, 10.5]);
}
let macros = db.macro_clusters();
assert!(
macros.len() >= 2,
"expected at least 2 macro-clusters, got {}",
macros.len()
);
}
#[test]
fn predict_returns_nearest() {
let config = default_config(1.0);
let mut db = DBStream::new(config);
db.train_one(&[0.0, 0.0]);
db.train_one(&[10.0, 10.0]);
let idx = db.predict(&[0.1, 0.1]);
let nearest_center = &db.micro_clusters()[idx].center;
let d_origin = euclidean_distance(nearest_center, &[0.0, 0.0]);
let d_far = euclidean_distance(nearest_center, &[10.0, 10.0]);
assert!(
d_origin < d_far,
"predicted MC should be closer to origin than to (10,10)"
);
let idx2 = db.predict(&[9.9, 9.9]);
let nearest_center2 = &db.micro_clusters()[idx2].center;
let d_origin2 = euclidean_distance(nearest_center2, &[0.0, 0.0]);
let d_far2 = euclidean_distance(nearest_center2, &[10.0, 10.0]);
assert!(
d_far2 < d_origin2,
"predicted MC should be closer to (10,10) than to origin"
);
}
#[test]
fn predict_or_noise_returns_none() {
let config = default_config(1.0);
let mut db = DBStream::new(config);
db.train_one(&[0.0, 0.0]);
assert!(db.predict_or_noise(&[0.5, 0.5], 2.0).is_some());
assert!(db.predict_or_noise(&[100.0, 100.0], 1.0).is_none());
}
#[test]
fn reset_clears_state() {
let config = default_config(1.0);
let mut db = DBStream::new(config);
db.train_one(&[1.0, 2.0]);
db.train_one(&[3.0, 4.0]);
assert!(db.n_micro_clusters() > 0);
assert!(db.n_samples_seen() > 0);
db.reset();
assert_eq!(db.n_micro_clusters(), 0);
assert_eq!(db.n_samples_seen(), 0);
assert!(db.macro_clusters().is_empty());
}
#[test]
fn config_builder_validates() {
assert!(DBStreamConfig::builder(0.0).build().is_err());
assert!(DBStreamConfig::builder(-1.0).build().is_err());
assert!(DBStreamConfig::builder(1.0)
.decay_rate(0.0)
.build()
.is_err());
assert!(DBStreamConfig::builder(1.0)
.decay_rate(-1.0)
.build()
.is_err());
assert!(DBStreamConfig::builder(1.0)
.min_weight(-1.0)
.build()
.is_err());
assert!(DBStreamConfig::builder(1.0)
.cleanup_interval(0)
.build()
.is_err());
assert!(DBStreamConfig::builder(1.0)
.shared_density_threshold(-0.1)
.build()
.is_err());
assert!(DBStreamConfig::builder(1.0)
.shared_density_threshold(1.1)
.build()
.is_err());
assert!(DBStreamConfig::builder(1.0).build().is_ok());
}
}