Skip to main content

entrenar/merge/
slerp.rs

1//! SLERP (Spherical Linear intERPolation) merge algorithm
2//!
3//! SLERP blends two models using spherical interpolation, which is better
4//! than linear interpolation for unit-normalized weight spaces.
5
6use super::{validate_models, MergeError, Model};
7use crate::autograd::Tensor;
8use ndarray::Array1;
9use std::collections::HashMap;
10
11/// Configuration for SLERP merge
12#[derive(Clone, Debug)]
13pub struct SlerpConfig {
14    /// Interpolation parameter: t ∈ [0, 1]
15    /// t=0.0 -> 100% model1
16    /// t=0.5 -> 50/50 blend
17    /// t=1.0 -> 100% model2
18    pub t: f32,
19}
20
21impl Default for SlerpConfig {
22    fn default() -> Self {
23        Self { t: 0.5 }
24    }
25}
26
27impl SlerpConfig {
28    pub fn new(t: f32) -> Result<Self, MergeError> {
29        if !(0.0..=1.0).contains(&t) {
30            return Err(MergeError::InvalidConfig(format!(
31                "Interpolation parameter t must be in [0.0, 1.0], got {t}"
32            )));
33        }
34        Ok(Self { t })
35    }
36}
37
38/// SLERP merge: spherical linear interpolation between two models
39///
40/// # Arguments
41/// * `model1` - First model (t=0)
42/// * `model2` - Second model (t=1)
43/// * `config` - SLERP configuration (interpolation parameter t)
44///
45/// # Returns
46/// Interpolated model at parameter t
47///
48/// # Algorithm
49/// For each parameter w1, w2:
50/// 1. Compute angle: θ = arccos(w1·w2 / (|w1||w2|))
51/// 2. If θ ≈ 0 (parallel vectors), use linear interpolation
52/// 3. Otherwise: w = (sin((1-t)θ)/sinθ)*w1 + (sin(tθ)/sinθ)*w2
53///
54/// # Notes
55/// - SLERP is rotation-invariant and maintains constant angular velocity
56/// - Falls back to linear interpolation for nearly parallel weights
57/// - Only works for 2 models (unlike TIES/DARE which support N models)
58pub fn slerp_merge(
59    model1: &Model,
60    model2: &Model,
61    config: &SlerpConfig,
62) -> Result<Model, MergeError> {
63    validate_models(&[model1.clone(), model2.clone()])?;
64
65    let mut merged = HashMap::new();
66
67    for (name, tensor1) in model1 {
68        let tensor2 = &model2[name];
69        let merged_tensor = slerp_tensor(tensor1, tensor2, config.t);
70        merged.insert(name.clone(), merged_tensor);
71    }
72
73    Ok(merged)
74}
75
76/// Spherical linear interpolation between two tensors
77fn slerp_tensor(tensor1: &Tensor, tensor2: &Tensor, t: f32) -> Tensor {
78    let w1 = tensor1.data();
79    let w2 = tensor2.data();
80
81    // Compute dot product and norms
82    let dot = w1.iter().zip(w2.iter()).map(|(a, b)| a * b).sum::<f32>();
83    let norm1 = w1.iter().map(|x| x * x).sum::<f32>().sqrt();
84    let norm2 = w2.iter().map(|x| x * x).sum::<f32>().sqrt();
85
86    // Handle zero vectors
87    if norm1 < 1e-8 || norm2 < 1e-8 {
88        return linear_interp_tensor(tensor1, tensor2, t);
89    }
90
91    // Compute cosine of angle
92    let cos_theta = (dot / (norm1 * norm2)).clamp(-1.0, 1.0);
93
94    // If vectors are nearly parallel (cos_theta ≈ 1), use linear interpolation
95    const EPSILON: f32 = 1e-4;
96    if (cos_theta - 1.0).abs() < EPSILON {
97        return linear_interp_tensor(tensor1, tensor2, t);
98    }
99
100    // Compute theta and sin(theta)
101    let theta = cos_theta.acos();
102    let sin_theta = theta.sin();
103
104    // Spherical interpolation
105    // w = (sin((1-t)θ)/sinθ)*w1 + (sin(tθ)/sinθ)*w2
106    let coef1 = ((1.0 - t) * theta).sin() / sin_theta;
107    let coef2 = (t * theta).sin() / sin_theta;
108
109    let interpolated: Array1<f32> =
110        w1.iter().zip(w2.iter()).map(|(a, b)| coef1 * a + coef2 * b).collect();
111
112    Tensor::new(interpolated, false)
113}
114
115/// Linear interpolation fallback for parallel vectors
116fn linear_interp_tensor(tensor1: &Tensor, tensor2: &Tensor, t: f32) -> Tensor {
117    let w1 = tensor1.data();
118    let w2 = tensor2.data();
119
120    let interpolated: Array1<f32> =
121        w1.iter().zip(w2.iter()).map(|(a, b)| (1.0 - t) * a + t * b).collect();
122
123    Tensor::new(interpolated, false)
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use proptest::prelude::*;
130
131    #[test]
132    fn test_slerp_at_endpoints() {
133        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
134        let t2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], false);
135
136        // t=0.0 should return tensor1
137        let result = slerp_tensor(&t1, &t2, 0.0);
138        for (a, b) in result.data().iter().zip(t1.data().iter()) {
139            assert!((a - b).abs() < 1e-6);
140        }
141
142        // t=1.0 should return tensor2
143        let result = slerp_tensor(&t1, &t2, 1.0);
144        for (a, b) in result.data().iter().zip(t2.data().iter()) {
145            assert!((a - b).abs() < 1e-6);
146        }
147    }
148
149    #[test]
150    fn test_slerp_midpoint() {
151        let t1 = Tensor::from_vec(vec![1.0, 0.0], false);
152        let t2 = Tensor::from_vec(vec![0.0, 1.0], false);
153
154        // t=0.5 should be close to (1/√2, 1/√2) for perpendicular vectors
155        let result = slerp_tensor(&t1, &t2, 0.5);
156        let expected_val = 1.0 / 2.0f32.sqrt();
157
158        assert!((result.data()[0] - expected_val).abs() < 1e-5);
159        assert!((result.data()[1] - expected_val).abs() < 1e-5);
160    }
161
162    #[test]
163    fn test_linear_interp_fallback_for_parallel() {
164        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
165        let t2 = Tensor::from_vec(vec![2.0, 4.0, 6.0], false); // Parallel (2x t1)
166
167        let result = slerp_tensor(&t1, &t2, 0.5);
168
169        // Should fall back to linear interpolation
170        let expected = [1.5, 3.0, 4.5]; // (1+2)/2, (2+4)/2, (3+6)/2
171        for (a, e) in result.data().iter().zip(expected.iter()) {
172            assert!((a - e).abs() < 1e-5);
173        }
174    }
175
176    #[test]
177    fn test_slerp_config_validation() {
178        assert!(SlerpConfig::new(0.0).is_ok());
179        assert!(SlerpConfig::new(0.5).is_ok());
180        assert!(SlerpConfig::new(1.0).is_ok());
181        assert!(SlerpConfig::new(-0.1).is_err());
182        assert!(SlerpConfig::new(1.1).is_err());
183    }
184
185    #[test]
186    fn test_slerp_merge() {
187        let mut model1 = HashMap::new();
188        model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 0.0], false));
189
190        let mut model2 = HashMap::new();
191        model2.insert("w".to_string(), Tensor::from_vec(vec![0.0, 1.0], false));
192
193        let config = SlerpConfig::new(0.5).expect("slerp config creation should succeed");
194        let merged = slerp_merge(&model1, &model2, &config).expect("config should be valid");
195
196        // Midpoint of perpendicular unit vectors
197        let expected_val = 1.0 / 2.0f32.sqrt();
198        assert!((merged["w"].data()[0] - expected_val).abs() < 1e-5);
199        assert!((merged["w"].data()[1] - expected_val).abs() < 1e-5);
200    }
201
202    #[test]
203    fn test_linear_interp_basic() {
204        let t1 = Tensor::from_vec(vec![0.0, 0.0], false);
205        let t2 = Tensor::from_vec(vec![10.0, 20.0], false);
206
207        let result = linear_interp_tensor(&t1, &t2, 0.3);
208        assert!((result.data()[0] - 3.0).abs() < 1e-6);
209        assert!((result.data()[1] - 6.0).abs() < 1e-6);
210    }
211
212    #[test]
213    fn test_slerp_zero_vector_fallback() {
214        let t1 = Tensor::from_vec(vec![0.0, 0.0], false);
215        let t2 = Tensor::from_vec(vec![1.0, 1.0], false);
216
217        // Should fall back to linear interpolation for zero vector
218        let result = slerp_tensor(&t1, &t2, 0.5);
219        assert!((result.data()[0] - 0.5).abs() < 1e-6);
220        assert!((result.data()[1] - 0.5).abs() < 1e-6);
221    }
222
223    #[test]
224    fn test_slerp_negative_vectors() {
225        let t1 = Tensor::from_vec(vec![1.0, 0.0], false);
226        let t2 = Tensor::from_vec(vec![-1.0, 0.0], false); // Opposite direction
227
228        let result = slerp_tensor(&t1, &t2, 0.5);
229
230        // At midpoint of opposite vectors, SLERP rotates through perpendicular
231        // The result should be perpendicular (along y-axis)
232        assert!((result.data()[0]).abs() < 1e-5);
233    }
234
235    // Property tests
236
237    proptest! {
238        #![proptest_config(ProptestConfig::with_cases(200))]
239
240        #[test]
241        fn prop_slerp_config_valid_range(t in 0.0f32..=1.0) {
242            let config = SlerpConfig::new(t);
243            prop_assert!(config.is_ok());
244        }
245
246        #[test]
247        fn prop_slerp_config_invalid_negative(t in -10.0f32..-0.01) {
248            let config = SlerpConfig::new(t);
249            prop_assert!(config.is_err());
250        }
251
252        #[test]
253        fn prop_slerp_config_invalid_above_one(t in 1.01f32..10.0) {
254            let config = SlerpConfig::new(t);
255            prop_assert!(config.is_err());
256        }
257
258        #[test]
259        fn prop_slerp_t0_returns_first(
260            values1 in proptest::collection::vec(-10.0f32..10.0, 3..10),
261            values2 in proptest::collection::vec(-10.0f32..10.0, 3..10)
262        ) {
263            let len = values1.len().min(values2.len());
264            let v1: Vec<f32> = values1.into_iter().take(len).collect();
265            let v2: Vec<f32> = values2.into_iter().take(len).collect();
266
267            let t1 = Tensor::from_vec(v1.clone(), false);
268            let t2 = Tensor::from_vec(v2, false);
269
270            let result = slerp_tensor(&t1, &t2, 0.0);
271
272            for (orig, res) in v1.iter().zip(result.data().iter()) {
273                prop_assert!(
274                    (orig - res).abs() < 1e-5,
275                    "t=0 should return first tensor: {} vs {}",
276                    orig,
277                    res
278                );
279            }
280        }
281
282        #[test]
283        fn prop_slerp_t1_returns_second(
284            values1 in proptest::collection::vec(-10.0f32..10.0, 3..10),
285            values2 in proptest::collection::vec(-10.0f32..10.0, 3..10)
286        ) {
287            let len = values1.len().min(values2.len());
288            let v1: Vec<f32> = values1.into_iter().take(len).collect();
289            let v2: Vec<f32> = values2.into_iter().take(len).collect();
290
291            let t1 = Tensor::from_vec(v1, false);
292            let t2 = Tensor::from_vec(v2.clone(), false);
293
294            let result = slerp_tensor(&t1, &t2, 1.0);
295
296            for (orig, res) in v2.iter().zip(result.data().iter()) {
297                prop_assert!(
298                    (orig - res).abs() < 1e-5,
299                    "t=1 should return second tensor: {} vs {}",
300                    orig,
301                    res
302                );
303            }
304        }
305
306        #[test]
307        fn prop_linear_interp_bounded(
308            values1 in proptest::collection::vec(-100.0f32..100.0, 3..10),
309            values2 in proptest::collection::vec(-100.0f32..100.0, 3..10),
310            t in 0.0f32..=1.0
311        ) {
312            let len = values1.len().min(values2.len());
313            let v1: Vec<f32> = values1.into_iter().take(len).collect();
314            let v2: Vec<f32> = values2.into_iter().take(len).collect();
315
316            let t1 = Tensor::from_vec(v1.clone(), false);
317            let t2 = Tensor::from_vec(v2.clone(), false);
318
319            let result = linear_interp_tensor(&t1, &t2, t);
320
321            // Result should be within bounds of inputs
322            for i in 0..len {
323                let min_val = v1[i].min(v2[i]);
324                let max_val = v1[i].max(v2[i]);
325                prop_assert!(
326                    result.data()[i] >= min_val - 1e-5 && result.data()[i] <= max_val + 1e-5,
327                    "Linear interp out of bounds: {} not in [{}, {}]",
328                    result.data()[i],
329                    min_val,
330                    max_val
331                );
332            }
333        }
334
335        #[test]
336        fn prop_slerp_symmetric(
337            values1 in proptest::collection::vec(1.0f32..10.0, 3..6),
338            values2 in proptest::collection::vec(1.0f32..10.0, 3..6),
339            t in 0.1f32..0.9
340        ) {
341            let len = values1.len().min(values2.len());
342            let v1: Vec<f32> = values1.into_iter().take(len).collect();
343            let v2: Vec<f32> = values2.into_iter().take(len).collect();
344
345            let t1 = Tensor::from_vec(v1.clone(), false);
346            let t2 = Tensor::from_vec(v2.clone(), false);
347
348            // slerp(a, b, t) should have similar properties to slerp(b, a, 1-t)
349            let result1 = slerp_tensor(&t1, &t2, t);
350            let result2 = slerp_tensor(&t2, &t1, 1.0 - t);
351
352            for (r1, r2) in result1.data().iter().zip(result2.data().iter()) {
353                prop_assert!(
354                    (r1 - r2).abs() < 1e-4,
355                    "SLERP not symmetric: {} vs {}",
356                    r1,
357                    r2
358                );
359            }
360        }
361
362        #[test]
363        fn prop_linear_interp_t0_returns_first(
364            values1 in proptest::collection::vec(-100.0f32..100.0, 3..10),
365            values2 in proptest::collection::vec(-100.0f32..100.0, 3..10)
366        ) {
367            let len = values1.len().min(values2.len());
368            let v1: Vec<f32> = values1.into_iter().take(len).collect();
369            let v2: Vec<f32> = values2.into_iter().take(len).collect();
370
371            let t1 = Tensor::from_vec(v1.clone(), false);
372            let t2 = Tensor::from_vec(v2, false);
373
374            let result = linear_interp_tensor(&t1, &t2, 0.0);
375
376            for (orig, res) in v1.iter().zip(result.data().iter()) {
377                prop_assert!(
378                    (orig - res).abs() < 1e-6,
379                    "t=0 should return first: {} vs {}",
380                    orig,
381                    res
382                );
383            }
384        }
385
386        #[test]
387        fn prop_linear_interp_midpoint_is_average(
388            values1 in proptest::collection::vec(-100.0f32..100.0, 3..10),
389            values2 in proptest::collection::vec(-100.0f32..100.0, 3..10)
390        ) {
391            let len = values1.len().min(values2.len());
392            let v1: Vec<f32> = values1.into_iter().take(len).collect();
393            let v2: Vec<f32> = values2.into_iter().take(len).collect();
394
395            let t1 = Tensor::from_vec(v1.clone(), false);
396            let t2 = Tensor::from_vec(v2.clone(), false);
397
398            let result = linear_interp_tensor(&t1, &t2, 0.5);
399
400            for i in 0..len {
401                let expected = f32::midpoint(v1[i], v2[i]);
402                prop_assert!(
403                    (result.data()[i] - expected).abs() < 1e-5,
404                    "Midpoint not average: {} vs {}",
405                    result.data()[i],
406                    expected
407                );
408            }
409        }
410    }
411}