kizzasi_inference/
ensemble.rs

1//! Model ensembling for robust inference
2//!
3//! This module provides ensembling strategies to combine predictions from
4//! multiple models, improving accuracy and robustness.
5//!
6//! ## Ensemble Methods
7//!
8//! 1. **Averaging**: Average predictions from all models
9//! 2. **Weighted**: Weighted average based on model confidence
10//! 3. **Voting**: Majority voting for discrete outputs
11//! 4. **Stacking**: Use a meta-model to combine predictions
12
13use crate::error::{InferenceError, InferenceResult};
14use crate::sampling::{Sampler, SamplingConfig};
15use kizzasi_model::AutoregressiveModel;
16use scirs2_core::ndarray::Array1;
17
18/// Ensemble combination strategy
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum EnsembleStrategy {
21    /// Simple averaging of all model outputs
22    Average,
23    /// Weighted average (weights must be provided)
24    Weighted,
25    /// Maximum voting (selects most common prediction)
26    Voting,
27    /// Product of experts (multiply probabilities)
28    ProductOfExperts,
29}
30
31/// Configuration for model ensemble
32#[derive(Debug, Clone)]
33pub struct EnsembleConfig {
34    /// Ensemble strategy
35    pub strategy: EnsembleStrategy,
36    /// Model weights (for weighted averaging)
37    pub weights: Option<Vec<f32>>,
38    /// Whether to normalize outputs before combining
39    pub normalize_outputs: bool,
40    /// Temperature for final sampling
41    pub temperature: f32,
42}
43
44impl Default for EnsembleConfig {
45    fn default() -> Self {
46        Self {
47            strategy: EnsembleStrategy::Average,
48            weights: None,
49            normalize_outputs: true,
50            temperature: 1.0,
51        }
52    }
53}
54
55impl EnsembleConfig {
56    /// Create a new ensemble configuration
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Set ensemble strategy
62    pub fn strategy(mut self, strategy: EnsembleStrategy) -> Self {
63        self.strategy = strategy;
64        self
65    }
66
67    /// Set model weights for weighted averaging
68    pub fn weights(mut self, weights: Vec<f32>) -> Self {
69        self.weights = Some(weights);
70        self
71    }
72
73    /// Enable/disable output normalization
74    pub fn normalize_outputs(mut self, normalize: bool) -> Self {
75        self.normalize_outputs = normalize;
76        self
77    }
78
79    /// Set temperature for final sampling
80    pub fn temperature(mut self, temp: f32) -> Self {
81        self.temperature = temp;
82        self
83    }
84}
85
86/// Model ensemble for combining multiple models
87pub struct ModelEnsemble {
88    /// Component models
89    models: Vec<Box<dyn AutoregressiveModel>>,
90    /// Configuration
91    config: EnsembleConfig,
92    /// Sampler for final output
93    sampler: Sampler,
94}
95
96impl ModelEnsemble {
97    /// Create a new model ensemble
98    pub fn new(
99        models: Vec<Box<dyn AutoregressiveModel>>,
100        config: EnsembleConfig,
101    ) -> InferenceResult<Self> {
102        if models.is_empty() {
103            return Err(InferenceError::ForwardError(
104                "Ensemble must contain at least one model".to_string(),
105            ));
106        }
107
108        // Validate weights if provided
109        if let Some(ref weights) = config.weights {
110            if weights.len() != models.len() {
111                return Err(InferenceError::DimensionMismatch {
112                    expected: models.len(),
113                    got: weights.len(),
114                });
115            }
116
117            // Check weights are positive and sum to 1
118            let sum: f32 = weights.iter().sum();
119            if (sum - 1.0).abs() > 1e-6 {
120                return Err(InferenceError::ForwardError(format!(
121                    "Ensemble weights must sum to 1.0, got {}",
122                    sum
123                )));
124            }
125        }
126
127        let sampler_config = SamplingConfig::new().temperature(config.temperature);
128        let sampler = Sampler::new(sampler_config);
129
130        Ok(Self {
131            models,
132            config,
133            sampler,
134        })
135    }
136
137    /// Get number of models in ensemble
138    pub fn num_models(&self) -> usize {
139        self.models.len()
140    }
141
142    /// Perform ensemble inference step
143    pub fn step(&mut self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
144        // Collect predictions from all models
145        let mut predictions = Vec::with_capacity(self.models.len());
146
147        for model in &mut self.models {
148            let pred = model
149                .step(input)
150                .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
151            predictions.push(pred);
152        }
153
154        // Combine predictions based on strategy
155        self.combine_predictions(&predictions)
156    }
157
158    /// Combine predictions from multiple models
159    fn combine_predictions(&mut self, predictions: &[Array1<f32>]) -> InferenceResult<Array1<f32>> {
160        if predictions.is_empty() {
161            return Err(InferenceError::ForwardError(
162                "No predictions to combine".to_string(),
163            ));
164        }
165
166        let output_dim = predictions[0].len();
167
168        // Verify all predictions have same dimension
169        for pred in predictions {
170            if pred.len() != output_dim {
171                return Err(InferenceError::DimensionMismatch {
172                    expected: output_dim,
173                    got: pred.len(),
174                });
175            }
176        }
177
178        match self.config.strategy {
179            EnsembleStrategy::Average => self.combine_average(predictions, output_dim),
180            EnsembleStrategy::Weighted => self.combine_weighted(predictions, output_dim),
181            EnsembleStrategy::Voting => self.combine_voting(predictions),
182            EnsembleStrategy::ProductOfExperts => {
183                self.combine_product_of_experts(predictions, output_dim)
184            }
185        }
186    }
187
188    /// Average ensemble
189    fn combine_average(
190        &self,
191        predictions: &[Array1<f32>],
192        output_dim: usize,
193    ) -> InferenceResult<Array1<f32>> {
194        let mut combined = Array1::zeros(output_dim);
195        let n = predictions.len() as f32;
196
197        for pred in predictions {
198            combined += pred;
199        }
200
201        combined /= n;
202
203        if self.config.normalize_outputs {
204            combined = self.normalize(&combined);
205        }
206
207        Ok(combined)
208    }
209
210    /// Weighted average ensemble
211    fn combine_weighted(
212        &self,
213        predictions: &[Array1<f32>],
214        output_dim: usize,
215    ) -> InferenceResult<Array1<f32>> {
216        let weights = self.config.weights.as_ref().ok_or_else(|| {
217            InferenceError::ForwardError("Weights not provided for weighted ensemble".to_string())
218        })?;
219
220        let mut combined = Array1::zeros(output_dim);
221
222        for (pred, &weight) in predictions.iter().zip(weights.iter()) {
223            combined += &(pred * weight);
224        }
225
226        if self.config.normalize_outputs {
227            combined = self.normalize(&combined);
228        }
229
230        Ok(combined)
231    }
232
233    /// Voting ensemble (for discrete outputs)
234    fn combine_voting(&mut self, predictions: &[Array1<f32>]) -> InferenceResult<Array1<f32>> {
235        // For each model, get the argmax
236        let votes: Vec<usize> = predictions
237            .iter()
238            .map(|pred| {
239                pred.iter()
240                    .enumerate()
241                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
242                    .map(|(idx, _)| idx)
243                    .unwrap_or(0)
244            })
245            .collect();
246
247        // Count votes
248        let output_dim = predictions[0].len();
249        let mut vote_counts = vec![0usize; output_dim];
250        for &vote in &votes {
251            if vote < output_dim {
252                vote_counts[vote] += 1;
253            }
254        }
255
256        // Convert counts to probabilities
257        let total_votes = votes.len() as f32;
258        let combined = Array1::from_vec(
259            vote_counts
260                .iter()
261                .map(|&count| count as f32 / total_votes)
262                .collect(),
263        );
264
265        Ok(combined)
266    }
267
268    /// Product of experts ensemble
269    fn combine_product_of_experts(
270        &self,
271        predictions: &[Array1<f32>],
272        output_dim: usize,
273    ) -> InferenceResult<Array1<f32>> {
274        let mut combined = Array1::ones(output_dim);
275
276        // Multiply all predictions (after softmax)
277        for pred in predictions {
278            let normalized = self.softmax(pred);
279            combined *= &normalized;
280        }
281
282        // Normalize the product
283        let sum: f32 = combined.sum();
284        if sum > 0.0 {
285            combined /= sum;
286        }
287
288        Ok(combined)
289    }
290
291    /// Normalize output to probabilities
292    fn normalize(&self, output: &Array1<f32>) -> Array1<f32> {
293        self.softmax(output)
294    }
295
296    /// Apply softmax
297    fn softmax(&self, x: &Array1<f32>) -> Array1<f32> {
298        let max_x = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
299        let exp_x = x.mapv(|v| (v - max_x).exp());
300        let sum_exp: f32 = exp_x.sum();
301
302        if sum_exp > 0.0 {
303            exp_x / sum_exp
304        } else {
305            Array1::from_elem(x.len(), 1.0 / x.len() as f32)
306        }
307    }
308
309    /// Get ensemble configuration
310    pub fn config(&self) -> &EnsembleConfig {
311        &self.config
312    }
313
314    /// Get mutable access to sampler
315    pub fn sampler_mut(&mut self) -> &mut Sampler {
316        &mut self.sampler
317    }
318}
319
320/// Builder for creating model ensembles
321pub struct EnsembleBuilder {
322    models: Vec<Box<dyn AutoregressiveModel>>,
323    config: EnsembleConfig,
324}
325
326impl EnsembleBuilder {
327    /// Create a new ensemble builder
328    pub fn new() -> Self {
329        Self {
330            models: Vec::new(),
331            config: EnsembleConfig::default(),
332        }
333    }
334
335    /// Add a model to the ensemble
336    pub fn add_model(mut self, model: Box<dyn AutoregressiveModel>) -> Self {
337        self.models.push(model);
338        self
339    }
340
341    /// Add multiple models
342    pub fn add_models(mut self, models: Vec<Box<dyn AutoregressiveModel>>) -> Self {
343        self.models.extend(models);
344        self
345    }
346
347    /// Set ensemble strategy
348    pub fn strategy(mut self, strategy: EnsembleStrategy) -> Self {
349        self.config.strategy = strategy;
350        self
351    }
352
353    /// Set model weights
354    pub fn weights(mut self, weights: Vec<f32>) -> Self {
355        self.config.weights = Some(weights);
356        self
357    }
358
359    /// Set temperature
360    pub fn temperature(mut self, temp: f32) -> Self {
361        self.config.temperature = temp;
362        self
363    }
364
365    /// Build the ensemble
366    pub fn build(self) -> InferenceResult<ModelEnsemble> {
367        ModelEnsemble::new(self.models, self.config)
368    }
369}
370
371impl Default for EnsembleBuilder {
372    fn default() -> Self {
373        Self::new()
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use kizzasi_model::s4::{S4Config, S4D};
381
382    #[test]
383    fn test_ensemble_creation() {
384        let model1 = create_test_model();
385        let model2 = create_test_model();
386
387        let ensemble = EnsembleBuilder::new()
388            .add_model(Box::new(model1))
389            .add_model(Box::new(model2))
390            .build();
391
392        assert!(ensemble.is_ok());
393        let ensemble = ensemble.unwrap();
394        assert_eq!(ensemble.num_models(), 2);
395    }
396
397    #[test]
398    fn test_ensemble_average() {
399        let model1 = create_test_model();
400        let model2 = create_test_model();
401
402        let mut ensemble = EnsembleBuilder::new()
403            .add_model(Box::new(model1))
404            .add_model(Box::new(model2))
405            .strategy(EnsembleStrategy::Average)
406            .build()
407            .unwrap();
408
409        let input = Array1::from_vec(vec![0.5]);
410        let output = ensemble.step(&input);
411
412        assert!(output.is_ok());
413    }
414
415    #[test]
416    fn test_ensemble_weighted() {
417        let model1 = create_test_model();
418        let model2 = create_test_model();
419
420        let mut ensemble = EnsembleBuilder::new()
421            .add_model(Box::new(model1))
422            .add_model(Box::new(model2))
423            .strategy(EnsembleStrategy::Weighted)
424            .weights(vec![0.7, 0.3])
425            .build()
426            .unwrap();
427
428        let input = Array1::from_vec(vec![0.5]);
429        let output = ensemble.step(&input);
430
431        assert!(output.is_ok());
432    }
433
434    #[test]
435    fn test_ensemble_voting() {
436        let model1 = create_test_model();
437        let model2 = create_test_model();
438        let model3 = create_test_model();
439
440        let mut ensemble = EnsembleBuilder::new()
441            .add_model(Box::new(model1))
442            .add_model(Box::new(model2))
443            .add_model(Box::new(model3))
444            .strategy(EnsembleStrategy::Voting)
445            .build()
446            .unwrap();
447
448        let input = Array1::from_vec(vec![0.5]);
449        let output = ensemble.step(&input);
450
451        assert!(output.is_ok());
452    }
453
454    #[test]
455    fn test_invalid_weights() {
456        let model1 = create_test_model();
457        let model2 = create_test_model();
458
459        let result = EnsembleBuilder::new()
460            .add_model(Box::new(model1))
461            .add_model(Box::new(model2))
462            .strategy(EnsembleStrategy::Weighted)
463            .weights(vec![0.5, 0.6]) // Sum > 1.0
464            .build();
465
466        assert!(result.is_err());
467    }
468
469    fn create_test_model() -> S4D {
470        let config = S4Config::new()
471            .input_dim(1)
472            .hidden_dim(64)
473            .state_dim(16)
474            .num_layers(2)
475            .diagonal(true);
476
477        S4D::new(config).unwrap()
478    }
479}