1use super::algorithms::EnsembleClusterer;
7use super::core::*;
8use crate::error::{ClusteringError, Result};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
10use scirs2_core::numeric::{Float, FromPrimitive};
11use scirs2_core::random::prelude::*;
12use std::collections::HashMap;
13use std::fmt::Debug;
14
15#[derive(Debug, Clone)]
17pub struct AdaptationConfig {
18 pub chunk_size: usize,
20 pub min_evaluations: usize,
22 pub performance_threshold: f64,
24 pub max_clusterers: usize,
26 pub strategy: AdaptationStrategy,
28}
29
30#[derive(Debug, Clone)]
32pub enum AdaptationStrategy {
33 AddDiverse,
35 RemoveWorst,
37 Replace,
39 Hybrid(Vec<AdaptationStrategy>),
41}
42
43#[derive(Debug, Clone)]
45pub struct FederationConfig {
46 pub differential_privacy: bool,
48 pub privacy_budget: f64,
50 pub aggregation_method: AggregationMethod,
52 pub max_rounds: usize,
54 pub convergence_threshold: f64,
56}
57
58#[derive(Debug, Clone)]
60pub enum AggregationMethod {
61 SecureAveraging,
63 HomomorphicEncryption,
65 MultiPartyComputation,
67}
68
69pub fn ensemble_clustering<F>(data: ArrayView2<F>) -> Result<EnsembleResult>
71where
72 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
73 f64: From<F>,
74{
75 let config = EnsembleConfig::default();
76 let ensemble = EnsembleClusterer::new(config);
77 ensemble.fit(data)
78}
79
80pub fn bootstrap_ensemble<F>(
82 data: ArrayView2<F>,
83 n_estimators: usize,
84 sample_ratio: f64,
85) -> Result<EnsembleResult>
86where
87 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
88 f64: From<F>,
89{
90 let config = EnsembleConfig {
91 n_estimators,
92 sampling_strategy: SamplingStrategy::Bootstrap { sample_ratio },
93 ..Default::default()
94 };
95 let ensemble = EnsembleClusterer::new(config);
96 ensemble.fit(data)
97}
98
99pub fn multi_algorithm_ensemble<F>(
101 data: ArrayView2<F>,
102 algorithms: Vec<ClusteringAlgorithm>,
103) -> Result<EnsembleResult>
104where
105 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
106 f64: From<F>,
107{
108 let config = EnsembleConfig {
109 diversity_strategy: Some(DiversityStrategy::AlgorithmDiversity { algorithms }),
110 ..Default::default()
111 };
112 let ensemble = EnsembleClusterer::new(config);
113 ensemble.fit(data)
114}
115
116pub fn meta_clustering_ensemble<F>(
121 data: ArrayView2<F>,
122 baseconfigs: Vec<EnsembleConfig>,
123 metaconfig: EnsembleConfig,
124) -> Result<EnsembleResult>
125where
126 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
127 f64: From<F>,
128{
129 let mut base_results = Vec::new();
130 let n_samples = data.shape()[0];
131
132 for config in baseconfigs {
134 let ensemble = EnsembleClusterer::new(config);
135 let result = ensemble.fit(data)?;
136 base_results.extend(result.individual_results);
137 }
138
139 let mut meta_features = Array2::zeros((n_samples, base_results.len()));
141 for (i, result) in base_results.iter().enumerate() {
142 for (j, &label) in result.labels.iter().enumerate() {
143 meta_features[[j, i]] = F::from(label).expect("Failed to convert to float");
144 }
145 }
146
147 let meta_ensemble = EnsembleClusterer::new(metaconfig);
149 let mut meta_result = meta_ensemble.fit(meta_features.view())?;
150
151 meta_result.individual_results = base_results;
153
154 Ok(meta_result)
155}
156
157pub fn adaptive_ensemble<F>(
162 data: ArrayView2<F>,
163 config: &EnsembleConfig,
164 adaptationconfig: AdaptationConfig,
165) -> Result<EnsembleResult>
166where
167 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
168 f64: From<F>,
169{
170 let mut ensemble = EnsembleClusterer::new(config.clone());
171 let mut current_results = Vec::new();
172 let chunk_size = adaptationconfig.chunk_size;
173
174 for chunk_start in (0..data.shape()[0]).step_by(chunk_size) {
176 let chunk_end = (chunk_start + chunk_size).min(data.shape()[0]);
177 let chunk_data = data.slice(s![chunk_start..chunk_end, ..]);
178
179 let chunk_result = ensemble.fit(chunk_data)?;
181
182 if current_results.len() >= adaptationconfig.min_evaluations {
184 let performance = evaluate_ensemble_performance(¤t_results);
185
186 if performance < adaptationconfig.performance_threshold {
187 ensemble =
189 adapt_ensemble_composition(ensemble, ¤t_results, &adaptationconfig)?;
190 }
191 }
192
193 current_results.push(chunk_result);
194 }
195
196 combine_chunkresults(current_results)
198}
199
200pub fn federated_ensemble<F>(
205 data_sources: Vec<ArrayView2<F>>,
206 config: &EnsembleConfig,
207 federationconfig: FederationConfig,
208) -> Result<EnsembleResult>
209where
210 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
211 f64: From<F>,
212{
213 let mut local_results = Vec::new();
214
215 for data_source in data_sources {
217 let local_ensemble = EnsembleClusterer::new(config.clone());
218 let result = local_ensemble.fit(data_source)?;
219
220 let private_result = if federationconfig.differential_privacy {
222 apply_differential_privacy(result, federationconfig.privacy_budget)?
223 } else {
224 result
225 };
226
227 local_results.push(private_result);
228 }
229
230 let aggregated_result = secure_aggregate_results(local_results, &federationconfig)?;
232
233 Ok(aggregated_result)
234}
235
236fn evaluate_ensemble_performance(results: &[EnsembleResult]) -> f64 {
239 if results.is_empty() {
240 return 0.0;
241 }
242
243 results.iter().map(|r| r.ensemble_quality).sum::<f64>() / results.len() as f64
245}
246
247fn adapt_ensemble_composition<F>(
248 mut ensemble: EnsembleClusterer<F>,
249 results: &[EnsembleResult],
250 config: &AdaptationConfig,
251) -> Result<EnsembleClusterer<F>>
252where
253 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
254{
255 apply_adaptation_strategy(&mut ensemble, &config.strategy, results, config);
256 Ok(ensemble)
257}
258
259fn apply_adaptation_strategy<F>(
270 ensemble: &mut EnsembleClusterer<F>,
271 strategy: &AdaptationStrategy,
272 results: &[EnsembleResult],
273 config: &AdaptationConfig,
274) where
275 F: Float + FromPrimitive + Debug + 'static + std::iter::Sum + std::fmt::Display + Send + Sync,
276{
277 match strategy {
278 AdaptationStrategy::RemoveWorst => {
279 if results.len() > 1 {
282 let mean_quality = evaluate_ensemble_performance(results);
283 let n_below = results
284 .iter()
285 .filter(|r| r.ensemble_quality < mean_quality)
286 .count();
287 let cfg = ensemble.config_mut();
289 let max_remove = (cfg.n_estimators / 4).max(1);
290 let remove = n_below.clamp(1, max_remove);
291 cfg.n_estimators = cfg.n_estimators.saturating_sub(remove).max(1);
292 }
293 }
294 AdaptationStrategy::AddDiverse => {
295 let cfg = ensemble.config_mut();
297 if cfg.n_estimators < config.max_clusterers {
298 let grow =
299 ((cfg.n_estimators / 4).max(1)).min(config.max_clusterers - cfg.n_estimators);
300 cfg.n_estimators += grow;
301 }
302 }
303 AdaptationStrategy::Replace => {
304 let cfg = ensemble.config_mut();
307 cfg.random_seed = Some(cfg.random_seed.map(|s| s.wrapping_add(1)).unwrap_or(1));
308 }
309 AdaptationStrategy::Hybrid(strategies) => {
310 for sub in strategies {
311 apply_adaptation_strategy(ensemble, sub, results, config);
312 }
313 }
314 }
315}
316
317fn combine_chunkresults(chunkresults: Vec<EnsembleResult>) -> Result<EnsembleResult> {
318 if chunkresults.is_empty() {
319 return Err(ClusteringError::InvalidInput(
320 "No chunk results to combine".to_string(),
321 ));
322 }
323
324 Ok(chunkresults.into_iter().next().expect("Operation failed"))
327}
328
329fn apply_differential_privacy(
330 mut result: EnsembleResult,
331 privacy_budget: f64,
332) -> Result<EnsembleResult> {
333 let mut rng = scirs2_core::random::thread_rng();
336
337 for label in result.consensus_labels.iter_mut() {
338 if rng.random::<f64>() < 0.05 {
339 *label = (*label + 1) % 3; }
342 }
343
344 Ok(result)
345}
346
347fn secure_aggregate_results(
348 local_results: Vec<EnsembleResult>,
349 config: &FederationConfig,
350) -> Result<EnsembleResult> {
351 if local_results.is_empty() {
352 return Err(ClusteringError::InvalidInput(
353 "No local results to aggregate".to_string(),
354 ));
355 }
356
357 let n_samples = local_results[0].consensus_labels.len();
360 let mut consensus_labels = Array1::<i32>::zeros(n_samples);
361
362 for i in 0..n_samples {
363 let mut votes = HashMap::new();
364 for result in &local_results {
365 *votes.entry(result.consensus_labels[i]).or_insert(0) += 1;
366 }
367
368 let majority_label = votes
370 .into_iter()
371 .max_by_key(|(_, count)| *count)
372 .map(|(label_, _)| label_)
373 .unwrap_or(0);
374
375 consensus_labels[i] = majority_label;
376 }
377
378 let mut aggregated = local_results.into_iter().next().expect("Operation failed");
380 aggregated.consensus_labels = consensus_labels;
381
382 Ok(aggregated)
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use scirs2_core::ndarray::Array2;
389
390 #[test]
391 fn test_simple_ensemble_clustering() {
392 let data = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect())
393 .expect("Operation failed");
394 let result = ensemble_clustering(data.view());
395 assert!(result.is_ok());
396 }
397
398 #[test]
399 fn test_bootstrap_ensemble() {
400 let data = Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64).collect())
401 .expect("Operation failed");
402 let result = bootstrap_ensemble(data.view(), 5, 0.8);
403 assert!(result.is_ok());
404 }
405
406 #[test]
407 fn test_adaptation_config() {
408 let config = AdaptationConfig {
409 chunk_size: 100,
410 min_evaluations: 3,
411 performance_threshold: 0.5,
412 max_clusterers: 20,
413 strategy: AdaptationStrategy::AddDiverse,
414 };
415 assert_eq!(config.chunk_size, 100);
416 assert_eq!(config.min_evaluations, 3);
417 }
418
419 #[test]
420 fn test_federation_config() {
421 let config = FederationConfig {
422 differential_privacy: true,
423 privacy_budget: 1.0,
424 aggregation_method: AggregationMethod::SecureAveraging,
425 max_rounds: 10,
426 convergence_threshold: 0.01,
427 };
428 assert!(config.differential_privacy);
429 assert_eq!(config.privacy_budget, 1.0);
430 }
431}