1use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
19pub struct DBStreamConfigBuilder {
20 radius: f64,
21 decay_rate: f64,
22 min_weight: f64,
23 cleanup_interval: usize,
24 shared_density_threshold: f64,
25}
26
27impl DBStreamConfigBuilder {
28 pub fn decay_rate(mut self, d: f64) -> Self {
30 self.decay_rate = d;
31 self
32 }
33
34 pub fn min_weight(mut self, w: f64) -> Self {
36 self.min_weight = w;
37 self
38 }
39
40 pub fn cleanup_interval(mut self, n: usize) -> Self {
42 self.cleanup_interval = n;
43 self
44 }
45
46 pub fn shared_density_threshold(mut self, t: f64) -> Self {
49 self.shared_density_threshold = t;
50 self
51 }
52
53 pub fn build(self) -> Result<DBStreamConfig, String> {
57 if self.radius <= 0.0 {
58 return Err("radius must be > 0".to_string());
59 }
60 if self.decay_rate <= 0.0 {
61 return Err("decay_rate must be > 0".to_string());
62 }
63 if self.min_weight < 0.0 {
64 return Err("min_weight must be >= 0".to_string());
65 }
66 if self.cleanup_interval == 0 {
67 return Err("cleanup_interval must be > 0".to_string());
68 }
69 if self.shared_density_threshold < 0.0 || self.shared_density_threshold > 1.0 {
70 return Err("shared_density_threshold must be in [0, 1]".to_string());
71 }
72 Ok(DBStreamConfig {
73 radius: self.radius,
74 decay_rate: self.decay_rate,
75 min_weight: self.min_weight,
76 cleanup_interval: self.cleanup_interval,
77 shared_density_threshold: self.shared_density_threshold,
78 })
79 }
80}
81
82#[derive(Debug, Clone)]
99pub struct DBStreamConfig {
100 pub radius: f64,
102 pub decay_rate: f64,
104 pub min_weight: f64,
106 pub cleanup_interval: usize,
108 pub shared_density_threshold: f64,
111}
112
113impl DBStreamConfig {
114 pub fn builder(radius: f64) -> DBStreamConfigBuilder {
116 DBStreamConfigBuilder {
117 radius,
118 decay_rate: 0.001,
119 min_weight: 1.0,
120 cleanup_interval: 100,
121 shared_density_threshold: 0.3,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
136pub struct MicroCluster {
137 pub center: Vec<f64>,
139 pub weight: f64,
141 pub creation_time: u64,
143}
144
145#[derive(Debug, Clone)]
173pub struct DBStream {
174 config: DBStreamConfig,
175 micro_clusters: Vec<MicroCluster>,
176 shared_density: HashMap<(usize, usize), f64>,
179 n_samples: u64,
180}
181
182impl DBStream {
183 pub fn new(config: DBStreamConfig) -> Self {
185 Self {
186 config,
187 micro_clusters: Vec::new(),
188 shared_density: HashMap::new(),
189 n_samples: 0,
190 }
191 }
192
193 pub fn train_one(&mut self, features: &[f64]) {
203 self.n_samples += 1;
204
205 let mut in_range: Vec<(usize, f64)> = Vec::new();
207 for (i, mc) in self.micro_clusters.iter().enumerate() {
208 let d = euclidean_distance(&mc.center, features);
209 if d <= self.config.radius {
210 in_range.push((i, d));
211 }
212 }
213
214 if !in_range.is_empty() {
215 let nearest_idx = in_range
217 .iter()
218 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
219 .unwrap()
220 .0;
221
222 let mc = &mut self.micro_clusters[nearest_idx];
223 let new_weight = mc.weight + 1.0;
224 for (c, f) in mc.center.iter_mut().zip(features.iter()) {
225 *c = (*c * mc.weight + f) / new_weight;
226 }
227 mc.weight = new_weight;
228
229 for i in 0..in_range.len() {
231 for j in (i + 1)..in_range.len() {
232 let a = in_range[i].0;
233 let b = in_range[j].0;
234 let key = make_pair_key(a, b);
235 *self.shared_density.entry(key).or_insert(0.0) += 1.0;
236 }
237 }
238 } else {
239 self.micro_clusters.push(MicroCluster {
241 center: features.to_vec(),
242 weight: 1.0,
243 creation_time: self.n_samples,
244 });
245 }
246
247 let decay_factor = 2.0_f64.powf(-self.config.decay_rate);
249 for mc in &mut self.micro_clusters {
250 mc.weight *= decay_factor;
251 }
252 for sd in self.shared_density.values_mut() {
253 *sd *= decay_factor;
254 }
255
256 if self.n_samples % self.config.cleanup_interval as u64 == 0 {
258 self.cleanup();
259 }
260 }
261
262 pub fn predict(&self, features: &[f64]) -> usize {
270 assert!(
271 !self.micro_clusters.is_empty(),
272 "cannot predict with no micro-clusters"
273 );
274 self.micro_clusters
275 .iter()
276 .enumerate()
277 .min_by(|(_, a), (_, b)| {
278 let da = euclidean_distance(&a.center, features);
279 let db = euclidean_distance(&b.center, features);
280 da.partial_cmp(&db).unwrap()
281 })
282 .unwrap()
283 .0
284 }
285
286 pub fn predict_or_noise(&self, features: &[f64], noise_radius: f64) -> Option<usize> {
289 let mut best_idx = None;
290 let mut best_dist = f64::INFINITY;
291
292 for (i, mc) in self.micro_clusters.iter().enumerate() {
293 let d = euclidean_distance(&mc.center, features);
294 if d < best_dist {
295 best_dist = d;
296 best_idx = Some(i);
297 }
298 }
299
300 if best_dist <= noise_radius {
301 best_idx
302 } else {
303 None
304 }
305 }
306
307 pub fn micro_clusters(&self) -> &[MicroCluster] {
309 &self.micro_clusters
310 }
311
312 pub fn n_micro_clusters(&self) -> usize {
314 self.micro_clusters.len()
315 }
316
317 pub fn macro_clusters(&self) -> Vec<Vec<usize>> {
325 let n = self.micro_clusters.len();
326 if n == 0 {
327 return Vec::new();
328 }
329
330 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
332 for (&(i, j), &sd) in &self.shared_density {
333 if i >= n || j >= n {
335 continue;
336 }
337 let combined_weight = self.micro_clusters[i].weight + self.micro_clusters[j].weight;
338 if sd > self.config.shared_density_threshold * combined_weight {
339 adj[i].push(j);
340 adj[j].push(i);
341 }
342 }
343
344 let mut visited = vec![false; n];
346 let mut components: Vec<Vec<usize>> = Vec::new();
347
348 for start in 0..n {
349 if visited[start] {
350 continue;
351 }
352 let mut component = Vec::new();
353 let mut stack = vec![start];
354 while let Some(node) = stack.pop() {
355 if visited[node] {
356 continue;
357 }
358 visited[node] = true;
359 component.push(node);
360 for &neighbor in &adj[node] {
361 if !visited[neighbor] {
362 stack.push(neighbor);
363 }
364 }
365 }
366 component.sort_unstable();
367 components.push(component);
368 }
369
370 components
371 }
372
373 pub fn n_clusters(&self) -> usize {
376 self.macro_clusters().len()
377 }
378
379 pub fn n_samples_seen(&self) -> u64 {
381 self.n_samples
382 }
383
384 pub fn reset(&mut self) {
386 self.micro_clusters.clear();
387 self.shared_density.clear();
388 self.n_samples = 0;
389 }
390
391 fn cleanup(&mut self) {
394 let mut keep_indices: Vec<usize> = Vec::new();
396 for (i, mc) in self.micro_clusters.iter().enumerate() {
397 if mc.weight >= self.config.min_weight {
398 keep_indices.push(i);
399 }
400 }
401
402 if keep_indices.len() == self.micro_clusters.len() {
404 return;
405 }
406
407 let mut index_map: HashMap<usize, usize> = HashMap::new();
409 for (new_idx, &old_idx) in keep_indices.iter().enumerate() {
410 index_map.insert(old_idx, new_idx);
411 }
412
413 let new_mcs: Vec<MicroCluster> = keep_indices
415 .iter()
416 .map(|&i| self.micro_clusters[i].clone())
417 .collect();
418 self.micro_clusters = new_mcs;
419
420 let mut new_sd: HashMap<(usize, usize), f64> = HashMap::new();
423 for (&(old_a, old_b), &val) in &self.shared_density {
424 if let (Some(&new_a), Some(&new_b)) = (index_map.get(&old_a), index_map.get(&old_b)) {
425 let key = make_pair_key(new_a, new_b);
426 new_sd.insert(key, val);
427 }
428 }
429 self.shared_density = new_sd;
430 }
431}
432
433fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
439 a.iter()
440 .zip(b.iter())
441 .map(|(x, y)| (x - y) * (x - y))
442 .sum::<f64>()
443 .sqrt()
444}
445
446fn make_pair_key(a: usize, b: usize) -> (usize, usize) {
448 if a <= b {
449 (a, b)
450 } else {
451 (b, a)
452 }
453}
454
455#[cfg(test)]
460mod tests {
461 use super::*;
462
463 const EPS: f64 = 1e-6;
464
465 fn approx_eq(a: f64, b: f64) -> bool {
466 (a - b).abs() < EPS
467 }
468
469 fn default_config(radius: f64) -> DBStreamConfig {
470 DBStreamConfig::builder(radius)
471 .decay_rate(0.001)
472 .min_weight(0.0) .cleanup_interval(1000)
474 .build()
475 .unwrap()
476 }
477
478 #[test]
479 fn single_point_creates_micro_cluster() {
480 let config = default_config(1.0);
481 let mut db = DBStream::new(config);
482 db.train_one(&[5.0, 5.0]);
483
484 assert_eq!(db.n_micro_clusters(), 1);
485 let mc = &db.micro_clusters()[0];
486 assert!(approx_eq(mc.center[0], 5.0));
487 assert!(approx_eq(mc.center[1], 5.0));
488 assert_eq!(db.n_samples_seen(), 1);
489 }
490
491 #[test]
492 fn nearby_points_merge() {
493 let config = default_config(1.0);
494 let mut db = DBStream::new(config);
495
496 db.train_one(&[0.0, 0.0]);
497 db.train_one(&[0.1, 0.1]);
498 db.train_one(&[0.2, 0.2]);
499
500 assert_eq!(db.n_micro_clusters(), 1);
502 assert_eq!(db.n_samples_seen(), 3);
503 }
504
505 #[test]
506 fn distant_points_separate() {
507 let config = default_config(1.0);
508 let mut db = DBStream::new(config);
509
510 db.train_one(&[0.0, 0.0]);
511 db.train_one(&[10.0, 10.0]);
512
513 assert_eq!(db.n_micro_clusters(), 2);
514 }
515
516 #[test]
517 fn decay_reduces_weights() {
518 let config = DBStreamConfig::builder(1.0)
519 .decay_rate(0.1) .min_weight(0.0)
521 .cleanup_interval(10_000)
522 .build()
523 .unwrap();
524 let mut db = DBStream::new(config);
525
526 db.train_one(&[0.0, 0.0]);
527 let initial_weight = db.micro_clusters()[0].weight;
528
529 for i in 1..20 {
531 db.train_one(&[100.0 * i as f64, 100.0 * i as f64]);
532 }
533
534 let final_weight = db.micro_clusters()[0].weight;
536 assert!(
537 final_weight < initial_weight,
538 "expected weight to decay: initial={}, final={}",
539 initial_weight,
540 final_weight
541 );
542 }
543
544 #[test]
545 fn cleanup_removes_light_clusters() {
546 let config = DBStreamConfig::builder(1.0)
547 .decay_rate(0.5) .min_weight(0.1)
549 .cleanup_interval(5)
550 .build()
551 .unwrap();
552 let mut db = DBStream::new(config);
553
554 db.train_one(&[0.0, 0.0]);
556 let initial_count = db.n_micro_clusters();
557 assert_eq!(initial_count, 1);
558
559 for i in 1..=20 {
562 db.train_one(&[1000.0 * i as f64, 1000.0 * i as f64]);
563 }
564
565 let has_origin = db
568 .micro_clusters()
569 .iter()
570 .any(|mc| approx_eq(mc.center[0], 0.0) && approx_eq(mc.center[1], 0.0));
571 assert!(
572 !has_origin,
573 "expected the origin MC to be removed after decay and cleanup"
574 );
575 }
576
577 #[test]
578 fn macro_clusters_merge_shared_density() {
579 let config = DBStreamConfig::builder(1.0)
583 .decay_rate(0.0001) .min_weight(0.0)
585 .cleanup_interval(10_000)
586 .shared_density_threshold(0.1)
587 .build()
588 .unwrap();
589 let mut db = DBStream::new(config);
590
591 for _ in 0..10 {
593 db.train_one(&[0.0, 0.0]);
594 db.train_one(&[0.5, 0.5]);
595 }
596
597 for _ in 0..10 {
599 db.train_one(&[10.0, 10.0]);
600 db.train_one(&[10.5, 10.5]);
601 }
602
603 let macros = db.macro_clusters();
604 assert!(
606 macros.len() >= 2,
607 "expected at least 2 macro-clusters, got {}",
608 macros.len()
609 );
610 }
611
612 #[test]
613 fn predict_returns_nearest() {
614 let config = default_config(1.0);
615 let mut db = DBStream::new(config);
616
617 db.train_one(&[0.0, 0.0]);
618 db.train_one(&[10.0, 10.0]);
619
620 let idx = db.predict(&[0.1, 0.1]);
622 let nearest_center = &db.micro_clusters()[idx].center;
623 let d_origin = euclidean_distance(nearest_center, &[0.0, 0.0]);
624 let d_far = euclidean_distance(nearest_center, &[10.0, 10.0]);
625 assert!(
626 d_origin < d_far,
627 "predicted MC should be closer to origin than to (10,10)"
628 );
629
630 let idx2 = db.predict(&[9.9, 9.9]);
632 let nearest_center2 = &db.micro_clusters()[idx2].center;
633 let d_origin2 = euclidean_distance(nearest_center2, &[0.0, 0.0]);
634 let d_far2 = euclidean_distance(nearest_center2, &[10.0, 10.0]);
635 assert!(
636 d_far2 < d_origin2,
637 "predicted MC should be closer to (10,10) than to origin"
638 );
639 }
640
641 #[test]
642 fn predict_or_noise_returns_none() {
643 let config = default_config(1.0);
644 let mut db = DBStream::new(config);
645
646 db.train_one(&[0.0, 0.0]);
647
648 assert!(db.predict_or_noise(&[0.5, 0.5], 2.0).is_some());
650
651 assert!(db.predict_or_noise(&[100.0, 100.0], 1.0).is_none());
653 }
654
655 #[test]
656 fn reset_clears_state() {
657 let config = default_config(1.0);
658 let mut db = DBStream::new(config);
659
660 db.train_one(&[1.0, 2.0]);
661 db.train_one(&[3.0, 4.0]);
662 assert!(db.n_micro_clusters() > 0);
663 assert!(db.n_samples_seen() > 0);
664
665 db.reset();
666
667 assert_eq!(db.n_micro_clusters(), 0);
668 assert_eq!(db.n_samples_seen(), 0);
669 assert!(db.macro_clusters().is_empty());
670 }
671
672 #[test]
673 fn config_builder_validates() {
674 assert!(DBStreamConfig::builder(0.0).build().is_err());
676 assert!(DBStreamConfig::builder(-1.0).build().is_err());
677
678 assert!(DBStreamConfig::builder(1.0)
680 .decay_rate(0.0)
681 .build()
682 .is_err());
683 assert!(DBStreamConfig::builder(1.0)
684 .decay_rate(-1.0)
685 .build()
686 .is_err());
687
688 assert!(DBStreamConfig::builder(1.0)
690 .min_weight(-1.0)
691 .build()
692 .is_err());
693
694 assert!(DBStreamConfig::builder(1.0)
696 .cleanup_interval(0)
697 .build()
698 .is_err());
699
700 assert!(DBStreamConfig::builder(1.0)
702 .shared_density_threshold(-0.1)
703 .build()
704 .is_err());
705 assert!(DBStreamConfig::builder(1.0)
706 .shared_density_threshold(1.1)
707 .build()
708 .is_err());
709
710 assert!(DBStreamConfig::builder(1.0).build().is_ok());
712 }
713}