Skip to main content

entrenar/merge/
ties.rs

1//! TIES (Task Inference via Elimination and Sign) merge algorithm
2//!
3//! TIES merges multiple fine-tuned models by:
4//! 1. **Trim**: Keep only top-k% magnitude parameters (eliminate noise)
5//! 2. **Elect Sign**: Majority vote per parameter to resolve conflicts
6//! 3. **Merge**: Average parameters with elected sign
7
8use super::{compute_deltas, merge_with_base, validate_models, MergeError, Model};
9use crate::autograd::Tensor;
10use ndarray::Array1;
11use std::collections::HashMap;
12
13/// Configuration for TIES merge
14#[derive(Clone, Debug)]
15pub struct TiesConfig {
16    /// Density parameter: fraction of parameters to keep (0.0 to 1.0)
17    /// Higher density = keep more parameters = less aggressive trimming
18    /// Typical values: 0.2 (20% kept) to 0.5 (50% kept)
19    pub density: f32,
20}
21
22impl Default for TiesConfig {
23    fn default() -> Self {
24        Self { density: 0.2 }
25    }
26}
27
28impl TiesConfig {
29    pub fn new(density: f32) -> Result<Self, MergeError> {
30        if !(0.0..=1.0).contains(&density) {
31            return Err(MergeError::InvalidConfig(format!(
32                "Density must be in [0.0, 1.0], got {density}"
33            )));
34        }
35        Ok(Self { density })
36    }
37}
38
39/// TIES merge: trim, elect sign, merge same-sign parameters
40///
41/// # Arguments
42/// * `models` - Fine-tuned models to merge
43/// * `base` - Base model (pre-fine-tuning checkpoint)
44/// * `config` - TIES configuration (density parameter)
45///
46/// # Returns
47/// Merged model combining all input models
48///
49/// # Algorithm
50/// 1. Compute deltas: Δᵢ = model_i - base
51/// 2. Trim: Keep only top-k% magnitude values per delta
52/// 3. Elect sign: For each parameter, take sign of majority
53/// 4. Merge: Average trimmed deltas with elected sign
54/// 5. Add back to base: merged = base + averaged_delta
55pub fn ties_merge(
56    models: &[Model],
57    base: &Model,
58    config: &TiesConfig,
59) -> Result<Model, MergeError> {
60    if models.len() < 2 {
61        return Err(MergeError::InsufficientModels { min: 2, got: models.len() });
62    }
63
64    validate_models(models)?;
65
66    // Step 1: Compute deltas (model - base)
67    let deltas = compute_deltas(models, base)?;
68
69    // Step 2: Trim deltas (keep top-k% magnitude)
70    let trimmed_deltas = trim_deltas(&deltas, config.density);
71
72    // Step 3 & 4: Elect sign and merge
73    let merged_delta = elect_and_merge(&trimmed_deltas);
74
75    // Step 5: Add back to base
76    Ok(merge_with_base(base, merged_delta))
77}
78
79/// Trim each delta to keep only top-k% magnitude parameters
80fn trim_deltas(deltas: &[Model], density: f32) -> Vec<Model> {
81    deltas
82        .iter()
83        .map(|delta| {
84            let mut trimmed = HashMap::new();
85            for (name, tensor) in delta {
86                trimmed.insert(name.clone(), trim_tensor(tensor, density));
87            }
88            trimmed
89        })
90        .collect()
91}
92
93/// Trim a single tensor to keep only top-k% magnitude values
94fn trim_tensor(tensor: &Tensor, density: f32) -> Tensor {
95    let data = tensor.data();
96    let n = data.len();
97    let k = ((n as f32 * density).ceil() as usize).max(1).min(n);
98
99    // Get magnitude-sorted indices
100    let mut indices_and_magnitudes: Vec<(usize, f32)> =
101        data.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
102
103    indices_and_magnitudes.sort_by(|a, b| b.1.total_cmp(&a.1));
104
105    // Keep top-k magnitude values, zero out rest
106    let mut trimmed_data = Array1::zeros(n);
107    for (idx, _) in indices_and_magnitudes.iter().take(k) {
108        trimmed_data[*idx] = data[*idx];
109    }
110
111    Tensor::new(trimmed_data, false)
112}
113
114/// Elect sign per parameter and merge same-sign values
115fn elect_and_merge(trimmed_deltas: &[Model]) -> Model {
116    if trimmed_deltas.is_empty() {
117        return HashMap::new();
118    }
119
120    let reference = &trimmed_deltas[0];
121    let mut merged = HashMap::new();
122
123    for name in reference.keys() {
124        // Collect all delta values for this parameter
125        let all_values: Vec<&Array1<f32>> =
126            trimmed_deltas.iter().map(|delta| delta[name].data()).collect();
127
128        let merged_tensor = elect_and_merge_parameter(&all_values);
129        merged.insert(name.clone(), merged_tensor);
130    }
131
132    merged
133}
134
135/// Elect sign and merge for a single parameter across all models
136fn elect_and_merge_parameter(values: &[&Array1<f32>]) -> Tensor {
137    let n = values[0].len();
138    let mut merged_data = Array1::zeros(n);
139
140    for i in 0..n {
141        // Count positive and negative non-zero values
142        let (pos_sum, pos_count, neg_sum, neg_count) = values.iter().fold(
143            (0.0f32, 0usize, 0.0f32, 0usize),
144            |(pos_sum, pos_count, neg_sum, neg_count), arr| {
145                let val = arr[i];
146                if val > 0.0 {
147                    (pos_sum + val, pos_count + 1, neg_sum, neg_count)
148                } else if val < 0.0 {
149                    (pos_sum, pos_count, neg_sum + val, neg_count + 1)
150                } else {
151                    (pos_sum, pos_count, neg_sum, neg_count)
152                }
153            },
154        );
155
156        // Elect sign by majority vote (most non-zero contributors)
157        // Merge by averaging same-sign values
158        merged_data[i] = match pos_count.cmp(&neg_count) {
159            std::cmp::Ordering::Greater => {
160                // Positive wins: average positive values only
161                if pos_count > 0 {
162                    pos_sum / pos_count as f32
163                } else {
164                    0.0
165                }
166            }
167            std::cmp::Ordering::Less => {
168                // Negative wins: average negative values only
169                if neg_count > 0 {
170                    neg_sum / neg_count as f32
171                } else {
172                    0.0
173                }
174            }
175            std::cmp::Ordering::Equal => {
176                // Tie or all zero: take overall average
177                let total = pos_sum + neg_sum;
178                let total_count = pos_count + neg_count;
179                if total_count > 0 {
180                    total / total_count as f32
181                } else {
182                    0.0
183                }
184            }
185        };
186    }
187
188    Tensor::new(merged_data, false)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use proptest::prelude::*;
195
196    #[test]
197    fn test_trim_tensor_keeps_top_k() {
198        let tensor = Tensor::from_vec(vec![1.0, -5.0, 2.0, -0.1, 3.0], false);
199        let trimmed = trim_tensor(&tensor, 0.4); // Keep top 40% = 2 values
200
201        // Should keep -5.0 and 3.0 (highest magnitudes)
202        let data = trimmed.data();
203        assert_eq!(data[0], 0.0); // 1.0 trimmed
204        assert_eq!(data[1], -5.0); // kept
205        assert_eq!(data[2], 0.0); // 2.0 trimmed
206        assert_eq!(data[3], 0.0); // -0.1 trimmed
207        assert_eq!(data[4], 3.0); // kept
208    }
209
210    #[test]
211    fn test_elect_and_merge_parameter_majority_positive() {
212        let v1 = Array1::from(vec![1.0, -1.0, 0.0]);
213        let v2 = Array1::from(vec![2.0, 0.0, 1.0]);
214        let v3 = Array1::from(vec![3.0, -2.0, 0.0]);
215
216        let result = elect_and_merge_parameter(&[&v1, &v2, &v3]);
217
218        // Index 0: 3 positive votes -> average (1+2+3)/3 = 2.0
219        assert!((result.data()[0] - 2.0).abs() < 1e-6);
220
221        // Index 1: 2 negative votes -> average (-1-2)/2 = -1.5
222        assert!((result.data()[1] - (-1.5)).abs() < 1e-6);
223
224        // Index 2: 1 positive vote -> 1.0
225        assert!((result.data()[2] - 1.0).abs() < 1e-6);
226    }
227
228    #[test]
229    fn test_ties_config_validation() {
230        assert!(TiesConfig::new(0.5).is_ok());
231        assert!(TiesConfig::new(0.0).is_ok());
232        assert!(TiesConfig::new(1.0).is_ok());
233        assert!(TiesConfig::new(-0.1).is_err());
234        assert!(TiesConfig::new(1.1).is_err());
235    }
236
237    #[test]
238    fn test_ties_merge_insufficient_models() {
239        let mut base = HashMap::new();
240        base.insert("w".to_string(), Tensor::from_vec(vec![0.0], false));
241
242        let models = vec![base.clone()];
243        let config = TiesConfig::default();
244
245        let result = ties_merge(&models, &base, &config);
246        assert!(matches!(result, Err(MergeError::InsufficientModels { min: 2, got: 1 })));
247    }
248
249    #[test]
250    fn test_trim_tensor_density_zero() {
251        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
252        let trimmed = trim_tensor(&tensor, 0.0);
253
254        // Density 0 should still keep at least 1 value (the maximum)
255        let data = trimmed.data();
256        let non_zero_count = data.iter().filter(|&&x| x != 0.0).count();
257        assert!(non_zero_count >= 1);
258    }
259
260    #[test]
261    fn test_trim_tensor_density_one() {
262        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
263        let trimmed = trim_tensor(&tensor, 1.0);
264
265        // Density 1.0 should keep all values
266        let data = trimmed.data();
267        assert_eq!(data[0], 1.0);
268        assert_eq!(data[4], 5.0);
269    }
270
271    #[test]
272    fn test_elect_sign_tie_breaker() {
273        // Test tie case: equal positive and negative votes
274        let v1 = Array1::from(vec![1.0]);
275        let v2 = Array1::from(vec![-1.0]);
276
277        let result = elect_and_merge_parameter(&[&v1, &v2]);
278
279        // Tie: should average all values (1 + -1) / 2 = 0
280        assert!((result.data()[0] - 0.0).abs() < 1e-6);
281    }
282
283    #[test]
284    fn test_elect_sign_all_zeros() {
285        let v1 = Array1::from(vec![0.0, 0.0]);
286        let v2 = Array1::from(vec![0.0, 0.0]);
287
288        let result = elect_and_merge_parameter(&[&v1, &v2]);
289
290        assert_eq!(result.data()[0], 0.0);
291        assert_eq!(result.data()[1], 0.0);
292    }
293
294    #[test]
295    fn test_ties_merge_two_models() {
296        let mut base = HashMap::new();
297        base.insert("w".to_string(), Tensor::from_vec(vec![0.0, 0.0, 0.0], false));
298
299        let mut model1 = HashMap::new();
300        model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], false));
301
302        let mut model2 = HashMap::new();
303        model2.insert("w".to_string(), Tensor::from_vec(vec![2.0, -1.0, 4.0], false));
304
305        let config = TiesConfig::new(1.0).expect("config should be valid"); // Keep all
306        let result = ties_merge(&[model1, model2], &base, &config).expect("config should be valid");
307
308        // Both positive at index 0: average (1+2)/2 = 1.5
309        // Mixed at index 1: pos=2, neg=-1, pos wins -> 2.0
310        // Both positive at index 2: average (3+4)/2 = 3.5
311        let w = result.get("w").expect("key should exist");
312        assert!((w.data()[0] - 1.5).abs() < 1e-6);
313        assert!((w.data()[2] - 3.5).abs() < 1e-6);
314    }
315
316    // Property tests
317
318    proptest! {
319        #![proptest_config(ProptestConfig::with_cases(200))]
320
321        #[test]
322        fn prop_trim_preserves_top_k_count(
323            values in proptest::collection::vec(-100.0f32..100.0, 10..50),
324            density in 0.1f32..1.0
325        ) {
326            let tensor = Tensor::from_vec(values.clone(), false);
327            let trimmed = trim_tensor(&tensor, density);
328
329            let expected_k = ((values.len() as f32 * density).ceil() as usize).max(1).min(values.len());
330            let actual_nonzero = trimmed.data().iter().filter(|&&x| x != 0.0).count();
331
332            // Should keep approximately expected_k values (exact for non-zero inputs)
333            prop_assert!(actual_nonzero <= expected_k + 1);
334        }
335
336        #[test]
337        fn prop_trim_keeps_highest_magnitudes(
338            values in proptest::collection::vec(-100.0f32..100.0, 5..20),
339            density in 0.3f32..0.7
340        ) {
341            let tensor = Tensor::from_vec(values.clone(), false);
342            let trimmed = trim_tensor(&tensor, density);
343
344            // Find the minimum magnitude among kept values
345            let kept_magnitudes: Vec<f32> = trimmed.data()
346                .iter()
347                .filter(|&&x| x != 0.0)
348                .map(|x| x.abs())
349                .collect();
350
351            if !kept_magnitudes.is_empty() {
352                let min_kept = kept_magnitudes.iter().copied().fold(f32::INFINITY, f32::min);
353
354                // All trimmed values should have magnitude <= min_kept
355                for (orig, trim) in values.iter().zip(trimmed.data().iter()) {
356                    if *trim == 0.0 && *orig != 0.0 {
357                        prop_assert!(
358                            orig.abs() <= min_kept + 1e-6,
359                            "Trimmed value {} has higher magnitude than kept minimum {}",
360                            orig.abs(),
361                            min_kept
362                        );
363                    }
364                }
365            }
366        }
367
368        #[test]
369        fn prop_elect_sign_follows_majority(
370            pos_count in 1usize..5,
371            neg_count in 1usize..5,
372            pos_val in 0.1f32..10.0,
373            neg_val in -10.0f32..-0.1
374        ) {
375            let mut arrays: Vec<Array1<f32>> = Vec::new();
376
377            for _ in 0..pos_count {
378                arrays.push(Array1::from(vec![pos_val]));
379            }
380            for _ in 0..neg_count {
381                arrays.push(Array1::from(vec![neg_val]));
382            }
383
384            let refs: Vec<&Array1<f32>> = arrays.iter().collect();
385            let result = elect_and_merge_parameter(&refs);
386
387            if pos_count > neg_count {
388                // Positive majority: result should be positive
389                prop_assert!(result.data()[0] > 0.0, "Expected positive, got {}", result.data()[0]);
390            } else if neg_count > pos_count {
391                // Negative majority: result should be negative
392                prop_assert!(result.data()[0] < 0.0, "Expected negative, got {}", result.data()[0]);
393            }
394            // Tie case: could be either, don't assert
395        }
396
397        #[test]
398        fn prop_ties_config_density_valid(density in 0.0f32..=1.0) {
399            let config = TiesConfig::new(density);
400            prop_assert!(config.is_ok());
401        }
402
403        #[test]
404        fn prop_ties_config_density_invalid_negative(density in -10.0f32..-0.01) {
405            let config = TiesConfig::new(density);
406            prop_assert!(config.is_err());
407        }
408
409        #[test]
410        fn prop_ties_config_density_invalid_above_one(density in 1.01f32..10.0) {
411            let config = TiesConfig::new(density);
412            prop_assert!(config.is_err());
413        }
414
415        #[test]
416        fn prop_trim_idempotent_at_density_one(
417            values in proptest::collection::vec(-100.0f32..100.0, 5..20)
418        ) {
419            let tensor = Tensor::from_vec(values.clone(), false);
420            let trimmed = trim_tensor(&tensor, 1.0);
421
422            // At density 1.0, all values should be preserved
423            for (orig, trim) in values.iter().zip(trimmed.data().iter()) {
424                prop_assert!(
425                    (orig - trim).abs() < 1e-6,
426                    "Value changed at density 1.0: {} -> {}",
427                    orig,
428                    trim
429                );
430            }
431        }
432
433        #[test]
434        fn prop_elect_preserves_magnitude_order(
435            values in proptest::collection::vec(1.0f32..10.0, 3..6)
436        ) {
437            // All same sign: result should be average
438            let arrays: Vec<Array1<f32>> = values.iter().map(|&v| Array1::from(vec![v])).collect();
439            let refs: Vec<&Array1<f32>> = arrays.iter().collect();
440
441            let result = elect_and_merge_parameter(&refs);
442            let expected_avg: f32 = values.iter().sum::<f32>() / values.len() as f32;
443
444            prop_assert!(
445                (result.data()[0] - expected_avg).abs() < 1e-5,
446                "Expected average {}, got {}",
447                expected_avg,
448                result.data()[0]
449            );
450        }
451    }
452}