use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum MergeError {
#[error("shape mismatch for tensor '{name}': {a:?} vs {b:?}")]
ShapeMismatch {
name: String,
a: Vec<usize>,
b: Vec<usize>,
},
#[error("empty tensor: '{0}'")]
EmptyTensor(String),
#[error("invalid alpha {0}: must be in [0.0, 1.0]")]
InvalidAlpha(f32),
#[error("invalid density {0}: must be in (0.0, 1.0]")]
InvalidDensity(f32),
#[error("SLERP failed: zero vector")]
SierpZeroVector,
}
#[derive(Debug, Clone)]
pub struct WeightTensor {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl WeightTensor {
pub fn new(name: impl Into<String>, data: Vec<f32>, shape: Vec<usize>) -> Self {
Self {
name: name.into(),
data,
shape,
}
}
pub fn zeros(name: impl Into<String>, shape: Vec<usize>) -> Self {
let n = shape.iter().product();
Self {
name: name.into(),
data: vec![0.0f32; n],
shape,
}
}
pub fn element_count(&self) -> usize {
self.shape.iter().product()
}
pub fn l2_norm(&self) -> f32 {
self.data.iter().map(|x| x * x).sum::<f32>().sqrt()
}
pub fn cosine_similarity(&self, other: &WeightTensor) -> Result<f32, MergeError> {
let n = self.element_count();
if n == 0 {
return Err(MergeError::EmptyTensor(self.name.clone()));
}
if other.element_count() == 0 {
return Err(MergeError::EmptyTensor(other.name.clone()));
}
if n != other.element_count() {
return Err(MergeError::ShapeMismatch {
name: self.name.clone(),
a: self.shape.clone(),
b: other.shape.clone(),
});
}
let dot: f32 = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a = self.l2_norm();
let norm_b = other.l2_norm();
let denom = norm_a * norm_b;
if denom == 0.0 {
return Ok(0.0);
}
Ok(dot / denom)
}
pub fn add(&self, other: &WeightTensor) -> Result<WeightTensor, MergeError> {
check_compatible(self, other)?;
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a + b)
.collect();
Ok(WeightTensor::new(
self.name.clone(),
data,
self.shape.clone(),
))
}
pub fn sub(&self, other: &WeightTensor) -> Result<WeightTensor, MergeError> {
check_compatible(self, other)?;
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a - b)
.collect();
Ok(WeightTensor::new(
self.name.clone(),
data,
self.shape.clone(),
))
}
pub fn scale(&self, alpha: f32) -> WeightTensor {
let data: Vec<f32> = self.data.iter().map(|x| x * alpha).collect();
WeightTensor::new(self.name.clone(), data, self.shape.clone())
}
pub fn lerp(&self, other: &WeightTensor, t: f32) -> Result<WeightTensor, MergeError> {
check_compatible(self, other)?;
let data = linear_merge(&self.data, &other.data, t);
Ok(WeightTensor::new(
self.name.clone(),
data,
self.shape.clone(),
))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum MergeMethod {
Linear,
Slerp,
Ties,
TaskVector,
Dare {
seed: u64,
dropout_rate: f32,
},
}
#[derive(Debug, Clone)]
pub struct MergeConfig {
pub method: MergeMethod,
pub alpha: f32,
pub normalize: bool,
pub density: f32,
}
impl Default for MergeConfig {
fn default() -> Self {
Self {
method: MergeMethod::Linear,
alpha: 0.5,
normalize: false,
density: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct MergeStats {
pub tensors_merged: usize,
pub tensors_copied: usize,
pub total_params: usize,
pub mean_cosine_similarity: f32,
pub method: MergeMethod,
}
impl MergeStats {
pub fn summary(&self) -> String {
format!(
"method={:?} merged={} copied={} total_params={} mean_cosine_sim={:.4}",
self.method,
self.tensors_merged,
self.tensors_copied,
self.total_params,
self.mean_cosine_similarity,
)
}
}
pub fn linear_merge(a: &[f32], b: &[f32], alpha: f32) -> Vec<f32> {
let one_minus_alpha = 1.0 - alpha;
a.iter()
.zip(b.iter())
.map(|(ai, bi)| one_minus_alpha * ai + alpha * bi)
.collect()
}
pub fn slerp(a: &[f32], b: &[f32], t: f32) -> Vec<f32> {
let n = a.len().min(b.len());
if n == 0 {
return Vec::new();
}
let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
return linear_merge(a, b, t);
}
let cos_theta: f32 = a[..n]
.iter()
.zip(b[..n].iter())
.map(|(ai, bi)| (ai / norm_a) * (bi / norm_b))
.sum::<f32>()
.clamp(-1.0, 1.0);
if cos_theta > 0.9995 {
return linear_merge(a, b, t);
}
let theta = cos_theta.acos();
let sin_theta = theta.sin();
if sin_theta.abs() < f32::EPSILON {
return linear_merge(a, b, t);
}
let coeff_a = ((1.0 - t) * theta).sin() / sin_theta;
let coeff_b = (t * theta).sin() / sin_theta;
a[..n]
.iter()
.zip(b[..n].iter())
.map(|(ai, bi)| coeff_a * ai + coeff_b * bi)
.collect()
}
pub fn ties_merge(a: &[f32], b: &[f32], alpha: f32, density: f32) -> Vec<f32> {
let n = a.len().min(b.len());
if n == 0 {
return Vec::new();
}
let trimmed_a = trim_by_magnitude(a, density);
let trimmed_b = trim_by_magnitude(b, density);
trimmed_a
.iter()
.zip(trimmed_b.iter())
.map(|(va, vb)| {
let sign_a = va.signum(); let sign_b = vb.signum();
let abs_a = va.abs();
let abs_b = vb.abs();
if sign_a == sign_b {
(1.0 - alpha) * va + alpha * vb
} else if abs_a >= abs_b {
va * (1.0 - alpha)
} else {
vb * alpha
}
})
.collect()
}
pub fn task_vector_merge(base: &[f32], finetuned: &[f32], alpha: f32) -> Vec<f32> {
base.iter()
.zip(finetuned.iter())
.map(|(b, f)| b + alpha * (f - b))
.collect()
}
pub fn dare_merge(
base: &[f32],
finetuned: &[f32],
alpha: f32,
dropout_rate: f32,
seed: u64,
) -> Vec<f32> {
let mut state = seed;
let rescale = if dropout_rate < 1.0 {
1.0 / (1.0 - dropout_rate)
} else {
0.0
};
base.iter()
.zip(finetuned.iter())
.map(|(b, f)| {
let rand_val = lcg_next(&mut state);
let delta = f - b;
let sparse_delta = if rand_val < dropout_rate {
0.0
} else {
delta * rescale
};
b + alpha * sparse_delta
})
.collect()
}
pub fn merge_tensors(
base: &WeightTensor,
other: &WeightTensor,
config: &MergeConfig,
) -> Result<WeightTensor, MergeError> {
validate_config(config)?;
check_compatible(base, other)?;
if base.element_count() == 0 {
return Err(MergeError::EmptyTensor(base.name.clone()));
}
let (a_data, b_data) = if config.normalize {
let norm_a = base.l2_norm();
let norm_b = other.l2_norm();
let a_norm = if norm_a > f32::EPSILON {
base.data.iter().map(|x| x / norm_a).collect()
} else {
base.data.clone()
};
let b_norm = if norm_b > f32::EPSILON {
other.data.iter().map(|x| x / norm_b).collect()
} else {
other.data.clone()
};
(a_norm, b_norm)
} else {
(base.data.clone(), other.data.clone())
};
let merged_data = apply_merge_method(&a_data, &b_data, config)?;
Ok(WeightTensor::new(
base.name.clone(),
merged_data,
base.shape.clone(),
))
}
pub fn merge_models(
base: &[WeightTensor],
other: &[WeightTensor],
config: &MergeConfig,
) -> Result<Vec<WeightTensor>, MergeError> {
let (merged, _stats) = merge_models_with_stats(base, other, config)?;
Ok(merged)
}
pub fn merge_models_with_stats(
base: &[WeightTensor],
other: &[WeightTensor],
config: &MergeConfig,
) -> Result<(Vec<WeightTensor>, MergeStats), MergeError> {
validate_config(config)?;
let other_map: HashMap<&str, &WeightTensor> =
other.iter().map(|t| (t.name.as_str(), t)).collect();
let mut result = Vec::with_capacity(base.len());
let mut tensors_merged = 0usize;
let mut tensors_copied = 0usize;
let mut total_params = 0usize;
let mut cosine_sum = 0.0f32;
let mut cosine_count = 0usize;
for base_tensor in base {
total_params += base_tensor.element_count();
if let Some(other_tensor) = other_map.get(base_tensor.name.as_str()) {
if let Ok(sim) = base_tensor.cosine_similarity(other_tensor) {
cosine_sum += sim;
cosine_count += 1;
}
let merged_tensor = merge_tensors(base_tensor, other_tensor, config)?;
result.push(merged_tensor);
tensors_merged += 1;
} else {
result.push(base_tensor.clone());
tensors_copied += 1;
}
}
let mean_cosine_similarity = if cosine_count > 0 {
cosine_sum / cosine_count as f32
} else {
0.0
};
let stats = MergeStats {
tensors_merged,
tensors_copied,
total_params,
mean_cosine_similarity,
method: config.method.clone(),
};
Ok((result, stats))
}
fn validate_config(config: &MergeConfig) -> Result<(), MergeError> {
if !(0.0..=1.0).contains(&config.alpha) {
return Err(MergeError::InvalidAlpha(config.alpha));
}
if config.density <= 0.0 || config.density > 1.0 {
return Err(MergeError::InvalidDensity(config.density));
}
Ok(())
}
fn check_compatible(a: &WeightTensor, b: &WeightTensor) -> Result<(), MergeError> {
if a.element_count() != b.element_count() {
return Err(MergeError::ShapeMismatch {
name: a.name.clone(),
a: a.shape.clone(),
b: b.shape.clone(),
});
}
Ok(())
}
fn apply_merge_method(a: &[f32], b: &[f32], config: &MergeConfig) -> Result<Vec<f32>, MergeError> {
match &config.method {
MergeMethod::Linear => Ok(linear_merge(a, b, config.alpha)),
MergeMethod::Slerp => {
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
return Err(MergeError::SierpZeroVector);
}
Ok(slerp(a, b, config.alpha))
}
MergeMethod::Ties => Ok(ties_merge(a, b, config.alpha, config.density)),
MergeMethod::TaskVector => Ok(task_vector_merge(a, b, config.alpha)),
MergeMethod::Dare { seed, dropout_rate } => {
Ok(dare_merge(a, b, config.alpha, *dropout_rate, *seed))
}
}
}
fn trim_by_magnitude(data: &[f32], density: f32) -> Vec<f32> {
if data.is_empty() {
return Vec::new();
}
if density >= 1.0 {
return data.to_vec();
}
let mut abs_sorted: Vec<f32> = data.iter().map(|x| x.abs()).collect();
abs_sorted.sort_by(|x, y| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal));
let trim_count = ((1.0 - density) * abs_sorted.len() as f32).round() as usize;
let threshold = if trim_count < abs_sorted.len() {
abs_sorted[trim_count]
} else {
f32::MAX
};
data.iter()
.map(|x| if x.abs() < threshold { 0.0 } else { *x })
.collect()
}
#[inline]
fn lcg_next(state: &mut u64) -> f32 {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let bits = (*state >> 32) as u32;
(bits as f32) / (u32::MAX as f32 + 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lcg_produces_values_in_unit_interval() {
let mut state = 42u64;
for _ in 0..1000 {
let v = lcg_next(&mut state);
assert!((0.0..=1.0).contains(&v), "lcg value {v} out of [0,1]");
}
}
#[test]
fn trim_by_magnitude_density_one_noop() {
let data = vec![0.1, 0.5, -0.3, 0.9, -0.7];
let trimmed = trim_by_magnitude(&data, 1.0);
assert_eq!(trimmed, data);
}
#[test]
fn trim_by_magnitude_zeros_smallest() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let trimmed = trim_by_magnitude(&data, 0.6);
assert_eq!(trimmed[0], 0.0, "1.0 should be trimmed");
assert_eq!(trimmed[1], 0.0, "2.0 should be trimmed");
assert!(trimmed[2] != 0.0, "3.0 should be kept");
}
#[test]
fn validate_config_rejects_bad_alpha() {
let config = MergeConfig {
alpha: 1.5,
..Default::default()
};
assert!(validate_config(&config).is_err());
}
#[test]
fn validate_config_rejects_zero_density() {
let config = MergeConfig {
density: 0.0,
..Default::default()
};
assert!(validate_config(&config).is_err());
}
}