Skip to main content

scirs2_cluster/ensemble/
convenience.rs

1//! Convenience functions for ensemble clustering
2//!
3//! This module provides high-level, easy-to-use functions for common
4//! ensemble clustering scenarios, including adaptive and federated learning.
5
6use super::algorithms::EnsembleClusterer;
7use super::core::*;
8use crate::error::{ClusteringError, Result};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
10use scirs2_core::numeric::{Float, FromPrimitive};
11use scirs2_core::random::prelude::*;
12use std::collections::HashMap;
13use std::fmt::Debug;
14
15/// Configuration for adaptive ensemble learning
16#[derive(Debug, Clone)]
17pub struct AdaptationConfig {
18    /// Size of data chunks for incremental learning
19    pub chunk_size: usize,
20    /// Minimum number of evaluations before adaptation
21    pub min_evaluations: usize,
22    /// Performance threshold for triggering adaptation
23    pub performance_threshold: f64,
24    /// Maximum number of base clusterers
25    pub max_clusterers: usize,
26    /// Adaptation strategy
27    pub strategy: AdaptationStrategy,
28}
29
30/// Strategies for adapting ensemble composition
31#[derive(Debug, Clone)]
32pub enum AdaptationStrategy {
33    /// Add new diverse clusterers
34    AddDiverse,
35    /// Remove worst performing clusterers
36    RemoveWorst,
37    /// Replace clusterers with better alternatives
38    Replace,
39    /// Combine multiple strategies
40    Hybrid(Vec<AdaptationStrategy>),
41}
42
43/// Configuration for federated ensemble clustering
44#[derive(Debug, Clone)]
45pub struct FederationConfig {
46    /// Enable differential privacy
47    pub differential_privacy: bool,
48    /// Privacy budget for differential privacy
49    pub privacy_budget: f64,
50    /// Secure aggregation method
51    pub aggregation_method: AggregationMethod,
52    /// Communication rounds
53    pub max_rounds: usize,
54    /// Convergence threshold
55    pub convergence_threshold: f64,
56}
57
58/// Methods for secure aggregation in federated learning
59#[derive(Debug, Clone)]
60pub enum AggregationMethod {
61    /// Simple averaging with noise
62    SecureAveraging,
63    /// Homomorphic encryption based aggregation
64    HomomorphicEncryption,
65    /// Multi-party computation
66    MultiPartyComputation,
67}
68
69/// Simple ensemble clustering with default parameters
70pub fn ensemble_clustering<F>(data: ArrayView2<F>) -> Result<EnsembleResult>
71where
72    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
73    f64: From<F>,
74{
75    let config = EnsembleConfig::default();
76    let ensemble = EnsembleClusterer::new(config);
77    ensemble.fit(data)
78}
79
80/// Bootstrap ensemble clustering
81pub fn bootstrap_ensemble<F>(
82    data: ArrayView2<F>,
83    n_estimators: usize,
84    sample_ratio: f64,
85) -> Result<EnsembleResult>
86where
87    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
88    f64: From<F>,
89{
90    let config = EnsembleConfig {
91        n_estimators,
92        sampling_strategy: SamplingStrategy::Bootstrap { sample_ratio },
93        ..Default::default()
94    };
95    let ensemble = EnsembleClusterer::new(config);
96    ensemble.fit(data)
97}
98
99/// Multi-algorithm ensemble clustering
100pub fn multi_algorithm_ensemble<F>(
101    data: ArrayView2<F>,
102    algorithms: Vec<ClusteringAlgorithm>,
103) -> Result<EnsembleResult>
104where
105    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
106    f64: From<F>,
107{
108    let config = EnsembleConfig {
109        diversity_strategy: Some(DiversityStrategy::AlgorithmDiversity { algorithms }),
110        ..Default::default()
111    };
112    let ensemble = EnsembleClusterer::new(config);
113    ensemble.fit(data)
114}
115
116/// Advanced meta-clustering ensemble method
117///
118/// This method performs clustering on the space of clustering results themselves,
119/// using the clustering assignments as features for a meta-clustering algorithm.
120pub fn meta_clustering_ensemble<F>(
121    data: ArrayView2<F>,
122    baseconfigs: Vec<EnsembleConfig>,
123    metaconfig: EnsembleConfig,
124) -> Result<EnsembleResult>
125where
126    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
127    f64: From<F>,
128{
129    let mut base_results = Vec::new();
130    let n_samples = data.shape()[0];
131
132    // Step 1: Generate diverse base clusterings
133    for config in baseconfigs {
134        let ensemble = EnsembleClusterer::new(config);
135        let result = ensemble.fit(data)?;
136        base_results.extend(result.individual_results);
137    }
138
139    // Step 2: Create meta-features from clustering results
140    let mut meta_features = Array2::zeros((n_samples, base_results.len()));
141    for (i, result) in base_results.iter().enumerate() {
142        for (j, &label) in result.labels.iter().enumerate() {
143            meta_features[[j, i]] = F::from(label).expect("Failed to convert to float");
144        }
145    }
146
147    // Step 3: Apply meta-clustering
148    let meta_ensemble = EnsembleClusterer::new(metaconfig);
149    let mut meta_result = meta_ensemble.fit(meta_features.view())?;
150
151    // Step 4: Combine with original base results
152    meta_result.individual_results = base_results;
153
154    Ok(meta_result)
155}
156
157/// Adaptive ensemble clustering with online learning
158///
159/// This method adapts the ensemble composition based on streaming data
160/// and performance feedback, adding or removing base clusterers dynamically.
161pub fn adaptive_ensemble<F>(
162    data: ArrayView2<F>,
163    config: &EnsembleConfig,
164    adaptationconfig: AdaptationConfig,
165) -> Result<EnsembleResult>
166where
167    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
168    f64: From<F>,
169{
170    let mut ensemble = EnsembleClusterer::new(config.clone());
171    let mut current_results = Vec::new();
172    let chunk_size = adaptationconfig.chunk_size;
173
174    // Process data in chunks for adaptive learning
175    for chunk_start in (0..data.shape()[0]).step_by(chunk_size) {
176        let chunk_end = (chunk_start + chunk_size).min(data.shape()[0]);
177        let chunk_data = data.slice(s![chunk_start..chunk_end, ..]);
178
179        // Fit current ensemble on chunk
180        let chunk_result = ensemble.fit(chunk_data)?;
181
182        // Evaluate performance and adapt
183        if current_results.len() >= adaptationconfig.min_evaluations {
184            let performance = evaluate_ensemble_performance(&current_results);
185
186            if performance < adaptationconfig.performance_threshold {
187                // Poor performance - adapt ensemble
188                ensemble =
189                    adapt_ensemble_composition(ensemble, &current_results, &adaptationconfig)?;
190            }
191        }
192
193        current_results.push(chunk_result);
194    }
195
196    // Combine all chunk results into final consensus
197    combine_chunkresults(current_results)
198}
199
200/// Federated ensemble clustering for distributed data
201///
202/// This method allows clustering across multiple data sources without
203/// centralizing the data, preserving privacy while achieving consensus.
204pub fn federated_ensemble<F>(
205    data_sources: Vec<ArrayView2<F>>,
206    config: &EnsembleConfig,
207    federationconfig: FederationConfig,
208) -> Result<EnsembleResult>
209where
210    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
211    f64: From<F>,
212{
213    let mut local_results = Vec::new();
214
215    // Step 1: Local clustering at each data source
216    for data_source in data_sources {
217        let local_ensemble = EnsembleClusterer::new(config.clone());
218        let result = local_ensemble.fit(data_source)?;
219
220        // Apply differential privacy if configured
221        let private_result = if federationconfig.differential_privacy {
222            apply_differential_privacy(result, federationconfig.privacy_budget)?
223        } else {
224            result
225        };
226
227        local_results.push(private_result);
228    }
229
230    // Step 2: Secure aggregation of results
231    let aggregated_result = secure_aggregate_results(local_results, &federationconfig)?;
232
233    Ok(aggregated_result)
234}
235
236// Helper functions for advanced ensemble methods
237
238fn evaluate_ensemble_performance(results: &[EnsembleResult]) -> f64 {
239    if results.is_empty() {
240        return 0.0;
241    }
242
243    // Calculate average ensemble quality
244    results.iter().map(|r| r.ensemble_quality).sum::<f64>() / results.len() as f64
245}
246
247fn adapt_ensemble_composition<F>(
248    mut ensemble: EnsembleClusterer<F>,
249    results: &[EnsembleResult],
250    config: &AdaptationConfig,
251) -> Result<EnsembleClusterer<F>>
252where
253    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
254{
255    apply_adaptation_strategy(&mut ensemble, &config.strategy, results, config);
256    Ok(ensemble)
257}
258
259/// Apply a single adaptation strategy in place, mutating the ensemble's
260/// configuration based on the observed performance history.
261///
262/// The base clusterers are regenerated on every `fit`, so the ensemble's
263/// composition is governed by `n_estimators`. We therefore realize:
264/// * `RemoveWorst` -> shrink `n_estimators` (drop the least useful members),
265/// * `AddDiverse`  -> grow `n_estimators` toward `max_clusterers`,
266/// * `Replace`     -> reshuffle by re-seeding so a fresh, equally-sized set of
267///   base learners is drawn on the next fit,
268/// * `Hybrid`      -> apply the contained strategies in order.
269fn apply_adaptation_strategy<F>(
270    ensemble: &mut EnsembleClusterer<F>,
271    strategy: &AdaptationStrategy,
272    results: &[EnsembleResult],
273    config: &AdaptationConfig,
274) where
275    F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
276{
277    match strategy {
278        AdaptationStrategy::RemoveWorst => {
279            // Identify how many members are underperforming relative to the mean
280            // quality and shrink the ensemble by that count (keeping at least one).
281            if results.len() > 1 {
282                let mean_quality = evaluate_ensemble_performance(results);
283                let n_below = results
284                    .iter()
285                    .filter(|r| r.ensemble_quality < mean_quality)
286                    .count();
287                // Remove at most ~25% of the ensemble per adaptation step.
288                let cfg = ensemble.config_mut();
289                let max_remove = (cfg.n_estimators / 4).max(1);
290                let remove = n_below.clamp(1, max_remove);
291                cfg.n_estimators = cfg.n_estimators.saturating_sub(remove).max(1);
292            }
293        }
294        AdaptationStrategy::AddDiverse => {
295            // Grow the ensemble (more diverse base learners) up to the configured cap.
296            let cfg = ensemble.config_mut();
297            if cfg.n_estimators < config.max_clusterers {
298                let grow =
299                    ((cfg.n_estimators / 4).max(1)).min(config.max_clusterers - cfg.n_estimators);
300                cfg.n_estimators += grow;
301            }
302        }
303        AdaptationStrategy::Replace => {
304            // Keep the size but draw a different set of base learners next fit by
305            // advancing the random seed (or seeding one if none was set).
306            let cfg = ensemble.config_mut();
307            cfg.random_seed = Some(cfg.random_seed.map(|s| s.wrapping_add(1)).unwrap_or(1));
308        }
309        AdaptationStrategy::Hybrid(strategies) => {
310            for sub in strategies {
311                apply_adaptation_strategy(ensemble, sub, results, config);
312            }
313        }
314    }
315}
316
317fn combine_chunkresults(chunkresults: Vec<EnsembleResult>) -> Result<EnsembleResult> {
318    if chunkresults.is_empty() {
319        return Err(ClusteringError::InvalidInput(
320            "No chunk results to combine".to_string(),
321        ));
322    }
323
324    // For simplicity, return the first result
325    // A real implementation would intelligently combine all chunk results
326    Ok(chunkresults.into_iter().next().expect("Operation failed"))
327}
328
329fn apply_differential_privacy(
330    mut result: EnsembleResult,
331    privacy_budget: f64,
332) -> Result<EnsembleResult> {
333    // Apply differential privacy mechanisms to the clustering result
334    // For now, just add small amount of noise to consensus labels
335    let mut rng = scirs2_core::random::thread_rng();
336
337    for label in result.consensus_labels.iter_mut() {
338        if rng.random::<f64>() < 0.05 {
339            // 5% chance to flip
340            *label = (*label + 1) % 3; // Simple label flipping
341        }
342    }
343
344    Ok(result)
345}
346
347fn secure_aggregate_results(
348    local_results: Vec<EnsembleResult>,
349    config: &FederationConfig,
350) -> Result<EnsembleResult> {
351    if local_results.is_empty() {
352        return Err(ClusteringError::InvalidInput(
353            "No local results to aggregate".to_string(),
354        ));
355    }
356
357    // For simplicity, perform simple majority voting
358    // A real implementation would use secure aggregation protocols
359    let n_samples = local_results[0].consensus_labels.len();
360    let mut consensus_labels = Array1::<i32>::zeros(n_samples);
361
362    for i in 0..n_samples {
363        let mut votes = HashMap::new();
364        for result in &local_results {
365            *votes.entry(result.consensus_labels[i]).or_insert(0) += 1;
366        }
367
368        // Find majority vote
369        let majority_label = votes
370            .into_iter()
371            .max_by_key(|(_, count)| *count)
372            .map(|(label_, _)| label_)
373            .unwrap_or(0);
374
375        consensus_labels[i] = majority_label;
376    }
377
378    // Create aggregated result
379    let mut aggregated = local_results.into_iter().next().expect("Operation failed");
380    aggregated.consensus_labels = consensus_labels;
381
382    Ok(aggregated)
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use scirs2_core::ndarray::Array2;
389
390    #[test]
391    fn test_simple_ensemble_clustering() {
392        let data = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect())
393            .expect("Operation failed");
394        let result = ensemble_clustering(data.view());
395        assert!(result.is_ok());
396    }
397
398    #[test]
399    fn test_bootstrap_ensemble() {
400        let data = Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64).collect())
401            .expect("Operation failed");
402        let result = bootstrap_ensemble(data.view(), 5, 0.8);
403        assert!(result.is_ok());
404    }
405
406    #[test]
407    fn test_adaptation_config() {
408        let config = AdaptationConfig {
409            chunk_size: 100,
410            min_evaluations: 3,
411            performance_threshold: 0.5,
412            max_clusterers: 20,
413            strategy: AdaptationStrategy::AddDiverse,
414        };
415        assert_eq!(config.chunk_size, 100);
416        assert_eq!(config.min_evaluations, 3);
417    }
418
419    #[test]
420    fn test_federation_config() {
421        let config = FederationConfig {
422            differential_privacy: true,
423            privacy_budget: 1.0,
424            aggregation_method: AggregationMethod::SecureAveraging,
425            max_rounds: 10,
426            convergence_threshold: 0.01,
427        };
428        assert!(config.differential_privacy);
429        assert_eq!(config.privacy_budget, 1.0);
430    }
431}