#[derive(Debug, Clone)]
pub struct AdmmConfig {
pub rho: f64,
pub max_iter: usize,
pub abs_tol: f64,
pub rel_tol: f64,
pub warm_start: bool,
pub over_relaxation: f64,
}
impl Default for AdmmConfig {
fn default() -> Self {
Self {
rho: 1.0,
max_iter: 1000,
abs_tol: 1e-4,
rel_tol: 1e-3,
warm_start: false,
over_relaxation: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct PdmmConfig {
pub stepsize: f64,
pub max_iter: usize,
pub tol: f64,
}
impl Default for PdmmConfig {
fn default() -> Self {
Self {
stepsize: 0.5,
max_iter: 500,
tol: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct ExtraConfig {
pub alpha: f64,
pub max_iter: usize,
pub tol: f64,
}
impl Default for ExtraConfig {
fn default() -> Self {
Self {
alpha: 0.05,
max_iter: 500,
tol: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct AdmmResult {
pub x: Vec<f64>,
pub primal_residual: Vec<f64>,
pub dual_residual: Vec<f64>,
pub converged: bool,
pub iterations: usize,
}
#[derive(Debug, Clone)]
pub struct ConsensusNode {
pub local_x: Vec<f64>,
pub local_z: Vec<f64>,
pub dual_y: Vec<f64>,
}
impl ConsensusNode {
pub fn new(n_vars: usize) -> Self {
Self {
local_x: vec![0.0; n_vars],
local_z: vec![0.0; n_vars],
dual_y: vec![0.0; n_vars],
}
}
pub fn warm(x0: Vec<f64>) -> Self {
let n = x0.len();
Self {
local_z: x0.clone(),
dual_y: vec![0.0; n],
local_x: x0,
}
}
pub fn primal_residual(&self) -> Vec<f64> {
self.local_x
.iter()
.zip(self.local_z.iter())
.map(|(xi, zi)| xi - zi)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_admm_config_default() {
let cfg = AdmmConfig::default();
assert!((cfg.rho - 1.0).abs() < 1e-15);
assert_eq!(cfg.max_iter, 1000);
assert!((cfg.abs_tol - 1e-4).abs() < 1e-15);
assert!(!cfg.warm_start);
}
#[test]
fn test_pdmm_config_default() {
let cfg = PdmmConfig::default();
assert!((cfg.stepsize - 0.5).abs() < 1e-15);
assert_eq!(cfg.max_iter, 500);
}
#[test]
fn test_extra_config_default() {
let cfg = ExtraConfig::default();
assert!((cfg.alpha - 0.05).abs() < 1e-15);
assert_eq!(cfg.max_iter, 500);
}
#[test]
fn test_consensus_node_new() {
let node = ConsensusNode::new(3);
assert_eq!(node.local_x.len(), 3);
assert!(node.local_x.iter().all(|&v| v == 0.0));
}
#[test]
fn test_consensus_node_warm() {
let node = ConsensusNode::warm(vec![1.0, 2.0, 3.0]);
assert_eq!(node.local_x, vec![1.0, 2.0, 3.0]);
assert_eq!(node.local_z, vec![1.0, 2.0, 3.0]);
assert!(node.dual_y.iter().all(|&v| v == 0.0));
}
#[test]
fn test_primal_residual() {
let mut node = ConsensusNode::new(2);
node.local_x = vec![1.0, 2.0];
node.local_z = vec![0.5, 1.5];
let res = node.primal_residual();
assert!((res[0] - 0.5).abs() < 1e-15);
assert!((res[1] - 0.5).abs() < 1e-15);
}
}