use scirs2_cluster::{
affinity_propagation,
birch,
dbscan_clustering,
density::optics::optics,
density::{dbscan, DistanceMetric},
gaussian_mixture,
hdbscan,
hierarchy::{fcluster, linkage, ClusterCriterion, LinkageMethod, Metric},
meanshift::{mean_shift, MeanShiftOptions},
metrics::silhouette_score,
spectral_clustering,
vq::{kmeans2, MinitMethod, MissingMethod},
AffinityMode,
AffinityPropagationOptions,
BirchOptions,
CovarianceType,
GMMOptions,
HDBSCANOptions,
SpectralClusteringOptions,
};
use scirs2_core::ndarray::{array, Array2};
#[allow(dead_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Comprehensive Clustering Algorithm Test");
println!("======================================\n");
let data = array![
[0.0, 0.0],
[0.1, 0.1],
[0.2, 0.0],
[0.0, 0.2],
[0.15, 0.15],
[3.0, 3.0],
[3.1, 3.1],
[3.2, 3.0],
[3.0, 3.2],
[3.15, 3.15],
[6.0, 0.0],
[6.1, 0.1],
[6.2, 0.0],
[6.0, 0.2],
[6.15, 0.15],
];
println!("Test data: 3 clusters with 5 points each\n");
println!("1. K-means Clustering");
let (_, kmeans_labels) = kmeans2(
data.view(),
3,
Some(10), None, Some(MinitMethod::Random),
Some(MissingMethod::Warn),
Some(true), Some(42), )?;
print_results(&kmeans_labels, &data);
println!("\n2. DBSCAN Clustering");
let dbscan_labels = dbscan(data.view(), 0.5, 2, Some(DistanceMetric::Euclidean))?;
print_results(&dbscan_labels.mapv(|x| x as usize), &data);
println!("\n3. HDBSCAN Clustering");
let hdbscan_opts = HDBSCANOptions {
min_cluster_size: 2,
minsamples: Some(2),
..Default::default()
};
let hdbscan_result = hdbscan(data.view(), Some(hdbscan_opts))?;
if hdbscan_result.labels.iter().all(|&x| x == -1) {
println!(" HDBSCAN found all noise, extracting DBSCAN with cut_distance=1.0");
let dbscan_from_hdbscan = dbscan_clustering(&hdbscan_result, 1.0)?;
print_results(&dbscan_from_hdbscan.mapv(|x| x as usize), &data);
} else {
print_results(&hdbscan_result.labels.mapv(|x| x as usize), &data);
}
println!("\n4. OPTICS Clustering");
let optics_result = optics(data.view(), 2, None, Some(DistanceMetric::Euclidean))?;
let optics_labels =
scirs2_cluster::density::optics::extract_dbscan_clustering(&optics_result, 0.5);
print_results(&optics_labels.mapv(|x| x as usize), &data);
println!("\n5. Mean Shift Clustering");
let ms_opts = MeanShiftOptions {
bandwidth: Some(1.0),
..Default::default()
};
let (_, ms_labels) = mean_shift(&data.view(), ms_opts)?;
print_results(&ms_labels.mapv(|x| x as usize), &data);
println!("\n6. Gaussian Mixture Model");
let gmm_opts = GMMOptions {
n_components: 3,
covariance_type: CovarianceType::Full,
..Default::default()
};
let gmm_labels = gaussian_mixture(data.view(), gmm_opts)?;
print_results(&gmm_labels.mapv(|x| x as usize), &data);
println!("\n7. Hierarchical Clustering");
let linkage_result = linkage(data.view(), LinkageMethod::Complete, Metric::Euclidean)?;
let hier_labels = fcluster(&linkage_result, 3, Some(ClusterCriterion::MaxClust))?;
print_results(&hier_labels, &data);
println!("\n8. Spectral Clustering");
let spec_opts = SpectralClusteringOptions {
affinity: AffinityMode::RBF,
gamma: 1.0,
..Default::default()
};
let (_, spec_labels) = spectral_clustering(data.view(), 3, Some(spec_opts))?;
print_results(&spec_labels, &data);
println!("\n9. BIRCH Clustering");
let birch_opts = BirchOptions {
threshold: 1.0,
n_clusters: Some(3),
..Default::default()
};
let (_, birch_labels) = birch(data.view(), birch_opts)?;
print_results(&birch_labels.mapv(|x| x as usize), &data);
println!("\n10. Affinity Propagation");
let ap_opts = AffinityPropagationOptions {
damping: 0.7,
preference: Some(-5.0),
..Default::default()
};
let (_, ap_labels) = affinity_propagation(data.view(), false, Some(ap_opts))?;
print_results(&ap_labels.mapv(|x| x as usize), &data);
Ok(())
}
#[allow(dead_code)]
fn print_results(labels: &scirs2_core::ndarray::Array1<usize>, data: &Array2<f64>) {
let mut unique_labels = std::collections::HashSet::new();
for &label in labels.iter() {
unique_labels.insert(label);
}
let n_clusters = unique_labels.iter().filter(|&&l| l != usize::MAX).count();
println!(" Clusters found: {}", n_clusters);
if n_clusters > 1 {
match silhouette_score(data.view(), labels.mapv(|x| x as i32).view()) {
Ok(score) => println!(" Silhouette score: {:.3}", score),
Err(_) => println!(" Silhouette score: N/A"),
}
}
let mut label_counts: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
for &label in labels.iter() {
*label_counts.entry(label).or_insert(0) += 1;
}
print!(" Label distribution: ");
let mut sorted_labels: Vec<_> = label_counts.into_iter().collect();
sorted_labels.sort_by_key(|&(label_, _count)| label_);
for (i, (label, count)) in sorted_labels.iter().enumerate() {
if i > 0 {
print!(", ");
}
if *label == usize::MAX {
print!("Noise: {}", count);
} else {
print!("C{}: {}", label, count);
}
}
println!();
}