fn ties_merge(
base_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
task_models: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
density: f32,
) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let mut merged = BTreeMap::new();
for (name, (base_data, shape)) in base_tensors {
let deltas: Vec<Vec<f32>> = task_models
.iter()
.map(|model| {
let (model_data, _) = model.get(name).expect("validated above");
model_data
.iter()
.zip(base_data.iter())
.map(|(&m, &b)| m - b)
.collect()
})
.collect();
let trimmed: Vec<Vec<f32>> = deltas
.iter()
.map(|delta| ties_trim(delta, density))
.collect();
let merged_delta = ties_elect_and_merge(&trimmed, base_data.len());
let result: Vec<f32> = base_data
.iter()
.zip(merged_delta.iter())
.map(|(&b, &d)| b + d)
.collect();
merged.insert(name.clone(), (result, shape.clone()));
}
merged
}
fn ties_trim(delta: &[f32], density: f32) -> Vec<f32> {
let max_abs = delta.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
if max_abs < 1e-12 {
return vec![0.0; delta.len()];
}
let threshold = density * max_abs;
delta
.iter()
.map(|&x| if x.abs() >= threshold { x } else { 0.0 })
.collect()
}
fn ties_elect_sign(trimmed_deltas: &[Vec<f32>], i: usize) -> bool {
let mut pos_count = 0i32;
let mut neg_count = 0i32;
for delta in trimmed_deltas {
let val = delta[i];
if val > 0.0 {
pos_count += 1;
} else if val < 0.0 {
neg_count += 1;
}
}
pos_count >= neg_count
}
fn ties_sum_agreeing(trimmed_deltas: &[Vec<f32>], i: usize, elected_positive: bool) -> (f32, u32) {
let mut sum = 0.0f32;
let mut count = 0u32;
for delta in trimmed_deltas {
let val = delta[i];
let agrees = (elected_positive && val > 0.0) || (!elected_positive && val < 0.0);
if agrees {
sum += val;
count += 1;
}
}
(sum, count)
}
fn ties_elect_and_merge(trimmed_deltas: &[Vec<f32>], len: usize) -> Vec<f32> {
let mut result = vec![0.0f32; len];
let num_models = trimmed_deltas.len();
for i in 0..len {
let elected_positive = ties_elect_sign(trimmed_deltas, i);
let (sum, count) = ties_sum_agreeing(trimmed_deltas, i, elected_positive);
if count > 0 {
result[i] = sum / num_models as f32;
}
}
result
}
fn dare_merge(
base_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
task_models: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
drop_rate: f32,
seed: u64,
weights: Option<&[f32]>,
) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let mut merged = BTreeMap::new();
let rescale = 1.0 / (1.0 - drop_rate);
let num_models = task_models.len();
let default_weights: Vec<f32> = vec![1.0 / num_models as f32; num_models];
let w = weights.unwrap_or(&default_weights);
for (tensor_idx, (name, (base_data, shape))) in base_tensors.iter().enumerate() {
let mut rng = StdRng::seed_from_u64(seed.wrapping_add(tensor_idx as u64));
let mut merged_delta = vec![0.0f32; base_data.len()];
for (model_idx, model_tensors) in task_models.iter().enumerate() {
let (model_data, _) = model_tensors.get(name).expect("validated above");
let weight = w[model_idx];
for (i, (&m_val, &b_val)) in model_data.iter().zip(base_data.iter()).enumerate() {
let delta = m_val - b_val;
let keep: bool = rng.random::<f32>() >= drop_rate;
if keep {
merged_delta[i] += delta * rescale * weight;
}
}
}
let result: Vec<f32> = base_data
.iter()
.zip(merged_delta.iter())
.map(|(&b, &d)| b + d)
.collect();
merged.insert(name.clone(), (result, shape.clone()));
}
merged
}
pub fn apr_merge<P: AsRef<Path>>(
inputs: &[P],
output: P,
options: MergeOptions,
) -> Result<MergeReport> {
validate_merge_options(inputs, &options)?;
let all_tensors = load_all_models(inputs)?;
verify_tensor_compatibility(&all_tensors)?;
let merged = match options.strategy {
MergeStrategy::Average | MergeStrategy::Weighted => {
let weights = calculate_merge_weights(inputs.len(), &options)?;
merge_tensors(&all_tensors, &weights)
}
MergeStrategy::Slerp => {
let t = options
.weights
.as_ref()
.and_then(|w| w.first().copied())
.unwrap_or(0.5);
slerp_tensors(&all_tensors[0], &all_tensors[1], t)
}
MergeStrategy::NuSlerp => {
let t = options
.weights
.as_ref()
.and_then(|w| w.first().copied())
.unwrap_or(0.5);
nuslerp_tensors(&all_tensors[0], &all_tensors[1], t)
}
MergeStrategy::MultiSlerp => {
let default_weights = vec![1.0 / inputs.len() as f32; inputs.len()];
let weights = options.weights.as_deref().unwrap_or(&default_weights);
multi_slerp_tensors(&all_tensors, weights)
}
MergeStrategy::Ties => {
let base_path = options.base_model.as_ref().expect("validated above");
let base_tensors = load_model_tensors(base_path)?;
verify_tensor_compatibility(&[base_tensors.clone(), all_tensors[0].clone()])?;
ties_merge(&base_tensors, &all_tensors, options.density)
}
MergeStrategy::Dare => {
let base_path = options.base_model.as_ref().expect("validated above");
let base_tensors = load_model_tensors(base_path)?;
verify_tensor_compatibility(&[base_tensors.clone(), all_tensors[0].clone()])?;
dare_merge(
&base_tensors,
&all_tensors,
options.drop_rate,
options.seed,
options.weights.as_deref(),
)
}
MergeStrategy::TaskArithmetic => {
let base_path = options.base_model.as_ref().expect("validated above");
let base_tensors = load_model_tensors(base_path)?;
verify_tensor_compatibility(&[base_tensors.clone(), all_tensors[0].clone()])?;
let default_scales = vec![1.0f32; all_tensors.len()];
let scales = options.scales.as_deref().unwrap_or(&default_scales);
task_arithmetic_merge(&base_tensors, &all_tensors, scales)
}
MergeStrategy::Della => {
let base_path = options.base_model.as_ref().expect("validated above");
let base_tensors = load_model_tensors(base_path)?;
verify_tensor_compatibility(&[base_tensors.clone(), all_tensors[0].clone()])?;
della_merge(
&base_tensors,
&all_tensors,
options.drop_rate,
options.seed,
options.weights.as_deref(),
)
}
MergeStrategy::Breadcrumbs => {
let base_path = options.base_model.as_ref().expect("validated above");
let base_tensors = load_model_tensors(base_path)?;
verify_tensor_compatibility(&[base_tensors.clone(), all_tensors[0].clone()])?;
let default_scales = vec![1.0f32; all_tensors.len()];
let scales = options.scales.as_deref().unwrap_or(&default_scales);
breadcrumbs_merge(&base_tensors, &all_tensors, scales, options.outlier_k)
}
MergeStrategy::Sce => {
let default_weights = vec![1.0 / inputs.len() as f32; inputs.len()];
let weights = options.weights.as_deref().unwrap_or(&default_weights);
sce_merge(&all_tensors, weights)
}
MergeStrategy::Passthrough => {
let ranges = options.layer_ranges.as_ref().expect("validated above");
passthrough_merge(&all_tensors, ranges)
}
};
let output_path = output.as_ref();
save_safetensors(output_path, &merged).map_err(|e| AprenderError::FormatError {
message: format!("Failed to save merged model: {e}"),
})?;
let output_size = fs::metadata(output_path)
.map(|m| m.len() as usize)
.unwrap_or(0);
Ok(MergeReport {
model_count: inputs.len(),
tensor_count: merged.len(),
output_size,
strategy: options.strategy,
weights_used: options.weights,
})
}
#[cfg(test)]
#[path = "merge_tests.rs"]
mod tests;