entrenar/merge/ensemble/
merge.rs1use super::config::EnsembleConfig;
4use super::hierarchical::hierarchical_merge;
5use super::slerp::iterative_slerp_merge;
6use super::strategy::EnsembleStrategy;
7use super::weighted::weighted_average_merge;
8use super::{dare_merge, ties_merge, DareConfig, MergeError, Model, TiesConfig};
9
10pub fn ensemble_merge(models: &[Model], config: &EnsembleConfig) -> Result<Model, MergeError> {
19 if models.len() < 2 {
20 return Err(MergeError::InsufficientModels { min: 2, got: models.len() });
21 }
22
23 match &config.strategy {
24 EnsembleStrategy::WeightedAverage { weights } => weighted_average_merge(models, weights),
25 EnsembleStrategy::Ties { density } => {
26 let base = config
27 .base
28 .as_ref()
29 .ok_or_else(|| MergeError::InvalidConfig("TIES requires base model".to_string()))?;
30 let ties_config = TiesConfig::new(*density)?;
31 ties_merge(models, base, &ties_config)
32 }
33 EnsembleStrategy::Dare { drop_prob, seed } => {
34 let base = config
35 .base
36 .as_ref()
37 .ok_or_else(|| MergeError::InvalidConfig("DARE requires base model".to_string()))?;
38 let mut dare_config = DareConfig::new(*drop_prob)?;
39 if let Some(s) = seed {
40 dare_config = dare_config.with_seed(*s);
41 }
42 dare_merge(models, base, &dare_config)
43 }
44 EnsembleStrategy::IterativeSlerp { t } => iterative_slerp_merge(models, *t),
45 EnsembleStrategy::Hierarchical { leaf_strategy } => {
46 hierarchical_merge(models, leaf_strategy, config.base.as_ref())
47 }
48 }
49}