Skip to main content

entrenar/merge/ensemble/
weighted.rs

1//! ENT-032: Weighted average merging
2
3use super::{MergeError, Model};
4use std::collections::HashMap;
5
6/// Weighted average of multiple models
7pub fn weighted_average_merge(models: &[Model], weights: &[f32]) -> Result<Model, MergeError> {
8    let weights = normalize_weights(weights, models.len())?;
9    let reference = &models[0];
10    let mut merged = HashMap::new();
11
12    for name in reference.keys() {
13        let tensor = weighted_sum_param(name, models, &weights)?;
14        merged.insert(name.clone(), tensor);
15    }
16
17    Ok(merged)
18}
19
20/// Normalize weights or create uniform weights.
21fn normalize_weights(weights: &[f32], n: usize) -> Result<Vec<f32>, MergeError> {
22    if weights.is_empty() {
23        return Ok(vec![1.0 / n as f32; n]);
24    }
25    if weights.len() != n {
26        return Err(MergeError::InvalidConfig(format!(
27            "Weights length {} doesn't match models length {n}",
28            weights.len(),
29        )));
30    }
31    let sum: f32 = weights.iter().sum();
32    if sum <= 0.0 {
33        return Err(MergeError::InvalidConfig("Weights must sum to positive value".to_string()));
34    }
35    Ok(weights.iter().map(|w| w / sum).collect())
36}
37
38/// Compute weighted sum for a single parameter across models.
39fn weighted_sum_param(
40    name: &str,
41    models: &[Model],
42    weights: &[f32],
43) -> Result<crate::autograd::Tensor, MergeError> {
44    let param_len = models[0][name].len();
45    let mut weighted_sum = ndarray::Array1::<f32>::zeros(param_len);
46
47    for (model, weight) in models.iter().zip(weights.iter()) {
48        let param = model
49            .get(name)
50            .ok_or_else(|| MergeError::IncompatibleArchitectures(format!("Missing {name}")))?;
51        if param.len() != param_len {
52            return Err(MergeError::ShapeMismatch(name.to_string()));
53        }
54        weighted_sum = weighted_sum + param.data() * *weight;
55    }
56
57    Ok(crate::autograd::Tensor::new(weighted_sum, false))
58}