sklears_ensemble/
compression.rs

1//! Model compression techniques for large ensembles
2//!
3//! This module provides various compression strategies to reduce memory usage
4//! and improve inference speed for large ensemble models, including knowledge
5//! distillation, pruning, quantization, and ensemble reduction techniques.
6
7// ❌ REMOVED: rand_chacha::rand_core - use scirs2_core::random instead
8// ❌ REMOVED: rand_chacha::scirs2_core::random::rngs::StdRng - use scirs2_core::random instead
9use scirs2_core::ndarray::{Array1, Array2};
10#[allow(unused_imports)]
11use scirs2_core::random::SeedableRng;
12use sklears_core::error::{Result, SklearsError};
13use sklears_core::prelude::Predict;
14use sklears_core::traits::{Estimator, PredictProba};
15use sklears_core::types::{Float, Int};
16use std::collections::HashMap;
17
18/// Helper function to generate random f64 from scirs2_core::random::RngCore
19fn gen_f64(rng: &mut impl scirs2_core::random::RngCore) -> f64 {
20    let mut bytes = [0u8; 8];
21    rng.fill_bytes(&mut bytes);
22    f64::from_le_bytes(bytes) / f64::from_le_bytes([255u8; 8])
23}
24
25/// Helper function to generate random value in range from scirs2_core::random::RngCore
26fn gen_range_usize(
27    rng: &mut impl scirs2_core::random::RngCore,
28    range: std::ops::Range<usize>,
29) -> usize {
30    let mut bytes = [0u8; 8];
31    rng.fill_bytes(&mut bytes);
32    let val = u64::from_le_bytes(bytes);
33    range.start + (val as usize % (range.end - range.start))
34}
35
36/// Compression strategy enumeration
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum CompressionStrategy {
39    /// Knowledge distillation - train a smaller model to mimic ensemble
40    KnowledgeDistillation,
41    /// Ensemble pruning - remove redundant or weak models
42    EnsemblePruning,
43    /// Model quantization - reduce precision of model parameters
44    Quantization,
45    /// Weight sharing - share parameters across models
46    WeightSharing,
47    /// Low-rank approximation - approximate weight matrices
48    LowRankApproximation,
49    /// Sparse representation - remove near-zero weights
50    SparseRepresentation,
51    /// Hierarchical compression - compress at multiple levels
52    HierarchicalCompression,
53    /// Bayesian optimization for ensemble size optimization
54    BayesianOptimization,
55}
56
57/// Compression configuration
58#[derive(Debug, Clone)]
59pub struct CompressionConfig {
60    /// Primary compression strategy
61    pub strategy: CompressionStrategy,
62    /// Target compression ratio (0.0 to 1.0)
63    pub compression_ratio: Float,
64    /// Quality threshold - minimum acceptable performance
65    pub quality_threshold: Float,
66    /// Number of bits for quantization
67    pub quantization_bits: Option<u8>,
68    /// Sparsity level for sparse representation
69    pub sparsity_level: Option<Float>,
70    /// Rank for low-rank approximation
71    pub low_rank: Option<usize>,
72    /// Distillation temperature for knowledge distillation
73    pub distillation_temperature: Float,
74    /// Number of distillation epochs
75    pub distillation_epochs: usize,
76    /// Learning rate for distillation
77    pub distillation_lr: Float,
78    /// Enable progressive compression
79    pub progressive_compression: bool,
80    /// Bayesian optimization configuration
81    pub bayes_opt_n_calls: usize,
82    pub bayes_opt_n_initial: usize,
83    pub bayes_opt_acquisition_kappa: Float,
84    pub bayes_opt_random_state: Option<u64>,
85    /// Performance-cost trade-off weight
86    pub performance_cost_trade_off: Float,
87}
88
89impl Default for CompressionConfig {
90    fn default() -> Self {
91        Self {
92            strategy: CompressionStrategy::EnsemblePruning,
93            compression_ratio: 0.5,
94            quality_threshold: 0.95,
95            quantization_bits: Some(8),
96            sparsity_level: Some(0.1),
97            low_rank: Some(64),
98            distillation_temperature: 3.0,
99            distillation_epochs: 100,
100            distillation_lr: 0.01,
101            progressive_compression: false,
102            bayes_opt_n_calls: 50,
103            bayes_opt_n_initial: 10,
104            bayes_opt_acquisition_kappa: 2.576,
105            bayes_opt_random_state: None,
106            performance_cost_trade_off: 0.7,
107        }
108    }
109}
110
111/// Compression statistics
112#[derive(Debug, Clone)]
113pub struct CompressionStats {
114    /// Original model size (bytes)
115    pub original_size_bytes: usize,
116    /// Compressed model size (bytes)
117    pub compressed_size_bytes: usize,
118    /// Compression ratio achieved
119    pub compression_ratio: Float,
120    /// Performance before compression
121    pub original_accuracy: Float,
122    /// Performance after compression
123    pub compressed_accuracy: Float,
124    /// Performance degradation
125    pub accuracy_loss: Float,
126    /// Compression time (seconds)
127    pub compression_time_secs: Float,
128    /// Inference speedup factor
129    pub speedup_factor: Float,
130    /// Memory reduction factor
131    pub memory_reduction_factor: Float,
132}
133
134/// Ensemble compressor
135pub struct EnsembleCompressor {
136    config: CompressionConfig,
137    stats: Option<CompressionStats>,
138}
139
140/// Compressed ensemble representation
141#[derive(Debug, Clone)]
142pub struct CompressedEnsemble<T> {
143    /// Compressed models
144    pub models: Vec<T>,
145    /// Compression metadata
146    pub metadata: CompressionMetadata,
147    /// Model weights for voting/averaging
148    pub weights: Option<Array1<Float>>,
149}
150
151/// Compression metadata
152#[derive(Debug, Clone)]
153pub struct CompressionMetadata {
154    /// Original number of models
155    pub original_count: usize,
156    /// Compressed number of models
157    pub compressed_count: usize,
158    /// Compression strategy used
159    pub strategy: CompressionStrategy,
160    /// Model mapping (original -> compressed indices)
161    pub model_mapping: HashMap<usize, usize>,
162    /// Quantization parameters
163    pub quantization_params: Option<QuantizationParams>,
164    /// Sparsity information
165    pub sparsity_info: Option<SparsityInfo>,
166}
167
168/// Quantization parameters
169#[derive(Debug, Clone)]
170pub struct QuantizationParams {
171    /// Number of bits per parameter
172    pub bits: u8,
173    /// Scale factor for quantization
174    pub scale: Float,
175    /// Zero point for quantization
176    pub zero_point: Int,
177    /// Min and max values for clipping
178    pub min_val: Float,
179    pub max_val: Float,
180}
181
182/// Sparsity information
183#[derive(Debug, Clone)]
184pub struct SparsityInfo {
185    /// Sparsity level (fraction of zero weights)
186    pub sparsity_level: Float,
187    /// Number of non-zero parameters
188    pub non_zero_params: usize,
189    /// Total number of parameters
190    pub total_params: usize,
191    /// Sparse representation indices
192    pub sparse_indices: Vec<usize>,
193}
194
195/// Knowledge distillation trainer
196pub struct KnowledgeDistillationTrainer {
197    temperature: Float,
198    alpha: Float, // Weight for distillation loss
199    beta: Float,  // Weight for ground truth loss
200}
201
202/// Ensemble pruning algorithm
203pub struct EnsemblePruner {
204    /// Diversity threshold for pruning
205    diversity_threshold: Float,
206    /// Performance threshold
207    performance_threshold: Float,
208    /// Correlation threshold
209    correlation_threshold: Float,
210}
211
212/// Bayesian optimization for ensemble size selection
213#[derive(Debug)]
214pub struct BayesianEnsembleOptimizer {
215    /// Configuration for optimization
216    config: CompressionConfig,
217    /// Gaussian Process for surrogate modeling
218    gp: SimpleGaussianProcess,
219    /// Random number generator
220    rng: scirs2_core::random::rngs::StdRng,
221    /// Evaluation history (ensemble_size, accuracy, cost, objective)
222    evaluations: Vec<(usize, Float, Float, Float)>,
223    /// Best configuration found
224    best_config: Option<(usize, Float)>,
225}
226
227/// Simple Gaussian Process for Bayesian optimization
228#[derive(Debug)]
229struct SimpleGaussianProcess {
230    x_train: Array2<Float>,
231    y_train: Array1<Float>,
232    noise_level: Float,
233    length_scale: Float,
234    signal_variance: Float,
235}
236
237/// Acquisition function for Bayesian optimization
238#[derive(Debug, Clone, Copy)]
239pub enum AcquisitionFunction {
240    /// Expected Improvement
241    ExpectedImprovement,
242    /// Upper Confidence Bound
243    UpperConfidenceBound { kappa: Float },
244    /// Probability of Improvement
245    ProbabilityOfImprovement,
246}
247
248impl EnsembleCompressor {
249    /// Create a new ensemble compressor
250    pub fn new(config: CompressionConfig) -> Self {
251        Self {
252            config,
253            stats: None,
254        }
255    }
256
257    /// Create compressor with knowledge distillation
258    pub fn knowledge_distillation(compression_ratio: Float, temperature: Float) -> Self {
259        Self::new(CompressionConfig {
260            strategy: CompressionStrategy::KnowledgeDistillation,
261            compression_ratio,
262            distillation_temperature: temperature,
263            ..Default::default()
264        })
265    }
266
267    /// Create compressor with ensemble pruning
268    pub fn ensemble_pruning(compression_ratio: Float, quality_threshold: Float) -> Self {
269        Self::new(CompressionConfig {
270            strategy: CompressionStrategy::EnsemblePruning,
271            compression_ratio,
272            quality_threshold,
273            ..Default::default()
274        })
275    }
276
277    /// Create compressor with quantization
278    pub fn quantization(bits: u8) -> Self {
279        Self::new(CompressionConfig {
280            strategy: CompressionStrategy::Quantization,
281            quantization_bits: Some(bits),
282            compression_ratio: 1.0 - (bits as Float / 32.0), // Assume 32-bit baseline
283            ..Default::default()
284        })
285    }
286
287    /// Create compressor with Bayesian optimization
288    pub fn bayesian_optimization(
289        performance_cost_trade_off: Float,
290        n_calls: usize,
291        random_state: Option<u64>,
292    ) -> Self {
293        Self::new(CompressionConfig {
294            strategy: CompressionStrategy::BayesianOptimization,
295            performance_cost_trade_off,
296            bayes_opt_n_calls: n_calls,
297            bayes_opt_n_initial: (n_calls / 5).max(5),
298            bayes_opt_random_state: random_state,
299            ..Default::default()
300        })
301    }
302
303    /// Compress an ensemble using the configured strategy
304    pub fn compress<T>(
305        &mut self,
306        ensemble: &[T],
307        x_val: &Array2<Float>,
308        y_val: &Array1<Int>,
309    ) -> Result<CompressedEnsemble<T>>
310    where
311        T: Clone + Predict<Array2<Float>, Array1<Int>>,
312    {
313        let start_time = std::time::Instant::now();
314
315        // Calculate original performance
316        let original_accuracy = self.evaluate_ensemble_accuracy(ensemble, x_val, y_val)?;
317        let original_size = self.estimate_ensemble_size(ensemble);
318
319        let compressed = match self.config.strategy {
320            CompressionStrategy::EnsemblePruning => {
321                self.compress_by_pruning(ensemble, x_val, y_val)?
322            }
323            CompressionStrategy::Quantization => self.compress_by_quantization(ensemble)?,
324            CompressionStrategy::WeightSharing => self.compress_by_weight_sharing(ensemble)?,
325            CompressionStrategy::SparseRepresentation => {
326                self.compress_by_sparsification(ensemble)?
327            }
328            CompressionStrategy::HierarchicalCompression => {
329                self.compress_hierarchically(ensemble, x_val, y_val)?
330            }
331            CompressionStrategy::BayesianOptimization => {
332                self.compress_by_bayesian_optimization(ensemble, x_val, y_val)?
333            }
334            _ => {
335                return Err(SklearsError::InvalidInput(format!(
336                    "Compression strategy {:?} not yet implemented",
337                    self.config.strategy
338                )));
339            }
340        };
341
342        // Calculate compressed performance
343        let compressed_accuracy = self.evaluate_compressed_accuracy(&compressed, x_val, y_val)?;
344        let compressed_size = self.estimate_compressed_size(&compressed);
345
346        // Update statistics
347        self.stats = Some(CompressionStats {
348            original_size_bytes: original_size,
349            compressed_size_bytes: compressed_size,
350            compression_ratio: 1.0 - (compressed_size as Float / original_size as Float),
351            original_accuracy,
352            compressed_accuracy,
353            accuracy_loss: original_accuracy - compressed_accuracy,
354            compression_time_secs: start_time.elapsed().as_secs_f64(),
355            speedup_factor: original_size as Float / compressed_size as Float,
356            memory_reduction_factor: original_size as Float / compressed_size as Float,
357        });
358
359        Ok(compressed)
360    }
361
362    /// Compress ensemble by pruning weak or redundant models
363    fn compress_by_pruning<T>(
364        &self,
365        ensemble: &[T],
366        x_val: &Array2<Float>,
367        y_val: &Array1<Int>,
368    ) -> Result<CompressedEnsemble<T>>
369    where
370        T: Clone + Predict<Array2<Float>, Array1<Int>>,
371    {
372        let target_count = (ensemble.len() as Float * self.config.compression_ratio) as usize;
373        let target_count = target_count.max(1); // Keep at least one model
374
375        // Calculate individual model performances
376        let mut model_scores = Vec::new();
377        for (i, model) in ensemble.iter().enumerate() {
378            let predictions = model.predict(x_val)?;
379            let accuracy = self.calculate_accuracy(&predictions, y_val);
380            model_scores.push((i, accuracy));
381        }
382
383        // Sort by performance and keep top models
384        model_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
385
386        let selected_indices: Vec<usize> = model_scores
387            .into_iter()
388            .take(target_count)
389            .map(|(i, _)| i)
390            .collect();
391
392        let mut selected_models = Vec::new();
393        let mut model_mapping = HashMap::new();
394
395        for (new_idx, &orig_idx) in selected_indices.iter().enumerate() {
396            selected_models.push(ensemble[orig_idx].clone());
397            model_mapping.insert(orig_idx, new_idx);
398        }
399
400        // Calculate uniform weights for selected models
401        let weights =
402            Array1::from_elem(selected_models.len(), 1.0 / selected_models.len() as Float);
403
404        Ok(CompressedEnsemble {
405            models: selected_models,
406            metadata: CompressionMetadata {
407                original_count: ensemble.len(),
408                compressed_count: target_count,
409                strategy: CompressionStrategy::EnsemblePruning,
410                model_mapping,
411                quantization_params: None,
412                sparsity_info: None,
413            },
414            weights: Some(weights),
415        })
416    }
417
418    /// Compress ensemble using quantization
419    fn compress_by_quantization<T>(&self, ensemble: &[T]) -> Result<CompressedEnsemble<T>>
420    where
421        T: Clone,
422    {
423        let bits = self.config.quantization_bits.unwrap_or(8);
424
425        // For demonstration, we'll create quantization parameters
426        // In a real implementation, this would quantize the actual model weights
427        let quantization_params = QuantizationParams {
428            bits,
429            scale: 1.0 / (2_i32.pow(bits as u32 - 1) as Float),
430            zero_point: 0,
431            min_val: -1.0,
432            max_val: 1.0,
433        };
434
435        // Clone all models (in practice, these would be quantized versions)
436        let compressed_models = ensemble.to_vec();
437
438        Ok(CompressedEnsemble {
439            models: compressed_models,
440            metadata: CompressionMetadata {
441                original_count: ensemble.len(),
442                compressed_count: ensemble.len(),
443                strategy: CompressionStrategy::Quantization,
444                model_mapping: (0..ensemble.len()).map(|i| (i, i)).collect(),
445                quantization_params: Some(quantization_params),
446                sparsity_info: None,
447            },
448            weights: None,
449        })
450    }
451
452    /// Compress ensemble using weight sharing
453    fn compress_by_weight_sharing<T>(&self, ensemble: &[T]) -> Result<CompressedEnsemble<T>>
454    where
455        T: Clone,
456    {
457        // Simplified weight sharing - group similar models
458        let num_groups = (ensemble.len() as Float * self.config.compression_ratio) as usize;
459        let num_groups = num_groups.max(1);
460
461        let models_per_group = ensemble.len() / num_groups;
462        let mut compressed_models = Vec::new();
463        let mut model_mapping = HashMap::new();
464
465        for group_id in 0..num_groups {
466            let start_idx = group_id * models_per_group;
467            let end_idx = if group_id == num_groups - 1 {
468                ensemble.len()
469            } else {
470                (group_id + 1) * models_per_group
471            };
472
473            // Use the first model in each group as representative
474            if start_idx < ensemble.len() {
475                compressed_models.push(ensemble[start_idx].clone());
476
477                // Map all models in group to the representative
478                for orig_idx in start_idx..end_idx {
479                    model_mapping.insert(orig_idx, group_id);
480                }
481            }
482        }
483
484        let compressed_count = compressed_models.len();
485        Ok(CompressedEnsemble {
486            models: compressed_models,
487            metadata: CompressionMetadata {
488                original_count: ensemble.len(),
489                compressed_count,
490                strategy: CompressionStrategy::WeightSharing,
491                model_mapping,
492                quantization_params: None,
493                sparsity_info: None,
494            },
495            weights: None,
496        })
497    }
498
499    /// Compress ensemble using sparsification
500    fn compress_by_sparsification<T>(&self, ensemble: &[T]) -> Result<CompressedEnsemble<T>>
501    where
502        T: Clone,
503    {
504        let sparsity_level = self.config.sparsity_level.unwrap_or(0.1);
505
506        // For demonstration, create sparsity info
507        // In practice, this would modify the actual model weights
508        let total_params = ensemble.len() * 1000; // Assumed 1000 params per model
509        let non_zero_params = (total_params as Float * (1.0 - sparsity_level)) as usize;
510
511        let sparsity_info = SparsityInfo {
512            sparsity_level,
513            non_zero_params,
514            total_params,
515            sparse_indices: (0..non_zero_params).collect(),
516        };
517
518        // Clone all models (in practice, these would be sparsified versions)
519        let compressed_models = ensemble.to_vec();
520
521        Ok(CompressedEnsemble {
522            models: compressed_models,
523            metadata: CompressionMetadata {
524                original_count: ensemble.len(),
525                compressed_count: ensemble.len(),
526                strategy: CompressionStrategy::SparseRepresentation,
527                model_mapping: (0..ensemble.len()).map(|i| (i, i)).collect(),
528                quantization_params: None,
529                sparsity_info: Some(sparsity_info),
530            },
531            weights: None,
532        })
533    }
534
535    /// Compress ensemble using hierarchical approach
536    fn compress_hierarchically<T>(
537        &self,
538        ensemble: &[T],
539        x_val: &Array2<Float>,
540        y_val: &Array1<Int>,
541    ) -> Result<CompressedEnsemble<T>>
542    where
543        T: Clone + Predict<Array2<Float>, Array1<Int>>,
544    {
545        // First, apply pruning
546        let pruned = self.compress_by_pruning(ensemble, x_val, y_val)?;
547
548        // Then, apply quantization to the pruned ensemble
549        let quantized = self.compress_by_quantization(&pruned.models)?;
550
551        let compressed_count = quantized.models.len();
552        Ok(CompressedEnsemble {
553            models: quantized.models,
554            metadata: CompressionMetadata {
555                original_count: ensemble.len(),
556                compressed_count,
557                strategy: CompressionStrategy::HierarchicalCompression,
558                model_mapping: pruned.metadata.model_mapping,
559                quantization_params: quantized.metadata.quantization_params,
560                sparsity_info: None,
561            },
562            weights: pruned.weights,
563        })
564    }
565
566    /// Compress ensemble using Bayesian optimization
567    fn compress_by_bayesian_optimization<T>(
568        &self,
569        ensemble: &[T],
570        x_val: &Array2<Float>,
571        y_val: &Array1<Int>,
572    ) -> Result<CompressedEnsemble<T>>
573    where
574        T: Clone + Predict<Array2<Float>, Array1<Int>>,
575    {
576        let mut optimizer = BayesianEnsembleOptimizer::new(
577            self.config.clone(),
578            self.config.bayes_opt_random_state.unwrap_or(42),
579        );
580
581        // Find optimal ensemble size using Bayesian optimization
582        let optimal_size = optimizer.optimize_ensemble_size(ensemble, x_val, y_val)?;
583
584        // Use the optimal size to perform intelligent pruning
585        let mut pruning_config = self.config.clone();
586        pruning_config.strategy = CompressionStrategy::EnsemblePruning;
587        pruning_config.compression_ratio = optimal_size as Float / ensemble.len() as Float;
588
589        let temp_compressor = EnsembleCompressor::new(pruning_config);
590        let mut compressed = temp_compressor.compress_by_pruning(ensemble, x_val, y_val)?;
591
592        // Update metadata to reflect Bayesian optimization strategy
593        compressed.metadata.strategy = CompressionStrategy::BayesianOptimization;
594
595        Ok(compressed)
596    }
597
598    /// Evaluate ensemble accuracy
599    fn evaluate_ensemble_accuracy<T>(
600        &self,
601        ensemble: &[T],
602        x_val: &Array2<Float>,
603        y_val: &Array1<Int>,
604    ) -> Result<Float>
605    where
606        T: Predict<Array2<Float>, Array1<Int>>,
607    {
608        let n_samples = x_val.nrows();
609        let mut correct = 0;
610
611        for i in 0..n_samples {
612            let x_sample = x_val.row(i).insert_axis(scirs2_core::ndarray::Axis(0));
613            let mut votes = HashMap::new();
614
615            // Collect votes from all models
616            for model in ensemble {
617                let prediction = model.predict(&x_sample.to_owned())?;
618                if !prediction.is_empty() {
619                    *votes.entry(prediction[0]).or_insert(0) += 1;
620                }
621            }
622
623            // Find majority vote
624            if let Some((&predicted_class, _)) = votes.iter().max_by_key(|(_, &count)| count) {
625                if predicted_class == y_val[i] {
626                    correct += 1;
627                }
628            }
629        }
630
631        Ok(correct as Float / n_samples as Float)
632    }
633
634    /// Evaluate compressed ensemble accuracy
635    fn evaluate_compressed_accuracy<T>(
636        &self,
637        compressed: &CompressedEnsemble<T>,
638        x_val: &Array2<Float>,
639        y_val: &Array1<Int>,
640    ) -> Result<Float>
641    where
642        T: Predict<Array2<Float>, Array1<Int>>,
643    {
644        self.evaluate_ensemble_accuracy(&compressed.models, x_val, y_val)
645    }
646
647    /// Calculate accuracy for predictions
648    fn calculate_accuracy(&self, predictions: &Array1<Int>, y_true: &Array1<Int>) -> Float {
649        let correct = predictions
650            .iter()
651            .zip(y_true.iter())
652            .map(|(pred, true_val)| if pred == true_val { 1 } else { 0 })
653            .sum::<i32>();
654
655        correct as Float / predictions.len() as Float
656    }
657
658    /// Estimate ensemble size in bytes
659    fn estimate_ensemble_size<T>(&self, ensemble: &[T]) -> usize {
660        // Simplified size estimation
661        // In practice, this would calculate actual memory usage
662        ensemble.len() * 1024 * 1024 // Assume 1MB per model
663    }
664
665    /// Estimate compressed ensemble size
666    fn estimate_compressed_size<T>(&self, compressed: &CompressedEnsemble<T>) -> usize {
667        let base_size = compressed.models.len() * 1024 * 1024;
668
669        match compressed.metadata.strategy {
670            CompressionStrategy::Quantization => {
671                if let Some(ref params) = compressed.metadata.quantization_params {
672                    (base_size as Float * (params.bits as Float / 32.0)) as usize
673                } else {
674                    base_size
675                }
676            }
677            CompressionStrategy::SparseRepresentation => {
678                if let Some(ref info) = compressed.metadata.sparsity_info {
679                    (base_size as Float * (1.0 - info.sparsity_level)) as usize
680                } else {
681                    base_size
682                }
683            }
684            _ => base_size,
685        }
686    }
687
688    /// Get compression statistics
689    pub fn stats(&self) -> Option<&CompressionStats> {
690        self.stats.as_ref()
691    }
692
693    /// Reset compression statistics
694    pub fn reset_stats(&mut self) {
695        self.stats = None;
696    }
697}
698
699impl KnowledgeDistillationTrainer {
700    /// Create new knowledge distillation trainer
701    pub fn new(temperature: Float, alpha: Float, beta: Float) -> Self {
702        Self {
703            temperature,
704            alpha,
705            beta,
706        }
707    }
708
709    /// Train student model using teacher ensemble
710    pub fn distill<Teacher, Student>(
711        &self,
712        teachers: &[Teacher],
713        student: Student,
714        x_train: &Array2<Float>,
715        y_train: &Array1<Int>,
716    ) -> Result<Student>
717    where
718        Teacher: PredictProba<Array2<Float>, Array2<Float>>,
719        Student: Clone,
720    {
721        // Simplified distillation - in practice this would involve
722        // actual gradient-based optimization
723
724        // Get soft targets from teacher ensemble
725        let _soft_targets = self.get_ensemble_soft_targets(teachers, x_train)?;
726
727        // Here you would train the student model using both
728        // hard targets (y_train) and soft targets from teachers
729        // with the appropriate loss function combination
730
731        Ok(student)
732    }
733
734    /// Get soft targets from teacher ensemble
735    fn get_ensemble_soft_targets<T>(
736        &self,
737        teachers: &[T],
738        x: &Array2<Float>,
739    ) -> Result<Array2<Float>>
740    where
741        T: PredictProba<Array2<Float>, Array2<Float>>,
742    {
743        let n_samples = x.nrows();
744        let first_proba = teachers[0].predict_proba(x)?;
745        let n_classes = first_proba.ncols();
746
747        let mut ensemble_proba = Array2::zeros((n_samples, n_classes));
748
749        // Average probabilities from all teachers
750        for teacher in teachers {
751            let proba = teacher.predict_proba(x)?;
752            ensemble_proba = ensemble_proba + proba;
753        }
754
755        ensemble_proba /= teachers.len() as Float;
756
757        // Apply temperature scaling
758        ensemble_proba.mapv_inplace(|p| (p / self.temperature).exp());
759
760        // Normalize probabilities
761        for mut row in ensemble_proba.rows_mut() {
762            let sum = row.sum();
763            if sum > 0.0 {
764                row /= sum;
765            }
766        }
767
768        Ok(ensemble_proba)
769    }
770}
771
772impl EnsemblePruner {
773    /// Create new ensemble pruner
774    pub fn new(
775        diversity_threshold: Float,
776        performance_threshold: Float,
777        correlation_threshold: Float,
778    ) -> Self {
779        Self {
780            diversity_threshold,
781            performance_threshold,
782            correlation_threshold,
783        }
784    }
785
786    /// Prune ensemble based on diversity and performance criteria
787    pub fn prune<T>(
788        &self,
789        ensemble: &[T],
790        x_val: &Array2<Float>,
791        y_val: &Array1<Int>,
792        target_size: usize,
793    ) -> Result<Vec<usize>>
794    where
795        T: Predict<Array2<Float>, Array1<Int>>,
796    {
797        let n_models = ensemble.len();
798        let target_size = target_size.min(n_models);
799
800        // Calculate performance scores
801        let mut model_scores = Vec::new();
802        for (i, model) in ensemble.iter().enumerate() {
803            let predictions = model.predict(x_val)?;
804            let accuracy = self.calculate_accuracy(&predictions, y_val);
805            model_scores.push((i, accuracy));
806        }
807
808        // Sort by performance
809        model_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
810
811        // Select top performers that meet diversity criteria
812        let mut selected = Vec::new();
813        selected.push(model_scores[0].0); // Start with best model
814
815        for (model_idx, score) in model_scores.iter().skip(1) {
816            if selected.len() >= target_size {
817                break;
818            }
819
820            if *score >= self.performance_threshold {
821                // Check diversity with already selected models
822                let is_diverse =
823                    self.check_diversity_with_selected(*model_idx, &selected, ensemble, x_val)?;
824
825                if is_diverse {
826                    selected.push(*model_idx);
827                }
828            }
829        }
830
831        Ok(selected)
832    }
833
834    /// Check if a model is diverse enough compared to selected models
835    fn check_diversity_with_selected<T>(
836        &self,
837        candidate_idx: usize,
838        selected: &[usize],
839        ensemble: &[T],
840        x_val: &Array2<Float>,
841    ) -> Result<bool>
842    where
843        T: Predict<Array2<Float>, Array1<Int>>,
844    {
845        let candidate_pred = ensemble[candidate_idx].predict(x_val)?;
846
847        for &selected_idx in selected {
848            let selected_pred = ensemble[selected_idx].predict(x_val)?;
849            let correlation = self.calculate_correlation(&candidate_pred, &selected_pred);
850
851            if correlation > self.correlation_threshold {
852                return Ok(false); // Too similar to existing model
853            }
854        }
855
856        Ok(true)
857    }
858
859    /// Calculate correlation between two prediction vectors
860    fn calculate_correlation(&self, pred1: &Array1<Int>, pred2: &Array1<Int>) -> Float {
861        if pred1.len() != pred2.len() {
862            return 0.0;
863        }
864
865        let n = pred1.len() as Float;
866        let agreements = pred1
867            .iter()
868            .zip(pred2.iter())
869            .map(|(a, b)| if a == b { 1.0 } else { 0.0 })
870            .sum::<Float>();
871
872        agreements / n
873    }
874
875    /// Calculate accuracy
876    fn calculate_accuracy(&self, predictions: &Array1<Int>, y_true: &Array1<Int>) -> Float {
877        let correct = predictions
878            .iter()
879            .zip(y_true.iter())
880            .map(|(pred, true_val)| if pred == true_val { 1 } else { 0 })
881            .sum::<i32>();
882
883        correct as Float / predictions.len() as Float
884    }
885}
886
887impl BayesianEnsembleOptimizer {
888    /// Create a new Bayesian ensemble optimizer
889    pub fn new(config: CompressionConfig, random_state: u64) -> Self {
890        Self {
891            config,
892            gp: SimpleGaussianProcess::new(0.01),
893            rng: scirs2_core::random::rngs::StdRng::seed_from_u64(random_state),
894            evaluations: Vec::new(),
895            best_config: None,
896        }
897    }
898
899    /// Optimize ensemble size using Bayesian optimization
900    pub fn optimize_ensemble_size<T>(
901        &mut self,
902        ensemble: &[T],
903        x_val: &Array2<Float>,
904        y_val: &Array1<Int>,
905    ) -> Result<usize>
906    where
907        T: Clone + Predict<Array2<Float>, Array1<Int>>,
908    {
909        let max_size = ensemble.len();
910        let min_size = 1;
911
912        // Phase 1: Random exploration
913        for _ in 0..self.config.bayes_opt_n_initial {
914            let ensemble_size = gen_range_usize(&mut self.rng, min_size..(max_size + 1));
915            let objective =
916                self.evaluate_ensemble_configuration(ensemble, x_val, y_val, ensemble_size)?;
917            self.evaluations.push((ensemble_size, 0.0, 0.0, objective));
918            self.update_best(ensemble_size, objective);
919        }
920
921        // Phase 2: Bayesian optimization
922        let remaining_calls = self
923            .config
924            .bayes_opt_n_calls
925            .saturating_sub(self.config.bayes_opt_n_initial);
926
927        for _ in 0..remaining_calls {
928            // Fit GP to current evaluations
929            self.fit_surrogate_model()?;
930
931            // Select next ensemble size using acquisition function
932            let next_size = self.select_next_ensemble_size(min_size, max_size)?;
933            let objective =
934                self.evaluate_ensemble_configuration(ensemble, x_val, y_val, next_size)?;
935
936            self.evaluations.push((next_size, 0.0, 0.0, objective));
937            self.update_best(next_size, objective);
938        }
939
940        Ok(self
941            .best_config
942            .map(|(size, _)| size)
943            .unwrap_or(max_size / 2))
944    }
945
946    /// Evaluate objective function for a given ensemble size
947    fn evaluate_ensemble_configuration<T>(
948        &self,
949        ensemble: &[T],
950        x_val: &Array2<Float>,
951        y_val: &Array1<Int>,
952        ensemble_size: usize,
953    ) -> Result<Float>
954    where
955        T: Clone + Predict<Array2<Float>, Array1<Int>>,
956    {
957        if ensemble_size == 0 || ensemble_size > ensemble.len() {
958            return Ok(Float::NEG_INFINITY);
959        }
960
961        // Select top performing models up to ensemble_size
962        let mut model_scores = Vec::new();
963        for (i, model) in ensemble.iter().enumerate() {
964            let predictions = model.predict(x_val)?;
965            let accuracy = self.calculate_accuracy(&predictions, y_val);
966            model_scores.push((i, accuracy));
967        }
968
969        // Sort by performance and take top ensemble_size models
970        model_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
971        let selected_models: Vec<&T> = model_scores
972            .iter()
973            .take(ensemble_size)
974            .map(|(i, _)| &ensemble[*i])
975            .collect();
976
977        // Evaluate ensemble performance
978        let ensemble_accuracy = self.evaluate_subset_accuracy(&selected_models, x_val, y_val)?;
979
980        // Calculate cost (inverse of ensemble size, normalized)
981        let cost = ensemble_size as Float / ensemble.len() as Float;
982
983        // Combine performance and cost with trade-off parameter
984        let objective = self.config.performance_cost_trade_off * ensemble_accuracy
985            - (1.0 - self.config.performance_cost_trade_off) * cost;
986
987        Ok(objective)
988    }
989
990    /// Evaluate accuracy of a subset of models
991    fn evaluate_subset_accuracy<T>(
992        &self,
993        models: &[&T],
994        x_val: &Array2<Float>,
995        y_val: &Array1<Int>,
996    ) -> Result<Float>
997    where
998        T: Predict<Array2<Float>, Array1<Int>>,
999    {
1000        let n_samples = x_val.nrows();
1001        let mut correct = 0;
1002
1003        for i in 0..n_samples {
1004            let x_sample = x_val.row(i).insert_axis(scirs2_core::ndarray::Axis(0));
1005            let mut votes = HashMap::new();
1006
1007            // Collect votes from selected models
1008            for model in models {
1009                let prediction = model.predict(&x_sample.to_owned())?;
1010                if !prediction.is_empty() {
1011                    *votes.entry(prediction[0]).or_insert(0) += 1;
1012                }
1013            }
1014
1015            // Find majority vote
1016            if let Some((&predicted_class, _)) = votes.iter().max_by_key(|(_, &count)| count) {
1017                if predicted_class == y_val[i] {
1018                    correct += 1;
1019                }
1020            }
1021        }
1022
1023        Ok(correct as Float / n_samples as Float)
1024    }
1025
1026    /// Calculate accuracy for predictions
1027    fn calculate_accuracy(&self, predictions: &Array1<Int>, y_true: &Array1<Int>) -> Float {
1028        let correct = predictions
1029            .iter()
1030            .zip(y_true.iter())
1031            .map(|(pred, true_val)| if pred == true_val { 1 } else { 0 })
1032            .sum::<i32>();
1033
1034        correct as Float / predictions.len() as Float
1035    }
1036
1037    /// Fit surrogate model to current evaluations
1038    fn fit_surrogate_model(&mut self) -> Result<()> {
1039        if self.evaluations.is_empty() {
1040            return Ok(());
1041        }
1042
1043        let n_points = self.evaluations.len();
1044        let mut x_train = Array2::zeros((n_points, 1));
1045        let mut y_train = Array1::zeros(n_points);
1046
1047        for (i, &(ensemble_size, _, _, objective)) in self.evaluations.iter().enumerate() {
1048            x_train[[i, 0]] = ensemble_size as Float;
1049            y_train[i] = objective;
1050        }
1051
1052        self.gp.fit(&x_train, &y_train)?;
1053        Ok(())
1054    }
1055
1056    /// Select next ensemble size using acquisition function
1057    fn select_next_ensemble_size(&mut self, min_size: usize, max_size: usize) -> Result<usize> {
1058        let mut best_size = min_size;
1059        let mut best_acquisition = Float::NEG_INFINITY;
1060
1061        // Evaluate acquisition function for all possible sizes
1062        for size in min_size..=max_size {
1063            let x_test = Array2::from_shape_vec((1, 1), vec![size as Float])
1064                .map_err(|_| SklearsError::InvalidInput("Invalid size".to_string()))?;
1065
1066            let acquisition_value = self.compute_acquisition(&x_test)?;
1067
1068            if acquisition_value > best_acquisition {
1069                best_acquisition = acquisition_value;
1070                best_size = size;
1071            }
1072        }
1073
1074        Ok(best_size)
1075    }
1076
1077    /// Compute acquisition function value
1078    fn compute_acquisition(&self, x: &Array2<Float>) -> Result<Float> {
1079        if self.evaluations.is_empty() {
1080            return Ok(0.0);
1081        }
1082
1083        let (mean, std) = self.gp.predict(x)?;
1084        let mu = mean[0];
1085        let sigma = std[0];
1086
1087        let acquisition_func = AcquisitionFunction::UpperConfidenceBound {
1088            kappa: self.config.bayes_opt_acquisition_kappa,
1089        };
1090
1091        let acquisition = match acquisition_func {
1092            AcquisitionFunction::UpperConfidenceBound { kappa } => mu + kappa * sigma,
1093            AcquisitionFunction::ExpectedImprovement => {
1094                let best_score = self
1095                    .best_config
1096                    .map(|(_, score)| score)
1097                    .unwrap_or(Float::NEG_INFINITY);
1098                if sigma <= 1e-8 {
1099                    0.0
1100                } else {
1101                    let improvement = mu - best_score;
1102                    let z = improvement / sigma;
1103                    let phi = 0.5 * (1.0 + erf(z / 2.0_f64.sqrt()));
1104                    let density = (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt();
1105                    improvement * phi + sigma * density
1106                }
1107            }
1108            AcquisitionFunction::ProbabilityOfImprovement => {
1109                let best_score = self
1110                    .best_config
1111                    .map(|(_, score)| score)
1112                    .unwrap_or(Float::NEG_INFINITY);
1113                if sigma <= 1e-8 {
1114                    0.0
1115                } else {
1116                    let z = (mu - best_score) / sigma;
1117                    0.5 * (1.0 + erf(z / 2.0_f64.sqrt()))
1118                }
1119            }
1120        };
1121
1122        Ok(acquisition)
1123    }
1124
1125    /// Update best configuration
1126    fn update_best(&mut self, ensemble_size: usize, objective: Float) {
1127        if let Some((_, best_obj)) = self.best_config {
1128            if objective > best_obj {
1129                self.best_config = Some((ensemble_size, objective));
1130            }
1131        } else {
1132            self.best_config = Some((ensemble_size, objective));
1133        }
1134    }
1135
1136    /// Get best configuration found
1137    pub fn best_config(&self) -> Option<(usize, Float)> {
1138        self.best_config
1139    }
1140
1141    /// Get all evaluations
1142    pub fn evaluations(&self) -> &[(usize, Float, Float, Float)] {
1143        &self.evaluations
1144    }
1145}
1146
1147impl SimpleGaussianProcess {
1148    fn new(noise_level: Float) -> Self {
1149        Self {
1150            x_train: Array2::zeros((0, 0)),
1151            y_train: Array1::zeros(0),
1152            noise_level,
1153            length_scale: 1.0,
1154            signal_variance: 1.0,
1155        }
1156    }
1157
1158    fn fit(&mut self, x: &Array2<Float>, y: &Array1<Float>) -> Result<()> {
1159        self.x_train = x.clone();
1160        self.y_train = y.clone();
1161
1162        // Simple hyperparameter estimation
1163        if x.nrows() > 1 {
1164            // Estimate length scale as median pairwise distance
1165            let mut distances = Vec::new();
1166            for i in 0..x.nrows() {
1167                for j in (i + 1)..x.nrows() {
1168                    let mut dist_sq = 0.0;
1169                    for k in 0..x.ncols() {
1170                        let diff = x[[i, k]] - x[[j, k]];
1171                        dist_sq += diff * diff;
1172                    }
1173                    distances.push(dist_sq.sqrt());
1174                }
1175            }
1176
1177            if !distances.is_empty() {
1178                distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
1179                self.length_scale = distances[distances.len() / 2].max(0.1);
1180            }
1181
1182            // Estimate signal variance as variance of y
1183            let y_mean = y.mean().unwrap_or(0.0);
1184            self.signal_variance =
1185                y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<Float>() / y.len() as Float;
1186            self.signal_variance = self.signal_variance.max(0.01);
1187        }
1188
1189        Ok(())
1190    }
1191
1192    fn predict(&self, x: &Array2<Float>) -> Result<(Array1<Float>, Array1<Float>)> {
1193        let n_test = x.nrows();
1194        let mut mean = Array1::zeros(n_test);
1195        let mut std = Array1::zeros(n_test);
1196
1197        if self.x_train.nrows() == 0 {
1198            return Ok((mean, Array1::from_elem(n_test, self.signal_variance.sqrt())));
1199        }
1200
1201        // Simple GP prediction using RBF kernel
1202        for i in 0..n_test {
1203            let mut kernel_values = Array1::zeros(self.x_train.nrows());
1204            let mut total_weight = 0.0;
1205
1206            for j in 0..self.x_train.nrows() {
1207                let mut dist_sq = 0.0;
1208                for k in 0..x.ncols() {
1209                    let diff = x[[i, k]] - self.x_train[[j, k]];
1210                    dist_sq += diff * diff;
1211                }
1212
1213                let kernel_val = self.signal_variance
1214                    * (-dist_sq / (2.0 * self.length_scale * self.length_scale)).exp();
1215                kernel_values[j] = kernel_val;
1216                total_weight += kernel_val;
1217            }
1218
1219            if total_weight > 1e-8 {
1220                kernel_values /= total_weight;
1221                mean[i] = kernel_values.dot(&self.y_train);
1222
1223                let kernel_var: Float = kernel_values
1224                    .iter()
1225                    .zip(self.y_train.iter())
1226                    .map(|(&k, &y)| k * (y - mean[i]).powi(2))
1227                    .sum();
1228
1229                let predictive_var = kernel_var + self.noise_level;
1230                std[i] = predictive_var.sqrt();
1231            } else {
1232                mean[i] = self.y_train.mean().unwrap_or(0.0);
1233                std[i] = self.signal_variance.sqrt();
1234            }
1235        }
1236
1237        Ok((mean, std))
1238    }
1239}
1240
1241// Simple error function approximation
1242fn erf(x: f64) -> f64 {
1243    let a1 = 0.254829592;
1244    let a2 = -0.284496736;
1245    let a3 = 1.421413741;
1246    let a4 = -1.453152027;
1247    let a5 = 1.061405429;
1248    let p = 0.3275911;
1249
1250    let sign = if x < 0.0 { -1.0 } else { 1.0 };
1251    let x = x.abs();
1252
1253    let t = 1.0 / (1.0 + p * x);
1254    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
1255
1256    sign * y
1257}
1258
1259#[allow(non_snake_case)]
1260#[cfg(test)]
1261mod tests {
1262    use super::*;
1263    use scirs2_core::ndarray::array;
1264
1265    // Mock model for testing
1266    #[derive(Clone)]
1267    struct MockModel {
1268        prediction: Int,
1269    }
1270
1271    impl Predict<Array2<Float>, Array1<Int>> for MockModel {
1272        fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
1273            Ok(Array1::from_elem(x.nrows(), self.prediction))
1274        }
1275    }
1276
1277    #[test]
1278    fn test_compression_config() {
1279        let config = CompressionConfig::default();
1280        assert_eq!(config.strategy, CompressionStrategy::EnsemblePruning);
1281        assert_eq!(config.compression_ratio, 0.5);
1282    }
1283
1284    #[test]
1285    fn test_ensemble_compressor_creation() {
1286        let compressor = EnsembleCompressor::knowledge_distillation(0.3, 2.0);
1287        assert_eq!(
1288            compressor.config.strategy,
1289            CompressionStrategy::KnowledgeDistillation
1290        );
1291        assert_eq!(compressor.config.compression_ratio, 0.3);
1292        assert_eq!(compressor.config.distillation_temperature, 2.0);
1293    }
1294
1295    #[test]
1296    fn test_ensemble_pruning() {
1297        let ensemble = vec![
1298            MockModel { prediction: 0 },
1299            MockModel { prediction: 1 },
1300            MockModel { prediction: 0 },
1301            MockModel { prediction: 1 },
1302        ];
1303
1304        let x_val = Array2::zeros((10, 5));
1305        let y_val = Array1::zeros(10);
1306
1307        let mut compressor = EnsembleCompressor::ensemble_pruning(0.5, 0.9);
1308        let compressed = compressor.compress(&ensemble, &x_val, &y_val).unwrap();
1309
1310        assert_eq!(compressed.models.len(), 2); // 50% compression
1311        assert_eq!(
1312            compressed.metadata.strategy,
1313            CompressionStrategy::EnsemblePruning
1314        );
1315        assert!(compressed.weights.is_some());
1316    }
1317
1318    #[test]
1319    fn test_quantization_compression() {
1320        let ensemble = vec![MockModel { prediction: 0 }];
1321        let mut compressor = EnsembleCompressor::quantization(8);
1322
1323        let x_val = Array2::zeros((10, 5));
1324        let y_val = Array1::zeros(10);
1325
1326        let compressed = compressor.compress(&ensemble, &x_val, &y_val).unwrap();
1327
1328        assert_eq!(compressed.models.len(), 1);
1329        assert_eq!(
1330            compressed.metadata.strategy,
1331            CompressionStrategy::Quantization
1332        );
1333        assert!(compressed.metadata.quantization_params.is_some());
1334    }
1335
1336    #[test]
1337    fn test_knowledge_distillation_trainer() {
1338        let trainer = KnowledgeDistillationTrainer::new(3.0, 0.7, 0.3);
1339        assert_eq!(trainer.temperature, 3.0);
1340        assert_eq!(trainer.alpha, 0.7);
1341        assert_eq!(trainer.beta, 0.3);
1342    }
1343
1344    #[test]
1345    fn test_ensemble_pruner() {
1346        let pruner = EnsemblePruner::new(0.8, 0.9, 0.7);
1347        assert_eq!(pruner.diversity_threshold, 0.8);
1348        assert_eq!(pruner.performance_threshold, 0.9);
1349        assert_eq!(pruner.correlation_threshold, 0.7);
1350    }
1351
1352    #[test]
1353    fn test_compression_stats() {
1354        let stats = CompressionStats {
1355            original_size_bytes: 1000,
1356            compressed_size_bytes: 500,
1357            compression_ratio: 0.5,
1358            original_accuracy: 0.95,
1359            compressed_accuracy: 0.92,
1360            accuracy_loss: 0.03,
1361            compression_time_secs: 1.5,
1362            speedup_factor: 2.0,
1363            memory_reduction_factor: 2.0,
1364        };
1365
1366        assert_eq!(stats.compression_ratio, 0.5);
1367        assert_eq!(stats.speedup_factor, 2.0);
1368    }
1369
1370    #[test]
1371    fn test_bayesian_optimization_compressor_creation() {
1372        let compressor = EnsembleCompressor::bayesian_optimization(0.8, 25, Some(42));
1373        assert_eq!(
1374            compressor.config.strategy,
1375            CompressionStrategy::BayesianOptimization
1376        );
1377        assert_eq!(compressor.config.performance_cost_trade_off, 0.8);
1378        assert_eq!(compressor.config.bayes_opt_n_calls, 25);
1379        assert_eq!(compressor.config.bayes_opt_random_state, Some(42));
1380    }
1381
1382    #[test]
1383    fn test_bayesian_ensemble_optimizer() {
1384        let ensemble = vec![
1385            MockModel { prediction: 0 },
1386            MockModel { prediction: 1 },
1387            MockModel { prediction: 0 },
1388            MockModel { prediction: 1 },
1389            MockModel { prediction: 0 },
1390        ];
1391
1392        let x_val = Array2::zeros((20, 3));
1393        let y_val = Array1::zeros(20);
1394
1395        let config = CompressionConfig {
1396            strategy: CompressionStrategy::BayesianOptimization,
1397            performance_cost_trade_off: 0.7,
1398            bayes_opt_n_calls: 10,
1399            bayes_opt_n_initial: 3,
1400            bayes_opt_acquisition_kappa: 2.0,
1401            bayes_opt_random_state: Some(42),
1402            ..Default::default()
1403        };
1404
1405        let mut optimizer = BayesianEnsembleOptimizer::new(config, 42);
1406        let optimal_size = optimizer
1407            .optimize_ensemble_size(&ensemble, &x_val, &y_val)
1408            .unwrap();
1409
1410        // Should find some reasonable ensemble size
1411        assert!(optimal_size >= 1);
1412        assert!(optimal_size <= ensemble.len());
1413        assert!(optimizer.evaluations().len() == 10);
1414        assert!(optimizer.best_config().is_some());
1415    }
1416
1417    #[test]
1418    fn test_bayesian_optimization_compression() {
1419        let ensemble = vec![
1420            MockModel { prediction: 0 },
1421            MockModel { prediction: 1 },
1422            MockModel { prediction: 0 },
1423            MockModel { prediction: 1 },
1424            MockModel { prediction: 0 },
1425            MockModel { prediction: 1 },
1426        ];
1427
1428        let x_val = Array2::zeros((15, 4));
1429        let y_val = Array1::zeros(15);
1430
1431        let mut compressor = EnsembleCompressor::bayesian_optimization(0.6, 8, Some(123));
1432        let compressed = compressor.compress(&ensemble, &x_val, &y_val).unwrap();
1433
1434        assert_eq!(
1435            compressed.metadata.strategy,
1436            CompressionStrategy::BayesianOptimization
1437        );
1438        assert!(compressed.models.len() <= ensemble.len());
1439        assert!(compressed.models.len() >= 1);
1440
1441        // Should have compression statistics
1442        let stats = compressor.stats().unwrap();
1443        assert!(stats.compression_ratio >= 0.0);
1444        assert!(stats.compression_ratio <= 1.0);
1445    }
1446
1447    #[test]
1448    fn test_simple_gaussian_process() {
1449        let mut gp = SimpleGaussianProcess::new(0.01);
1450
1451        // Test with no training data
1452        let x_test = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1453        let (mean, std) = gp.predict(&x_test).unwrap();
1454        assert_eq!(mean.len(), 2);
1455        assert_eq!(std.len(), 2);
1456
1457        // Train with some data
1458        let x_train = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1459        let y_train = Array1::from_vec(vec![0.0, 1.0, 4.0]);
1460        gp.fit(&x_train, &y_train).unwrap();
1461
1462        // Test prediction
1463        let (mean, std) = gp.predict(&x_test).unwrap();
1464        assert_eq!(mean.len(), 2);
1465        assert_eq!(std.len(), 2);
1466
1467        // Predictions should be finite
1468        for &m in mean.iter() {
1469            assert!(m.is_finite());
1470        }
1471        for &s in std.iter() {
1472            assert!(s.is_finite() && s > 0.0);
1473        }
1474    }
1475
1476    #[test]
1477    fn test_acquisition_functions() {
1478        let config = CompressionConfig {
1479            strategy: CompressionStrategy::BayesianOptimization,
1480            bayes_opt_acquisition_kappa: 1.96,
1481            performance_cost_trade_off: 0.5,
1482            ..Default::default()
1483        };
1484
1485        let mut optimizer = BayesianEnsembleOptimizer::new(config, 42);
1486
1487        // Add some mock evaluations
1488        optimizer.evaluations.push((3, 0.0, 0.0, 0.8));
1489        optimizer.evaluations.push((5, 0.0, 0.0, 0.6));
1490        optimizer.update_best(3, 0.8);
1491
1492        // Test acquisition function computation
1493        optimizer.fit_surrogate_model().unwrap();
1494        let x_test = Array2::from_shape_vec((1, 1), vec![4.0]).unwrap();
1495        let acquisition = optimizer.compute_acquisition(&x_test).unwrap();
1496
1497        // Should be finite
1498        assert!(acquisition.is_finite());
1499    }
1500
1501    #[test]
1502    fn test_erf_approximation() {
1503        // Test known values
1504        assert!((erf(0.0) - 0.0).abs() < 1e-6);
1505        assert!((erf(1.0) - 0.8427).abs() < 1e-3);
1506        assert!((erf(-1.0) - (-0.8427)).abs() < 1e-3);
1507
1508        // Test that erf is bounded between -1 and 1
1509        for x in [-5.0, -2.0, -1.0, 0.0, 1.0, 2.0, 5.0] {
1510            let result = erf(x);
1511            assert!(result >= -1.0 && result <= 1.0);
1512        }
1513    }
1514}