entrenar/merge/ensemble/
weighted.rs1use super::{MergeError, Model};
4use std::collections::HashMap;
5
6pub fn weighted_average_merge(models: &[Model], weights: &[f32]) -> Result<Model, MergeError> {
8 let weights = normalize_weights(weights, models.len())?;
9 let reference = &models[0];
10 let mut merged = HashMap::new();
11
12 for name in reference.keys() {
13 let tensor = weighted_sum_param(name, models, &weights)?;
14 merged.insert(name.clone(), tensor);
15 }
16
17 Ok(merged)
18}
19
20fn normalize_weights(weights: &[f32], n: usize) -> Result<Vec<f32>, MergeError> {
22 if weights.is_empty() {
23 return Ok(vec![1.0 / n as f32; n]);
24 }
25 if weights.len() != n {
26 return Err(MergeError::InvalidConfig(format!(
27 "Weights length {} doesn't match models length {n}",
28 weights.len(),
29 )));
30 }
31 let sum: f32 = weights.iter().sum();
32 if sum <= 0.0 {
33 return Err(MergeError::InvalidConfig("Weights must sum to positive value".to_string()));
34 }
35 Ok(weights.iter().map(|w| w / sum).collect())
36}
37
38fn weighted_sum_param(
40 name: &str,
41 models: &[Model],
42 weights: &[f32],
43) -> Result<crate::autograd::Tensor, MergeError> {
44 let param_len = models[0][name].len();
45 let mut weighted_sum = ndarray::Array1::<f32>::zeros(param_len);
46
47 for (model, weight) in models.iter().zip(weights.iter()) {
48 let param = model
49 .get(name)
50 .ok_or_else(|| MergeError::IncompatibleArchitectures(format!("Missing {name}")))?;
51 if param.len() != param_len {
52 return Err(MergeError::ShapeMismatch(name.to_string()));
53 }
54 weighted_sum = weighted_sum + param.data() * *weight;
55 }
56
57 Ok(crate::autograd::Tensor::new(weighted_sum, false))
58}