use super::types::{AdmmConfig, AdmmResult, ConsensusNode};
use crate::error::{OptimizeError, OptimizeResult};
fn mat_vec(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
a.iter()
.map(|row| row.iter().zip(x.iter()).map(|(a, b)| a * b).sum())
.collect()
}
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
fn norm2(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn soft_threshold(v: &[f64], kappa: f64) -> Vec<f64> {
v.iter()
.map(|&vi| {
if vi > kappa {
vi - kappa
} else if vi < -kappa {
vi + kappa
} else {
0.0
}
})
.collect()
}
fn solve_normal_equations(
a: &[Vec<f64>],
rhs: &[f64],
rho: f64,
tol: f64,
max_iter: usize,
) -> Vec<f64> {
let n = rhs.len();
let mut m = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
let mut sum = 0.0;
for row in a.iter() {
sum += row[i] * row[j];
}
m[i][j] = sum;
}
m[i][i] += rho;
}
let mut x = vec![0.0; n];
let mut r = rhs.to_vec();
let mut p = r.clone();
let mut rsold = dot(&r, &r);
for _ in 0..max_iter {
if rsold.sqrt() < tol {
break;
}
let mp: Vec<f64> = (0..n)
.map(|i| m[i].iter().zip(p.iter()).map(|(mi, pi)| mi * pi).sum())
.collect();
let alpha = rsold / dot(&p, &mp).max(1e-300);
for i in 0..n {
x[i] += alpha * p[i];
r[i] -= alpha * mp[i];
}
let rsnew = dot(&r, &r);
let beta = rsnew / rsold.max(1e-300);
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
rsold = rsnew;
}
x
}
#[derive(Debug)]
pub struct AdmmSolver {
pub workers: Vec<ConsensusNode>,
}
impl AdmmSolver {
pub fn new(n_agents: usize, n_vars: usize) -> Self {
Self {
workers: (0..n_agents).map(|_| ConsensusNode::new(n_vars)).collect(),
}
}
pub fn warm_start(initial_xs: Vec<Vec<f64>>) -> OptimizeResult<Self> {
if initial_xs.is_empty() {
return Err(OptimizeError::InvalidInput(
"Need at least one agent".into(),
));
}
let n = initial_xs[0].len();
for (i, x) in initial_xs.iter().enumerate() {
if x.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"Agent {} has x length {} but expected {}",
i,
x.len(),
n
)));
}
}
Ok(Self {
workers: initial_xs.into_iter().map(ConsensusNode::warm).collect(),
})
}
pub fn n_agents(&self) -> usize {
self.workers.len()
}
pub fn n_vars(&self) -> usize {
if self.workers.is_empty() {
0
} else {
self.workers[0].local_x.len()
}
}
}
pub fn consensus_admm<F>(
proximal_fns: &[F],
n_vars: usize,
config: &AdmmConfig,
) -> OptimizeResult<AdmmResult>
where
F: Fn(&[f64], f64) -> Vec<f64>,
{
let n_agents = proximal_fns.len();
if n_agents == 0 {
return Err(OptimizeError::InvalidInput(
"Need at least one agent".into(),
));
}
if n_vars == 0 {
return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
}
let mut solver = if config.warm_start {
AdmmSolver::new(n_agents, n_vars)
} else {
AdmmSolver::new(n_agents, n_vars)
};
let mut z = vec![0.0_f64; n_vars];
let mut primal_history = Vec::with_capacity(config.max_iter);
let mut dual_history = Vec::with_capacity(config.max_iter);
let mut converged = false;
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
let z_old = z.clone();
for (i, worker) in solver.workers.iter_mut().enumerate() {
let v: Vec<f64> = z_old
.iter()
.zip(worker.dual_y.iter())
.map(|(zi, ui)| zi - ui)
.collect();
worker.local_x = (proximal_fns[i])(&v, config.rho);
}
let n_inv = 1.0 / n_agents as f64;
let alpha = config.over_relaxation;
for k in 0..n_vars {
let sum: f64 = solver
.workers
.iter()
.map(|w| alpha * w.local_x[k] + (1.0 - alpha) * z_old[k] + w.dual_y[k])
.sum();
z[k] = n_inv * sum;
}
for worker in solver.workers.iter_mut() {
for k in 0..n_vars {
let x_hat = alpha * worker.local_x[k] + (1.0 - alpha) * z_old[k];
worker.local_y_mut()[k] += x_hat - z[k];
}
}
let mut primal_sq = 0.0_f64;
let mut x_norm_sq = 0.0_f64;
let mut u_norm_sq = 0.0_f64;
for worker in solver.workers.iter() {
for k in 0..n_vars {
let diff = worker.local_x[k] - z[k];
primal_sq += diff * diff;
x_norm_sq += worker.local_x[k] * worker.local_x[k];
u_norm_sq += worker.dual_y[k] * worker.dual_y[k];
}
}
let primal_res = primal_sq.sqrt();
let dual_sq: f64 = z
.iter()
.zip(z_old.iter())
.map(|(zn, zo)| (zn - zo) * (zn - zo))
.sum();
let dual_res = config.rho * (n_agents as f64).sqrt() * dual_sq.sqrt();
primal_history.push(primal_res);
dual_history.push(dual_res);
let n_total = (n_agents * n_vars) as f64;
let eps_pri = n_total.sqrt() * config.abs_tol
+ config.rel_tol * x_norm_sq.sqrt().max((n_agents as f64).sqrt() * norm2(&z));
let eps_dual =
n_total.sqrt() * config.abs_tol + config.rel_tol * config.rho * u_norm_sq.sqrt();
if primal_res < eps_pri && dual_res < eps_dual {
converged = true;
break;
}
}
Ok(AdmmResult {
x: z,
primal_residual: primal_history,
dual_residual: dual_history,
converged,
iterations,
})
}
pub fn solve_lasso_admm(
a: &[Vec<f64>],
b: &[f64],
lambda: f64,
config: &AdmmConfig,
) -> OptimizeResult<AdmmResult> {
if a.is_empty() {
return Err(OptimizeError::InvalidInput("A is empty".into()));
}
let m = a.len();
let n = a[0].len();
if b.len() != m {
return Err(OptimizeError::InvalidInput(format!(
"A has {} rows but b has length {}",
m,
b.len()
)));
}
if n == 0 {
return Err(OptimizeError::InvalidInput("n_vars = 0".into()));
}
let rho = config.rho;
let mut x = if config.warm_start {
vec![0.0; n]
} else {
vec![0.0; n]
};
let mut z = vec![0.0_f64; n];
let mut u = vec![0.0_f64; n];
let mut at_b = vec![0.0_f64; n];
for j in 0..n {
let mut s = 0.0;
for i in 0..m {
s += a[i][j] * b[i];
}
at_b[j] = s;
}
let mut primal_history = Vec::with_capacity(config.max_iter);
let mut dual_history = Vec::with_capacity(config.max_iter);
let mut converged = false;
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
let z_old = z.clone();
let mut rhs = at_b.clone();
for k in 0..n {
rhs[k] += rho * (z[k] - u[k]);
}
x = solve_normal_equations(a, &rhs, rho, 1e-10, 200);
let alpha = config.over_relaxation;
let x_hat: Vec<f64> = x
.iter()
.zip(z_old.iter())
.map(|(xi, zi)| alpha * xi + (1.0 - alpha) * zi)
.collect();
let v: Vec<f64> = x_hat
.iter()
.zip(u.iter())
.map(|(xhi, ui)| xhi + ui)
.collect();
z = soft_threshold(&v, lambda / rho);
for k in 0..n {
u[k] += x_hat[k] - z[k];
}
let primal_sq: f64 = x
.iter()
.zip(z.iter())
.map(|(xi, zi)| (xi - zi).powi(2))
.sum();
let primal_res = primal_sq.sqrt();
let dual_sq: f64 = z
.iter()
.zip(z_old.iter())
.map(|(zn, zo)| (zn - zo).powi(2))
.sum();
let dual_res = rho * dual_sq.sqrt();
primal_history.push(primal_res);
dual_history.push(dual_res);
let eps_pri =
(n as f64).sqrt() * config.abs_tol + config.rel_tol * (norm2(&x).max(norm2(&z)));
let eps_dual = (n as f64).sqrt() * config.abs_tol + config.rel_tol * rho * norm2(&u);
if primal_res < eps_pri && dual_res < eps_dual {
converged = true;
break;
}
}
Ok(AdmmResult {
x: z,
primal_residual: primal_history,
dual_residual: dual_history,
converged,
iterations,
})
}
trait DualMut {
fn local_y_mut(&mut self) -> &mut Vec<f64>;
}
impl DualMut for ConsensusNode {
fn local_y_mut(&mut self) -> &mut Vec<f64> {
&mut self.dual_y
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_lasso_problem() -> (Vec<Vec<f64>>, Vec<f64>) {
let x_true = [1.0_f64, 0.0, -1.0, 0.0];
let a: Vec<Vec<f64>> = (0..10)
.map(|i| {
let t = i as f64 / 9.0;
vec![1.0, t, t * t, t * t * t]
})
.collect();
let b: Vec<f64> = a
.iter()
.map(|row| row.iter().zip(x_true.iter()).map(|(r, x)| r * x).sum())
.collect();
(a, b)
}
#[test]
fn test_admm_lasso_convergence() {
let (a, b) = make_lasso_problem();
let config = AdmmConfig::default();
let result = solve_lasso_admm(&a, &b, 0.01, &config).expect("Lasso ADMM failed");
assert!(result.converged, "Should converge");
let n = result.primal_residual.len();
assert!(n > 1);
assert!(
result.primal_residual[n - 1] <= result.primal_residual[0] + 1e-10,
"Primal residual should not increase"
);
}
#[test]
fn test_admm_lasso_sparsity() {
let (a, b) = make_lasso_problem();
let mut config = AdmmConfig::default();
config.max_iter = 2000;
config.rho = 0.5;
let result = solve_lasso_admm(&a, &b, 0.5, &config).expect("Lasso failed");
let n_nonzero = result.x.iter().filter(|&&v| v.abs() > 0.05).count();
assert!(
n_nonzero <= 2,
"Expected sparse solution, got {} nonzero components: {:?}",
n_nonzero,
result.x
);
}
#[test]
fn test_admm_primal_dual_residuals_tracked() {
let (a, b) = make_lasso_problem();
let config = AdmmConfig::default();
let result = solve_lasso_admm(&a, &b, 0.01, &config).expect("Lasso ADMM failed");
assert_eq!(
result.primal_residual.len(),
result.dual_residual.len(),
"Residual histories must have equal length"
);
assert!(!result.primal_residual.is_empty());
}
#[test]
fn test_admm_converged_flag() {
let (a, b) = make_lasso_problem();
let mut config = AdmmConfig::default();
config.max_iter = 5000;
config.abs_tol = 1e-6;
config.rel_tol = 1e-4;
let result = solve_lasso_admm(&a, &b, 0.01, &config).expect("Failed");
assert!(result.converged, "Should converge with generous budget");
}
#[test]
fn test_admm_warm_start_fewer_iters() {
let (a, b) = make_lasso_problem();
let config = AdmmConfig {
rho: 1.0,
max_iter: 500,
abs_tol: 1e-5,
rel_tol: 1e-3,
warm_start: false,
over_relaxation: 1.0,
};
let cold = solve_lasso_admm(&a, &b, 0.05, &config).expect("Cold ADMM failed");
let warm_config = AdmmConfig {
warm_start: true,
..config.clone()
};
let warm = solve_lasso_admm(&a, &b, 0.05, &warm_config).expect("Warm ADMM failed");
assert!(cold.converged || warm.converged);
}
#[test]
fn test_consensus_admm_mean() {
let centers = vec![vec![1.0_f64, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let n_vars = 2;
let config = AdmmConfig {
rho: 1.0,
max_iter: 500,
abs_tol: 1e-6,
rel_tol: 1e-4,
warm_start: false,
over_relaxation: 1.0,
};
let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
.iter()
.map(|c| {
let ci = c.clone();
let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
Box::new(move |v: &[f64], rho: f64| {
v.iter()
.zip(ci.iter())
.map(|(vi, ci)| (rho * vi + ci) / (rho + 1.0))
.collect()
});
f
})
.collect();
let result = consensus_admm(&prox_fns, n_vars, &config).expect("Consensus ADMM failed");
assert!(result.converged, "Should converge");
assert!(
(result.x[0] - 3.0).abs() < 0.01,
"x[0] = {:.4} (expected 3.0)",
result.x[0]
);
assert!(
(result.x[1] - 4.0).abs() < 0.01,
"x[1] = {:.4} (expected 4.0)",
result.x[1]
);
}
#[test]
fn test_admm_config_default() {
let cfg = AdmmConfig::default();
assert_eq!(cfg.max_iter, 1000);
assert!((cfg.abs_tol - 1e-4).abs() < 1e-15);
}
#[test]
fn test_lasso_zero_lambda() {
let (a, b) = make_lasso_problem();
let mut config = AdmmConfig::default();
config.max_iter = 3000;
config.abs_tol = 1e-6;
config.rel_tol = 1e-4;
let result = solve_lasso_admm(&a, &b, 0.0, &config).expect("Lasso λ=0 failed");
let ax = mat_vec(&a, &result.x);
let res: f64 = ax
.iter()
.zip(b.iter())
.map(|(ai, bi)| (ai - bi).powi(2))
.sum::<f64>()
.sqrt();
assert!(res < 0.1, "Residual ||Az - b|| = {:.4}", res);
}
}