Skip to main content

entrenar/merge/
mod.rs

1//! Model merging methods (TIES, DARE, SLERP)
2//!
3//! This module provides three model merging algorithms for combining
4//! multiple fine-tuned models:
5//!
6//! - **TIES**: Task Inference via Elimination and Sign voting
7//! - **DARE**: Drop And REscale for stochastic merging
8//! - **SLERP**: Spherical Linear intERPolation for smooth blending
9
10mod dare;
11mod ensemble;
12mod slerp;
13mod ties;
14
15#[cfg(test)]
16mod tests;
17
18#[cfg(test)]
19mod commutativity;
20
21pub use dare::{dare_merge, DareConfig};
22pub use ensemble::{ensemble_merge, EnsembleConfig, EnsembleStrategy};
23pub use slerp::{slerp_merge, SlerpConfig};
24pub use ties::{ties_merge, TiesConfig};
25
26use crate::autograd::Tensor;
27use std::collections::HashMap;
28
29/// A model represented as a collection of named tensors
30pub type Model = HashMap<String, Tensor>;
31
32/// Error types for model merging operations
33#[derive(Debug, thiserror::Error)]
34pub enum MergeError {
35    #[error("Models have incompatible architectures: {0}")]
36    IncompatibleArchitectures(String),
37
38    #[error("Parameter {0} has mismatched shapes")]
39    ShapeMismatch(String),
40
41    #[error("Invalid merge configuration: {0}")]
42    InvalidConfig(String),
43
44    #[error("Insufficient models provided: need at least {min}, got {got}")]
45    InsufficientModels { min: usize, got: usize },
46}
47
48/// Compute delta weights (model - base) for each model
49pub(crate) fn compute_deltas(models: &[Model], base: &Model) -> Result<Vec<Model>, MergeError> {
50    models
51        .iter()
52        .map(|model| {
53            let mut delta = HashMap::new();
54            for (name, tensor) in model {
55                let base_tensor = base.get(name).ok_or_else(|| {
56                    MergeError::IncompatibleArchitectures(format!(
57                        "Base model missing parameter: {name}"
58                    ))
59                })?;
60
61                if tensor.len() != base_tensor.len() {
62                    return Err(MergeError::ShapeMismatch(name.clone()));
63                }
64
65                // Delta = model - base
66                let delta_data = tensor.data() - base_tensor.data();
67                delta.insert(name.clone(), Tensor::new(delta_data, false));
68            }
69            Ok(delta)
70        })
71        .collect()
72}
73
74/// Merge deltas back with base model
75pub(crate) fn merge_with_base(base: &Model, delta: Model) -> Model {
76    let mut merged = HashMap::new();
77    for (name, base_tensor) in base {
78        if let Some(delta_tensor) = delta.get(name) {
79            let merged_data = base_tensor.data() + delta_tensor.data();
80            merged.insert(name.clone(), Tensor::new(merged_data, false));
81        } else {
82            merged.insert(name.clone(), base_tensor.clone());
83        }
84    }
85    merged
86}
87
88/// Validate that all models have compatible architectures
89pub(crate) fn validate_models(models: &[Model]) -> Result<(), MergeError> {
90    if models.is_empty() {
91        return Err(MergeError::InsufficientModels { min: 1, got: 0 });
92    }
93
94    let reference = &models[0];
95    for (i, model) in models.iter().enumerate().skip(1) {
96        validate_model_keys(reference, model, i)?;
97        validate_model_shapes(reference, model, i)?;
98    }
99
100    Ok(())
101}
102
103/// Check all reference parameters exist in the model.
104fn validate_model_keys(reference: &Model, model: &Model, idx: usize) -> Result<(), MergeError> {
105    for name in reference.keys() {
106        if !model.contains_key(name) {
107            return Err(MergeError::IncompatibleArchitectures(format!(
108                "Model {idx} missing parameter: {name}"
109            )));
110        }
111    }
112    Ok(())
113}
114
115/// Check all parameter shapes match between reference and model.
116fn validate_model_shapes(reference: &Model, model: &Model, idx: usize) -> Result<(), MergeError> {
117    for (name, ref_tensor) in reference {
118        let model_tensor = &model[name];
119        if ref_tensor.len() != model_tensor.len() {
120            return Err(MergeError::ShapeMismatch(format!(
121                "{} (model 0: {}, model {}: {})",
122                name,
123                ref_tensor.len(),
124                idx,
125                model_tensor.len()
126            )));
127        }
128    }
129    Ok(())
130}