1use scirs2_core::ndarray::{Array1, Array2};
4use sklears_core::{error::Result, types::Float};
5
6pub trait EnsembleMember {
8 fn weight(&self) -> Float;
10
11 fn set_weight(&mut self, weight: Float);
13
14 fn performance(&self) -> Float;
16
17 fn update_performance(&mut self, performance: Float);
19
20 fn confidence(&self) -> Float;
22
23 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>>;
25
26 fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
28
29 fn supports_proba(&self) -> bool;
31
32 fn feature_importance(&self) -> Option<Array1<Float>>;
34
35 fn complexity(&self) -> Float;
37
38 fn is_fitted(&self) -> bool;
40
41 fn n_classes(&self) -> Option<usize>;
43
44 fn n_features(&self) -> Option<usize>;
46
47 fn uncertainty(&self, x: &Array2<Float>) -> Result<Array1<Float>>;
49
50 fn name(&self) -> String;
52
53 fn clone_estimator(&self) -> Box<dyn EnsembleMember + Send + Sync>;
55}
56
57#[derive(Debug, Clone)]
59pub struct MockEstimator {
60 weight: Float,
61 performance: Float,
62 confidence: Float,
63 bias: Float,
64 supports_proba: bool,
65 is_fitted: bool,
66 n_classes: Option<usize>,
67 n_features: Option<usize>,
68 name: String,
69}
70
71impl MockEstimator {
72 pub fn new(bias: Float) -> Self {
73 Self {
74 weight: 1.0,
75 performance: 0.8,
76 confidence: 0.9,
77 bias,
78 supports_proba: true,
79 is_fitted: true,
80 n_classes: Some(2),
81 n_features: Some(2),
82 name: format!("MockEstimator_{}", bias),
83 }
84 }
85
86 pub fn with_weight(mut self, weight: Float) -> Self {
87 self.weight = weight;
88 self
89 }
90
91 pub fn with_performance(mut self, performance: Float) -> Self {
92 self.performance = performance;
93 self
94 }
95
96 pub fn with_confidence(mut self, confidence: Float) -> Self {
97 self.confidence = confidence;
98 self
99 }
100
101 pub fn with_proba_support(mut self, supports: bool) -> Self {
102 self.supports_proba = supports;
103 self
104 }
105
106 pub fn with_fitted_status(mut self, fitted: bool) -> Self {
107 self.is_fitted = fitted;
108 self
109 }
110
111 pub fn with_classes(mut self, n_classes: usize) -> Self {
112 self.n_classes = Some(n_classes);
113 self
114 }
115
116 pub fn with_features(mut self, n_features: usize) -> Self {
117 self.n_features = Some(n_features);
118 self
119 }
120
121 pub fn with_name(mut self, name: String) -> Self {
122 self.name = name;
123 self
124 }
125}
126
127impl EnsembleMember for MockEstimator {
128 fn weight(&self) -> Float {
129 self.weight
130 }
131
132 fn set_weight(&mut self, weight: Float) {
133 self.weight = weight;
134 }
135
136 fn performance(&self) -> Float {
137 self.performance
138 }
139
140 fn update_performance(&mut self, performance: Float) {
141 self.performance = performance;
142 }
143
144 fn confidence(&self) -> Float {
145 self.confidence
146 }
147
148 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
149 if !self.is_fitted {
150 return Err(sklears_core::error::SklearsError::NotFitted {
151 operation: "predict".to_string(),
152 });
153 }
154
155 let n_samples = x.nrows();
156 let mut predictions = Array1::zeros(n_samples);
157
158 for i in 0..n_samples {
160 let feature_sum: Float = x.row(i).sum();
161 let prediction = if feature_sum + self.bias > 0.0 {
162 1.0
163 } else {
164 0.0
165 };
166 predictions[i] = prediction;
167 }
168
169 Ok(predictions)
170 }
171
172 fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
173 if !self.supports_proba {
174 return Err(sklears_core::error::SklearsError::InvalidOperation(
175 "Estimator does not support probability predictions".to_string(),
176 ));
177 }
178
179 if !self.is_fitted {
180 return Err(sklears_core::error::SklearsError::NotFitted {
181 operation: "predict".to_string(),
182 });
183 }
184
185 let n_samples = x.nrows();
186 let n_classes = self.n_classes.unwrap_or(2);
187 let mut probabilities = Array2::zeros((n_samples, n_classes));
188
189 for i in 0..n_samples {
191 let feature_sum: Float = x.row(i).sum();
192 let logit = feature_sum + self.bias;
193
194 if n_classes == 2 {
195 let prob_class_1 = 1.0 / (1.0 + (-logit).exp());
197 probabilities[[i, 0]] = 1.0 - prob_class_1;
198 probabilities[[i, 1]] = prob_class_1;
199 } else {
200 let prob_per_class = 1.0 / n_classes as Float;
202 for j in 0..n_classes {
203 probabilities[[i, j]] = prob_per_class;
204 }
205 probabilities[[i, 0]] += self.bias * 0.1;
207
208 let row_sum: Float = probabilities.row(i).sum();
210 if row_sum > 0.0 {
211 for j in 0..n_classes {
212 probabilities[[i, j]] /= row_sum;
213 }
214 }
215 }
216 }
217
218 Ok(probabilities)
219 }
220
221 fn supports_proba(&self) -> bool {
222 self.supports_proba
223 }
224
225 fn feature_importance(&self) -> Option<Array1<Float>> {
226 if let Some(n_features) = self.n_features {
227 let mut importance = Array1::ones(n_features) / n_features as Float;
229 if n_features > 0 {
230 importance[0] += self.bias.abs() * 0.1; }
232
233 let total: Float = importance.sum();
235 if total > 0.0 {
236 importance.mapv_inplace(|x| x / total);
237 }
238
239 Some(importance)
240 } else {
241 None
242 }
243 }
244
245 fn complexity(&self) -> Float {
246 self.bias.abs() + 1.0
248 }
249
250 fn is_fitted(&self) -> bool {
251 self.is_fitted
252 }
253
254 fn n_classes(&self) -> Option<usize> {
255 self.n_classes
256 }
257
258 fn n_features(&self) -> Option<usize> {
259 self.n_features
260 }
261
262 fn uncertainty(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
263 if !self.is_fitted {
264 return Err(sklears_core::error::SklearsError::NotFitted {
265 operation: "predict".to_string(),
266 });
267 }
268
269 let n_samples = x.nrows();
270 let mut uncertainty = Array1::zeros(n_samples);
271
272 for i in 0..n_samples {
274 let feature_sum: Float = x.row(i).sum();
275 let logit = feature_sum + self.bias;
276
277 let prob = 1.0 / (1.0 + (-logit).exp());
279 let entropy = -prob * prob.ln() - (1.0 - prob) * (1.0 - prob).ln();
280 uncertainty[i] = entropy;
281 }
282
283 Ok(uncertainty)
284 }
285
286 fn name(&self) -> String {
287 self.name.clone()
288 }
289
290 fn clone_estimator(&self) -> Box<dyn EnsembleMember + Send + Sync> {
291 Box::new(self.clone())
292 }
293}
294
295#[derive(Debug)]
297pub struct ExternalEstimatorWrapper {
298 weight: Float,
299 performance: Float,
300 confidence: Float,
301 name: String,
302}
303
304impl ExternalEstimatorWrapper {
305 pub fn new(name: String) -> Self {
306 Self {
307 weight: 1.0,
308 performance: 0.0,
309 confidence: 0.5,
310 name,
311 }
312 }
313}
314
315impl EnsembleMember for ExternalEstimatorWrapper {
316 fn weight(&self) -> Float {
317 self.weight
318 }
319
320 fn set_weight(&mut self, weight: Float) {
321 self.weight = weight;
322 }
323
324 fn performance(&self) -> Float {
325 self.performance
326 }
327
328 fn update_performance(&mut self, performance: Float) {
329 self.performance = performance;
330 }
331
332 fn confidence(&self) -> Float {
333 self.confidence
334 }
335
336 fn predict(&self, _x: &Array2<Float>) -> Result<Array1<Float>> {
337 Err(sklears_core::error::SklearsError::NotImplemented(
338 "External estimator prediction not implemented".to_string(),
339 ))
340 }
341
342 fn predict_proba(&self, _x: &Array2<Float>) -> Result<Array2<Float>> {
343 Err(sklears_core::error::SklearsError::NotImplemented(
344 "External estimator probability prediction not implemented".to_string(),
345 ))
346 }
347
348 fn supports_proba(&self) -> bool {
349 false
350 }
351
352 fn feature_importance(&self) -> Option<Array1<Float>> {
353 None
354 }
355
356 fn complexity(&self) -> Float {
357 1.0
358 }
359
360 fn is_fitted(&self) -> bool {
361 true
362 }
363
364 fn n_classes(&self) -> Option<usize> {
365 None
366 }
367
368 fn n_features(&self) -> Option<usize> {
369 None
370 }
371
372 fn uncertainty(&self, _x: &Array2<Float>) -> Result<Array1<Float>> {
373 Err(sklears_core::error::SklearsError::NotImplemented(
374 "External estimator uncertainty estimation not implemented".to_string(),
375 ))
376 }
377
378 fn name(&self) -> String {
379 self.name.clone()
380 }
381
382 fn clone_estimator(&self) -> Box<dyn EnsembleMember + Send + Sync> {
383 Box::new(Self {
384 weight: self.weight,
385 performance: self.performance,
386 confidence: self.confidence,
387 name: self.name.clone(),
388 })
389 }
390}
391
392pub mod ensemble_utils {
394 use super::*;
395
396 pub fn calculate_ensemble_diversity(
398 estimators: &[Box<dyn EnsembleMember + Send + Sync>],
399 x: &Array2<Float>,
400 ) -> Result<Float> {
401 if estimators.len() < 2 {
402 return Ok(0.0);
403 }
404
405 let n_samples = x.nrows();
406 let n_estimators = estimators.len();
407
408 let mut all_predictions = Vec::new();
410 for estimator in estimators {
411 let predictions = estimator.predict(x)?;
412 all_predictions.push(predictions);
413 }
414
415 let mut total_disagreement = 0.0;
417 let mut n_pairs = 0;
418
419 for i in 0..n_estimators {
420 for j in (i + 1)..n_estimators {
421 let mut disagreements = 0;
422 for sample_idx in 0..n_samples {
423 if (all_predictions[i][sample_idx] - all_predictions[j][sample_idx]).abs()
424 > 1e-6
425 {
426 disagreements += 1;
427 }
428 }
429 total_disagreement += disagreements as Float / n_samples as Float;
430 n_pairs += 1;
431 }
432 }
433
434 if n_pairs > 0 {
435 Ok(total_disagreement / n_pairs as Float)
436 } else {
437 Ok(0.0)
438 }
439 }
440
441 pub fn update_ensemble_weights(
443 estimators: &mut [Box<dyn EnsembleMember + Send + Sync>],
444 recent_performances: &[Float],
445 learning_rate: Float,
446 ) {
447 if estimators.len() != recent_performances.len() {
448 return;
449 }
450
451 let total_performance: Float = recent_performances.iter().sum();
453
454 if total_performance > 1e-8 {
455 for (estimator, &performance) in estimators.iter_mut().zip(recent_performances.iter()) {
456 let current_weight = estimator.weight();
457 let target_weight = performance / total_performance;
458 let new_weight = current_weight + learning_rate * (target_weight - current_weight);
459 estimator.set_weight(new_weight.max(0.01)); }
461 }
462 }
463
464 pub fn prune_ensemble(
466 estimators: &mut Vec<Box<dyn EnsembleMember + Send + Sync>>,
467 performance_threshold: Float,
468 min_ensemble_size: usize,
469 ) {
470 if estimators.len() <= min_ensemble_size {
471 return;
472 }
473
474 estimators.retain(|estimator| estimator.performance() >= performance_threshold);
475
476 if estimators.len() < min_ensemble_size {
478 }
481 }
482
483 pub fn get_ensemble_stats(
485 estimators: &[Box<dyn EnsembleMember + Send + Sync>],
486 ) -> EnsembleStats {
487 if estimators.is_empty() {
488 return EnsembleStats::default();
489 }
490
491 let weights: Vec<Float> = estimators.iter().map(|e| e.weight()).collect();
492 let performances: Vec<Float> = estimators.iter().map(|e| e.performance()).collect();
493 let confidences: Vec<Float> = estimators.iter().map(|e| e.confidence()).collect();
494
495 let mean_weight = weights.iter().sum::<Float>() / weights.len() as Float;
496 let mean_performance = performances.iter().sum::<Float>() / performances.len() as Float;
497 let mean_confidence = confidences.iter().sum::<Float>() / confidences.len() as Float;
498
499 let weight_variance = weights
500 .iter()
501 .map(|&w| (w - mean_weight).powi(2))
502 .sum::<Float>()
503 / weights.len() as Float;
504
505 EnsembleStats {
506 n_estimators: estimators.len(),
507 mean_weight,
508 mean_performance,
509 mean_confidence,
510 weight_variance,
511 total_complexity: estimators.iter().map(|e| e.complexity()).sum(),
512 }
513 }
514}
515
516#[derive(Debug, Clone)]
518pub struct EnsembleStats {
519 pub n_estimators: usize,
520 pub mean_weight: Float,
521 pub mean_performance: Float,
522 pub mean_confidence: Float,
523 pub weight_variance: Float,
524 pub total_complexity: Float,
525}
526
527impl Default for EnsembleStats {
528 fn default() -> Self {
529 Self {
530 n_estimators: 0,
531 mean_weight: 0.0,
532 mean_performance: 0.0,
533 mean_confidence: 0.0,
534 weight_variance: 0.0,
535 total_complexity: 0.0,
536 }
537 }
538}