use scirs2_core::ndarray::ArrayStatCompat;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnsembleConfig {
pub n_estimators: usize,
pub sampling_strategy: SamplingStrategy,
pub consensus_method: ConsensusMethod,
pub random_seed: Option<u64>,
pub diversity_strategy: Option<DiversityStrategy>,
pub quality_threshold: Option<f64>,
pub max_clusters: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SamplingStrategy {
Bootstrap { sample_ratio: f64 },
RandomSubspace { feature_ratio: f64 },
BootstrapSubspace {
sample_ratio: f64,
feature_ratio: f64,
},
RandomProjection { target_dimensions: usize },
NoiseInjection {
noise_level: f64,
noise_type: NoiseType,
},
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NoiseType {
Gaussian,
Uniform,
Outliers { outlier_ratio: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConsensusMethod {
MajorityVoting,
WeightedConsensus,
GraphBased { similarity_threshold: f64 },
Hierarchical { linkage_method: String },
CoAssociation { threshold: f64 },
EvidenceAccumulation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DiversityStrategy {
AlgorithmDiversity {
algorithms: Vec<ClusteringAlgorithm>,
},
ParameterDiversity {
algorithm: ClusteringAlgorithm,
parameter_ranges: HashMap<String, ParameterRange>,
},
DataDiversity {
sampling_strategies: Vec<SamplingStrategy>,
},
Combined { strategies: Vec<DiversityStrategy> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ClusteringAlgorithm {
KMeans { k_range: (usize, usize) },
DBSCAN {
eps_range: (f64, f64),
min_samples_range: (usize, usize),
},
MeanShift { bandwidth_range: (f64, f64) },
Hierarchical { methods: Vec<String> },
Spectral { k_range: (usize, usize) },
AffinityPropagation { damping_range: (f64, f64) },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ParameterRange {
Integer(i64, i64),
Float(f64, f64),
Categorical(Vec<String>),
Boolean,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringResult {
pub labels: Array1<i32>,
pub algorithm: String,
pub parameters: HashMap<String, String>,
pub quality_score: f64,
pub stability_score: Option<f64>,
pub n_clusters: usize,
pub runtime: f64,
}
impl ClusteringResult {
pub fn new(
labels: Array1<i32>,
algorithm: String,
parameters: HashMap<String, String>,
quality_score: f64,
runtime: f64,
) -> Self {
let n_clusters = labels
.iter()
.copied()
.filter(|&x| x >= 0)
.max()
.map(|x| x as usize + 1)
.unwrap_or(0);
Self {
labels,
algorithm,
parameters,
quality_score,
stability_score: None,
n_clusters,
runtime,
}
}
pub fn with_stability_score(mut self, score: f64) -> Self {
self.stability_score = Some(score);
self
}
pub fn has_noise(&self) -> bool {
self.labels.iter().any(|&x| x < 0)
}
pub fn noise_count(&self) -> usize {
self.labels.iter().filter(|&&x| x < 0).count()
}
pub fn cluster_sizes(&self) -> Vec<usize> {
let mut sizes = vec![0; self.n_clusters];
for &label in self.labels.iter() {
if label >= 0 {
let cluster_id = label as usize;
if cluster_id < sizes.len() {
sizes[cluster_id] += 1;
}
}
}
sizes
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnsembleResult {
pub consensus_labels: Array1<i32>,
pub individual_results: Vec<ClusteringResult>,
pub consensus_stats: ConsensusStatistics,
pub diversity_metrics: DiversityMetrics,
pub ensemble_quality: f64,
pub stability_score: f64,
}
impl EnsembleResult {
pub fn new(
consensus_labels: Array1<i32>,
individual_results: Vec<ClusteringResult>,
consensus_stats: ConsensusStatistics,
diversity_metrics: DiversityMetrics,
ensemble_quality: f64,
stability_score: f64,
) -> Self {
Self {
consensus_labels,
individual_results,
consensus_stats,
diversity_metrics,
ensemble_quality,
stability_score,
}
}
pub fn n_consensus_clusters(&self) -> usize {
self.consensus_labels
.iter()
.copied()
.filter(|&x| x >= 0)
.max()
.map(|x| x as usize + 1)
.unwrap_or(0)
}
pub fn consensus_cluster_sizes(&self) -> Vec<usize> {
let n_clusters = self.n_consensus_clusters();
let mut sizes = vec![0; n_clusters];
for &label in self.consensus_labels.iter() {
if label >= 0 {
let cluster_id = label as usize;
if cluster_id < sizes.len() {
sizes[cluster_id] += 1;
}
}
}
sizes
}
pub fn average_individual_quality(&self) -> f64 {
if self.individual_results.is_empty() {
0.0
} else {
self.individual_results
.iter()
.map(|r| r.quality_score)
.sum::<f64>()
/ self.individual_results.len() as f64
}
}
pub fn best_individual_result(&self) -> Option<&ClusteringResult> {
self.individual_results.iter().max_by(|a, b| {
a.quality_score
.partial_cmp(&b.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
pub fn algorithm_distribution(&self) -> HashMap<String, usize> {
let mut distribution = HashMap::new();
for result in &self.individual_results {
*distribution.entry(result.algorithm.clone()).or_insert(0) += 1;
}
distribution
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsensusStatistics {
pub agreement_matrix: Array2<f64>,
pub consensus_strength: Array1<f64>,
pub cluster_stability: Vec<f64>,
pub agreement_counts: Array1<usize>,
}
impl ConsensusStatistics {
pub fn new(
agreement_matrix: Array2<f64>,
consensus_strength: Array1<f64>,
cluster_stability: Vec<f64>,
agreement_counts: Array1<usize>,
) -> Self {
Self {
agreement_matrix,
consensus_strength,
cluster_stability,
agreement_counts,
}
}
pub fn average_consensus_strength(&self) -> f64 {
self.consensus_strength.mean_or(0.0)
}
pub fn min_consensus_strength(&self) -> f64 {
self.consensus_strength
.iter()
.cloned()
.fold(f64::INFINITY, f64::min)
}
pub fn max_consensus_strength(&self) -> f64 {
self.consensus_strength
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max)
}
pub fn average_cluster_stability(&self) -> f64 {
if self.cluster_stability.is_empty() {
0.0
} else {
self.cluster_stability.iter().sum::<f64>() / self.cluster_stability.len() as f64
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiversityMetrics {
pub average_diversity: f64,
pub diversity_matrix: Array2<f64>,
pub algorithm_distribution: HashMap<String, usize>,
pub parameter_diversity: HashMap<String, f64>,
}
impl DiversityMetrics {
pub fn new(
average_diversity: f64,
diversity_matrix: Array2<f64>,
algorithm_distribution: HashMap<String, usize>,
parameter_diversity: HashMap<String, f64>,
) -> Self {
Self {
average_diversity,
diversity_matrix,
algorithm_distribution,
parameter_diversity,
}
}
pub fn max_diversity(&self) -> f64 {
self.diversity_matrix
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max)
}
pub fn min_diversity(&self) -> f64 {
self.diversity_matrix
.iter()
.cloned()
.fold(f64::INFINITY, f64::min)
}
pub fn diversity_variance(&self) -> f64 {
let mean = self.average_diversity;
let variance = self
.diversity_matrix
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>()
/ (self.diversity_matrix.len() as f64);
variance
}
pub fn has_good_diversity(&self, threshold: f64) -> bool {
self.average_diversity >= threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr1;
#[test]
fn test_ensemble_config_default() {
let config = EnsembleConfig::default();
assert_eq!(config.n_estimators, 10);
assert!(matches!(
config.sampling_strategy,
SamplingStrategy::Bootstrap { .. }
));
assert!(matches!(
config.consensus_method,
ConsensusMethod::MajorityVoting
));
}
#[test]
fn test_clustering_result_creation() {
let labels = arr1(&[0, 0, 1, 1, -1]);
let mut params = HashMap::new();
params.insert("k".to_string(), "2".to_string());
let result = ClusteringResult::new(labels, "kmeans".to_string(), params, 0.8, 1.5);
assert_eq!(result.n_clusters, 2);
assert!(result.has_noise());
assert_eq!(result.noise_count(), 1);
assert_eq!(result.cluster_sizes(), vec![2, 2]);
}
#[test]
fn test_ensemble_result_metrics() {
let consensus_labels = arr1(&[0, 0, 1, 1]);
let individual_results = vec![
ClusteringResult::new(
arr1(&[0, 0, 1, 1]),
"kmeans".to_string(),
HashMap::new(),
0.8,
1.0,
),
ClusteringResult::new(
arr1(&[1, 1, 0, 0]),
"dbscan".to_string(),
HashMap::new(),
0.7,
1.5,
),
];
let consensus_stats = ConsensusStatistics::new(
Array2::zeros((2, 2)),
arr1(&[0.9, 0.9, 0.8, 0.8]),
vec![0.9, 0.8],
arr1(&[2, 2, 2, 2]),
);
let diversity_metrics =
DiversityMetrics::new(0.5, Array2::zeros((2, 2)), HashMap::new(), HashMap::new());
let result = EnsembleResult::new(
consensus_labels,
individual_results,
consensus_stats,
diversity_metrics,
0.85,
0.9,
);
assert_eq!(result.n_consensus_clusters(), 2);
assert_eq!(result.average_individual_quality(), 0.75);
assert!(result.best_individual_result().is_some());
}
#[test]
fn test_consensus_statistics() {
let stats = ConsensusStatistics::new(
Array2::zeros((3, 3)),
arr1(&[0.8, 0.9, 0.7]),
vec![0.9, 0.8, 0.85],
arr1(&[3, 2, 3]),
);
assert!((stats.average_consensus_strength() - 0.8).abs() < 1e-10);
assert_eq!(stats.min_consensus_strength(), 0.7);
assert_eq!(stats.max_consensus_strength(), 0.9);
assert!((stats.average_cluster_stability() - 0.85).abs() < 1e-10);
}
#[test]
fn test_diversity_metrics() {
let metrics = DiversityMetrics::new(
0.6,
Array2::from_shape_vec((2, 2), vec![0.0, 0.8, 0.8, 0.0]).expect("Operation failed"),
HashMap::new(),
HashMap::new(),
);
assert_eq!(metrics.max_diversity(), 0.8);
assert_eq!(metrics.min_diversity(), 0.0);
assert!(metrics.has_good_diversity(0.5));
assert!(!metrics.has_good_diversity(0.7));
}
}