1use super::{compute_deltas, merge_with_base, validate_models, MergeError, Model};
9use crate::autograd::Tensor;
10use ndarray::Array1;
11use std::collections::HashMap;
12
13#[derive(Clone, Debug)]
15pub struct TiesConfig {
16 pub density: f32,
20}
21
22impl Default for TiesConfig {
23 fn default() -> Self {
24 Self { density: 0.2 }
25 }
26}
27
28impl TiesConfig {
29 pub fn new(density: f32) -> Result<Self, MergeError> {
30 if !(0.0..=1.0).contains(&density) {
31 return Err(MergeError::InvalidConfig(format!(
32 "Density must be in [0.0, 1.0], got {density}"
33 )));
34 }
35 Ok(Self { density })
36 }
37}
38
39pub fn ties_merge(
56 models: &[Model],
57 base: &Model,
58 config: &TiesConfig,
59) -> Result<Model, MergeError> {
60 if models.len() < 2 {
61 return Err(MergeError::InsufficientModels { min: 2, got: models.len() });
62 }
63
64 validate_models(models)?;
65
66 let deltas = compute_deltas(models, base)?;
68
69 let trimmed_deltas = trim_deltas(&deltas, config.density);
71
72 let merged_delta = elect_and_merge(&trimmed_deltas);
74
75 Ok(merge_with_base(base, merged_delta))
77}
78
79fn trim_deltas(deltas: &[Model], density: f32) -> Vec<Model> {
81 deltas
82 .iter()
83 .map(|delta| {
84 let mut trimmed = HashMap::new();
85 for (name, tensor) in delta {
86 trimmed.insert(name.clone(), trim_tensor(tensor, density));
87 }
88 trimmed
89 })
90 .collect()
91}
92
93fn trim_tensor(tensor: &Tensor, density: f32) -> Tensor {
95 let data = tensor.data();
96 let n = data.len();
97 let k = ((n as f32 * density).ceil() as usize).max(1).min(n);
98
99 let mut indices_and_magnitudes: Vec<(usize, f32)> =
101 data.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
102
103 indices_and_magnitudes.sort_by(|a, b| b.1.total_cmp(&a.1));
104
105 let mut trimmed_data = Array1::zeros(n);
107 for (idx, _) in indices_and_magnitudes.iter().take(k) {
108 trimmed_data[*idx] = data[*idx];
109 }
110
111 Tensor::new(trimmed_data, false)
112}
113
114fn elect_and_merge(trimmed_deltas: &[Model]) -> Model {
116 if trimmed_deltas.is_empty() {
117 return HashMap::new();
118 }
119
120 let reference = &trimmed_deltas[0];
121 let mut merged = HashMap::new();
122
123 for name in reference.keys() {
124 let all_values: Vec<&Array1<f32>> =
126 trimmed_deltas.iter().map(|delta| delta[name].data()).collect();
127
128 let merged_tensor = elect_and_merge_parameter(&all_values);
129 merged.insert(name.clone(), merged_tensor);
130 }
131
132 merged
133}
134
135fn elect_and_merge_parameter(values: &[&Array1<f32>]) -> Tensor {
137 let n = values[0].len();
138 let mut merged_data = Array1::zeros(n);
139
140 for i in 0..n {
141 let (pos_sum, pos_count, neg_sum, neg_count) = values.iter().fold(
143 (0.0f32, 0usize, 0.0f32, 0usize),
144 |(pos_sum, pos_count, neg_sum, neg_count), arr| {
145 let val = arr[i];
146 if val > 0.0 {
147 (pos_sum + val, pos_count + 1, neg_sum, neg_count)
148 } else if val < 0.0 {
149 (pos_sum, pos_count, neg_sum + val, neg_count + 1)
150 } else {
151 (pos_sum, pos_count, neg_sum, neg_count)
152 }
153 },
154 );
155
156 merged_data[i] = match pos_count.cmp(&neg_count) {
159 std::cmp::Ordering::Greater => {
160 if pos_count > 0 {
162 pos_sum / pos_count as f32
163 } else {
164 0.0
165 }
166 }
167 std::cmp::Ordering::Less => {
168 if neg_count > 0 {
170 neg_sum / neg_count as f32
171 } else {
172 0.0
173 }
174 }
175 std::cmp::Ordering::Equal => {
176 let total = pos_sum + neg_sum;
178 let total_count = pos_count + neg_count;
179 if total_count > 0 {
180 total / total_count as f32
181 } else {
182 0.0
183 }
184 }
185 };
186 }
187
188 Tensor::new(merged_data, false)
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use proptest::prelude::*;
195
196 #[test]
197 fn test_trim_tensor_keeps_top_k() {
198 let tensor = Tensor::from_vec(vec![1.0, -5.0, 2.0, -0.1, 3.0], false);
199 let trimmed = trim_tensor(&tensor, 0.4); let data = trimmed.data();
203 assert_eq!(data[0], 0.0); assert_eq!(data[1], -5.0); assert_eq!(data[2], 0.0); assert_eq!(data[3], 0.0); assert_eq!(data[4], 3.0); }
209
210 #[test]
211 fn test_elect_and_merge_parameter_majority_positive() {
212 let v1 = Array1::from(vec![1.0, -1.0, 0.0]);
213 let v2 = Array1::from(vec![2.0, 0.0, 1.0]);
214 let v3 = Array1::from(vec![3.0, -2.0, 0.0]);
215
216 let result = elect_and_merge_parameter(&[&v1, &v2, &v3]);
217
218 assert!((result.data()[0] - 2.0).abs() < 1e-6);
220
221 assert!((result.data()[1] - (-1.5)).abs() < 1e-6);
223
224 assert!((result.data()[2] - 1.0).abs() < 1e-6);
226 }
227
228 #[test]
229 fn test_ties_config_validation() {
230 assert!(TiesConfig::new(0.5).is_ok());
231 assert!(TiesConfig::new(0.0).is_ok());
232 assert!(TiesConfig::new(1.0).is_ok());
233 assert!(TiesConfig::new(-0.1).is_err());
234 assert!(TiesConfig::new(1.1).is_err());
235 }
236
237 #[test]
238 fn test_ties_merge_insufficient_models() {
239 let mut base = HashMap::new();
240 base.insert("w".to_string(), Tensor::from_vec(vec![0.0], false));
241
242 let models = vec![base.clone()];
243 let config = TiesConfig::default();
244
245 let result = ties_merge(&models, &base, &config);
246 assert!(matches!(result, Err(MergeError::InsufficientModels { min: 2, got: 1 })));
247 }
248
249 #[test]
250 fn test_trim_tensor_density_zero() {
251 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
252 let trimmed = trim_tensor(&tensor, 0.0);
253
254 let data = trimmed.data();
256 let non_zero_count = data.iter().filter(|&&x| x != 0.0).count();
257 assert!(non_zero_count >= 1);
258 }
259
260 #[test]
261 fn test_trim_tensor_density_one() {
262 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
263 let trimmed = trim_tensor(&tensor, 1.0);
264
265 let data = trimmed.data();
267 assert_eq!(data[0], 1.0);
268 assert_eq!(data[4], 5.0);
269 }
270
271 #[test]
272 fn test_elect_sign_tie_breaker() {
273 let v1 = Array1::from(vec![1.0]);
275 let v2 = Array1::from(vec![-1.0]);
276
277 let result = elect_and_merge_parameter(&[&v1, &v2]);
278
279 assert!((result.data()[0] - 0.0).abs() < 1e-6);
281 }
282
283 #[test]
284 fn test_elect_sign_all_zeros() {
285 let v1 = Array1::from(vec![0.0, 0.0]);
286 let v2 = Array1::from(vec![0.0, 0.0]);
287
288 let result = elect_and_merge_parameter(&[&v1, &v2]);
289
290 assert_eq!(result.data()[0], 0.0);
291 assert_eq!(result.data()[1], 0.0);
292 }
293
294 #[test]
295 fn test_ties_merge_two_models() {
296 let mut base = HashMap::new();
297 base.insert("w".to_string(), Tensor::from_vec(vec![0.0, 0.0, 0.0], false));
298
299 let mut model1 = HashMap::new();
300 model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], false));
301
302 let mut model2 = HashMap::new();
303 model2.insert("w".to_string(), Tensor::from_vec(vec![2.0, -1.0, 4.0], false));
304
305 let config = TiesConfig::new(1.0).expect("config should be valid"); let result = ties_merge(&[model1, model2], &base, &config).expect("config should be valid");
307
308 let w = result.get("w").expect("key should exist");
312 assert!((w.data()[0] - 1.5).abs() < 1e-6);
313 assert!((w.data()[2] - 3.5).abs() < 1e-6);
314 }
315
316 proptest! {
319 #![proptest_config(ProptestConfig::with_cases(200))]
320
321 #[test]
322 fn prop_trim_preserves_top_k_count(
323 values in proptest::collection::vec(-100.0f32..100.0, 10..50),
324 density in 0.1f32..1.0
325 ) {
326 let tensor = Tensor::from_vec(values.clone(), false);
327 let trimmed = trim_tensor(&tensor, density);
328
329 let expected_k = ((values.len() as f32 * density).ceil() as usize).max(1).min(values.len());
330 let actual_nonzero = trimmed.data().iter().filter(|&&x| x != 0.0).count();
331
332 prop_assert!(actual_nonzero <= expected_k + 1);
334 }
335
336 #[test]
337 fn prop_trim_keeps_highest_magnitudes(
338 values in proptest::collection::vec(-100.0f32..100.0, 5..20),
339 density in 0.3f32..0.7
340 ) {
341 let tensor = Tensor::from_vec(values.clone(), false);
342 let trimmed = trim_tensor(&tensor, density);
343
344 let kept_magnitudes: Vec<f32> = trimmed.data()
346 .iter()
347 .filter(|&&x| x != 0.0)
348 .map(|x| x.abs())
349 .collect();
350
351 if !kept_magnitudes.is_empty() {
352 let min_kept = kept_magnitudes.iter().copied().fold(f32::INFINITY, f32::min);
353
354 for (orig, trim) in values.iter().zip(trimmed.data().iter()) {
356 if *trim == 0.0 && *orig != 0.0 {
357 prop_assert!(
358 orig.abs() <= min_kept + 1e-6,
359 "Trimmed value {} has higher magnitude than kept minimum {}",
360 orig.abs(),
361 min_kept
362 );
363 }
364 }
365 }
366 }
367
368 #[test]
369 fn prop_elect_sign_follows_majority(
370 pos_count in 1usize..5,
371 neg_count in 1usize..5,
372 pos_val in 0.1f32..10.0,
373 neg_val in -10.0f32..-0.1
374 ) {
375 let mut arrays: Vec<Array1<f32>> = Vec::new();
376
377 for _ in 0..pos_count {
378 arrays.push(Array1::from(vec![pos_val]));
379 }
380 for _ in 0..neg_count {
381 arrays.push(Array1::from(vec![neg_val]));
382 }
383
384 let refs: Vec<&Array1<f32>> = arrays.iter().collect();
385 let result = elect_and_merge_parameter(&refs);
386
387 if pos_count > neg_count {
388 prop_assert!(result.data()[0] > 0.0, "Expected positive, got {}", result.data()[0]);
390 } else if neg_count > pos_count {
391 prop_assert!(result.data()[0] < 0.0, "Expected negative, got {}", result.data()[0]);
393 }
394 }
396
397 #[test]
398 fn prop_ties_config_density_valid(density in 0.0f32..=1.0) {
399 let config = TiesConfig::new(density);
400 prop_assert!(config.is_ok());
401 }
402
403 #[test]
404 fn prop_ties_config_density_invalid_negative(density in -10.0f32..-0.01) {
405 let config = TiesConfig::new(density);
406 prop_assert!(config.is_err());
407 }
408
409 #[test]
410 fn prop_ties_config_density_invalid_above_one(density in 1.01f32..10.0) {
411 let config = TiesConfig::new(density);
412 prop_assert!(config.is_err());
413 }
414
415 #[test]
416 fn prop_trim_idempotent_at_density_one(
417 values in proptest::collection::vec(-100.0f32..100.0, 5..20)
418 ) {
419 let tensor = Tensor::from_vec(values.clone(), false);
420 let trimmed = trim_tensor(&tensor, 1.0);
421
422 for (orig, trim) in values.iter().zip(trimmed.data().iter()) {
424 prop_assert!(
425 (orig - trim).abs() < 1e-6,
426 "Value changed at density 1.0: {} -> {}",
427 orig,
428 trim
429 );
430 }
431 }
432
433 #[test]
434 fn prop_elect_preserves_magnitude_order(
435 values in proptest::collection::vec(1.0f32..10.0, 3..6)
436 ) {
437 let arrays: Vec<Array1<f32>> = values.iter().map(|&v| Array1::from(vec![v])).collect();
439 let refs: Vec<&Array1<f32>> = arrays.iter().collect();
440
441 let result = elect_and_merge_parameter(&refs);
442 let expected_avg: f32 = values.iter().sum::<f32>() / values.len() as f32;
443
444 prop_assert!(
445 (result.data()[0] - expected_avg).abs() < 1e-5,
446 "Expected average {}, got {}",
447 expected_avg,
448 result.data()[0]
449 );
450 }
451 }
452}