1mod dare;
11mod ensemble;
12mod slerp;
13mod ties;
14
15#[cfg(test)]
16mod tests;
17
18#[cfg(test)]
19mod commutativity;
20
21pub use dare::{dare_merge, DareConfig};
22pub use ensemble::{ensemble_merge, EnsembleConfig, EnsembleStrategy};
23pub use slerp::{slerp_merge, SlerpConfig};
24pub use ties::{ties_merge, TiesConfig};
25
26use crate::autograd::Tensor;
27use std::collections::HashMap;
28
29pub type Model = HashMap<String, Tensor>;
31
32#[derive(Debug, thiserror::Error)]
34pub enum MergeError {
35 #[error("Models have incompatible architectures: {0}")]
36 IncompatibleArchitectures(String),
37
38 #[error("Parameter {0} has mismatched shapes")]
39 ShapeMismatch(String),
40
41 #[error("Invalid merge configuration: {0}")]
42 InvalidConfig(String),
43
44 #[error("Insufficient models provided: need at least {min}, got {got}")]
45 InsufficientModels { min: usize, got: usize },
46}
47
48pub(crate) fn compute_deltas(models: &[Model], base: &Model) -> Result<Vec<Model>, MergeError> {
50 models
51 .iter()
52 .map(|model| {
53 let mut delta = HashMap::new();
54 for (name, tensor) in model {
55 let base_tensor = base.get(name).ok_or_else(|| {
56 MergeError::IncompatibleArchitectures(format!(
57 "Base model missing parameter: {name}"
58 ))
59 })?;
60
61 if tensor.len() != base_tensor.len() {
62 return Err(MergeError::ShapeMismatch(name.clone()));
63 }
64
65 let delta_data = tensor.data() - base_tensor.data();
67 delta.insert(name.clone(), Tensor::new(delta_data, false));
68 }
69 Ok(delta)
70 })
71 .collect()
72}
73
74pub(crate) fn merge_with_base(base: &Model, delta: Model) -> Model {
76 let mut merged = HashMap::new();
77 for (name, base_tensor) in base {
78 if let Some(delta_tensor) = delta.get(name) {
79 let merged_data = base_tensor.data() + delta_tensor.data();
80 merged.insert(name.clone(), Tensor::new(merged_data, false));
81 } else {
82 merged.insert(name.clone(), base_tensor.clone());
83 }
84 }
85 merged
86}
87
88pub(crate) fn validate_models(models: &[Model]) -> Result<(), MergeError> {
90 if models.is_empty() {
91 return Err(MergeError::InsufficientModels { min: 1, got: 0 });
92 }
93
94 let reference = &models[0];
95 for (i, model) in models.iter().enumerate().skip(1) {
96 validate_model_keys(reference, model, i)?;
97 validate_model_shapes(reference, model, i)?;
98 }
99
100 Ok(())
101}
102
103fn validate_model_keys(reference: &Model, model: &Model, idx: usize) -> Result<(), MergeError> {
105 for name in reference.keys() {
106 if !model.contains_key(name) {
107 return Err(MergeError::IncompatibleArchitectures(format!(
108 "Model {idx} missing parameter: {name}"
109 )));
110 }
111 }
112 Ok(())
113}
114
115fn validate_model_shapes(reference: &Model, model: &Model, idx: usize) -> Result<(), MergeError> {
117 for (name, ref_tensor) in reference {
118 let model_tensor = &model[name];
119 if ref_tensor.len() != model_tensor.len() {
120 return Err(MergeError::ShapeMismatch(format!(
121 "{} (model 0: {}, model {}: {})",
122 name,
123 ref_tensor.len(),
124 idx,
125 model_tensor.len()
126 )));
127 }
128 }
129 Ok(())
130}