Skip to main content

entrenar/merge/
dare.rs

1//! DARE (Drop And REscale) merge algorithm
2//!
3//! DARE merges models by randomly dropping delta parameters with probability p,
4//! then rescaling the remaining values to maintain expected magnitude.
5
6use super::{compute_deltas, merge_with_base, validate_models, MergeError, Model};
7use crate::autograd::Tensor;
8use ndarray::Array1;
9use rand::Rng;
10use std::collections::HashMap;
11
12/// Configuration for DARE merge
13#[derive(Clone, Debug)]
14pub struct DareConfig {
15    /// Drop probability: probability of zeroing out each delta parameter
16    /// Higher p = more aggressive dropping = sparser merged model
17    /// Typical values: 0.3 to 0.7
18    pub drop_prob: f32,
19
20    /// Random seed for reproducibility (None = random)
21    pub seed: Option<u64>,
22}
23
24impl Default for DareConfig {
25    fn default() -> Self {
26        Self { drop_prob: 0.5, seed: None }
27    }
28}
29
30impl DareConfig {
31    pub fn new(drop_prob: f32) -> Result<Self, MergeError> {
32        if !(0.0..=1.0).contains(&drop_prob) {
33            return Err(MergeError::InvalidConfig(format!(
34                "Drop probability must be in [0.0, 1.0], got {drop_prob}"
35            )));
36        }
37        Ok(Self { drop_prob, seed: None })
38    }
39
40    pub fn with_seed(mut self, seed: u64) -> Self {
41        self.seed = Some(seed);
42        self
43    }
44}
45
46/// DARE merge: drop and rescale delta parameters
47///
48/// # Arguments
49/// * `models` - Fine-tuned models to merge
50/// * `base` - Base model (pre-fine-tuning checkpoint)
51/// * `config` - DARE configuration (drop probability)
52///
53/// # Returns
54/// Merged model with sparsified deltas
55///
56/// # Algorithm
57/// 1. Compute deltas: Δᵢ = model_i - base
58/// 2. Drop: Apply Bernoulli(1-p) mask to each delta
59/// 3. Rescale: Multiply kept values by 1/(1-p) to maintain expected value
60/// 4. Average: Take mean across all masked deltas
61/// 5. Add back to base: merged = base + averaged_delta
62pub fn dare_merge(
63    models: &[Model],
64    base: &Model,
65    config: &DareConfig,
66) -> Result<Model, MergeError> {
67    if models.is_empty() {
68        return Err(MergeError::InsufficientModels { min: 1, got: 0 });
69    }
70
71    validate_models(models)?;
72
73    // Step 1: Compute deltas
74    let deltas = compute_deltas(models, base)?;
75
76    // Step 2 & 3: Drop and rescale
77    let masked_deltas = if let Some(seed) = config.seed {
78        // For deterministic merging (testing), use seeded RNG
79        use rand::SeedableRng;
80        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
81        drop_and_rescale_deltas(&deltas, config.drop_prob, &mut rng)
82    } else {
83        // For normal use, use thread-local RNG
84        let mut rng = rand::rng();
85        drop_and_rescale_deltas(&deltas, config.drop_prob, &mut rng)
86    };
87
88    // Step 4: Average masked deltas
89    let averaged_delta = average_deltas(&masked_deltas);
90
91    // Step 5: Add back to base
92    Ok(merge_with_base(base, averaged_delta))
93}
94
95/// Drop parameters with probability p and rescale by 1/(1-p)
96fn drop_and_rescale_deltas<R: Rng>(deltas: &[Model], drop_prob: f32, rng: &mut R) -> Vec<Model> {
97    let keep_prob = 1.0 - drop_prob;
98    let scale = if keep_prob > 0.0 { 1.0 / keep_prob } else { 1.0 };
99
100    deltas
101        .iter()
102        .map(|delta| {
103            let mut masked = HashMap::new();
104            for (name, tensor) in delta {
105                masked.insert(name.clone(), drop_and_rescale_tensor(tensor, drop_prob, scale, rng));
106            }
107            masked
108        })
109        .collect()
110}
111
112/// Apply Bernoulli dropout mask to a single tensor
113fn drop_and_rescale_tensor<R: Rng>(
114    tensor: &Tensor,
115    drop_prob: f32,
116    scale: f32,
117    rng: &mut R,
118) -> Tensor {
119    let data = tensor.data();
120    let masked_data: Array1<f32> = data
121        .iter()
122        .map(|&val| {
123            if rng.random::<f32>() < drop_prob {
124                0.0 // Drop
125            } else {
126                val * scale // Keep and rescale
127            }
128        })
129        .collect();
130
131    Tensor::new(masked_data, false)
132}
133
134/// Average multiple delta models
135fn average_deltas(deltas: &[Model]) -> Model {
136    if deltas.is_empty() {
137        return HashMap::new();
138    }
139
140    let n = deltas.len() as f32;
141    let reference = &deltas[0];
142    let mut averaged = HashMap::new();
143
144    for name in reference.keys() {
145        let sum_data: Array1<f32> = deltas
146            .iter()
147            .map(|delta| delta[name].data())
148            .fold(Array1::zeros(reference[name].len()), |acc, data| &acc + data);
149
150        let avg_data = sum_data / n;
151        averaged.insert(name.clone(), Tensor::new(avg_data, false));
152    }
153
154    averaged
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use proptest::prelude::*;
161    use rand::SeedableRng;
162
163    #[test]
164    fn test_drop_and_rescale_tensor_deterministic() {
165        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
166        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
167
168        let masked = drop_and_rescale_tensor(&tensor, 0.5, 2.0, &mut rng);
169
170        // With drop_prob=0.5, scale=2.0:
171        // - Dropped values -> 0.0
172        // - Kept values -> original * 2.0
173        let data = masked.data();
174        for &val in data {
175            assert!(val == 0.0 || val % 2.0 == 0.0);
176        }
177    }
178
179    #[test]
180    fn test_average_deltas() {
181        let mut delta1 = HashMap::new();
182        delta1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0], false));
183
184        let mut delta2 = HashMap::new();
185        delta2.insert("w".to_string(), Tensor::from_vec(vec![3.0, 4.0], false));
186
187        let averaged = average_deltas(&[delta1, delta2]);
188
189        let expected = [2.0, 3.0]; // (1+3)/2, (2+4)/2
190        let actual = averaged["w"].data();
191        for (a, e) in actual.iter().zip(expected.iter()) {
192            assert!((a - e).abs() < 1e-6);
193        }
194    }
195
196    #[test]
197    fn test_dare_config_validation() {
198        assert!(DareConfig::new(0.5).is_ok());
199        assert!(DareConfig::new(0.0).is_ok());
200        assert!(DareConfig::new(1.0).is_ok());
201        assert!(DareConfig::new(-0.1).is_err());
202        assert!(DareConfig::new(1.1).is_err());
203    }
204
205    #[test]
206    fn test_dare_merge_with_seed_is_deterministic() {
207        let mut base = HashMap::new();
208        base.insert("w".to_string(), Tensor::from_vec(vec![0.0, 0.0], false));
209
210        let mut model1 = base.clone();
211        model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0], false));
212
213        let mut model2 = base.clone();
214        model2.insert("w".to_string(), Tensor::from_vec(vec![3.0, 4.0], false));
215
216        let models = vec![model1, model2];
217        let config = DareConfig::new(0.5).expect("config should be valid").with_seed(42);
218
219        let result1 = dare_merge(&models, &base, &config).expect("config should be valid");
220        let result2 = dare_merge(&models, &base, &config).expect("config should be valid");
221
222        // Same seed should produce same results
223        let r1_data = result1["w"].data();
224        let r2_data = result2["w"].data();
225        for (a, b) in r1_data.iter().zip(r2_data.iter()) {
226            assert!((a - b).abs() < 1e-6);
227        }
228    }
229
230    #[test]
231    fn test_drop_prob_zero_keeps_all() {
232        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
233        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
234
235        // drop_prob=0 means keep all, scale=1
236        let masked = drop_and_rescale_tensor(&tensor, 0.0, 1.0, &mut rng);
237
238        let data = masked.data();
239        assert_eq!(data[0], 1.0);
240        assert_eq!(data[4], 5.0);
241    }
242
243    #[test]
244    fn test_drop_prob_one_drops_all() {
245        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
246        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
247
248        // drop_prob=1.0 means drop all
249        let masked = drop_and_rescale_tensor(&tensor, 1.0, 1.0, &mut rng);
250
251        let data = masked.data();
252        for &val in data {
253            assert_eq!(val, 0.0);
254        }
255    }
256
257    #[test]
258    fn test_dare_merge_empty_models() {
259        let mut base = HashMap::new();
260        base.insert("w".to_string(), Tensor::from_vec(vec![0.0], false));
261
262        let models: Vec<Model> = vec![];
263        let config = DareConfig::default();
264
265        let result = dare_merge(&models, &base, &config);
266        assert!(matches!(result, Err(MergeError::InsufficientModels { min: 1, got: 0 })));
267    }
268
269    #[test]
270    fn test_dare_merge_single_model() {
271        let mut base = HashMap::new();
272        base.insert("w".to_string(), Tensor::from_vec(vec![0.0, 0.0], false));
273
274        let mut model1 = HashMap::new();
275        model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0], false));
276
277        let models = vec![model1];
278        let config = DareConfig::new(0.0).expect("config should be valid").with_seed(42); // Keep all
279
280        let result = dare_merge(&models, &base, &config).expect("config should be valid");
281
282        // With drop_prob=0, should get model1 back
283        let w = result.get("w").expect("key should exist");
284        assert!((w.data()[0] - 1.0).abs() < 1e-6);
285        assert!((w.data()[1] - 2.0).abs() < 1e-6);
286    }
287
288    // Property tests
289
290    proptest! {
291        #![proptest_config(ProptestConfig::with_cases(200))]
292
293        #[test]
294        fn prop_dare_config_valid_range(drop_prob in 0.0f32..=1.0) {
295            let config = DareConfig::new(drop_prob);
296            prop_assert!(config.is_ok());
297        }
298
299        #[test]
300        fn prop_dare_config_invalid_negative(drop_prob in -10.0f32..-0.01) {
301            let config = DareConfig::new(drop_prob);
302            prop_assert!(config.is_err());
303        }
304
305        #[test]
306        fn prop_dare_config_invalid_above_one(drop_prob in 1.01f32..10.0) {
307            let config = DareConfig::new(drop_prob);
308            prop_assert!(config.is_err());
309        }
310
311        #[test]
312        fn prop_drop_and_rescale_output_values(
313            values in proptest::collection::vec(1.0f32..10.0, 10..50),
314            drop_prob in 0.0f32..1.0,
315            seed in 0u64..1000
316        ) {
317            let tensor = Tensor::from_vec(values.clone(), false);
318            let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
319            let keep_prob = 1.0 - drop_prob;
320            let scale = if keep_prob > 0.0 { 1.0 / keep_prob } else { 1.0 };
321
322            let masked = drop_and_rescale_tensor(&tensor, drop_prob, scale, &mut rng);
323
324            // Each value should be either 0 (dropped) or original * scale (kept)
325            for (orig, result) in values.iter().zip(masked.data().iter()) {
326                if *result != 0.0 {
327                    let expected = orig * scale;
328                    prop_assert!(
329                        (result - expected).abs() < 1e-4,
330                        "Expected {} * {} = {}, got {}",
331                        orig,
332                        scale,
333                        expected,
334                        result
335                    );
336                }
337            }
338        }
339
340        #[test]
341        fn prop_drop_prob_zero_preserves_values(
342            values in proptest::collection::vec(-100.0f32..100.0, 5..20),
343            seed in 0u64..1000
344        ) {
345            let tensor = Tensor::from_vec(values.clone(), false);
346            let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
347
348            // drop_prob=0 with scale=1 should preserve all values
349            let masked = drop_and_rescale_tensor(&tensor, 0.0, 1.0, &mut rng);
350
351            for (orig, result) in values.iter().zip(masked.data().iter()) {
352                prop_assert!(
353                    (orig - result).abs() < 1e-6,
354                    "Value not preserved: {} -> {}",
355                    orig,
356                    result
357                );
358            }
359        }
360
361        #[test]
362        fn prop_drop_prob_one_zeros_all(
363            values in proptest::collection::vec(-100.0f32..100.0, 5..20),
364            seed in 0u64..1000
365        ) {
366            let tensor = Tensor::from_vec(values, false);
367            let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
368
369            // drop_prob=1.0 should zero all values
370            let masked = drop_and_rescale_tensor(&tensor, 1.0, 1.0, &mut rng);
371
372            for &val in masked.data() {
373                prop_assert_eq!(val, 0.0);
374            }
375        }
376
377        #[test]
378        fn prop_average_deltas_is_mean(
379            v1 in proptest::collection::vec(-100.0f32..100.0, 5..10),
380            v2 in proptest::collection::vec(-100.0f32..100.0, 5..10)
381        ) {
382            // Ensure same length
383            let len = v1.len().min(v2.len());
384            let v1: Vec<f32> = v1.into_iter().take(len).collect();
385            let v2: Vec<f32> = v2.into_iter().take(len).collect();
386
387            let mut delta1 = HashMap::new();
388            delta1.insert("w".to_string(), Tensor::from_vec(v1.clone(), false));
389
390            let mut delta2 = HashMap::new();
391            delta2.insert("w".to_string(), Tensor::from_vec(v2.clone(), false));
392
393            let averaged = average_deltas(&[delta1, delta2]);
394            let avg_data = averaged["w"].data();
395
396            for i in 0..len {
397                let expected = f32::midpoint(v1[i], v2[i]);
398                prop_assert!(
399                    (avg_data[i] - expected).abs() < 1e-5,
400                    "Average mismatch at {}: expected {}, got {}",
401                    i,
402                    expected,
403                    avg_data[i]
404                );
405            }
406        }
407
408        #[test]
409        fn prop_dare_deterministic_with_same_seed(
410            delta_values in proptest::collection::vec(-10.0f32..10.0, 5..15),
411            seed in 0u64..1000,
412            drop_prob in 0.1f32..0.9
413        ) {
414            let mut base = HashMap::new();
415            base.insert("w".to_string(), Tensor::from_vec(vec![0.0; delta_values.len()], false));
416
417            let mut model1 = HashMap::new();
418            model1.insert("w".to_string(), Tensor::from_vec(delta_values, false));
419
420            let models = vec![model1];
421            let config = DareConfig::new(drop_prob).expect("config should be valid").with_seed(seed);
422
423            let result1 = dare_merge(&models, &base, &config).expect("config should be valid");
424            let result2 = dare_merge(&models, &base, &config).expect("config should be valid");
425
426            // Same seed should produce identical results
427            for (a, b) in result1["w"].data().iter().zip(result2["w"].data().iter()) {
428                prop_assert!(
429                    (a - b).abs() < 1e-6,
430                    "Non-deterministic result: {} vs {}",
431                    a,
432                    b
433                );
434            }
435        }
436    }
437}