entrenar/merge/ensemble/
hierarchical.rs1use super::strategy::EnsembleStrategy;
4use super::weighted::weighted_average_merge;
5use super::{
6 dare_merge, slerp_merge as slerp_merge_impl, ties_merge, DareConfig, MergeError, Model,
7 SlerpConfig, TiesConfig,
8};
9
10pub fn hierarchical_merge(
15 models: &[Model],
16 leaf_strategy: &EnsembleStrategy,
17 base: Option<&Model>,
18) -> Result<Model, MergeError> {
19 if models.len() == 1 {
20 return Ok(models[0].clone());
21 }
22
23 if models.len() == 2 {
24 return merge_pair(&models[0], &models[1], leaf_strategy, base);
25 }
26
27 let mid = models.len() / 2;
29 let left = hierarchical_merge(&models[..mid], leaf_strategy, base)?;
30 let right = hierarchical_merge(&models[mid..], leaf_strategy, base)?;
31
32 merge_pair(&left, &right, leaf_strategy, base)
33}
34
35pub fn merge_pair(
37 m1: &Model,
38 m2: &Model,
39 strategy: &EnsembleStrategy,
40 base: Option<&Model>,
41) -> Result<Model, MergeError> {
42 match strategy {
43 EnsembleStrategy::WeightedAverage { weights } => {
44 let w = if weights.len() == 2 { weights.clone() } else { vec![0.5, 0.5] };
45 weighted_average_merge(&[m1.clone(), m2.clone()], &w)
46 }
47 EnsembleStrategy::IterativeSlerp { t } => {
48 let config = SlerpConfig::new(*t)?;
49 slerp_merge_impl(m1, m2, &config)
50 }
51 EnsembleStrategy::Ties { density } => {
52 let base =
53 base.ok_or_else(|| MergeError::InvalidConfig("TIES requires base".to_string()))?;
54 let config = TiesConfig::new(*density)?;
55 ties_merge(&[m1.clone(), m2.clone()], base, &config)
56 }
57 EnsembleStrategy::Dare { drop_prob, seed } => {
58 let base =
59 base.ok_or_else(|| MergeError::InvalidConfig("DARE requires base".to_string()))?;
60 let mut config = DareConfig::new(*drop_prob)?;
61 if let Some(s) = seed {
62 config = config.with_seed(*s);
63 }
64 dare_merge(&[m1.clone(), m2.clone()], base, &config)
65 }
66 EnsembleStrategy::Hierarchical { .. } => {
67 weighted_average_merge(&[m1.clone(), m2.clone()], &[0.5, 0.5])
69 }
70 }
71}