Skip to main content

entrenar/merge/ensemble/
merge.rs

1//! ENT-032: Main ensemble merge function
2
3use super::config::EnsembleConfig;
4use super::hierarchical::hierarchical_merge;
5use super::slerp::iterative_slerp_merge;
6use super::strategy::EnsembleStrategy;
7use super::weighted::weighted_average_merge;
8use super::{dare_merge, ties_merge, DareConfig, MergeError, Model, TiesConfig};
9
10/// Merge multiple models using the specified strategy
11///
12/// # Arguments
13/// * `models` - Models to merge (must have at least 2)
14/// * `config` - Ensemble configuration
15///
16/// # Returns
17/// Merged model combining all inputs
18pub fn ensemble_merge(models: &[Model], config: &EnsembleConfig) -> Result<Model, MergeError> {
19    if models.len() < 2 {
20        return Err(MergeError::InsufficientModels { min: 2, got: models.len() });
21    }
22
23    match &config.strategy {
24        EnsembleStrategy::WeightedAverage { weights } => weighted_average_merge(models, weights),
25        EnsembleStrategy::Ties { density } => {
26            let base = config
27                .base
28                .as_ref()
29                .ok_or_else(|| MergeError::InvalidConfig("TIES requires base model".to_string()))?;
30            let ties_config = TiesConfig::new(*density)?;
31            ties_merge(models, base, &ties_config)
32        }
33        EnsembleStrategy::Dare { drop_prob, seed } => {
34            let base = config
35                .base
36                .as_ref()
37                .ok_or_else(|| MergeError::InvalidConfig("DARE requires base model".to_string()))?;
38            let mut dare_config = DareConfig::new(*drop_prob)?;
39            if let Some(s) = seed {
40                dare_config = dare_config.with_seed(*s);
41            }
42            dare_merge(models, base, &dare_config)
43        }
44        EnsembleStrategy::IterativeSlerp { t } => iterative_slerp_merge(models, *t),
45        EnsembleStrategy::Hierarchical { leaf_strategy } => {
46            hierarchical_merge(models, leaf_strategy, config.base.as_ref())
47        }
48    }
49}