use std::collections::BTreeMap;
use crate::error::{AprenderError, Result};
#[derive(Debug, Clone)]
pub struct DamConfig {
pub learning_rate: f64,
pub num_iterations: usize,
pub regularization: f64,
pub seed: u64,
}
impl Default for DamConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
num_iterations: 100,
regularization: 0.01,
seed: 42,
}
}
}
impl DamConfig {
pub fn validate(&self) -> Result<()> {
if self.learning_rate <= 0.0 || !self.learning_rate.is_finite() {
return Err(AprenderError::FormatError {
message: format!(
"learning_rate must be positive finite, got {}",
self.learning_rate
),
});
}
if self.num_iterations == 0 {
return Err(AprenderError::FormatError {
message: "num_iterations must be > 0".to_string(),
});
}
if self.regularization < 0.0 || !self.regularization.is_finite() {
return Err(AprenderError::FormatError {
message: format!(
"regularization must be non-negative finite, got {}",
self.regularization
),
});
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DamCoefficients {
pub per_tensor: BTreeMap<String, Vec<f64>>,
}
impl DamCoefficients {
#[must_use]
pub fn uniform(tensor_names: &[String], num_models: usize) -> Self {
let init = vec![0.0; num_models]; let per_tensor = tensor_names
.iter()
.map(|name| (name.clone(), init.clone()))
.collect();
Self { per_tensor }
}
#[must_use]
pub fn normalized_weights(&self, tensor_name: &str) -> Option<Vec<f64>> {
self.per_tensor.get(tensor_name).map(|w| softmax(w))
}
}
#[derive(Debug, Clone)]
pub struct DamLoss {
config: DamConfig,
}
impl DamLoss {
#[must_use]
pub fn new(config: DamConfig) -> Self {
Self { config }
}
#[must_use]
pub fn compute_merge_loss(merged: &[f64], target: &[f64]) -> f64 {
if merged.is_empty() || target.is_empty() {
return 0.0;
}
let n = merged.len().min(target.len());
let sum_sq: f64 = merged[..n]
.iter()
.zip(&target[..n])
.map(|(m, t)| {
let d = m - t;
d * d
})
.sum();
sum_sq / n as f64
}
#[must_use]
pub fn compute_regularization(coefficients: &[f64]) -> f64 {
if coefficients.is_empty() {
return 0.0;
}
let sum_sq: f64 = coefficients.iter().map(|c| c * c).sum();
sum_sq / coefficients.len() as f64
}
pub fn gradient_step(coefficients: &mut [f64], gradients: &[f64], lr: f64) {
let n = coefficients.len().min(gradients.len());
for i in 0..n {
coefficients[i] -= lr * gradients[i];
}
}
#[must_use]
pub fn config(&self) -> &DamConfig {
&self.config
}
}
#[must_use]
pub fn softmax(x: &[f64]) -> Vec<f64> {
if x.is_empty() {
return vec![];
}
let max_val = x.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = x.iter().map(|&xi| (xi - max_val).exp()).collect();
let sum: f64 = exps.iter().sum();
if sum == 0.0 {
return vec![1.0 / x.len() as f64; x.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
pub fn optimize_coefficients(
num_models: usize,
loss_fn: impl Fn(&[f64]) -> f64,
config: &DamConfig,
) -> Vec<f64> {
if num_models == 0 {
return vec![];
}
if num_models == 1 {
return vec![1.0];
}
let n = num_models;
let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
simplex.push(vec![0.0; n]);
let perturbation = 0.5;
for i in 0..n {
let mut vertex = vec![0.0; n];
let state = config
.seed
.wrapping_add(i as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let sign = if (state >> 33).is_multiple_of(2) {
1.0
} else {
-1.0
};
vertex[i] = sign * perturbation;
simplex.push(vertex);
}
let total_loss = |coeffs: &[f64]| -> f64 {
loss_fn(coeffs) + config.regularization * DamLoss::compute_regularization(coeffs)
};
let mut losses: Vec<f64> = simplex.iter().map(|v| total_loss(v)).collect();
let alpha = 1.0; let gamma = 2.0; let rho = 0.5; let sigma = 0.5;
for _iter in 0..config.num_iterations {
let mut indices: Vec<usize> = (0..=n).collect();
indices.sort_by(|&a, &b| {
losses[a]
.partial_cmp(&losses[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let best_idx = indices[0];
let worst_idx = indices[n];
let second_worst_idx = indices[n - 1];
let spread = losses[worst_idx] - losses[best_idx];
if spread.abs() < 1e-12 {
break;
}
let mut centroid = vec![0.0; n];
for &idx in &indices[..n] {
for j in 0..n {
centroid[j] += simplex[idx][j];
}
}
for c in &mut centroid {
*c /= n as f64;
}
let reflected: Vec<f64> = centroid
.iter()
.zip(&simplex[worst_idx])
.map(|(&c, &w)| c + alpha * (c - w))
.collect();
let reflected_loss = total_loss(&reflected);
if reflected_loss < losses[second_worst_idx] && reflected_loss >= losses[best_idx] {
simplex[worst_idx] = reflected;
losses[worst_idx] = reflected_loss;
continue;
}
if reflected_loss < losses[best_idx] {
let expanded: Vec<f64> = centroid
.iter()
.zip(&reflected)
.map(|(&c, &r)| c + gamma * (r - c))
.collect();
let expanded_loss = total_loss(&expanded);
if expanded_loss < reflected_loss {
simplex[worst_idx] = expanded;
losses[worst_idx] = expanded_loss;
} else {
simplex[worst_idx] = reflected;
losses[worst_idx] = reflected_loss;
}
continue;
}
let contracted: Vec<f64> = centroid
.iter()
.zip(&simplex[worst_idx])
.map(|(&c, &w)| c + rho * (w - c))
.collect();
let contracted_loss = total_loss(&contracted);
if contracted_loss < losses[worst_idx] {
simplex[worst_idx] = contracted;
losses[worst_idx] = contracted_loss;
continue;
}
let best = simplex[best_idx].clone();
for i in 0..=n {
if i == best_idx {
continue;
}
for j in 0..n {
simplex[i][j] = best[j] + sigma * (simplex[i][j] - best[j]);
}
losses[i] = total_loss(&simplex[i]);
}
}
let best_idx = losses
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
simplex[best_idx].clone()
}
#[derive(Debug, Clone)]
pub struct DamReport {
pub final_loss: f64,
pub num_iterations: usize,
pub coefficients: Vec<f64>,
pub converged: bool,
}
impl DamReport {
#[must_use]
pub fn normalized_coefficients(&self) -> Vec<f64> {
softmax(&self.coefficients)
}
}
#[cfg(test)]
#[path = "dam_tests.rs"]
mod tests;