1use super::{validate_models, MergeError, Model};
7use crate::autograd::Tensor;
8use ndarray::Array1;
9use std::collections::HashMap;
10
11#[derive(Clone, Debug)]
13pub struct SlerpConfig {
14 pub t: f32,
19}
20
21impl Default for SlerpConfig {
22 fn default() -> Self {
23 Self { t: 0.5 }
24 }
25}
26
27impl SlerpConfig {
28 pub fn new(t: f32) -> Result<Self, MergeError> {
29 if !(0.0..=1.0).contains(&t) {
30 return Err(MergeError::InvalidConfig(format!(
31 "Interpolation parameter t must be in [0.0, 1.0], got {t}"
32 )));
33 }
34 Ok(Self { t })
35 }
36}
37
38pub fn slerp_merge(
59 model1: &Model,
60 model2: &Model,
61 config: &SlerpConfig,
62) -> Result<Model, MergeError> {
63 validate_models(&[model1.clone(), model2.clone()])?;
64
65 let mut merged = HashMap::new();
66
67 for (name, tensor1) in model1 {
68 let tensor2 = &model2[name];
69 let merged_tensor = slerp_tensor(tensor1, tensor2, config.t);
70 merged.insert(name.clone(), merged_tensor);
71 }
72
73 Ok(merged)
74}
75
76fn slerp_tensor(tensor1: &Tensor, tensor2: &Tensor, t: f32) -> Tensor {
78 let w1 = tensor1.data();
79 let w2 = tensor2.data();
80
81 let dot = w1.iter().zip(w2.iter()).map(|(a, b)| a * b).sum::<f32>();
83 let norm1 = w1.iter().map(|x| x * x).sum::<f32>().sqrt();
84 let norm2 = w2.iter().map(|x| x * x).sum::<f32>().sqrt();
85
86 if norm1 < 1e-8 || norm2 < 1e-8 {
88 return linear_interp_tensor(tensor1, tensor2, t);
89 }
90
91 let cos_theta = (dot / (norm1 * norm2)).clamp(-1.0, 1.0);
93
94 const EPSILON: f32 = 1e-4;
96 if (cos_theta - 1.0).abs() < EPSILON {
97 return linear_interp_tensor(tensor1, tensor2, t);
98 }
99
100 let theta = cos_theta.acos();
102 let sin_theta = theta.sin();
103
104 let coef1 = ((1.0 - t) * theta).sin() / sin_theta;
107 let coef2 = (t * theta).sin() / sin_theta;
108
109 let interpolated: Array1<f32> =
110 w1.iter().zip(w2.iter()).map(|(a, b)| coef1 * a + coef2 * b).collect();
111
112 Tensor::new(interpolated, false)
113}
114
115fn linear_interp_tensor(tensor1: &Tensor, tensor2: &Tensor, t: f32) -> Tensor {
117 let w1 = tensor1.data();
118 let w2 = tensor2.data();
119
120 let interpolated: Array1<f32> =
121 w1.iter().zip(w2.iter()).map(|(a, b)| (1.0 - t) * a + t * b).collect();
122
123 Tensor::new(interpolated, false)
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use proptest::prelude::*;
130
131 #[test]
132 fn test_slerp_at_endpoints() {
133 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
134 let t2 = Tensor::from_vec(vec![4.0, 5.0, 6.0], false);
135
136 let result = slerp_tensor(&t1, &t2, 0.0);
138 for (a, b) in result.data().iter().zip(t1.data().iter()) {
139 assert!((a - b).abs() < 1e-6);
140 }
141
142 let result = slerp_tensor(&t1, &t2, 1.0);
144 for (a, b) in result.data().iter().zip(t2.data().iter()) {
145 assert!((a - b).abs() < 1e-6);
146 }
147 }
148
149 #[test]
150 fn test_slerp_midpoint() {
151 let t1 = Tensor::from_vec(vec![1.0, 0.0], false);
152 let t2 = Tensor::from_vec(vec![0.0, 1.0], false);
153
154 let result = slerp_tensor(&t1, &t2, 0.5);
156 let expected_val = 1.0 / 2.0f32.sqrt();
157
158 assert!((result.data()[0] - expected_val).abs() < 1e-5);
159 assert!((result.data()[1] - expected_val).abs() < 1e-5);
160 }
161
162 #[test]
163 fn test_linear_interp_fallback_for_parallel() {
164 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
165 let t2 = Tensor::from_vec(vec![2.0, 4.0, 6.0], false); let result = slerp_tensor(&t1, &t2, 0.5);
168
169 let expected = [1.5, 3.0, 4.5]; for (a, e) in result.data().iter().zip(expected.iter()) {
172 assert!((a - e).abs() < 1e-5);
173 }
174 }
175
176 #[test]
177 fn test_slerp_config_validation() {
178 assert!(SlerpConfig::new(0.0).is_ok());
179 assert!(SlerpConfig::new(0.5).is_ok());
180 assert!(SlerpConfig::new(1.0).is_ok());
181 assert!(SlerpConfig::new(-0.1).is_err());
182 assert!(SlerpConfig::new(1.1).is_err());
183 }
184
185 #[test]
186 fn test_slerp_merge() {
187 let mut model1 = HashMap::new();
188 model1.insert("w".to_string(), Tensor::from_vec(vec![1.0, 0.0], false));
189
190 let mut model2 = HashMap::new();
191 model2.insert("w".to_string(), Tensor::from_vec(vec![0.0, 1.0], false));
192
193 let config = SlerpConfig::new(0.5).expect("slerp config creation should succeed");
194 let merged = slerp_merge(&model1, &model2, &config).expect("config should be valid");
195
196 let expected_val = 1.0 / 2.0f32.sqrt();
198 assert!((merged["w"].data()[0] - expected_val).abs() < 1e-5);
199 assert!((merged["w"].data()[1] - expected_val).abs() < 1e-5);
200 }
201
202 #[test]
203 fn test_linear_interp_basic() {
204 let t1 = Tensor::from_vec(vec![0.0, 0.0], false);
205 let t2 = Tensor::from_vec(vec![10.0, 20.0], false);
206
207 let result = linear_interp_tensor(&t1, &t2, 0.3);
208 assert!((result.data()[0] - 3.0).abs() < 1e-6);
209 assert!((result.data()[1] - 6.0).abs() < 1e-6);
210 }
211
212 #[test]
213 fn test_slerp_zero_vector_fallback() {
214 let t1 = Tensor::from_vec(vec![0.0, 0.0], false);
215 let t2 = Tensor::from_vec(vec![1.0, 1.0], false);
216
217 let result = slerp_tensor(&t1, &t2, 0.5);
219 assert!((result.data()[0] - 0.5).abs() < 1e-6);
220 assert!((result.data()[1] - 0.5).abs() < 1e-6);
221 }
222
223 #[test]
224 fn test_slerp_negative_vectors() {
225 let t1 = Tensor::from_vec(vec![1.0, 0.0], false);
226 let t2 = Tensor::from_vec(vec![-1.0, 0.0], false); let result = slerp_tensor(&t1, &t2, 0.5);
229
230 assert!((result.data()[0]).abs() < 1e-5);
233 }
234
235 proptest! {
238 #![proptest_config(ProptestConfig::with_cases(200))]
239
240 #[test]
241 fn prop_slerp_config_valid_range(t in 0.0f32..=1.0) {
242 let config = SlerpConfig::new(t);
243 prop_assert!(config.is_ok());
244 }
245
246 #[test]
247 fn prop_slerp_config_invalid_negative(t in -10.0f32..-0.01) {
248 let config = SlerpConfig::new(t);
249 prop_assert!(config.is_err());
250 }
251
252 #[test]
253 fn prop_slerp_config_invalid_above_one(t in 1.01f32..10.0) {
254 let config = SlerpConfig::new(t);
255 prop_assert!(config.is_err());
256 }
257
258 #[test]
259 fn prop_slerp_t0_returns_first(
260 values1 in proptest::collection::vec(-10.0f32..10.0, 3..10),
261 values2 in proptest::collection::vec(-10.0f32..10.0, 3..10)
262 ) {
263 let len = values1.len().min(values2.len());
264 let v1: Vec<f32> = values1.into_iter().take(len).collect();
265 let v2: Vec<f32> = values2.into_iter().take(len).collect();
266
267 let t1 = Tensor::from_vec(v1.clone(), false);
268 let t2 = Tensor::from_vec(v2, false);
269
270 let result = slerp_tensor(&t1, &t2, 0.0);
271
272 for (orig, res) in v1.iter().zip(result.data().iter()) {
273 prop_assert!(
274 (orig - res).abs() < 1e-5,
275 "t=0 should return first tensor: {} vs {}",
276 orig,
277 res
278 );
279 }
280 }
281
282 #[test]
283 fn prop_slerp_t1_returns_second(
284 values1 in proptest::collection::vec(-10.0f32..10.0, 3..10),
285 values2 in proptest::collection::vec(-10.0f32..10.0, 3..10)
286 ) {
287 let len = values1.len().min(values2.len());
288 let v1: Vec<f32> = values1.into_iter().take(len).collect();
289 let v2: Vec<f32> = values2.into_iter().take(len).collect();
290
291 let t1 = Tensor::from_vec(v1, false);
292 let t2 = Tensor::from_vec(v2.clone(), false);
293
294 let result = slerp_tensor(&t1, &t2, 1.0);
295
296 for (orig, res) in v2.iter().zip(result.data().iter()) {
297 prop_assert!(
298 (orig - res).abs() < 1e-5,
299 "t=1 should return second tensor: {} vs {}",
300 orig,
301 res
302 );
303 }
304 }
305
306 #[test]
307 fn prop_linear_interp_bounded(
308 values1 in proptest::collection::vec(-100.0f32..100.0, 3..10),
309 values2 in proptest::collection::vec(-100.0f32..100.0, 3..10),
310 t in 0.0f32..=1.0
311 ) {
312 let len = values1.len().min(values2.len());
313 let v1: Vec<f32> = values1.into_iter().take(len).collect();
314 let v2: Vec<f32> = values2.into_iter().take(len).collect();
315
316 let t1 = Tensor::from_vec(v1.clone(), false);
317 let t2 = Tensor::from_vec(v2.clone(), false);
318
319 let result = linear_interp_tensor(&t1, &t2, t);
320
321 for i in 0..len {
323 let min_val = v1[i].min(v2[i]);
324 let max_val = v1[i].max(v2[i]);
325 prop_assert!(
326 result.data()[i] >= min_val - 1e-5 && result.data()[i] <= max_val + 1e-5,
327 "Linear interp out of bounds: {} not in [{}, {}]",
328 result.data()[i],
329 min_val,
330 max_val
331 );
332 }
333 }
334
335 #[test]
336 fn prop_slerp_symmetric(
337 values1 in proptest::collection::vec(1.0f32..10.0, 3..6),
338 values2 in proptest::collection::vec(1.0f32..10.0, 3..6),
339 t in 0.1f32..0.9
340 ) {
341 let len = values1.len().min(values2.len());
342 let v1: Vec<f32> = values1.into_iter().take(len).collect();
343 let v2: Vec<f32> = values2.into_iter().take(len).collect();
344
345 let t1 = Tensor::from_vec(v1.clone(), false);
346 let t2 = Tensor::from_vec(v2.clone(), false);
347
348 let result1 = slerp_tensor(&t1, &t2, t);
350 let result2 = slerp_tensor(&t2, &t1, 1.0 - t);
351
352 for (r1, r2) in result1.data().iter().zip(result2.data().iter()) {
353 prop_assert!(
354 (r1 - r2).abs() < 1e-4,
355 "SLERP not symmetric: {} vs {}",
356 r1,
357 r2
358 );
359 }
360 }
361
362 #[test]
363 fn prop_linear_interp_t0_returns_first(
364 values1 in proptest::collection::vec(-100.0f32..100.0, 3..10),
365 values2 in proptest::collection::vec(-100.0f32..100.0, 3..10)
366 ) {
367 let len = values1.len().min(values2.len());
368 let v1: Vec<f32> = values1.into_iter().take(len).collect();
369 let v2: Vec<f32> = values2.into_iter().take(len).collect();
370
371 let t1 = Tensor::from_vec(v1.clone(), false);
372 let t2 = Tensor::from_vec(v2, false);
373
374 let result = linear_interp_tensor(&t1, &t2, 0.0);
375
376 for (orig, res) in v1.iter().zip(result.data().iter()) {
377 prop_assert!(
378 (orig - res).abs() < 1e-6,
379 "t=0 should return first: {} vs {}",
380 orig,
381 res
382 );
383 }
384 }
385
386 #[test]
387 fn prop_linear_interp_midpoint_is_average(
388 values1 in proptest::collection::vec(-100.0f32..100.0, 3..10),
389 values2 in proptest::collection::vec(-100.0f32..100.0, 3..10)
390 ) {
391 let len = values1.len().min(values2.len());
392 let v1: Vec<f32> = values1.into_iter().take(len).collect();
393 let v2: Vec<f32> = values2.into_iter().take(len).collect();
394
395 let t1 = Tensor::from_vec(v1.clone(), false);
396 let t2 = Tensor::from_vec(v2.clone(), false);
397
398 let result = linear_interp_tensor(&t1, &t2, 0.5);
399
400 for i in 0..len {
401 let expected = f32::midpoint(v1[i], v2[i]);
402 prop_assert!(
403 (result.data()[i] - expected).abs() < 1e-5,
404 "Midpoint not average: {} vs {}",
405 result.data()[i],
406 expected
407 );
408 }
409 }
410 }
411}