use super::strategy::EnsembleStrategy;
use super::Model;
#[derive(Clone, Debug, Default)]
pub struct EnsembleConfig {
pub base: Option<Model>,
pub strategy: EnsembleStrategy,
}
impl EnsembleConfig {
pub fn weighted_average(weights: Vec<f32>) -> Self {
Self { base: None, strategy: EnsembleStrategy::WeightedAverage { weights } }
}
pub fn uniform_average() -> Self {
Self { base: None, strategy: EnsembleStrategy::WeightedAverage { weights: Vec::new() } }
}
pub fn ties(base: Model, density: f32) -> Self {
Self { base: Some(base), strategy: EnsembleStrategy::Ties { density } }
}
pub fn dare(base: Model, drop_prob: f32, seed: Option<u64>) -> Self {
Self { base: Some(base), strategy: EnsembleStrategy::Dare { drop_prob, seed } }
}
pub fn iterative_slerp(t: f32) -> Self {
Self { base: None, strategy: EnsembleStrategy::IterativeSlerp { t } }
}
pub fn hierarchical(leaf_strategy: EnsembleStrategy) -> Self {
Self {
base: None,
strategy: EnsembleStrategy::Hierarchical { leaf_strategy: Box::new(leaf_strategy) },
}
}
pub fn with_base(mut self, base: Model) -> Self {
self.base = Some(base);
self
}
}