use super::types::{AdmmResult, ExtraConfig, PdmmConfig};
use crate::error::{OptimizeError, OptimizeResult};
fn norm2(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
a.iter().zip(b.iter()).map(|(ai, bi)| ai + bi).collect()
}
fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
a.iter().zip(b.iter()).map(|(ai, bi)| ai - bi).collect()
}
fn vec_scale(a: &[f64], s: f64) -> Vec<f64> {
a.iter().map(|ai| ai * s).collect()
}
fn mat_vec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
w.iter()
.map(|row| row.iter().zip(x.iter()).map(|(wi, xi)| wi * xi).sum())
.collect()
}
fn check_doubly_stochastic(w: &[Vec<f64>], tol: f64) -> bool {
let n = w.len();
for row in w.iter() {
if row.len() != n {
return false;
}
let s: f64 = row.iter().sum();
if (s - 1.0).abs() > tol {
return false;
}
}
for j in 0..n {
let s: f64 = w.iter().map(|row| row[j]).sum();
if (s - 1.0).abs() > tol {
return false;
}
}
true
}
#[derive(Debug)]
pub struct PdmmSolver {
pub topology: Vec<Vec<f64>>,
}
impl PdmmSolver {
pub fn new(topology: Vec<Vec<f64>>) -> OptimizeResult<Self> {
let n = topology.len();
for (i, row) in topology.iter().enumerate() {
if row.len() != n {
return Err(OptimizeError::InvalidInput(format!(
"Topology row {} has length {} but expected {}",
i,
row.len(),
n
)));
}
}
Ok(Self { topology })
}
pub fn solve<F>(
&self,
local_fns: &[F],
n_vars: usize,
config: &PdmmConfig,
) -> OptimizeResult<AdmmResult>
where
F: Fn(&[f64], f64) -> Vec<f64>,
{
let n_agents = self.topology.len();
if local_fns.len() != n_agents {
return Err(OptimizeError::InvalidInput(format!(
"Expected {} local functions but got {}",
n_agents,
local_fns.len()
)));
}
if n_vars == 0 {
return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
}
let rho = config.stepsize;
let mut x: Vec<Vec<f64>> = (0..n_agents).map(|_| vec![0.0; n_vars]).collect();
let mut lam: Vec<Vec<Vec<f64>>> = (0..n_agents)
.map(|_| (0..n_agents).map(|_| vec![0.0_f64; n_vars]).collect())
.collect();
let mut primal_history = Vec::with_capacity(config.max_iter);
let mut dual_history = Vec::with_capacity(config.max_iter);
let mut converged = false;
let mut iterations = 0;
for iter in 0..config.max_iter {
iterations = iter + 1;
let x_old = x.clone();
for i in 0..n_agents {
let mut neighbours = 0usize;
let mut agg = vec![0.0_f64; n_vars];
for j in 0..n_agents {
if self.topology[i][j] > 0.0 {
neighbours += 1;
for k in 0..n_vars {
agg[k] += lam[i][j][k] - rho * x_old[j][k];
}
}
}
let rho_eff = rho * (neighbours.max(1) as f64);
let prox_arg: Vec<f64> = agg.iter().map(|a| -a / rho_eff).collect();
x[i] = (local_fns[i])(&prox_arg, rho_eff);
}
for i in 0..n_agents {
for j in 0..n_agents {
if self.topology[i][j] > 0.0 {
for k in 0..n_vars {
lam[i][j][k] += rho * (x[i][k] - x[j][k]);
}
}
}
}
let mut primal_sq = 0.0_f64;
let mut dual_sq = 0.0_f64;
for i in 0..n_agents {
for j in 0..n_agents {
if self.topology[i][j] > 0.0 {
for k in 0..n_vars {
primal_sq += (x[i][k] - x[j][k]).powi(2);
}
}
}
for k in 0..n_vars {
dual_sq += (x[i][k] - x_old[i][k]).powi(2);
}
}
let primal_res = primal_sq.sqrt();
let dual_res = rho * dual_sq.sqrt();
primal_history.push(primal_res);
dual_history.push(dual_res);
if primal_res < config.tol {
converged = true;
break;
}
}
let mut x_consensus = vec![0.0_f64; n_vars];
let scale = 1.0 / n_agents as f64;
for xi in x.iter() {
for k in 0..n_vars {
x_consensus[k] += scale * xi[k];
}
}
Ok(AdmmResult {
x: x_consensus,
primal_residual: primal_history,
dual_residual: dual_history,
converged,
iterations,
})
}
}
#[derive(Debug)]
pub struct ExtraSolver {
pub w: Vec<Vec<f64>>,
pub w_tilde: Vec<Vec<f64>>,
}
impl ExtraSolver {
pub fn new(w: Vec<Vec<f64>>) -> OptimizeResult<Self> {
let n = w.len();
if !check_doubly_stochastic(&w, 1e-6) {
return Err(OptimizeError::InvalidInput(
"W must be doubly stochastic".into(),
));
}
let w_tilde: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..n)
.map(|j| {
let eye = if i == j { 1.0 } else { 0.0 };
(eye + w[i][j]) / 2.0
})
.collect()
})
.collect();
Ok(Self { w, w_tilde })
}
pub fn solve<F>(
&self,
grad_fns: &[F],
n_vars: usize,
config: &ExtraConfig,
) -> OptimizeResult<AdmmResult>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let n_agents = self.w.len();
if grad_fns.len() != n_agents {
return Err(OptimizeError::InvalidInput(format!(
"Expected {} gradient functions but got {}",
n_agents,
grad_fns.len()
)));
}
if n_vars == 0 {
return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
}
let alpha = config.alpha;
let mut x_curr: Vec<Vec<f64>> = (0..n_agents).map(|_| vec![0.0; n_vars]).collect();
let grad_curr: Vec<Vec<f64>> = (0..n_agents).map(|i| (grad_fns[i])(&x_curr[i])).collect();
let x_next: Vec<Vec<f64>> = (0..n_agents)
.map(|i| {
let wx_i: Vec<f64> = (0..n_vars)
.map(|k| {
(0..n_agents)
.map(|j| self.w[i][j] * x_curr[j][k])
.sum::<f64>()
})
.collect();
wx_i.iter()
.zip(grad_curr[i].iter())
.map(|(w, g)| w - alpha * g)
.collect()
})
.collect();
let mut x_prev = x_curr.clone();
let mut x_curr = x_next;
let mut grad_prev = grad_curr;
let mut primal_history = Vec::with_capacity(config.max_iter);
let mut dual_history = Vec::with_capacity(config.max_iter);
let mut converged = false;
let mut iterations = 1;
for iter in 1..config.max_iter {
iterations = iter + 1;
let grad_curr: Vec<Vec<f64>> =
(0..n_agents).map(|i| (grad_fns[i])(&x_curr[i])).collect();
let w_tilde_x_curr: Vec<Vec<f64>> = (0..n_agents)
.map(|i| {
(0..n_vars)
.map(|k| {
(0..n_agents)
.map(|j| self.w_tilde[i][j] * x_curr[j][k])
.sum::<f64>()
})
.collect()
})
.collect();
let w_tilde_x_prev: Vec<Vec<f64>> = (0..n_agents)
.map(|i| {
(0..n_vars)
.map(|k| {
(0..n_agents)
.map(|j| self.w_tilde[i][j] * x_prev[j][k])
.sum::<f64>()
})
.collect()
})
.collect();
let x_new: Vec<Vec<f64>> = (0..n_agents)
.map(|i| {
(0..n_vars)
.map(|k| {
w_tilde_x_curr[i][k] + x_curr[i][k]
- w_tilde_x_prev[i][k]
- alpha * (grad_curr[i][k] - grad_prev[i][k])
})
.collect()
})
.collect();
let x_bar: Vec<f64> = (0..n_vars)
.map(|k| x_new.iter().map(|xi| xi[k]).sum::<f64>() / n_agents as f64)
.collect();
let cons_res: f64 = x_new
.iter()
.map(|xi| {
xi.iter()
.zip(x_bar.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt()
})
.fold(0.0_f64, f64::max);
let dx: f64 = x_new
.iter()
.zip(x_curr.iter())
.map(|(xn, xc)| {
xn.iter()
.zip(xc.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
})
.sum::<f64>()
.sqrt();
primal_history.push(cons_res);
dual_history.push(dx);
x_prev = x_curr;
x_curr = x_new;
grad_prev = grad_curr;
if cons_res < config.tol && dx < config.tol {
converged = true;
break;
}
}
let x_bar: Vec<f64> = (0..n_vars)
.map(|k| x_curr.iter().map(|xi| xi[k]).sum::<f64>() / n_agents as f64)
.collect();
Ok(AdmmResult {
x: x_bar,
primal_residual: primal_history,
dual_residual: dual_history,
converged,
iterations,
})
}
}
pub fn ring_topology(n: usize) -> Vec<Vec<f64>> {
let mut adj = vec![vec![0.0_f64; n]; n];
for i in 0..n {
let next = (i + 1) % n;
let prev = (i + n - 1) % n;
adj[i][next] = 1.0;
adj[i][prev] = 1.0;
}
adj
}
pub fn metropolis_hastings_weights(adj: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = adj.len();
let degrees: Vec<usize> = (0..n)
.map(|i| adj[i].iter().filter(|&&v| v > 0.0).count())
.collect();
let mut w = vec![vec![0.0_f64; n]; n];
for i in 0..n {
let mut row_sum = 0.0;
for j in 0..n {
if adj[i][j] > 0.0 && i != j {
let denom = 1.0 + degrees[i].max(degrees[j]) as f64;
w[i][j] = 1.0 / denom;
row_sum += w[i][j];
}
}
w[i][i] = 1.0 - row_sum;
}
w
}
#[cfg(test)]
mod tests {
use super::*;
fn ring_w(n: usize) -> Vec<Vec<f64>> {
let adj = ring_topology(n);
metropolis_hastings_weights(&adj)
}
#[test]
fn test_ring_topology() {
let adj = ring_topology(4);
assert_eq!(adj[0][1], 1.0);
assert_eq!(adj[0][3], 1.0);
assert_eq!(adj[0][0], 0.0);
assert_eq!(adj[0][2], 0.0);
}
#[test]
fn test_metropolis_hastings_doubly_stochastic() {
let w = ring_w(4);
for row in w.iter() {
let s: f64 = row.iter().sum();
assert!((s - 1.0).abs() < 1e-10, "Row sum = {}", s);
}
let n = w.len();
for j in 0..n {
let s: f64 = w.iter().map(|row| row[j]).sum();
assert!((s - 1.0).abs() < 1e-10, "Col {} sum = {}", j, s);
}
}
#[test]
fn test_pdmm_converges() {
let n_agents = 3;
let n_vars = 1;
let centers = vec![1.0_f64, 3.0, 5.0]; let topology = vec![
vec![0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0],
];
let solver = PdmmSolver::new(topology).expect("PDMM creation failed");
let config = PdmmConfig {
stepsize: 0.2,
max_iter: 2000,
tol: 1e-4,
};
let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
.iter()
.map(|&c| {
let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
Box::new(move |v: &[f64], rho: f64| vec![(c + rho * v[0]) / (1.0 + rho)]);
f
})
.collect();
let result = solver
.solve(&prox_fns, n_vars, &config)
.expect("PDMM solve failed");
assert!(
result.converged,
"PDMM should converge, iters={}",
result.iterations
);
assert!(
(result.x[0] - 3.0).abs() < 0.1,
"x = {:.4} (expected 3.0)",
result.x[0]
);
}
#[test]
fn test_pdmm_topology_ring() {
let centers = vec![0.0_f64, 2.0, 4.0, 6.0]; let adj = ring_topology(4);
let solver = PdmmSolver::new(adj).expect("PDMM ring creation failed");
let config = PdmmConfig {
stepsize: 0.1,
max_iter: 5000,
tol: 1e-3,
};
let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
.iter()
.map(|&c| {
let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
Box::new(move |v: &[f64], rho: f64| vec![(c + rho * v[0]) / (1.0 + rho)]);
f
})
.collect();
let result = solver
.solve(&prox_fns, 1, &config)
.expect("PDMM ring solve failed");
assert!(
(result.x[0] - 3.0).abs() < 0.5,
"x = {:.4} (expected ~3.0)",
result.x[0]
);
}
#[test]
fn test_extra_exact_consensus() {
let centers = vec![1.0_f64, 3.0, 5.0, 7.0]; let w = ring_w(4);
let solver = ExtraSolver::new(w).expect("EXTRA creation failed");
let config = ExtraConfig {
alpha: 0.02,
max_iter: 2000,
tol: 1e-4,
};
let grad_fns: Vec<Box<dyn Fn(&[f64]) -> Vec<f64>>> = centers
.iter()
.map(|&c| {
let f: Box<dyn Fn(&[f64]) -> Vec<f64>> =
Box::new(move |x: &[f64]| vec![2.0 * (x[0] - c)]);
f
})
.collect();
let result = solver
.solve(&grad_fns, 1, &config)
.expect("EXTRA solve failed");
assert!(
result.converged || result.iterations == config.max_iter,
"EXTRA iterations: {}",
result.iterations
);
assert!(
(result.x[0] - 4.0).abs() < 0.1,
"x = {:.4} (expected 4.0), iters={}",
result.x[0],
result.iterations
);
}
#[test]
fn test_extra_vs_admm_same_solution() {
use super::super::admm::solve_lasso_admm;
let centers = vec![2.0_f64, 4.0, 6.0]; let n_agents = 3_usize;
let w = ring_w(n_agents);
let solver = ExtraSolver::new(w).expect("EXTRA creation failed");
let config = ExtraConfig {
alpha: 0.02,
max_iter: 2000,
tol: 1e-4,
};
let grad_fns: Vec<Box<dyn Fn(&[f64]) -> Vec<f64>>> = centers
.iter()
.map(|&c| {
let f: Box<dyn Fn(&[f64]) -> Vec<f64>> =
Box::new(move |x: &[f64]| vec![2.0 * (x[0] - c)]);
f
})
.collect();
let extra_res = solver.solve(&grad_fns, 1, &config).expect("EXTRA failed");
use super::super::admm::consensus_admm;
let admm_config = super::super::types::AdmmConfig {
rho: 1.0,
max_iter: 500,
abs_tol: 1e-6,
rel_tol: 1e-4,
warm_start: false,
over_relaxation: 1.0,
};
let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
.iter()
.map(|&c| {
let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
Box::new(move |v: &[f64], rho: f64| {
vec![(rho * v[0] + 2.0 * c) / (rho + 2.0)]
});
f
})
.collect();
let admm_res = consensus_admm(&prox_fns, 1, &admm_config).expect("ADMM failed");
assert!(
(extra_res.x[0] - 4.0).abs() < 0.2,
"EXTRA x = {:.4}",
extra_res.x[0]
);
assert!(
(admm_res.x[0] - 4.0).abs() < 0.1,
"ADMM x = {:.4}",
admm_res.x[0]
);
}
#[test]
fn test_extra_solver_invalid_w() {
let w = vec![vec![0.5, 0.5], vec![0.9, 0.1]]; let result = ExtraSolver::new(w);
assert!(result.is_err());
}
}