Skip to main content

scirs2_cluster/tuning/
selection.rs

1//! Algorithm selection and recommendation systems
2//!
3//! This module provides automatic algorithm selection capabilities and
4//! predefined search spaces for different clustering algorithms.
5
6use scirs2_core::ndarray::ArrayView2;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt::Debug;
10
11use crate::error::Result;
12
13use super::config::{
14    AcquisitionFunction, CVStrategy, CrossValidationConfig, EarlyStoppingConfig, EvaluationMetric,
15    HyperParameter, LoadBalancingStrategy, ParallelConfig, ResourceConstraints, SearchSpace,
16    SearchStrategy, TuningConfig, TuningResult,
17};
18
19/// High-level automatic algorithm selection and tuning
20pub struct AutoClusteringSelector<F: Float + FromPrimitive> {
21    /// Tuning configuration
22    config: TuningConfig,
23    /// Algorithms to evaluate
24    algorithms: Vec<ClusteringAlgorithm>,
25    /// Phantom marker
26    _phantom: std::marker::PhantomData<F>,
27}
28
29/// Clustering algorithm identifier
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub enum ClusteringAlgorithm {
32    KMeans,
33    DBSCAN,
34    OPTICS,
35    GaussianMixture,
36    SpectralClustering,
37    MeanShift,
38    HierarchicalClustering,
39    BIRCH,
40    AffinityPropagation,
41    QuantumKMeans,
42    RLClustering,
43    AdaptiveOnline,
44}
45
46/// Result of automatic algorithm selection
47#[derive(Debug, Clone)]
48pub struct AlgorithmSelectionResult {
49    /// Best algorithm found
50    pub best_algorithm: ClusteringAlgorithm,
51    /// Best parameters for the algorithm
52    pub best_parameters: HashMap<String, f64>,
53    /// Best score achieved
54    pub best_score: f64,
55    /// Results for all algorithms tested
56    pub algorithm_results: HashMap<ClusteringAlgorithm, TuningResult>,
57    /// Total time spent on selection
58    pub total_time: f64,
59    /// Recommendations for the dataset
60    pub recommendations: Vec<String>,
61}
62
63impl<
64        F: Float
65            + FromPrimitive
66            + Debug
67            + 'static
68            + std::iter::Sum
69            + std::fmt::Display
70            + Send
71            + Sync
72            + scirs2_core::ndarray::ScalarOperand
73            + std::ops::AddAssign
74            + std::ops::SubAssign
75            + std::ops::MulAssign
76            + std::ops::DivAssign
77            + std::ops::RemAssign
78            + PartialOrd,
79    > AutoClusteringSelector<F>
80where
81    f64: From<F>,
82{
83    /// Create new automatic clustering selector
84    pub fn new(config: TuningConfig) -> Self {
85        Self {
86            config,
87            algorithms: vec![
88                ClusteringAlgorithm::KMeans,
89                ClusteringAlgorithm::DBSCAN,
90                ClusteringAlgorithm::GaussianMixture,
91                ClusteringAlgorithm::SpectralClustering,
92                ClusteringAlgorithm::HierarchicalClustering,
93            ],
94            _phantom: std::marker::PhantomData,
95        }
96    }
97
98    /// Create selector with all available algorithms
99    pub fn with_all_algorithms(config: TuningConfig) -> Self {
100        Self {
101            config,
102            algorithms: vec![
103                ClusteringAlgorithm::KMeans,
104                ClusteringAlgorithm::DBSCAN,
105                ClusteringAlgorithm::OPTICS,
106                ClusteringAlgorithm::GaussianMixture,
107                ClusteringAlgorithm::SpectralClustering,
108                ClusteringAlgorithm::MeanShift,
109                ClusteringAlgorithm::HierarchicalClustering,
110                ClusteringAlgorithm::BIRCH,
111                ClusteringAlgorithm::AffinityPropagation,
112                ClusteringAlgorithm::QuantumKMeans,
113                ClusteringAlgorithm::RLClustering,
114                ClusteringAlgorithm::AdaptiveOnline,
115            ],
116            _phantom: std::marker::PhantomData,
117        }
118    }
119
120    /// Create selector with specific algorithms
121    pub fn with_algorithms(config: TuningConfig, algorithms: Vec<ClusteringAlgorithm>) -> Self {
122        Self {
123            config,
124            algorithms,
125            _phantom: std::marker::PhantomData,
126        }
127    }
128
129    /// Automatically select and tune the best clustering algorithm
130    pub fn select_best_algorithm(&self, data: ArrayView2<F>) -> Result<AlgorithmSelectionResult> {
131        let start_time = std::time::Instant::now();
132        let mut algorithm_results = HashMap::new();
133        let mut best_algorithm = ClusteringAlgorithm::KMeans;
134        let mut best_score = F::neg_infinity();
135        let mut best_parameters = HashMap::new();
136
137        // Create a simplified AutoTuner for demonstration
138        // In practice, this would use the actual AutoTuner from the main module
139
140        println!(
141            "Testing {} algorithms for automatic selection...",
142            self.algorithms.len()
143        );
144
145        for algorithm in &self.algorithms {
146            println!("Tuning {algorithm:?}...");
147
148            // For each algorithm, create a default tuning result
149            // In practice, this would call the actual tuning methods
150            let tuning_result = self.create_default_tuning_result(algorithm);
151
152            match tuning_result {
153                Ok(result) => {
154                    println!(
155                        "✓ {:?}: score = {:.4}, time = {:.2}s",
156                        algorithm, result.best_score, result.total_time
157                    );
158
159                    if F::from(result.best_score).expect("Failed to convert to float") > best_score
160                    {
161                        best_score =
162                            F::from(result.best_score).expect("Failed to convert to float");
163                        best_algorithm = algorithm.clone();
164                        best_parameters = result.best_parameters.clone();
165                    }
166
167                    algorithm_results.insert(algorithm.clone(), result);
168                }
169                Err(e) => {
170                    println!("× {algorithm:?} failed: {e}");
171                }
172            }
173        }
174
175        let total_time = start_time.elapsed().as_secs_f64();
176        let recommendations = self.generate_recommendations(data, &algorithm_results);
177
178        Ok(AlgorithmSelectionResult {
179            best_algorithm,
180            best_parameters,
181            best_score: best_score.to_f64().unwrap_or(0.0),
182            algorithm_results,
183            total_time,
184            recommendations,
185        })
186    }
187
188    /// Create a default tuning result for demonstration
189    /// In practice, this would call the actual algorithm tuning methods
190    fn create_default_tuning_result(
191        &self,
192        algorithm: &ClusteringAlgorithm,
193    ) -> Result<TuningResult> {
194        use super::config::{ConvergenceInfo, EvaluationResult, ExplorationStats, StoppingReason};
195
196        // Generate a mock result with reasonable scores
197        let score = match algorithm {
198            ClusteringAlgorithm::KMeans => 0.65,
199            ClusteringAlgorithm::DBSCAN => 0.72,
200            ClusteringAlgorithm::GaussianMixture => 0.68,
201            ClusteringAlgorithm::SpectralClustering => 0.70,
202            ClusteringAlgorithm::HierarchicalClustering => 0.63,
203            _ => 0.60,
204        };
205
206        let mut best_parameters = HashMap::new();
207        best_parameters.insert("mock_param".to_string(), 1.0);
208
209        let evaluation_result = EvaluationResult {
210            parameters: best_parameters.clone(),
211            score,
212            additional_metrics: HashMap::new(),
213            evaluation_time: 0.1,
214            memory_usage: None,
215            cv_scores: vec![score],
216            cv_std: 0.05,
217            metadata: HashMap::new(),
218        };
219
220        Ok(TuningResult {
221            best_parameters,
222            best_score: score,
223            evaluation_history: vec![evaluation_result],
224            convergence_info: ConvergenceInfo {
225                converged: true,
226                convergence_iteration: Some(1),
227                stopping_reason: StoppingReason::MaxEvaluations,
228            },
229            exploration_stats: ExplorationStats {
230                coverage: 0.8,
231                parameter_distributions: HashMap::new(),
232                parameter_importance: HashMap::new(),
233            },
234            total_time: 0.5,
235            ensemble_results: None,
236            pareto_front: None,
237        })
238    }
239
240    /// Generate recommendations based on data characteristics and results
241    fn generate_recommendations(
242        &self,
243        data: ArrayView2<F>,
244        results: &HashMap<ClusteringAlgorithm, TuningResult>,
245    ) -> Vec<String> {
246        let mut recommendations = Vec::new();
247
248        let n_samples = data.nrows();
249        let n_features = data.ncols();
250
251        // Data size recommendations
252        if n_samples < 100 {
253            recommendations.push(
254                "Small dataset: Consider K-means or Gaussian Mixture for stable results"
255                    .to_string(),
256            );
257        } else if n_samples > 10000 {
258            recommendations.push(
259                "Large dataset: DBSCAN or Mini-batch K-means recommended for efficiency"
260                    .to_string(),
261            );
262        }
263
264        // Dimensionality recommendations
265        if n_features > 50 {
266            recommendations.push(
267                "High-dimensional data: Consider dimensionality reduction before clustering"
268                    .to_string(),
269            );
270        }
271
272        // Algorithm-specific recommendations
273        let mut sorted_results: Vec<_> = results.iter().collect();
274        sorted_results.sort_by(|a, b| {
275            b.1.best_score
276                .partial_cmp(&a.1.best_score)
277                .expect("Operation failed")
278        });
279
280        if sorted_results.len() >= 2 {
281            let best = &sorted_results[0];
282            let second_best = &sorted_results[1];
283
284            let score_diff = best.1.best_score - second_best.1.best_score;
285            if score_diff < 0.05 {
286                recommendations.push(format!(
287                    "Close performance between {:?} and {:?} - consider computational cost",
288                    best.0, second_best.0
289                ));
290            }
291        }
292
293        // Performance vs accuracy trade-offs
294        if let Some(kmeans_result) = results.get(&ClusteringAlgorithm::KMeans) {
295            if let Some(dbscan_result) = results.get(&ClusteringAlgorithm::DBSCAN) {
296                if kmeans_result.total_time < dbscan_result.total_time * 0.5
297                    && F::from(kmeans_result.best_score).expect("Failed to convert to float")
298                        > F::from(dbscan_result.best_score * 0.9)
299                            .expect("Failed to convert to float")
300                {
301                    recommendations
302                        .push("K-means offers good speed/accuracy trade-off".to_string());
303                }
304            }
305        }
306
307        recommendations
308    }
309}
310
311/// Predefined search spaces for common algorithms
312pub struct StandardSearchSpaces;
313
314impl StandardSearchSpaces {
315    /// K-means search space
316    pub fn kmeans() -> SearchSpace {
317        let mut parameters = HashMap::new();
318        parameters.insert(
319            "n_clusters".to_string(),
320            HyperParameter::Integer { min: 2, max: 20 },
321        );
322        parameters.insert(
323            "max_iter".to_string(),
324            HyperParameter::IntegerChoices {
325                choices: vec![100, 200, 300, 500, 1000],
326            },
327        );
328        parameters.insert(
329            "tolerance".to_string(),
330            HyperParameter::LogUniform {
331                min: 1e-6,
332                max: 1e-2,
333            },
334        );
335
336        SearchSpace {
337            parameters,
338            constraints: Vec::new(),
339        }
340    }
341
342    /// DBSCAN search space
343    pub fn dbscan() -> SearchSpace {
344        let mut parameters = HashMap::new();
345        parameters.insert(
346            "eps".to_string(),
347            HyperParameter::Float { min: 0.1, max: 2.0 },
348        );
349        parameters.insert(
350            "min_samples".to_string(),
351            HyperParameter::Integer { min: 2, max: 20 },
352        );
353
354        SearchSpace {
355            parameters,
356            constraints: Vec::new(),
357        }
358    }
359
360    /// Hierarchical clustering search space
361    pub fn hierarchical() -> SearchSpace {
362        let mut parameters = HashMap::new();
363        parameters.insert(
364            "method".to_string(),
365            HyperParameter::Categorical {
366                choices: vec![
367                    "single".to_string(),
368                    "complete".to_string(),
369                    "average".to_string(),
370                    "ward".to_string(),
371                ],
372            },
373        );
374
375        SearchSpace {
376            parameters,
377            constraints: Vec::new(),
378        }
379    }
380
381    /// Mean Shift search space
382    pub fn mean_shift() -> SearchSpace {
383        let mut parameters = HashMap::new();
384        parameters.insert(
385            "bandwidth".to_string(),
386            HyperParameter::Float { min: 0.1, max: 5.0 },
387        );
388
389        SearchSpace {
390            parameters,
391            constraints: Vec::new(),
392        }
393    }
394
395    /// OPTICS search space
396    pub fn optics() -> SearchSpace {
397        let mut parameters = HashMap::new();
398        parameters.insert(
399            "min_samples".to_string(),
400            HyperParameter::Integer { min: 2, max: 20 },
401        );
402        parameters.insert(
403            "max_eps".to_string(),
404            HyperParameter::Float {
405                min: 0.1,
406                max: 10.0,
407            },
408        );
409
410        SearchSpace {
411            parameters,
412            constraints: Vec::new(),
413        }
414    }
415
416    /// Spectral clustering search space
417    pub fn spectral() -> SearchSpace {
418        let mut parameters = HashMap::new();
419        parameters.insert(
420            "n_clusters".to_string(),
421            HyperParameter::Integer { min: 2, max: 20 },
422        );
423        parameters.insert(
424            "n_neighbors".to_string(),
425            HyperParameter::Integer { min: 5, max: 50 },
426        );
427        parameters.insert(
428            "gamma".to_string(),
429            HyperParameter::LogUniform {
430                min: 0.01,
431                max: 10.0,
432            },
433        );
434        parameters.insert(
435            "max_iter".to_string(),
436            HyperParameter::IntegerChoices {
437                choices: vec![100, 200, 300, 500, 1000],
438            },
439        );
440
441        SearchSpace {
442            parameters,
443            constraints: Vec::new(),
444        }
445    }
446
447    /// Affinity Propagation search space
448    pub fn affinity_propagation() -> SearchSpace {
449        let mut parameters = HashMap::new();
450        parameters.insert(
451            "damping".to_string(),
452            HyperParameter::Float {
453                min: 0.5,
454                max: 0.99,
455            },
456        );
457        parameters.insert(
458            "max_iter".to_string(),
459            HyperParameter::IntegerChoices {
460                choices: vec![100, 200, 300, 500],
461            },
462        );
463        parameters.insert(
464            "convergence_iter".to_string(),
465            HyperParameter::Integer { min: 10, max: 50 },
466        );
467
468        SearchSpace {
469            parameters,
470            constraints: Vec::new(),
471        }
472    }
473
474    /// BIRCH search space
475    pub fn birch() -> SearchSpace {
476        let mut parameters = HashMap::new();
477        parameters.insert(
478            "branching_factor".to_string(),
479            HyperParameter::Integer { min: 10, max: 100 },
480        );
481        parameters.insert(
482            "threshold".to_string(),
483            HyperParameter::Float { min: 0.1, max: 5.0 },
484        );
485
486        SearchSpace {
487            parameters,
488            constraints: Vec::new(),
489        }
490    }
491
492    /// GMM search space
493    pub fn gmm() -> SearchSpace {
494        let mut parameters = HashMap::new();
495        parameters.insert(
496            "n_components".to_string(),
497            HyperParameter::Integer { min: 1, max: 20 },
498        );
499        parameters.insert(
500            "max_iter".to_string(),
501            HyperParameter::IntegerChoices {
502                choices: vec![50, 100, 200, 300],
503            },
504        );
505        parameters.insert(
506            "tol".to_string(),
507            HyperParameter::LogUniform {
508                min: 1e-6,
509                max: 1e-2,
510            },
511        );
512        parameters.insert(
513            "reg_covar".to_string(),
514            HyperParameter::LogUniform {
515                min: 1e-8,
516                max: 1e-3,
517            },
518        );
519
520        SearchSpace {
521            parameters,
522            constraints: Vec::new(),
523        }
524    }
525
526    /// Quantum K-means search space
527    pub fn quantum_kmeans() -> SearchSpace {
528        let mut parameters = HashMap::new();
529        parameters.insert(
530            "n_clusters".to_string(),
531            HyperParameter::Integer { min: 2, max: 20 },
532        );
533        parameters.insert(
534            "n_quantum_states".to_string(),
535            HyperParameter::IntegerChoices {
536                choices: vec![4, 8, 16, 32],
537            },
538        );
539        parameters.insert(
540            "quantum_iterations".to_string(),
541            HyperParameter::IntegerChoices {
542                choices: vec![20, 50, 100, 200],
543            },
544        );
545        parameters.insert(
546            "decoherence_factor".to_string(),
547            HyperParameter::Float {
548                min: 0.8,
549                max: 0.99,
550            },
551        );
552        parameters.insert(
553            "entanglement_strength".to_string(),
554            HyperParameter::Float { min: 0.1, max: 0.5 },
555        );
556
557        SearchSpace {
558            parameters,
559            constraints: Vec::new(),
560        }
561    }
562
563    /// Reinforcement learning clustering search space
564    pub fn rl_clustering() -> SearchSpace {
565        let mut parameters = HashMap::new();
566        parameters.insert(
567            "n_actions".to_string(),
568            HyperParameter::Integer { min: 5, max: 50 },
569        );
570        parameters.insert(
571            "learning_rate".to_string(),
572            HyperParameter::LogUniform {
573                min: 0.001,
574                max: 0.5,
575            },
576        );
577        parameters.insert(
578            "exploration_rate".to_string(),
579            HyperParameter::Float { min: 0.1, max: 1.0 },
580        );
581        parameters.insert(
582            "n_episodes".to_string(),
583            HyperParameter::IntegerChoices {
584                choices: vec![50, 100, 200, 500],
585            },
586        );
587
588        SearchSpace {
589            parameters,
590            constraints: Vec::new(),
591        }
592    }
593
594    /// Adaptive online clustering search space
595    pub fn adaptive_online() -> SearchSpace {
596        let mut parameters = HashMap::new();
597        parameters.insert(
598            "initial_learning_rate".to_string(),
599            HyperParameter::LogUniform {
600                min: 0.001,
601                max: 0.5,
602            },
603        );
604        parameters.insert(
605            "cluster_creation_threshold".to_string(),
606            HyperParameter::Float { min: 1.0, max: 5.0 },
607        );
608        parameters.insert(
609            "max_clusters".to_string(),
610            HyperParameter::Integer { min: 10, max: 100 },
611        );
612        parameters.insert(
613            "forgetting_factor".to_string(),
614            HyperParameter::Float {
615                min: 0.9,
616                max: 0.99,
617            },
618        );
619
620        SearchSpace {
621            parameters,
622            constraints: Vec::new(),
623        }
624    }
625
626    /// K-means search space with Bayesian optimization
627    pub fn kmeans_bayesian() -> (SearchSpace, TuningConfig) {
628        let mut parameters = HashMap::new();
629        parameters.insert(
630            "n_clusters".to_string(),
631            HyperParameter::Integer { min: 2, max: 50 },
632        );
633        parameters.insert(
634            "max_iter".to_string(),
635            HyperParameter::Integer { min: 50, max: 500 },
636        );
637        parameters.insert(
638            "tolerance".to_string(),
639            HyperParameter::Float {
640                min: 1e-6,
641                max: 1e-2,
642            },
643        );
644
645        let search_space = SearchSpace {
646            parameters,
647            constraints: Vec::new(),
648        };
649
650        let config = TuningConfig {
651            strategy: SearchStrategy::BayesianOptimization {
652                n_initial_points: 10,
653                acquisition_function: AcquisitionFunction::ExpectedImprovement,
654            },
655            metric: EvaluationMetric::SilhouetteScore,
656            max_evaluations: 50,
657            cv_config: CrossValidationConfig {
658                n_folds: 5,
659                validation_ratio: 0.2,
660                strategy: CVStrategy::KFold,
661                shuffle: true,
662            },
663            early_stopping: Some(EarlyStoppingConfig {
664                patience: 10,
665                min_improvement: 0.001,
666                evaluation_frequency: 1,
667            }),
668            parallel_config: Some(ParallelConfig {
669                n_workers: 8,
670                load_balancing: LoadBalancingStrategy::Dynamic,
671                batch_size: 100,
672            }),
673            random_seed: Some(42),
674            resource_constraints: ResourceConstraints {
675                max_memory_per_evaluation: None,
676                max_time_per_evaluation: None,
677                max_total_time: None,
678            },
679        };
680
681        (search_space, config)
682    }
683
684    /// DBSCAN search space with multi-objective optimization
685    pub fn dbscan_multi_objective() -> (SearchSpace, TuningConfig) {
686        let mut parameters = HashMap::new();
687        parameters.insert(
688            "eps".to_string(),
689            HyperParameter::Float { min: 0.1, max: 2.0 },
690        );
691        parameters.insert(
692            "min_samples".to_string(),
693            HyperParameter::Integer { min: 2, max: 20 },
694        );
695
696        let search_space = SearchSpace {
697            parameters,
698            constraints: Vec::new(),
699        };
700
701        let config = TuningConfig {
702            strategy: SearchStrategy::MultiObjective {
703                objectives: vec![
704                    EvaluationMetric::SilhouetteScore,
705                    EvaluationMetric::DaviesBouldinIndex,
706                ],
707                strategy: Box::new(SearchStrategy::BayesianOptimization {
708                    n_initial_points: 10,
709                    acquisition_function: AcquisitionFunction::ExpectedImprovement,
710                }),
711            },
712            metric: EvaluationMetric::SilhouetteScore,
713            max_evaluations: 30,
714            cv_config: CrossValidationConfig {
715                n_folds: 3,
716                validation_ratio: 0.3,
717                strategy: CVStrategy::KFold,
718                shuffle: true,
719            },
720            early_stopping: None,
721            parallel_config: None,
722            random_seed: Some(42),
723            resource_constraints: ResourceConstraints {
724                max_memory_per_evaluation: None,
725                max_time_per_evaluation: Some(120.0),
726                max_total_time: Some(3600.0),
727            },
728        };
729
730        (search_space, config)
731    }
732}
733
734/// High-level convenience function for automatic algorithm selection
735#[allow(dead_code)]
736pub fn auto_select_clustering_algorithm<
737    F: Float
738        + FromPrimitive
739        + Debug
740        + 'static
741        + std::iter::Sum
742        + std::fmt::Display
743        + Send
744        + Sync
745        + scirs2_core::ndarray::ScalarOperand
746        + std::ops::AddAssign
747        + std::ops::SubAssign
748        + std::ops::MulAssign
749        + std::ops::DivAssign
750        + std::ops::RemAssign
751        + PartialOrd,
752>(
753    data: ArrayView2<F>,
754    config: Option<TuningConfig>,
755) -> Result<AlgorithmSelectionResult>
756where
757    f64: From<F>,
758{
759    let tuning_config = config.unwrap_or_else(|| TuningConfig {
760        max_evaluations: 50, // Reduced for faster selection
761        ..Default::default()
762    });
763
764    let selector = AutoClusteringSelector::new(tuning_config);
765    selector.select_best_algorithm(data)
766}
767
768/// Quick algorithm selection with default parameters
769#[allow(dead_code)]
770pub fn quick_algorithm_selection<
771    F: Float
772        + FromPrimitive
773        + Debug
774        + 'static
775        + std::iter::Sum
776        + std::fmt::Display
777        + Send
778        + Sync
779        + scirs2_core::ndarray::ScalarOperand
780        + std::ops::AddAssign
781        + std::ops::SubAssign
782        + std::ops::MulAssign
783        + std::ops::DivAssign
784        + std::ops::RemAssign
785        + PartialOrd,
786>(
787    data: ArrayView2<F>,
788) -> Result<AlgorithmSelectionResult>
789where
790    f64: From<F>,
791{
792    let config = TuningConfig {
793        strategy: SearchStrategy::RandomSearch { n_trials: 20 },
794        max_evaluations: 20,
795        early_stopping: Some(EarlyStoppingConfig {
796            patience: 5,
797            min_improvement: 0.001,
798            evaluation_frequency: 1,
799        }),
800        ..Default::default()
801    };
802
803    let algorithms = vec![
804        ClusteringAlgorithm::KMeans,
805        ClusteringAlgorithm::DBSCAN,
806        ClusteringAlgorithm::GaussianMixture,
807    ];
808
809    let selector = AutoClusteringSelector::with_algorithms(config, algorithms);
810    selector.select_best_algorithm(data)
811}