1use super::{compute_deltas, merge_with_base, validate_models, MergeError, Model};
7use crate::autograd::Tensor;
8use ndarray::Array1;
9use rand::Rng;
10use std::collections::HashMap;
11
12#[derive(Clone, Debug)]
14pub struct DareConfig {
15 pub drop_prob: f32,
19
20 pub seed: Option<u64>,
22}
23
24impl Default for DareConfig {
25 fn default() -> Self {
26 Self { drop_prob: 0.5, seed: None }
27 }
28}
29
30impl DareConfig {
31 pub fn new(drop_prob: f32) -> Result<Self, MergeError> {
32 if !(0.0..=1.0).contains(&drop_prob) {
33 return Err(MergeError::InvalidConfig(format!(
34 "Drop probability must be in [0.0, 1.0], got {drop_prob}"
35 )));
36 }
37 Ok(Self { drop_prob, seed: None })
38 }
39
40 pub fn with_seed(mut self, seed: u64) -> Self {
41 self.seed = Some(seed);
42 self
43 }
44}
45
46pub fn dare_merge(
63 models: &[Model],
64 base: &Model,
65 config: &DareConfig,
66) -> Result<Model, MergeError> {
67 if models.is_empty() {
68 return Err(MergeError::InsufficientModels { min: 1, got: 0 });
69 }
70
71 validate_models(models)?;
72
73 let deltas = compute_deltas(models, base)?;
75
76 let masked_deltas = if let Some(seed) = config.seed {
78 use rand::SeedableRng;
80 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
81 drop_and_rescale_deltas(&deltas, config.drop_prob, &mut rng)
82 } else {
83 let mut rng = rand::rng();
85 drop_and_rescale_deltas(&deltas, config.drop_prob, &mut rng)
86 };
87
88 let averaged_delta = average_deltas(&masked_deltas);
90
91 Ok(merge_with_base(base, averaged_delta))
93}
94
95fn drop_and_rescale_deltas<R: Rng>(deltas: &[Model], drop_prob: f32, rng: &mut R) -> Vec<Model> {
97 let keep_prob = 1.0 - drop_prob;
98 let scale = if keep_prob > 0.0 { 1.0 / keep_prob } else { 1.0 };
99
100 deltas
101 .iter()
102 .map(|delta| {
103 let mut masked = HashMap::new();
104 for (name, tensor) in delta {
105 masked.insert(name.clone(), drop_and_rescale_tensor(tensor, drop_prob, scale, rng));
106 }
107 masked
108 })
109 .collect()
110}
111
112fn drop_and_rescale_tensor<R: Rng>(
114 tensor: &Tensor,
115 drop_prob: f32,
116 scale: f32,
117 rng: &mut R,
118) -> Tensor {
119 let data = tensor.data();
120 let masked_data: Array1<f32> = data
121 .iter()
122 .map(|&val| {
123 if rng.random::<f32>() < drop_prob {
124 0.0 } else {
126 val * scale }
128 })
129 .collect();
130
131 Tensor::new(masked_data, false)
132}
133
134fn average_deltas(deltas: &[Model]) -> Model {
136 if deltas.is_empty() {
137 return HashMap::new();
138 }
139
140 let n = deltas.len() as f32;
141 let reference = &deltas[0];
142 let mut averaged = HashMap::new();
143
144 for name in reference.keys() {
145 let sum_data: Array1<f32> = deltas
146 .iter()
147 .map(|delta| delta[name].data())
148 .fold(Array1::zeros(reference[name].len()), |acc, data| &acc + data);
149
150 let avg_data = sum_data / n;
151 averaged.insert(name.clone(), Tensor::new(avg_data, false));
152 }
153
154 averaged
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use proptest::prelude::*;
161 use rand::SeedableRng;
162
163 #[test]
164 fn test_drop_and_rescale_tensor_deterministic() {
165 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
166 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
167
168 let masked = drop_and_rescale_tensor(&tensor, 0.5, 2.0, &mut rng);
169
170 let data = masked.data();
174 for &val in data {
175 assert!(val == 0.0 || val % 2.0 == 0.0);
176 }
177 }
178
179 #[test]
180 fn test_average_deltas() {
181 let mut delta1 = HashMap::new();
182 delta1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0], false));
183
184 let mut delta2 = HashMap::new();
185 delta2.insert("w".to_string(), Tensor::from_vec(vec![3.0, 4.0], false));
186
187 let averaged = average_deltas(&[delta1, delta2]);
188
189 let expected = [2.0, 3.0]; let actual = averaged["w"].data();
191 for (a, e) in actual.iter().zip(expected.iter()) {
192 assert!((a - e).abs() < 1e-6);
193 }
194 }
195
196 #[test]
197 fn test_dare_config_validation() {
198 assert!(DareConfig::new(0.5).is_ok());
199 assert!(DareConfig::new(0.0).is_ok());
200 assert!(DareConfig::new(1.0).is_ok());
201 assert!(DareConfig::new(-0.1).is_err());
202 assert!(DareConfig::new(1.1).is_err());
203 }
204
205 #[test]
206 fn test_dare_merge_with_seed_is_deterministic() {
207 let mut base = HashMap::new();
208 base.insert("w".to_string(), Tensor::from_vec(vec![0.0, 0.0], false));
209
210 let mut model1 = base.clone();
211 model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0], false));
212
213 let mut model2 = base.clone();
214 model2.insert("w".to_string(), Tensor::from_vec(vec![3.0, 4.0], false));
215
216 let models = vec![model1, model2];
217 let config = DareConfig::new(0.5).expect("config should be valid").with_seed(42);
218
219 let result1 = dare_merge(&models, &base, &config).expect("config should be valid");
220 let result2 = dare_merge(&models, &base, &config).expect("config should be valid");
221
222 let r1_data = result1["w"].data();
224 let r2_data = result2["w"].data();
225 for (a, b) in r1_data.iter().zip(r2_data.iter()) {
226 assert!((a - b).abs() < 1e-6);
227 }
228 }
229
230 #[test]
231 fn test_drop_prob_zero_keeps_all() {
232 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
233 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
234
235 let masked = drop_and_rescale_tensor(&tensor, 0.0, 1.0, &mut rng);
237
238 let data = masked.data();
239 assert_eq!(data[0], 1.0);
240 assert_eq!(data[4], 5.0);
241 }
242
243 #[test]
244 fn test_drop_prob_one_drops_all() {
245 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], false);
246 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
247
248 let masked = drop_and_rescale_tensor(&tensor, 1.0, 1.0, &mut rng);
250
251 let data = masked.data();
252 for &val in data {
253 assert_eq!(val, 0.0);
254 }
255 }
256
257 #[test]
258 fn test_dare_merge_empty_models() {
259 let mut base = HashMap::new();
260 base.insert("w".to_string(), Tensor::from_vec(vec![0.0], false));
261
262 let models: Vec<Model> = vec![];
263 let config = DareConfig::default();
264
265 let result = dare_merge(&models, &base, &config);
266 assert!(matches!(result, Err(MergeError::InsufficientModels { min: 1, got: 0 })));
267 }
268
269 #[test]
270 fn test_dare_merge_single_model() {
271 let mut base = HashMap::new();
272 base.insert("w".to_string(), Tensor::from_vec(vec![0.0, 0.0], false));
273
274 let mut model1 = HashMap::new();
275 model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 2.0], false));
276
277 let models = vec![model1];
278 let config = DareConfig::new(0.0).expect("config should be valid").with_seed(42); let result = dare_merge(&models, &base, &config).expect("config should be valid");
281
282 let w = result.get("w").expect("key should exist");
284 assert!((w.data()[0] - 1.0).abs() < 1e-6);
285 assert!((w.data()[1] - 2.0).abs() < 1e-6);
286 }
287
288 proptest! {
291 #![proptest_config(ProptestConfig::with_cases(200))]
292
293 #[test]
294 fn prop_dare_config_valid_range(drop_prob in 0.0f32..=1.0) {
295 let config = DareConfig::new(drop_prob);
296 prop_assert!(config.is_ok());
297 }
298
299 #[test]
300 fn prop_dare_config_invalid_negative(drop_prob in -10.0f32..-0.01) {
301 let config = DareConfig::new(drop_prob);
302 prop_assert!(config.is_err());
303 }
304
305 #[test]
306 fn prop_dare_config_invalid_above_one(drop_prob in 1.01f32..10.0) {
307 let config = DareConfig::new(drop_prob);
308 prop_assert!(config.is_err());
309 }
310
311 #[test]
312 fn prop_drop_and_rescale_output_values(
313 values in proptest::collection::vec(1.0f32..10.0, 10..50),
314 drop_prob in 0.0f32..1.0,
315 seed in 0u64..1000
316 ) {
317 let tensor = Tensor::from_vec(values.clone(), false);
318 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
319 let keep_prob = 1.0 - drop_prob;
320 let scale = if keep_prob > 0.0 { 1.0 / keep_prob } else { 1.0 };
321
322 let masked = drop_and_rescale_tensor(&tensor, drop_prob, scale, &mut rng);
323
324 for (orig, result) in values.iter().zip(masked.data().iter()) {
326 if *result != 0.0 {
327 let expected = orig * scale;
328 prop_assert!(
329 (result - expected).abs() < 1e-4,
330 "Expected {} * {} = {}, got {}",
331 orig,
332 scale,
333 expected,
334 result
335 );
336 }
337 }
338 }
339
340 #[test]
341 fn prop_drop_prob_zero_preserves_values(
342 values in proptest::collection::vec(-100.0f32..100.0, 5..20),
343 seed in 0u64..1000
344 ) {
345 let tensor = Tensor::from_vec(values.clone(), false);
346 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
347
348 let masked = drop_and_rescale_tensor(&tensor, 0.0, 1.0, &mut rng);
350
351 for (orig, result) in values.iter().zip(masked.data().iter()) {
352 prop_assert!(
353 (orig - result).abs() < 1e-6,
354 "Value not preserved: {} -> {}",
355 orig,
356 result
357 );
358 }
359 }
360
361 #[test]
362 fn prop_drop_prob_one_zeros_all(
363 values in proptest::collection::vec(-100.0f32..100.0, 5..20),
364 seed in 0u64..1000
365 ) {
366 let tensor = Tensor::from_vec(values, false);
367 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
368
369 let masked = drop_and_rescale_tensor(&tensor, 1.0, 1.0, &mut rng);
371
372 for &val in masked.data() {
373 prop_assert_eq!(val, 0.0);
374 }
375 }
376
377 #[test]
378 fn prop_average_deltas_is_mean(
379 v1 in proptest::collection::vec(-100.0f32..100.0, 5..10),
380 v2 in proptest::collection::vec(-100.0f32..100.0, 5..10)
381 ) {
382 let len = v1.len().min(v2.len());
384 let v1: Vec<f32> = v1.into_iter().take(len).collect();
385 let v2: Vec<f32> = v2.into_iter().take(len).collect();
386
387 let mut delta1 = HashMap::new();
388 delta1.insert("w".to_string(), Tensor::from_vec(v1.clone(), false));
389
390 let mut delta2 = HashMap::new();
391 delta2.insert("w".to_string(), Tensor::from_vec(v2.clone(), false));
392
393 let averaged = average_deltas(&[delta1, delta2]);
394 let avg_data = averaged["w"].data();
395
396 for i in 0..len {
397 let expected = f32::midpoint(v1[i], v2[i]);
398 prop_assert!(
399 (avg_data[i] - expected).abs() < 1e-5,
400 "Average mismatch at {}: expected {}, got {}",
401 i,
402 expected,
403 avg_data[i]
404 );
405 }
406 }
407
408 #[test]
409 fn prop_dare_deterministic_with_same_seed(
410 delta_values in proptest::collection::vec(-10.0f32..10.0, 5..15),
411 seed in 0u64..1000,
412 drop_prob in 0.1f32..0.9
413 ) {
414 let mut base = HashMap::new();
415 base.insert("w".to_string(), Tensor::from_vec(vec![0.0; delta_values.len()], false));
416
417 let mut model1 = HashMap::new();
418 model1.insert("w".to_string(), Tensor::from_vec(delta_values, false));
419
420 let models = vec![model1];
421 let config = DareConfig::new(drop_prob).expect("config should be valid").with_seed(seed);
422
423 let result1 = dare_merge(&models, &base, &config).expect("config should be valid");
424 let result2 = dare_merge(&models, &base, &config).expect("config should be valid");
425
426 for (a, b) in result1["w"].data().iter().zip(result2["w"].data().iter()) {
428 prop_assert!(
429 (a - b).abs() < 1e-6,
430 "Non-deterministic result: {} vs {}",
431 a,
432 b
433 );
434 }
435 }
436 }
437}