use crate::metaheuristics::{Budget, CmaEs, PerturbativeMetaheuristic, SearchSpace};
use std::collections::BTreeMap;
use super::merge::{MergeOptions, MergeStrategy};
#[derive(Debug, Clone)]
pub struct EvolutionaryMergeConfig {
pub strategy: MergeStrategy,
pub max_evaluations: usize,
pub sigma: f64,
pub seed: u64,
pub num_models: usize,
pub optimize_density: bool,
pub optimize_drop_rate: bool,
}
impl Default for EvolutionaryMergeConfig {
fn default() -> Self {
Self {
strategy: MergeStrategy::Weighted,
max_evaluations: 100,
sigma: 0.3,
seed: 42,
num_models: 2,
optimize_density: false,
optimize_drop_rate: false,
}
}
}
#[derive(Debug, Clone)]
pub struct EvolutionaryMergeResult {
pub weights: Vec<f32>,
pub density: f32,
pub drop_rate: f32,
pub best_score: f64,
pub evaluations: usize,
pub merge_options: MergeOptions,
}
fn param_dim(config: &EvolutionaryMergeConfig) -> usize {
let mut dim = config.num_models; if config.optimize_density {
dim += 1;
}
if config.optimize_drop_rate {
dim += 1;
}
dim
}
pub fn decode_params(params: &[f64], config: &EvolutionaryMergeConfig) -> (Vec<f32>, f32, f32) {
let raw_weights = ¶ms[..config.num_models];
let weights = softmax_normalize(raw_weights);
let mut idx = config.num_models;
let density = if config.optimize_density {
let v = sigmoid(params[idx]);
idx += 1;
v as f32
} else {
0.2 };
let drop_rate = if config.optimize_drop_rate {
sigmoid(params[idx]) as f32
} else {
0.9 };
(weights, density, drop_rate)
}
pub fn softmax_normalize(raw: &[f64]) -> Vec<f32> {
let max_val = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_vals: Vec<f64> = raw.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f64 = exp_vals.iter().sum();
exp_vals.iter().map(|&v| (v / sum) as f32).collect()
}
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub fn build_merge_options(
config: &EvolutionaryMergeConfig,
weights: Vec<f32>,
density: f32,
drop_rate: f32,
) -> MergeOptions {
MergeOptions {
strategy: config.strategy,
weights: Some(weights),
base_model: None,
drop_rate,
density,
seed: config.seed,
scales: None,
outlier_k: 3.0,
layer_ranges: None,
}
}
fn vector_norm_f64(v: &[f32]) -> f64 {
v.iter()
.map(|&x| f64::from(x) * f64::from(x))
.sum::<f64>()
.sqrt()
}
pub fn merge_tensors_in_memory(
models: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
weights: &[f32],
strategy: MergeStrategy,
) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
let reference = &models[0];
let mut merged = BTreeMap::new();
for (name, (_, shape)) in reference {
let merged_data = merge_single_tensor(models, name, weights, strategy);
merged.insert(name.clone(), (merged_data, shape.clone()));
}
merged
}
fn merge_single_tensor(
models: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
name: &str,
weights: &[f32],
strategy: MergeStrategy,
) -> Vec<f32> {
match strategy {
MergeStrategy::Slerp if models.len() == 2 => {
let (a, _) = &models[0][name];
let (b, _) = &models[1][name];
slerp_vectors(a, b, weights[1])
}
_ => {
let data_len = models[0][name].0.len();
let mut result = vec![0.0f32; data_len];
for (model_idx, model) in models.iter().enumerate() {
let (data, _) = &model[name];
let w = weights[model_idx];
for (i, &val) in data.iter().enumerate() {
result[i] += val * w;
}
}
result
}
}
}
fn slerp_vectors(a: &[f32], b: &[f32], t: f32) -> Vec<f32> {
let norm_a = vector_norm_f64(a);
let norm_b = vector_norm_f64(b);
if norm_a < 1e-12 || norm_b < 1e-12 {
return lerp_vectors(a, b, t);
}
let dot: f64 = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| f64::from(x) * f64::from(y))
.sum();
let cos_omega = (dot / (norm_a * norm_b)).clamp(-1.0, 1.0);
let omega = cos_omega.acos();
if omega.abs() < 1e-6 {
return lerp_vectors(a, b, t);
}
let sin_omega = omega.sin();
let t64 = f64::from(t);
let coeff_a = ((1.0 - t64) * omega).sin() / sin_omega;
let coeff_b = (t64 * omega).sin() / sin_omega;
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (coeff_a * f64::from(x) + coeff_b * f64::from(y)) as f32)
.collect()
}
fn lerp_vectors(a: &[f32], b: &[f32], t: f32) -> Vec<f32> {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x * (1.0 - t) + y * t)
.collect()
}
pub fn evolutionary_merge<F>(
models: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
config: &EvolutionaryMergeConfig,
objective_fn: F,
) -> EvolutionaryMergeResult
where
F: Fn(&BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> f64,
{
let dim = param_dim(config);
let space = SearchSpace::continuous(dim, -3.0, 3.0);
let mut cma = CmaEs::new(dim)
.with_seed(config.seed)
.with_sigma(config.sigma);
let objective = |params: &[f64]| -> f64 {
let (weights, _density, _drop_rate) = decode_params(params, config);
let merged = merge_tensors_in_memory(models, &weights, config.strategy);
objective_fn(&merged)
};
let result = cma.optimize(
&objective,
&space,
Budget::Evaluations(config.max_evaluations),
);
let (weights, density, drop_rate) = decode_params(&result.solution, config);
let merge_options = build_merge_options(config, weights.clone(), density, drop_rate);
EvolutionaryMergeResult {
weights,
density,
drop_rate,
best_score: result.objective_value,
evaluations: result.evaluations,
merge_options,
}
}
#[cfg(test)]
#[path = "evolutionary_merge_tests.rs"]
mod tests;