use super::*;
use crate::autograd::Tensor;
use std::collections::HashMap;
fn create_model(name: &str, values: Vec<f32>) -> Model {
let mut model = HashMap::new();
model.insert(name.to_string(), Tensor::from_vec(values, false));
model
}
fn create_models(values_per_model: Vec<Vec<f32>>) -> Vec<Model> {
values_per_model.into_iter().map(|values| create_model("w", values)).collect()
}
#[cfg(test)]
mod ties_properties {
use super::*;
use crate::merge::{ties_merge, TiesConfig};
#[test]
fn ties_is_permutation_invariant() {
let base = create_model("w", vec![0.0, 0.0, 0.0, 0.0]);
let models = create_models(vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![-1.0, -2.0, 3.0, 4.0],
vec![1.0, -2.0, -3.0, 4.0],
]);
let config = TiesConfig::new(0.5).expect("config should be valid");
let result1 = ties_merge(&models, &base, &config).expect("config should be valid");
let permuted = vec![models[2].clone(), models[0].clone(), models[1].clone()];
let result2 = ties_merge(&permuted, &base, &config).expect("config should be valid");
let r1_data = result1["w"].data();
let r2_data = result2["w"].data();
for (a, b) in r1_data.iter().zip(r2_data.iter()) {
assert!((a - b).abs() < 1e-5, "TIES should be permutation-invariant: {a} != {b}");
}
}
#[test]
fn ties_with_identical_models_has_same_deltas() {
let base = create_model("w", vec![0.0, 0.0, 0.0, 0.0, 0.0]);
let model = create_model("w", vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let models = vec![model.clone(), model.clone()];
let config = TiesConfig::new(0.8).expect("config should be valid");
let result = ties_merge(&models, &base, &config).expect("config should be valid");
let expected = model["w"].data();
let actual = result["w"].data();
for (a, e) in actual.iter().zip(expected.iter()) {
if a.abs() > 1e-6 {
assert!(a * e > 0.0, "Sign mismatch: {a} vs {e}");
}
}
}
#[test]
fn ties_preserves_zero_deltas() {
let base = create_model("w", vec![5.0, 10.0]);
let models = vec![base.clone(), base.clone()];
let config = TiesConfig::default();
let result = ties_merge(&models, &base, &config).expect("config should be valid");
let expected = base["w"].data();
let actual = result["w"].data();
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-5);
}
}
}
#[cfg(test)]
mod dare_properties {
use super::*;
use crate::merge::{dare_merge, DareConfig};
#[test]
fn dare_is_deterministic_with_seed() {
let base = create_model("w", vec![0.0, 0.0]);
let models = create_models(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let config = DareConfig::new(0.5).expect("config should be valid").with_seed(42);
let result1 = dare_merge(&models, &base, &config).expect("config should be valid");
let result2 = dare_merge(&models, &base, &config).expect("config should be valid");
let r1_data = result1["w"].data();
let r2_data = result2["w"].data();
for (a, b) in r1_data.iter().zip(r2_data.iter()) {
assert!((a - b).abs() < 1e-10);
}
}
#[test]
fn dare_with_zero_drop_is_average() {
let base = create_model("w", vec![0.0, 0.0]);
let models = create_models(vec![vec![2.0, 4.0], vec![4.0, 6.0]]);
let config = DareConfig::new(0.0).expect("config should be valid");
let result = dare_merge(&models, &base, &config).expect("config should be valid");
assert!((result["w"].data()[0] - 3.0).abs() < 1e-5);
assert!((result["w"].data()[1] - 5.0).abs() < 1e-5);
}
#[test]
fn dare_preserves_zero_deltas() {
let base = create_model("w", vec![7.0, 14.0]);
let models = vec![base.clone(), base.clone()];
let config = DareConfig::default();
let result = dare_merge(&models, &base, &config).expect("config should be valid");
let expected = base["w"].data();
let actual = result["w"].data();
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-5);
}
}
}
#[cfg(test)]
mod slerp_properties {
use super::*;
use crate::merge::{slerp_merge, SlerpConfig};
#[test]
fn slerp_at_t0_returns_model1() {
let model1 = create_model("w", vec![1.0, 2.0, 3.0]);
let model2 = create_model("w", vec![4.0, 5.0, 6.0]);
let config = SlerpConfig::new(0.0).expect("slerp config creation should succeed");
let result = slerp_merge(&model1, &model2, &config).expect("config should be valid");
let expected = model1["w"].data();
let actual = result["w"].data();
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-6);
}
}
#[test]
fn slerp_at_t1_returns_model2() {
let model1 = create_model("w", vec![1.0, 2.0, 3.0]);
let model2 = create_model("w", vec![4.0, 5.0, 6.0]);
let config = SlerpConfig::new(1.0).expect("slerp config creation should succeed");
let result = slerp_merge(&model1, &model2, &config).expect("config should be valid");
let expected = model2["w"].data();
let actual = result["w"].data();
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-6);
}
}
#[test]
fn slerp_is_continuous() {
let model1 = create_model("w", vec![1.0, 0.0]);
let model2 = create_model("w", vec![0.0, 1.0]);
let config1 = SlerpConfig::new(0.5).expect("slerp config creation should succeed");
let config2 = SlerpConfig::new(0.51).expect("slerp config creation should succeed");
let result1 = slerp_merge(&model1, &model2, &config1).expect("config should be valid");
let result2 = slerp_merge(&model1, &model2, &config2).expect("config should be valid");
let r1_data = result1["w"].data();
let r2_data = result2["w"].data();
for (a, b) in r1_data.iter().zip(r2_data.iter()) {
assert!((a - b).abs() < 0.1); }
}
#[test]
fn slerp_symmetric_models_at_midpoint() {
let model1 = create_model("w", vec![1.0]);
let model2 = create_model("w", vec![-1.0]);
let config = SlerpConfig::new(0.5).expect("slerp config creation should succeed");
let result = slerp_merge(&model1, &model2, &config).expect("config should be valid");
assert!(result["w"].data()[0].abs() < 0.1);
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::merge::{dare_merge, slerp_merge, ties_merge, DareConfig, SlerpConfig, TiesConfig};
#[test]
fn test_three_way_merge_comparison() {
let base = create_model("w", vec![0.0, 0.0, 0.0]);
let models = create_models(vec![vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0]]);
let ties_config = TiesConfig::default();
let ties_result = ties_merge(&models, &base, &ties_config).expect("config should be valid");
let dare_config = DareConfig::new(0.0).expect("config should be valid");
let dare_result = dare_merge(&models, &base, &dare_config).expect("config should be valid");
let slerp_config = SlerpConfig::new(0.5).expect("slerp config creation should succeed");
let slerp_result =
slerp_merge(&models[0], &models[1], &slerp_config).expect("config should be valid");
for val in ties_result["w"].data() {
assert!(val.is_finite());
}
for val in dare_result["w"].data() {
assert!(val.is_finite());
}
for val in slerp_result["w"].data() {
assert!(val.is_finite());
}
}
#[test]
fn test_incompatible_shapes_rejected() {
let base = create_model("w", vec![0.0, 0.0]);
let model1 = create_model("w", vec![1.0, 2.0]);
let model2 = create_model("w", vec![3.0, 4.0, 5.0]);
let models = vec![model1, model2];
let ties_result = ties_merge(&models, &base, &TiesConfig::default());
assert!(ties_result.is_err());
let dare_result = dare_merge(&models, &base, &DareConfig::default());
assert!(dare_result.is_err());
}
}
#[cfg(test)]
mod mod_coverage_tests {
use super::*;
use crate::merge::{compute_deltas, merge_with_base, validate_models, MergeError};
#[test]
fn test_compute_deltas_basic() {
let base = create_model("w", vec![0.0, 0.0, 0.0]);
let model1 = create_model("w", vec![1.0, 2.0, 3.0]);
let model2 = create_model("w", vec![4.0, 5.0, 6.0]);
let models = vec![model1, model2];
let deltas = compute_deltas(&models, &base).expect("operation should succeed");
assert_eq!(deltas.len(), 2);
assert_eq!(deltas[0]["w"].data()[0], 1.0);
assert_eq!(deltas[1]["w"].data()[0], 4.0);
}
#[test]
fn test_compute_deltas_missing_param() {
let base = create_model("w", vec![0.0]);
let mut model = HashMap::new();
model.insert("other".to_string(), Tensor::from_vec(vec![1.0], false));
let result = compute_deltas(&[model], &base);
assert!(matches!(result, Err(MergeError::IncompatibleArchitectures(_))));
}
#[test]
fn test_compute_deltas_shape_mismatch() {
let base = create_model("w", vec![0.0, 0.0]);
let model = create_model("w", vec![1.0, 2.0, 3.0]);
let result = compute_deltas(&[model], &base);
assert!(matches!(result, Err(MergeError::ShapeMismatch(_))));
}
#[test]
fn test_merge_with_base() {
let base = create_model("w", vec![10.0, 20.0]);
let delta = create_model("w", vec![1.0, 2.0]);
let merged = merge_with_base(&base, delta);
assert_eq!(merged["w"].data()[0], 11.0);
assert_eq!(merged["w"].data()[1], 22.0);
}
#[test]
fn test_merge_with_base_missing_delta() {
let mut base = HashMap::new();
base.insert("w".to_string(), Tensor::from_vec(vec![1.0], false));
base.insert("b".to_string(), Tensor::from_vec(vec![2.0], false));
let delta = create_model("w", vec![0.5]);
let merged = merge_with_base(&base, delta);
assert_eq!(merged["w"].data()[0], 1.5);
assert_eq!(merged["b"].data()[0], 2.0);
}
#[test]
fn test_validate_models_empty() {
let models: Vec<Model> = vec![];
let result = validate_models(&models);
assert!(matches!(result, Err(MergeError::InsufficientModels { min: 1, got: 0 })));
}
#[test]
fn test_validate_models_missing_param() {
let model1 = create_model("w", vec![1.0]);
let model2 = create_model("b", vec![2.0]);
let models = vec![model1, model2];
let result = validate_models(&models);
assert!(matches!(result, Err(MergeError::IncompatibleArchitectures(_))));
}
#[test]
fn test_validate_models_shape_mismatch() {
let model1 = create_model("w", vec![1.0, 2.0]);
let model2 = create_model("w", vec![3.0, 4.0, 5.0]);
let models = vec![model1, model2];
let result = validate_models(&models);
assert!(matches!(result, Err(MergeError::ShapeMismatch(_))));
}
#[test]
fn test_validate_models_valid() {
let model1 = create_model("w", vec![1.0, 2.0]);
let model2 = create_model("w", vec![3.0, 4.0]);
let models = vec![model1, model2];
assert!(validate_models(&models).is_ok());
}
#[test]
fn test_merge_error_display() {
let err1 = MergeError::IncompatibleArchitectures("test".to_string());
assert!(err1.to_string().contains("incompatible"));
let err2 = MergeError::ShapeMismatch("param".to_string());
assert!(err2.to_string().contains("shape"));
let err3 = MergeError::InvalidConfig("bad config".to_string());
assert!(err3.to_string().contains("Invalid"));
let err4 = MergeError::InsufficientModels { min: 2, got: 1 };
assert!(err4.to_string().contains("Insufficient"));
}
}