kizzasi_inference/
ensemble.rs1use crate::error::{InferenceError, InferenceResult};
14use crate::sampling::{Sampler, SamplingConfig};
15use kizzasi_model::AutoregressiveModel;
16use scirs2_core::ndarray::Array1;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum EnsembleStrategy {
21 Average,
23 Weighted,
25 Voting,
27 ProductOfExperts,
29}
30
31#[derive(Debug, Clone)]
33pub struct EnsembleConfig {
34 pub strategy: EnsembleStrategy,
36 pub weights: Option<Vec<f32>>,
38 pub normalize_outputs: bool,
40 pub temperature: f32,
42}
43
44impl Default for EnsembleConfig {
45 fn default() -> Self {
46 Self {
47 strategy: EnsembleStrategy::Average,
48 weights: None,
49 normalize_outputs: true,
50 temperature: 1.0,
51 }
52 }
53}
54
55impl EnsembleConfig {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn strategy(mut self, strategy: EnsembleStrategy) -> Self {
63 self.strategy = strategy;
64 self
65 }
66
67 pub fn weights(mut self, weights: Vec<f32>) -> Self {
69 self.weights = Some(weights);
70 self
71 }
72
73 pub fn normalize_outputs(mut self, normalize: bool) -> Self {
75 self.normalize_outputs = normalize;
76 self
77 }
78
79 pub fn temperature(mut self, temp: f32) -> Self {
81 self.temperature = temp;
82 self
83 }
84}
85
86pub struct ModelEnsemble {
88 models: Vec<Box<dyn AutoregressiveModel>>,
90 config: EnsembleConfig,
92 sampler: Sampler,
94}
95
96impl ModelEnsemble {
97 pub fn new(
99 models: Vec<Box<dyn AutoregressiveModel>>,
100 config: EnsembleConfig,
101 ) -> InferenceResult<Self> {
102 if models.is_empty() {
103 return Err(InferenceError::ForwardError(
104 "Ensemble must contain at least one model".to_string(),
105 ));
106 }
107
108 if let Some(ref weights) = config.weights {
110 if weights.len() != models.len() {
111 return Err(InferenceError::DimensionMismatch {
112 expected: models.len(),
113 got: weights.len(),
114 });
115 }
116
117 let sum: f32 = weights.iter().sum();
119 if (sum - 1.0).abs() > 1e-6 {
120 return Err(InferenceError::ForwardError(format!(
121 "Ensemble weights must sum to 1.0, got {}",
122 sum
123 )));
124 }
125 }
126
127 let sampler_config = SamplingConfig::new().temperature(config.temperature);
128 let sampler = Sampler::new(sampler_config);
129
130 Ok(Self {
131 models,
132 config,
133 sampler,
134 })
135 }
136
137 pub fn num_models(&self) -> usize {
139 self.models.len()
140 }
141
142 pub fn step(&mut self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
144 let mut predictions = Vec::with_capacity(self.models.len());
146
147 for model in &mut self.models {
148 let pred = model
149 .step(input)
150 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
151 predictions.push(pred);
152 }
153
154 self.combine_predictions(&predictions)
156 }
157
158 fn combine_predictions(&mut self, predictions: &[Array1<f32>]) -> InferenceResult<Array1<f32>> {
160 if predictions.is_empty() {
161 return Err(InferenceError::ForwardError(
162 "No predictions to combine".to_string(),
163 ));
164 }
165
166 let output_dim = predictions[0].len();
167
168 for pred in predictions {
170 if pred.len() != output_dim {
171 return Err(InferenceError::DimensionMismatch {
172 expected: output_dim,
173 got: pred.len(),
174 });
175 }
176 }
177
178 match self.config.strategy {
179 EnsembleStrategy::Average => self.combine_average(predictions, output_dim),
180 EnsembleStrategy::Weighted => self.combine_weighted(predictions, output_dim),
181 EnsembleStrategy::Voting => self.combine_voting(predictions),
182 EnsembleStrategy::ProductOfExperts => {
183 self.combine_product_of_experts(predictions, output_dim)
184 }
185 }
186 }
187
188 fn combine_average(
190 &self,
191 predictions: &[Array1<f32>],
192 output_dim: usize,
193 ) -> InferenceResult<Array1<f32>> {
194 let mut combined = Array1::zeros(output_dim);
195 let n = predictions.len() as f32;
196
197 for pred in predictions {
198 combined += pred;
199 }
200
201 combined /= n;
202
203 if self.config.normalize_outputs {
204 combined = self.normalize(&combined);
205 }
206
207 Ok(combined)
208 }
209
210 fn combine_weighted(
212 &self,
213 predictions: &[Array1<f32>],
214 output_dim: usize,
215 ) -> InferenceResult<Array1<f32>> {
216 let weights = self.config.weights.as_ref().ok_or_else(|| {
217 InferenceError::ForwardError("Weights not provided for weighted ensemble".to_string())
218 })?;
219
220 let mut combined = Array1::zeros(output_dim);
221
222 for (pred, &weight) in predictions.iter().zip(weights.iter()) {
223 combined += &(pred * weight);
224 }
225
226 if self.config.normalize_outputs {
227 combined = self.normalize(&combined);
228 }
229
230 Ok(combined)
231 }
232
233 fn combine_voting(&mut self, predictions: &[Array1<f32>]) -> InferenceResult<Array1<f32>> {
235 let votes: Vec<usize> = predictions
237 .iter()
238 .map(|pred| {
239 pred.iter()
240 .enumerate()
241 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
242 .map(|(idx, _)| idx)
243 .unwrap_or(0)
244 })
245 .collect();
246
247 let output_dim = predictions[0].len();
249 let mut vote_counts = vec![0usize; output_dim];
250 for &vote in &votes {
251 if vote < output_dim {
252 vote_counts[vote] += 1;
253 }
254 }
255
256 let total_votes = votes.len() as f32;
258 let combined = Array1::from_vec(
259 vote_counts
260 .iter()
261 .map(|&count| count as f32 / total_votes)
262 .collect(),
263 );
264
265 Ok(combined)
266 }
267
268 fn combine_product_of_experts(
270 &self,
271 predictions: &[Array1<f32>],
272 output_dim: usize,
273 ) -> InferenceResult<Array1<f32>> {
274 let mut combined = Array1::ones(output_dim);
275
276 for pred in predictions {
278 let normalized = self.softmax(pred);
279 combined *= &normalized;
280 }
281
282 let sum: f32 = combined.sum();
284 if sum > 0.0 {
285 combined /= sum;
286 }
287
288 Ok(combined)
289 }
290
291 fn normalize(&self, output: &Array1<f32>) -> Array1<f32> {
293 self.softmax(output)
294 }
295
296 fn softmax(&self, x: &Array1<f32>) -> Array1<f32> {
298 let max_x = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
299 let exp_x = x.mapv(|v| (v - max_x).exp());
300 let sum_exp: f32 = exp_x.sum();
301
302 if sum_exp > 0.0 {
303 exp_x / sum_exp
304 } else {
305 Array1::from_elem(x.len(), 1.0 / x.len() as f32)
306 }
307 }
308
309 pub fn config(&self) -> &EnsembleConfig {
311 &self.config
312 }
313
314 pub fn sampler_mut(&mut self) -> &mut Sampler {
316 &mut self.sampler
317 }
318}
319
320pub struct EnsembleBuilder {
322 models: Vec<Box<dyn AutoregressiveModel>>,
323 config: EnsembleConfig,
324}
325
326impl EnsembleBuilder {
327 pub fn new() -> Self {
329 Self {
330 models: Vec::new(),
331 config: EnsembleConfig::default(),
332 }
333 }
334
335 pub fn add_model(mut self, model: Box<dyn AutoregressiveModel>) -> Self {
337 self.models.push(model);
338 self
339 }
340
341 pub fn add_models(mut self, models: Vec<Box<dyn AutoregressiveModel>>) -> Self {
343 self.models.extend(models);
344 self
345 }
346
347 pub fn strategy(mut self, strategy: EnsembleStrategy) -> Self {
349 self.config.strategy = strategy;
350 self
351 }
352
353 pub fn weights(mut self, weights: Vec<f32>) -> Self {
355 self.config.weights = Some(weights);
356 self
357 }
358
359 pub fn temperature(mut self, temp: f32) -> Self {
361 self.config.temperature = temp;
362 self
363 }
364
365 pub fn build(self) -> InferenceResult<ModelEnsemble> {
367 ModelEnsemble::new(self.models, self.config)
368 }
369}
370
371impl Default for EnsembleBuilder {
372 fn default() -> Self {
373 Self::new()
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use kizzasi_model::s4::{S4Config, S4D};
381
382 #[test]
383 fn test_ensemble_creation() {
384 let model1 = create_test_model();
385 let model2 = create_test_model();
386
387 let ensemble = EnsembleBuilder::new()
388 .add_model(Box::new(model1))
389 .add_model(Box::new(model2))
390 .build();
391
392 assert!(ensemble.is_ok());
393 let ensemble = ensemble.unwrap();
394 assert_eq!(ensemble.num_models(), 2);
395 }
396
397 #[test]
398 fn test_ensemble_average() {
399 let model1 = create_test_model();
400 let model2 = create_test_model();
401
402 let mut ensemble = EnsembleBuilder::new()
403 .add_model(Box::new(model1))
404 .add_model(Box::new(model2))
405 .strategy(EnsembleStrategy::Average)
406 .build()
407 .unwrap();
408
409 let input = Array1::from_vec(vec![0.5]);
410 let output = ensemble.step(&input);
411
412 assert!(output.is_ok());
413 }
414
415 #[test]
416 fn test_ensemble_weighted() {
417 let model1 = create_test_model();
418 let model2 = create_test_model();
419
420 let mut ensemble = EnsembleBuilder::new()
421 .add_model(Box::new(model1))
422 .add_model(Box::new(model2))
423 .strategy(EnsembleStrategy::Weighted)
424 .weights(vec![0.7, 0.3])
425 .build()
426 .unwrap();
427
428 let input = Array1::from_vec(vec![0.5]);
429 let output = ensemble.step(&input);
430
431 assert!(output.is_ok());
432 }
433
434 #[test]
435 fn test_ensemble_voting() {
436 let model1 = create_test_model();
437 let model2 = create_test_model();
438 let model3 = create_test_model();
439
440 let mut ensemble = EnsembleBuilder::new()
441 .add_model(Box::new(model1))
442 .add_model(Box::new(model2))
443 .add_model(Box::new(model3))
444 .strategy(EnsembleStrategy::Voting)
445 .build()
446 .unwrap();
447
448 let input = Array1::from_vec(vec![0.5]);
449 let output = ensemble.step(&input);
450
451 assert!(output.is_ok());
452 }
453
454 #[test]
455 fn test_invalid_weights() {
456 let model1 = create_test_model();
457 let model2 = create_test_model();
458
459 let result = EnsembleBuilder::new()
460 .add_model(Box::new(model1))
461 .add_model(Box::new(model2))
462 .strategy(EnsembleStrategy::Weighted)
463 .weights(vec![0.5, 0.6]) .build();
465
466 assert!(result.is_err());
467 }
468
469 fn create_test_model() -> S4D {
470 let config = S4Config::new()
471 .input_dim(1)
472 .hidden_dim(64)
473 .state_dim(16)
474 .num_layers(2)
475 .diagonal(true);
476
477 S4D::new(config).unwrap()
478 }
479}