use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::vq::euclidean_distance;
#[derive(Debug, Clone)]
pub struct ClusteringFeature<F: Float> {
n: usize,
linear_sum: Array1<F>,
squared_sum: F,
}
impl<F: Float + FromPrimitive + ScalarOperand> ClusteringFeature<F> {
fn new(datapoint: ArrayView1<F>) -> Self {
let squared_sum = datapoint.dot(&datapoint);
Self {
n: 1,
linear_sum: datapoint.to_owned(),
squared_sum,
}
}
fn empty(n_features: usize) -> Self {
Self {
n: 0,
linear_sum: Array1::zeros(n_features),
squared_sum: F::zero(),
}
}
fn add(&mut self, other: &Self) {
self.n += other.n;
self.linear_sum = &self.linear_sum + &other.linear_sum;
self.squared_sum = self.squared_sum + other.squared_sum;
}
fn merge(&self, other: &Self) -> Self {
let mut result = self.clone();
result.add(other);
result
}
fn centroid(&self) -> Array1<F> {
if self.n == 0 {
Array1::zeros(self.linear_sum.len())
} else {
let n_f = F::from(self.n).unwrap_or(F::one());
&self.linear_sum / n_f
}
}
fn radius(&self) -> F {
if self.n <= 1 {
F::zero()
} else {
let n_f = F::from(self.n).unwrap_or(F::one());
let centroid = self.centroid();
let centroid_ss = centroid.dot(¢roid);
let variance = (self.squared_sum / n_f) - centroid_ss;
variance.max(F::zero()).sqrt()
}
}
fn diameter(&self) -> F {
if self.n <= 1 {
F::zero()
} else {
let n_f = F::from(self.n).unwrap_or(F::one());
let ls_dot = self.linear_sum.dot(&self.linear_sum);
let numerator = n_f * self.squared_sum - ls_dot;
let denominator = n_f * (n_f - F::one());
if denominator <= F::zero() {
return F::zero();
}
let two = F::from(2.0).unwrap_or(F::one() + F::one());
(two * numerator / denominator).max(F::zero()).sqrt()
}
}
fn centroid_distance(&self, other: &Self) -> F {
let c1 = self.centroid();
let c2 = other.centroid();
let mut dist = F::zero();
for i in 0..c1.len() {
let diff = c1[i] - c2[i];
dist = dist + diff * diff;
}
dist.sqrt()
}
fn d0_distance(&self, other: &Self) -> F {
self.centroid_distance(other)
}
fn d2_distance(&self, other: &Self) -> F {
let merged = self.merge(other);
merged.diameter()
}
}
#[derive(Debug)]
struct CFNode<F: Float> {
is_leaf: bool,
cfs: Vec<ClusteringFeature<F>>,
children: Vec<CFNode<F>>,
}
impl<F: Float + FromPrimitive + ScalarOperand> CFNode<F> {
fn new_leaf() -> Self {
Self {
is_leaf: true,
cfs: Vec::new(),
children: Vec::new(),
}
}
fn new_non_leaf() -> Self {
Self {
is_leaf: false,
cfs: Vec::new(),
children: Vec::new(),
}
}
fn get_cf(&self) -> ClusteringFeature<F> {
if self.cfs.is_empty() {
return ClusteringFeature::empty(0);
}
let mut result = self.cfs[0].clone();
for cf in self.cfs.iter().skip(1) {
result.add(cf);
}
result
}
fn insert_cf(
&mut self,
new_cf: ClusteringFeature<F>,
branching_factor: usize,
threshold: F,
) -> Result<Option<CFNode<F>>> {
if !self.is_leaf {
return self.insert_cf_nonleaf(new_cf, branching_factor, threshold);
}
if self.cfs.is_empty() {
self.cfs.push(new_cf);
return Ok(None);
}
let (closest_idx, _closest_dist) = self.find_closest_cf(&new_cf);
let merged = self.cfs[closest_idx].merge(&new_cf);
if merged.radius() <= threshold {
self.cfs[closest_idx] = merged;
return Ok(None);
}
if self.cfs.len() < branching_factor {
self.cfs.push(new_cf);
return Ok(None);
}
self.cfs.push(new_cf);
let new_node = self.split_leaf(branching_factor);
Ok(Some(new_node))
}
fn insert_cf_nonleaf(
&mut self,
new_cf: ClusteringFeature<F>,
branching_factor: usize,
threshold: F,
) -> Result<Option<CFNode<F>>> {
if self.children.is_empty() {
self.is_leaf = true;
self.cfs.push(new_cf);
return Ok(None);
}
let (closest_idx, _) = self.find_closest_cf(&new_cf);
let closest_idx = closest_idx.min(self.children.len() - 1);
let split_result =
self.children[closest_idx].insert_cf(new_cf, branching_factor, threshold)?;
self.cfs[closest_idx] = self.children[closest_idx].get_cf();
if let Some(new_child) = split_result {
let new_child_cf = new_child.get_cf();
if self.children.len() < branching_factor {
self.cfs.push(new_child_cf);
self.children.push(new_child);
Ok(None)
} else {
self.cfs.push(new_child_cf);
self.children.push(new_child);
let new_node = self.split_nonleaf(branching_factor);
Ok(Some(new_node))
}
} else {
Ok(None)
}
}
fn find_closest_cf(&self, target: &ClusteringFeature<F>) -> (usize, F) {
let mut closest_idx = 0;
let mut min_dist = F::infinity();
for (i, cf) in self.cfs.iter().enumerate() {
let dist = cf.centroid_distance(target);
if dist < min_dist {
min_dist = dist;
closest_idx = i;
}
}
(closest_idx, min_dist)
}
fn split_leaf(&mut self, _branching_factor: usize) -> CFNode<F> {
let n = self.cfs.len();
if n <= 1 {
return CFNode::new_leaf();
}
let (seed1, seed2) = find_farthest_pair(&self.cfs);
let all_cfs: Vec<ClusteringFeature<F>> = self.cfs.drain(..).collect();
let mut group1 = Vec::new();
let mut group2 = Vec::new();
for (i, cf) in all_cfs.into_iter().enumerate() {
if i == seed1 {
group1.push(cf);
} else if i == seed2 {
group2.push(cf);
} else {
let dist1 = distance_to_group_centroid(&cf, &group1);
let dist2 = distance_to_group_centroid(&cf, &group2);
if dist1 <= dist2 {
group1.push(cf);
} else {
group2.push(cf);
}
}
}
self.cfs = group1;
let mut new_node = CFNode::new_leaf();
new_node.cfs = group2;
new_node
}
fn split_nonleaf(&mut self, _branching_factor: usize) -> CFNode<F> {
let n = self.cfs.len();
if n <= 1 {
return CFNode::new_non_leaf();
}
let (seed1, seed2) = find_farthest_pair(&self.cfs);
let mut group1_cfs = Vec::new();
let mut group1_children = Vec::new();
let mut group2_cfs = Vec::new();
let mut group2_children = Vec::new();
let all_cfs: Vec<ClusteringFeature<F>> = self.cfs.drain(..).collect();
let all_children: Vec<CFNode<F>> = self.children.drain(..).collect();
for (i, (cf, child)) in all_cfs
.into_iter()
.zip(all_children.into_iter())
.enumerate()
{
if i == seed1 {
group1_cfs.push(cf);
group1_children.push(child);
} else if i == seed2 {
group2_cfs.push(cf);
group2_children.push(child);
} else {
let dist1 = distance_to_group_centroid(&cf, &group1_cfs);
let dist2 = distance_to_group_centroid(&cf, &group2_cfs);
if dist1 <= dist2 {
group1_cfs.push(cf);
group1_children.push(child);
} else {
group2_cfs.push(cf);
group2_children.push(child);
}
}
}
self.cfs = group1_cfs;
self.children = group1_children;
let mut new_node = CFNode::new_non_leaf();
new_node.cfs = group2_cfs;
new_node.children = group2_children;
new_node
}
fn dist_to_seed(&self, cf: &ClusteringFeature<F>, group: &[ClusteringFeature<F>]) -> F {
distance_to_group_centroid(cf, group)
}
fn collect_leaf_entries(&self, out: &mut Vec<ClusteringFeature<F>>) {
if self.is_leaf {
for cf in &self.cfs {
out.push(cf.clone());
}
} else {
for child in &self.children {
child.collect_leaf_entries(out);
}
}
}
}
fn find_farthest_pair<F: Float + FromPrimitive + ScalarOperand>(
cfs: &[ClusteringFeature<F>],
) -> (usize, usize) {
let n = cfs.len();
if n < 2 {
return (0, 0);
}
let mut max_dist = F::zero();
let mut pair = (0, 1);
for i in 0..n {
for j in (i + 1)..n {
let dist = cfs[i].centroid_distance(&cfs[j]);
if dist > max_dist {
max_dist = dist;
pair = (i, j);
}
}
}
pair
}
fn distance_to_group_centroid<F: Float + FromPrimitive + ScalarOperand>(
cf: &ClusteringFeature<F>,
group: &[ClusteringFeature<F>],
) -> F {
if group.is_empty() {
return F::infinity();
}
let mut sum = ClusteringFeature::empty(cf.linear_sum.len());
for g in group {
sum.add(g);
}
cf.centroid_distance(&sum)
}
#[derive(Debug, Clone)]
pub struct BirchOptions<F: Float> {
pub branching_factor: usize,
pub threshold: F,
pub n_clusters: Option<usize>,
pub max_leaf_entries: Option<usize>,
pub n_refinement_iter: usize,
}
impl<F: Float + FromPrimitive> Default for BirchOptions<F> {
fn default() -> Self {
Self {
branching_factor: 50,
threshold: F::from(0.5).unwrap_or(F::one()),
n_clusters: None,
max_leaf_entries: None,
n_refinement_iter: 5,
}
}
}
pub struct Birch<F: Float> {
options: BirchOptions<F>,
root: Option<Box<CFNode<F>>>,
leaf_entries: Vec<ClusteringFeature<F>>,
n_features: Option<usize>,
effective_threshold: Option<F>,
}
impl<F: Float + FromPrimitive + Debug + ScalarOperand> Birch<F> {
pub fn new(options: BirchOptions<F>) -> Self {
Self {
options,
root: None,
leaf_entries: Vec::new(),
n_features: None,
effective_threshold: None,
}
}
pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
self.n_features = Some(n_features);
let mut root = Box::new(CFNode::new_leaf());
let threshold = self.options.threshold;
let branching_factor = self.options.branching_factor;
self.effective_threshold = Some(threshold);
for i in 0..n_samples {
let point = data.slice(s![i, ..]);
let new_cf = ClusteringFeature::new(point);
let split_result = root.insert_cf(new_cf, branching_factor, threshold)?;
if let Some(sibling) = split_result {
let old_root_cf = root.get_cf();
let sibling_cf = sibling.get_cf();
let mut new_root = Box::new(CFNode::new_non_leaf());
new_root.cfs.push(old_root_cf);
new_root.cfs.push(sibling_cf);
new_root.children.push(*root);
new_root.children.push(sibling);
root = new_root;
}
}
self.leaf_entries.clear();
root.collect_leaf_entries(&mut self.leaf_entries);
if let Some(max_entries) = self.options.max_leaf_entries {
if self.leaf_entries.len() > max_entries {
self.rebuild_with_larger_threshold(&data, max_entries)?;
}
}
self.root = Some(root);
Ok(())
}
fn rebuild_with_larger_threshold(
&mut self,
data: &ArrayView2<F>,
max_entries: usize,
) -> Result<()> {
let mut threshold = self.options.threshold;
let increase_factor = F::from(1.5).unwrap_or(F::one() + F::one());
let max_attempts = 10;
for _ in 0..max_attempts {
threshold = threshold * increase_factor;
let mut root = Box::new(CFNode::new_leaf());
let branching_factor = self.options.branching_factor;
for i in 0..data.shape()[0] {
let point = data.slice(s![i, ..]);
let new_cf = ClusteringFeature::new(point);
let split_result = root.insert_cf(new_cf, branching_factor, threshold)?;
if let Some(sibling) = split_result {
let old_root_cf = root.get_cf();
let sibling_cf = sibling.get_cf();
let mut new_root = Box::new(CFNode::new_non_leaf());
new_root.cfs.push(old_root_cf);
new_root.cfs.push(sibling_cf);
new_root.children.push(*root);
new_root.children.push(sibling);
root = new_root;
}
}
self.leaf_entries.clear();
root.collect_leaf_entries(&mut self.leaf_entries);
if self.leaf_entries.len() <= max_entries {
self.effective_threshold = Some(threshold);
self.root = Some(root);
return Ok(());
}
}
self.effective_threshold = Some(threshold);
Ok(())
}
pub fn partial_fit(&mut self, data: ArrayView2<F>) -> Result<()> {
let n_features = data.shape()[1];
if self.n_features.is_none() {
self.n_features = Some(n_features);
}
let threshold = self.effective_threshold.unwrap_or(self.options.threshold);
let branching_factor = self.options.branching_factor;
if self.root.is_none() {
self.root = Some(Box::new(CFNode::new_leaf()));
}
for i in 0..data.shape()[0] {
let point = data.slice(s![i, ..]);
let new_cf = ClusteringFeature::new(point);
let mut root = self
.root
.take()
.ok_or_else(|| ClusteringError::InvalidState("Root should exist".into()))?;
let split_result = root.insert_cf(new_cf, branching_factor, threshold)?;
if let Some(sibling) = split_result {
let old_root_cf = root.get_cf();
let sibling_cf = sibling.get_cf();
let mut new_root = Box::new(CFNode::new_non_leaf());
new_root.cfs.push(old_root_cf);
new_root.cfs.push(sibling_cf);
new_root.children.push(*root);
new_root.children.push(sibling);
self.root = Some(new_root);
} else {
self.root = Some(root);
}
}
self.leaf_entries.clear();
if let Some(ref root) = self.root {
root.collect_leaf_entries(&mut self.leaf_entries);
}
Ok(())
}
pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<i32>> {
if self.leaf_entries.is_empty() {
return Err(ClusteringError::InvalidInput(
"Model has not been fitted yet".to_string(),
));
}
let n_samples = data.shape()[0];
let mut labels = Array1::zeros(n_samples);
for i in 0..n_samples {
let point = data.slice(s![i, ..]);
let mut min_dist = F::infinity();
let mut closest_cf = 0;
for (j, cf) in self.leaf_entries.iter().enumerate() {
let centroid = cf.centroid();
let dist = euclidean_distance(point, centroid.view());
if dist < min_dist {
min_dist = dist;
closest_cf = j;
}
}
labels[i] = closest_cf as i32;
}
Ok(labels)
}
pub fn extract_clusters(&self) -> Result<(Array2<F>, Array1<i32>)> {
if self.leaf_entries.is_empty() {
return Err(ClusteringError::InvalidInput(
"No data has been processed".to_string(),
));
}
let n_features = self
.n_features
.ok_or_else(|| ClusteringError::InvalidState("n_features not set".into()))?;
let n_cf_entries = self.leaf_entries.len();
let n_clusters = self.options.n_clusters.unwrap_or(n_cf_entries);
if n_clusters >= n_cf_entries {
let mut centroids = Array2::zeros((n_cf_entries, n_features));
let mut labels = Array1::zeros(n_cf_entries);
for (i, cf) in self.leaf_entries.iter().enumerate() {
let centroid = cf.centroid();
centroids.slice_mut(s![i, ..]).assign(¢roid);
labels[i] = i as i32;
}
Ok((centroids, labels))
} else {
self.cluster_cf_entries_refined(n_clusters, n_features)
}
}
fn cluster_cf_entries_refined(
&self,
n_clusters: usize,
n_features: usize,
) -> Result<(Array2<F>, Array1<i32>)> {
let n_cfs = self.leaf_entries.len();
let mut cf_centroids = Array2::zeros((n_cfs, n_features));
for (i, cf) in self.leaf_entries.iter().enumerate() {
let centroid = cf.centroid();
cf_centroids.slice_mut(s![i, ..]).assign(¢roid);
}
let mut cluster_centers = Array2::zeros((n_clusters, n_features));
cluster_centers
.slice_mut(s![0, ..])
.assign(&cf_centroids.slice(s![0, ..]));
for c in 1..n_clusters {
let mut max_min_dist = F::zero();
let mut best_idx = c % n_cfs;
for i in 0..n_cfs {
let mut min_dist = F::infinity();
for j in 0..c {
let dist = euclidean_distance(
cf_centroids.slice(s![i, ..]),
cluster_centers.slice(s![j, ..]),
);
if dist < min_dist {
min_dist = dist;
}
}
if min_dist > max_min_dist {
max_min_dist = min_dist;
best_idx = i;
}
}
cluster_centers
.slice_mut(s![c, ..])
.assign(&cf_centroids.slice(s![best_idx, ..]));
}
let mut assignments = Array1::zeros(n_cfs);
for _iter in 0..self.options.n_refinement_iter {
for i in 0..n_cfs {
let mut min_dist = F::infinity();
let mut closest_cluster = 0;
for j in 0..n_clusters {
let dist = euclidean_distance(
cf_centroids.slice(s![i, ..]),
cluster_centers.slice(s![j, ..]),
);
if dist < min_dist {
min_dist = dist;
closest_cluster = j;
}
}
assignments[i] = closest_cluster as i32;
}
let mut new_centers = Array2::zeros((n_clusters, n_features));
let mut weights = vec![F::zero(); n_clusters];
for (cf_idx, &cluster_id) in assignments.iter().enumerate() {
let cid = cluster_id as usize;
let cf = &self.leaf_entries[cf_idx];
let cf_weight = F::from(cf.n).unwrap_or(F::one());
let centroid = cf.centroid();
for f in 0..n_features {
new_centers[[cid, f]] = new_centers[[cid, f]] + centroid[f] * cf_weight;
}
weights[cid] = weights[cid] + cf_weight;
}
for c in 0..n_clusters {
if weights[c] > F::zero() {
for f in 0..n_features {
new_centers[[c, f]] = new_centers[[c, f]] / weights[c];
}
}
}
cluster_centers = new_centers;
}
Ok((cluster_centers, assignments))
}
pub fn get_statistics(&self) -> BirchStatistics<F> {
let total_points: usize = self.leaf_entries.iter().map(|cf| cf.n).sum();
let avg_cf_size = if !self.leaf_entries.is_empty() {
total_points as f64 / self.leaf_entries.len() as f64
} else {
0.0
};
let avg_radius = if !self.leaf_entries.is_empty() {
let total_radius: F = self
.leaf_entries
.iter()
.map(|cf| cf.radius())
.fold(F::zero(), |acc, x| acc + x);
let n_entries = F::from(self.leaf_entries.len()).unwrap_or(F::one());
total_radius / n_entries
} else {
F::zero()
};
BirchStatistics {
num_cf_entries: self.leaf_entries.len(),
total_points,
avg_cf_size,
avg_radius,
threshold: self.effective_threshold.unwrap_or(self.options.threshold),
branching_factor: self.options.branching_factor,
}
}
}
#[derive(Debug)]
pub struct BirchStatistics<F: Float> {
pub num_cf_entries: usize,
pub total_points: usize,
pub avg_cf_size: f64,
pub avg_radius: F,
pub threshold: F,
pub branching_factor: usize,
}
pub fn birch<F>(data: ArrayView2<F>, options: BirchOptions<F>) -> Result<(Array2<F>, Array1<i32>)>
where
F: Float + FromPrimitive + Debug + ScalarOperand,
{
let mut model = Birch::new(options);
model.fit(data)?;
let (centroids, _cf_labels) = model.extract_clusters()?;
let labels = model.predict(data)?;
Ok((centroids, labels))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_two_cluster_data() -> Array2<f64> {
Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Failed to create test data")
}
#[test]
fn test_clustering_feature() {
let point = Array1::from_vec(vec![1.0, 2.0]);
let cf = ClusteringFeature::<f64>::new(point.view());
assert_eq!(cf.n, 1);
assert_eq!(cf.linear_sum, point);
assert_eq!(cf.squared_sum, 5.0);
let centroid = cf.centroid();
assert_eq!(centroid, point);
}
#[test]
fn test_cf_merge() {
let point1 = Array1::from_vec(vec![1.0, 2.0]);
let point2 = Array1::from_vec(vec![3.0, 4.0]);
let cf1 = ClusteringFeature::new(point1.view());
let cf2 = ClusteringFeature::new(point2.view());
let merged = cf1.merge(&cf2);
assert_eq!(merged.n, 2);
assert_eq!(merged.linear_sum, Array1::from_vec(vec![4.0, 6.0]));
assert_eq!(merged.squared_sum, 30.0);
let centroid = merged.centroid();
assert_eq!(centroid, Array1::from_vec(vec![2.0, 3.0]));
}
#[test]
fn test_cf_radius() {
let p1 = Array1::from_vec(vec![0.0, 0.0]);
let p2 = Array1::from_vec(vec![2.0, 0.0]);
let cf1 = ClusteringFeature::<f64>::new(p1.view());
let cf2 = ClusteringFeature::new(p2.view());
let merged = cf1.merge(&cf2);
let radius = merged.radius();
assert!(radius > 0.0, "Radius should be positive");
assert!(radius <= 2.0, "Radius should be <= diameter");
}
#[test]
fn test_cf_diameter() {
let p1 = Array1::from_vec(vec![0.0, 0.0]);
let p2 = Array1::from_vec(vec![2.0, 0.0]);
let cf1 = ClusteringFeature::<f64>::new(p1.view());
let cf2 = ClusteringFeature::new(p2.view());
let merged = cf1.merge(&cf2);
let diameter = merged.diameter();
assert!(diameter > 0.0, "Diameter should be positive");
assert!(
(diameter - 2.0).abs() < 0.1,
"Diameter should be ~2.0, got {}",
diameter
);
}
#[test]
fn test_birch_simple() {
let data = make_two_cluster_data();
let options = BirchOptions {
n_clusters: Some(2),
threshold: 1.0,
..Default::default()
};
let result = birch(data.view(), options);
assert!(result.is_ok());
let (centroids, labels) = result.expect("Should succeed");
assert_eq!(centroids.shape()[0], 2);
assert_eq!(labels.len(), 6);
}
#[test]
fn test_birch_default_options() {
let data = make_two_cluster_data();
let options = BirchOptions::default();
let result = birch(data.view(), options);
assert!(result.is_ok());
let (_centroids, labels) = result.expect("Should succeed");
assert_eq!(labels.len(), 6);
}
#[test]
fn test_birch_empty_data() {
let data = Array2::<f64>::zeros((0, 2));
let options = BirchOptions::default();
let result = birch(data.view(), options);
assert!(result.is_err());
}
#[test]
fn test_birch_single_point() {
let data = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("Failed to create data");
let options = BirchOptions {
n_clusters: Some(1),
..Default::default()
};
let result = birch(data.view(), options);
assert!(result.is_ok());
let (centroids, labels) = result.expect("Should succeed");
assert_eq!(centroids.shape()[0], 1);
assert_eq!(labels.len(), 1);
assert_eq!(labels[0], 0);
}
#[test]
fn test_birch_statistics() {
let data = make_two_cluster_data();
let mut model = Birch::new(BirchOptions {
threshold: 1.0,
..Default::default()
});
model.fit(data.view()).expect("Should fit");
let stats = model.get_statistics();
assert_eq!(stats.total_points, 6);
assert!(stats.num_cf_entries > 0);
assert!(stats.avg_cf_size > 0.0);
assert!(stats.branching_factor == 50);
}
#[test]
fn test_birch_incremental() {
let data1 = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9])
.expect("Failed to create data");
let data2 = Array2::from_shape_vec((3, 2), vec![4.0, 5.0, 4.2, 4.8, 3.9, 5.1])
.expect("Failed to create data");
let mut model = Birch::new(BirchOptions {
threshold: 1.0,
n_clusters: Some(2),
..Default::default()
});
model.fit(data1.view()).expect("Should fit first batch");
model
.partial_fit(data2.view())
.expect("Should fit second batch");
let stats = model.get_statistics();
assert_eq!(stats.total_points, 6);
}
#[test]
fn test_birch_small_threshold() {
let data = make_two_cluster_data();
let options = BirchOptions {
threshold: 0.01, n_clusters: Some(2),
..Default::default()
};
let result = birch(data.view(), options);
assert!(result.is_ok());
let (centroids, labels) = result.expect("Should succeed");
assert_eq!(centroids.shape()[0], 2);
assert_eq!(labels.len(), 6);
}
#[test]
fn test_birch_large_threshold() {
let data = make_two_cluster_data();
let options = BirchOptions {
threshold: 100.0, ..Default::default()
};
let result = birch(data.view(), options);
assert!(result.is_ok());
let (_centroids, labels) = result.expect("Should succeed");
assert_eq!(labels.len(), 6);
}
#[test]
fn test_birch_branching_factor() {
let data = make_two_cluster_data();
let options = BirchOptions {
branching_factor: 2, threshold: 0.5,
n_clusters: Some(2),
..Default::default()
};
let result = birch(data.view(), options);
assert!(result.is_ok());
let (_centroids, labels) = result.expect("Should succeed");
assert_eq!(labels.len(), 6);
}
#[test]
fn test_birch_threshold_auto_adjustment() {
let mut data_vec = Vec::new();
for i in 0..50 {
data_vec.push(i as f64 * 0.1);
data_vec.push(i as f64 * 0.1 + 0.5);
}
let data = Array2::from_shape_vec((50, 2), data_vec).expect("Failed to create data");
let options = BirchOptions {
threshold: 0.01,
max_leaf_entries: Some(10), n_clusters: Some(3),
..Default::default()
};
let mut model = Birch::new(options);
model.fit(data.view()).expect("Should fit");
let stats = model.get_statistics();
assert!(stats.num_cf_entries <= 50, "Should have compacted entries");
}
}