Skip to main content

kizzasi_logic/
distributed_admm.rs

1//! Distributed ADMM (Alternating Direction Method of Multipliers) for constraint solving
2//!
3//! Implements the consensus form of ADMM for distributed optimization:
4//!
5//! ```text
6//! minimize   Σᵢ fᵢ(xᵢ)
7//! subject to xᵢ = z  for all i   (consensus constraint)
8//! ```
9//!
10//! Each sub-problem is solved independently (in parallel via rayon), and the
11//! global consensus variable `z` is updated by averaging. Convergence is
12//! measured using primal and dual residuals with optional over-relaxation.
13//!
14//! ## References
15//! - Boyd et al., "Distributed Optimization and Statistical Learning via ADMM", 2011.
16
17use scirs2_core::ndarray::{Array1, Array2};
18use thiserror::Error;
19
20// ─────────────────────────────── Error type ───────────────────────────────
21
22/// Errors that can occur during ADMM solving
23#[derive(Debug, Error)]
24pub enum AdmmError {
25    /// Local variable dimension differs from the global consensus dimension
26    #[error("Dimension mismatch: expected {expected}, got {got}")]
27    DimensionMismatch { expected: usize, got: usize },
28
29    /// `solve()` was called on a solver with no registered sub-problems
30    #[error("No subproblems added")]
31    NoSubproblems,
32
33    /// A local sub-problem returned an error
34    #[error("Subproblem solve failed: {0}")]
35    SubproblemFailed(String),
36
37    /// The iteration limit was reached before the residuals satisfied the tolerances
38    #[error("Maximum iterations reached without convergence")]
39    MaxIterationsReached,
40
41    /// A floating-point operation produced a non-finite value
42    #[error("Numerical error: {0}")]
43    NumericalError(String),
44}
45
46// ─────────────────────────────── Config ──────────────────────────────────
47
48/// Configuration for the ADMM algorithm
49#[derive(Debug, Clone)]
50pub struct AdmmConfig {
51    /// Augmented Lagrangian penalty parameter ρ (default 1.0).
52    ///
53    /// Larger values impose stronger coupling between sub-problems.
54    pub rho: f32,
55    /// Maximum number of ADMM outer iterations (default 100)
56    pub max_iterations: usize,
57    /// Absolute convergence tolerance for primal/dual residuals (default 1e-4)
58    pub abs_tol: f32,
59    /// Relative convergence tolerance (default 1e-3)
60    pub rel_tol: f32,
61    /// Over-relaxation parameter α ∈ [1.0, 1.8] (default 1.0 = no relaxation).
62    ///
63    /// Values in (1, 1.8) often accelerate convergence.
64    pub over_relaxation: f32,
65    /// Print iteration statistics when `true`
66    pub verbose: bool,
67}
68
69impl Default for AdmmConfig {
70    fn default() -> Self {
71        Self {
72            rho: 1.0,
73            max_iterations: 100,
74            abs_tol: 1e-4,
75            rel_tol: 1e-3,
76            over_relaxation: 1.0,
77            verbose: false,
78        }
79    }
80}
81
82impl AdmmConfig {
83    /// Create a config with sensible defaults
84    pub fn new() -> Self {
85        Self::default()
86    }
87
88    /// Set the penalty parameter ρ
89    pub fn with_rho(mut self, rho: f32) -> Self {
90        self.rho = rho;
91        self
92    }
93
94    /// Set the maximum number of iterations
95    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
96        self.max_iterations = max_iterations;
97        self
98    }
99
100    /// Set the absolute tolerance
101    pub fn with_abs_tol(mut self, abs_tol: f32) -> Self {
102        self.abs_tol = abs_tol;
103        self
104    }
105
106    /// Set the relative tolerance
107    pub fn with_rel_tol(mut self, rel_tol: f32) -> Self {
108        self.rel_tol = rel_tol;
109        self
110    }
111
112    /// Set the over-relaxation parameter α (must be in `(0, 2)`)
113    pub fn with_over_relaxation(mut self, alpha: f32) -> Self {
114        self.over_relaxation = alpha;
115        self
116    }
117
118    /// Enable verbose iteration logging
119    pub fn with_verbose(mut self, verbose: bool) -> Self {
120        self.verbose = verbose;
121        self
122    }
123}
124
125// ─────────────────────────────── Result ──────────────────────────────────
126
127/// The output of a completed ADMM solve
128#[derive(Debug, Clone)]
129pub struct AdmmResult {
130    /// The consensus solution vector `z`
131    pub solution: Array1<f32>,
132    /// Number of iterations performed
133    pub iterations: usize,
134    /// Whether the algorithm converged within the tolerance before `max_iterations`
135    pub converged: bool,
136    /// Final primal residual ‖xᵢ − z‖ (averaged across agents)
137    pub primal_residual: f32,
138    /// Final dual residual ρ‖z − z_old‖
139    pub dual_residual: f32,
140    /// Sum of local objectives Σᵢ fᵢ(z) evaluated at the consensus solution
141    pub objective: f32,
142}
143
144// ─────────────────────────────── Trait ───────────────────────────────────
145
146/// A local sub-problem solved by one ADMM agent
147///
148/// Each agent minimises:
149/// ```text
150///   fᵢ(xᵢ) + (ρ/2) ‖xᵢ − z + uᵢ‖²
151/// ```
152/// where `z` is the current global consensus variable and `uᵢ` is the
153/// scaled dual variable for agent `i`.
154pub trait AdmmSubproblem: Send + Sync {
155    /// Solve the local proximal sub-problem and return the updated `xᵢ`.
156    ///
157    /// # Arguments
158    /// * `z` – current global consensus variable
159    /// * `u` – current scaled dual variable for this agent
160    /// * `rho` – penalty parameter
161    fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, rho: f32) -> Result<Array1<f32>, AdmmError>;
162
163    /// Evaluate the local objective fᵢ(x)
164    fn objective(&self, x: &Array1<f32>) -> f32;
165
166    /// Dimension of the local variable xᵢ (must equal the global dimension)
167    fn dim(&self) -> usize;
168}
169
170// ─────────────────────────── Helper utilities ────────────────────────────
171
172/// Compute the L2 norm of an array
173#[inline]
174fn l2_norm(v: &Array1<f32>) -> f32 {
175    v.iter().map(|&x| x * x).sum::<f32>().sqrt()
176}
177
178/// Element-wise soft-threshold (shrinkage) operator
179///
180/// `shrink(v, κ)ᵢ = sign(vᵢ) · max(|vᵢ| − κ, 0)`
181fn soft_threshold(v: &Array1<f32>, kappa: f32) -> Array1<f32> {
182    v.mapv(|x| {
183        if x > kappa {
184            x - kappa
185        } else if x < -kappa {
186            x + kappa
187        } else {
188            0.0
189        }
190    })
191}
192
193/// Box projection: clip each element of `v` to `[lb, ub]`
194fn box_clip(v: &Array1<f32>, lb: &Array1<f32>, ub: &Array1<f32>) -> Array1<f32> {
195    v.iter()
196        .zip(lb.iter())
197        .zip(ub.iter())
198        .map(|((&vi, &li), &ui)| vi.clamp(li, ui))
199        .collect()
200}
201
202/// Solve the symmetric positive-definite linear system (Q + ρ I) x = b using
203/// the Gauss-Seidel method.
204///
205/// This avoids any LAPACK/BLAS dependency while remaining correct for PSD
206/// matrices. For small dimensions the iteration converges very quickly.
207fn gauss_seidel_solve(
208    q: &Array2<f32>,
209    rho: f32,
210    b: &Array1<f32>,
211) -> Result<Array1<f32>, AdmmError> {
212    let n = b.len();
213    if q.nrows() != n || q.ncols() != n {
214        return Err(AdmmError::DimensionMismatch {
215            expected: n,
216            got: q.nrows(),
217        });
218    }
219
220    let mut x = Array1::<f32>::zeros(n);
221    // The effective matrix is A = Q + rho*I
222    // Gauss-Seidel: xᵢ ← (bᵢ − Σⱼ≠ᵢ Aᵢⱼ xⱼ) / Aᵢᵢ
223    let max_inner = 200usize;
224    for _iter in 0..max_inner {
225        let x_old = x.clone();
226        for i in 0..n {
227            let a_ii = q[[i, i]] + rho;
228            if a_ii.abs() < f32::EPSILON {
229                return Err(AdmmError::NumericalError(format!(
230                    "Near-zero diagonal at index {i}"
231                )));
232            }
233            let mut sum = 0.0f32;
234            for j in 0..n {
235                if j != i {
236                    sum += (q[[i, j]] + if i == j { rho } else { 0.0 }) * x[j];
237                }
238            }
239            x[i] = (b[i] - sum) / a_ii;
240        }
241        // Check convergence of the inner loop
242        let diff: f32 = x
243            .iter()
244            .zip(x_old.iter())
245            .map(|(a, b)| (a - b) * (a - b))
246            .sum::<f32>()
247            .sqrt();
248        if diff < 1e-8 {
249            break;
250        }
251    }
252
253    // Verify the solution is finite
254    for &xi in x.iter() {
255        if !xi.is_finite() {
256            return Err(AdmmError::NumericalError(
257                "Gauss-Seidel produced non-finite value".into(),
258            ));
259        }
260    }
261    Ok(x)
262}
263
264// ──────────────────────── Concrete sub-problems ───────────────────────────
265
266/// Quadratic sub-problem: minimise ½ xᵀ Q x + cᵀ x  subject to  lb ≤ x ≤ ub
267///
268/// The augmented problem is:
269/// ```text
270///   minimise  ½ xᵀ Q x + cᵀ x + (ρ/2) ‖x − (z − u)‖²
271/// = minimise  ½ xᵀ (Q + ρI) x + (c − ρ(z − u))ᵀ x
272/// ```
273/// Solution: x̂ = (Q + ρI)⁻¹ (ρ(z − u) − c), then clipped to [lb, ub].
274#[derive(Debug, Clone)]
275pub struct QuadraticSubproblem {
276    /// Positive semi-definite quadratic term (n × n)
277    pub q: Array2<f32>,
278    /// Linear cost term (n)
279    pub c: Array1<f32>,
280    /// Optional per-element lower bounds
281    pub lb: Option<Array1<f32>>,
282    /// Optional per-element upper bounds
283    pub ub: Option<Array1<f32>>,
284}
285
286impl QuadraticSubproblem {
287    /// Create an unconstrained quadratic sub-problem
288    pub fn new(q: Array2<f32>, c: Array1<f32>) -> Self {
289        Self {
290            q,
291            c,
292            lb: None,
293            ub: None,
294        }
295    }
296
297    /// Add box constraints lb ≤ x ≤ ub
298    pub fn with_bounds(mut self, lb: Array1<f32>, ub: Array1<f32>) -> Self {
299        self.lb = Some(lb);
300        self.ub = Some(ub);
301        self
302    }
303}
304
305impl AdmmSubproblem for QuadraticSubproblem {
306    fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, rho: f32) -> Result<Array1<f32>, AdmmError> {
307        let n = self.dim();
308        if z.len() != n || u.len() != n {
309            return Err(AdmmError::DimensionMismatch {
310                expected: n,
311                got: z.len(),
312            });
313        }
314        // RHS: ρ(z − u) − c
315        let rhs: Array1<f32> = (z - u).mapv(|v| rho * v) - &self.c;
316
317        // Solve (Q + ρI) x = rhs
318        let x = gauss_seidel_solve(&self.q, rho, &rhs)?;
319
320        // Clip to box if bounds are present
321        let x = match (&self.lb, &self.ub) {
322            (Some(lb), Some(ub)) => box_clip(&x, lb, ub),
323            (Some(lb), None) => x.mapv(|v| v.max(lb[0])),
324            (None, Some(ub)) => x.mapv(|v| v.min(ub[0])),
325            (None, None) => x,
326        };
327
328        Ok(x)
329    }
330
331    fn objective(&self, x: &Array1<f32>) -> f32 {
332        // ½ xᵀ Q x + cᵀ x
333        let qx: Array1<f32> = self.q.dot(x);
334        0.5 * x.iter().zip(qx.iter()).map(|(a, b)| a * b).sum::<f32>()
335            + x.iter().zip(self.c.iter()).map(|(a, b)| a * b).sum::<f32>()
336    }
337
338    fn dim(&self) -> usize {
339        self.c.len()
340    }
341}
342
343/// LASSO sub-problem: minimise ‖Ax − b‖² + λ‖x‖₁
344///
345/// The augmented problem admits a closed-form solution via soft-thresholding.
346///
347/// The x-update for this sub-problem is:
348/// ```text
349///   v = (AᵀA + ρI)⁻¹ (Aᵀb + ρ(z − u))
350///   x̂ = shrink(v, λ/ρ)
351/// ```
352#[derive(Debug, Clone)]
353pub struct LassoSubproblem {
354    /// Measurement matrix (m × n)
355    pub a: Array2<f32>,
356    /// Observation vector (m)
357    pub b: Array1<f32>,
358    /// L1 regularisation weight λ
359    pub lambda: f32,
360}
361
362impl LassoSubproblem {
363    /// Create a new LASSO sub-problem
364    pub fn new(a: Array2<f32>, b: Array1<f32>, lambda: f32) -> Self {
365        Self { a, b, lambda }
366    }
367}
368
369impl AdmmSubproblem for LassoSubproblem {
370    fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, rho: f32) -> Result<Array1<f32>, AdmmError> {
371        let n = self.dim();
372        if z.len() != n || u.len() != n {
373            return Err(AdmmError::DimensionMismatch {
374                expected: n,
375                got: z.len(),
376            });
377        }
378
379        // Build AᵀA (n × n)
380        let at = self.a.t();
381        let ata: Array2<f32> = at.dot(&self.a);
382
383        // RHS: Aᵀb + ρ(z − u)
384        let atb: Array1<f32> = at.dot(&self.b);
385        let rhs: Array1<f32> = atb + (z - u).mapv(|v| rho * v);
386
387        // Solve (AᵀA + ρI) v = rhs
388        let v = gauss_seidel_solve(&ata, rho, &rhs)?;
389
390        // Apply soft-threshold
391        let kappa = self.lambda / rho;
392        Ok(soft_threshold(&v, kappa))
393    }
394
395    fn objective(&self, x: &Array1<f32>) -> f32 {
396        let ax_minus_b: Array1<f32> = self.a.dot(x) - &self.b;
397        let l2_sq: f32 = ax_minus_b.iter().map(|&v| v * v).sum();
398        let l1: f32 = x.iter().map(|&v| v.abs()).sum();
399        l2_sq + self.lambda * l1
400    }
401
402    fn dim(&self) -> usize {
403        self.a.ncols()
404    }
405}
406
407/// Projection sub-problem: project onto the box [lb, ub]
408///
409/// The augmented problem is:
410/// ```text
411///   minimise  ‖x − v‖²  s.t.  lb ≤ x ≤ ub,  where v = z − u
412/// ```
413/// which has the closed-form solution x̂ = clip(z − u, lb, ub).
414#[derive(Debug, Clone)]
415pub struct ProjectionSubproblem {
416    /// Per-element lower bounds
417    pub lb: Array1<f32>,
418    /// Per-element upper bounds
419    pub ub: Array1<f32>,
420}
421
422impl ProjectionSubproblem {
423    /// Create a box-projection sub-problem
424    pub fn new(lb: Array1<f32>, ub: Array1<f32>) -> Self {
425        Self { lb, ub }
426    }
427}
428
429impl AdmmSubproblem for ProjectionSubproblem {
430    fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, _rho: f32) -> Result<Array1<f32>, AdmmError> {
431        let n = self.dim();
432        if z.len() != n || u.len() != n {
433            return Err(AdmmError::DimensionMismatch {
434                expected: n,
435                got: z.len(),
436            });
437        }
438        let v: Array1<f32> = z - u;
439        Ok(box_clip(&v, &self.lb, &self.ub))
440    }
441
442    fn objective(&self, _x: &Array1<f32>) -> f32 {
443        // Indicator function: 0 inside the box, but we treat it as 0 here
444        // because feasible iterates will always be inside [lb, ub].
445        0.0
446    }
447
448    fn dim(&self) -> usize {
449        self.lb.len()
450    }
451}
452
453// ─────────────────────── Core ADMM iteration logic ───────────────────────
454
455/// Internal output type for one ADMM sweep:
456/// `(local_xs, dual_us, z_new, primal_residual, dual_residual)`
457type AdmmSweepOutput = (Vec<Array1<f32>>, Vec<Array1<f32>>, Array1<f32>, f32, f32);
458
459/// Run one full ADMM sweep and return (updated local_xs, updated_us, primal_res, dual_res, z_new).
460///
461/// Uses rayon for parallel x-updates when the `parallel` feature of scirs2-core is active.
462fn admm_sweep(
463    subproblems: &[Box<dyn AdmmSubproblem>],
464    z: &Array1<f32>,
465    us: &[Array1<f32>],
466    rho: f32,
467    alpha: f32,
468    weights: Option<&Array1<f32>>,
469) -> Result<AdmmSweepOutput, AdmmError> {
470    let n_agents = subproblems.len();
471    let dim = z.len();
472
473    // ── Step 1: parallel x-update ─────────────────────────────────────────
474    // Each agent i solves:  xᵢ ← argmin fᵢ(xᵢ) + (ρ/2)‖xᵢ − z + uᵢ‖²
475    use rayon::prelude::*;
476
477    let x_results: Vec<Result<Array1<f32>, AdmmError>> = subproblems
478        .par_iter()
479        .enumerate()
480        .map(|(i, sp)| sp.solve(z, &us[i], rho))
481        .collect();
482
483    let mut new_xs: Vec<Array1<f32>> = Vec::with_capacity(n_agents);
484    for result in x_results {
485        new_xs.push(result?);
486    }
487
488    // ── Step 2: z-update with optional over-relaxation ────────────────────
489    // z_new = (1/N) Σᵢ (αxᵢ + (1−α)z + uᵢ)
490    //       = α * x_avg + (1−α) * z + u_avg
491    let z_old = z.clone();
492
493    let (x_avg, u_avg) = match weights {
494        Some(w) => {
495            // Weighted average: x_avg = Σᵢ wᵢ xᵢ  (weights sum to 1)
496            let mut xa = Array1::<f32>::zeros(dim);
497            let mut ua = Array1::<f32>::zeros(dim);
498            for i in 0..n_agents {
499                xa = xa + new_xs[i].mapv(|v| v * w[i]);
500                ua = ua + us[i].mapv(|v| v * w[i]);
501            }
502            (xa, ua)
503        }
504        None => {
505            // Simple average
506            let mut xa = Array1::<f32>::zeros(dim);
507            let mut ua = Array1::<f32>::zeros(dim);
508            for i in 0..n_agents {
509                xa += &new_xs[i];
510                ua += &us[i];
511            }
512            let inv_n = 1.0 / n_agents as f32;
513            (xa.mapv(|v| v * inv_n), ua.mapv(|v| v * inv_n))
514        }
515    };
516
517    // Over-relaxation: x̃ᵢ = α xᵢ + (1−α) z
518    let z_new: Array1<f32> = x_avg.mapv(|v| alpha * v) + z_old.mapv(|v| (1.0 - alpha) * v) + &u_avg;
519
520    // ── Step 3: dual variable update ─────────────────────────────────────
521    // uᵢ ← uᵢ + α xᵢ + (1−α) z − z_new
522    //     = uᵢ + x̃ᵢ − z_new   (with x̃ᵢ = α xᵢ + (1−α) z_old)
523    let mut new_us: Vec<Array1<f32>> = Vec::with_capacity(n_agents);
524    let mut primal_sq = 0.0f32;
525    for i in 0..n_agents {
526        let x_tilde: Array1<f32> =
527            new_xs[i].mapv(|v| alpha * v) + z_old.mapv(|v| (1.0 - alpha) * v);
528        let residual_i: Array1<f32> = &x_tilde - &z_new;
529        primal_sq += residual_i.iter().map(|&v| v * v).sum::<f32>();
530        let u_new: Array1<f32> = &us[i] + &residual_i;
531        new_us.push(u_new);
532    }
533    let primal_res = (primal_sq / n_agents as f32).sqrt();
534
535    // Dual residual: ρ ‖z_new − z_old‖
536    let dual_res = rho * l2_norm(&(&z_new - &z_old));
537
538    Ok((new_xs, new_us, z_new, primal_res, dual_res))
539}
540
541/// Check ADMM convergence using the Boyd et al. stopping criteria.
542///
543/// Returns `true` when both the primal and dual residuals fall below the
544/// combined absolute + relative threshold.
545fn check_convergence(
546    primal_res: f32,
547    dual_res: f32,
548    config: &AdmmConfig,
549    n_agents: usize,
550    dim: usize,
551) -> bool {
552    let scale = ((n_agents * dim) as f32).sqrt();
553    let eps_primal = config.abs_tol * scale + config.rel_tol;
554    let eps_dual = config.abs_tol * scale + config.rel_tol;
555    primal_res < eps_primal && dual_res < eps_dual
556}
557
558// ──────────────────────── DistributedAdmm ────────────────────────────────
559
560/// Distributed ADMM solver coordinating multiple local sub-problems
561///
562/// Runs the standard consensus ADMM algorithm with optional over-relaxation
563/// and parallelism via rayon.
564pub struct DistributedAdmm {
565    config: AdmmConfig,
566    subproblems: Vec<Box<dyn AdmmSubproblem>>,
567    dim: usize,
568}
569
570impl DistributedAdmm {
571    /// Create a new solver for variables of length `dim`
572    pub fn new(config: AdmmConfig, dim: usize) -> Self {
573        Self {
574            config,
575            subproblems: Vec::new(),
576            dim,
577        }
578    }
579
580    /// Register a new local sub-problem.
581    ///
582    /// Returns `Err(AdmmError::DimensionMismatch)` if the sub-problem has a
583    /// different variable dimension from the solver.
584    pub fn add_subproblem(&mut self, subproblem: Box<dyn AdmmSubproblem>) -> Result<(), AdmmError> {
585        if subproblem.dim() != self.dim {
586            return Err(AdmmError::DimensionMismatch {
587                expected: self.dim,
588                got: subproblem.dim(),
589            });
590        }
591        self.subproblems.push(subproblem);
592        Ok(())
593    }
594
595    /// Number of registered sub-problems
596    pub fn num_subproblems(&self) -> usize {
597        self.subproblems.len()
598    }
599
600    /// Solve starting from the zero vector
601    pub fn solve(&self) -> Result<AdmmResult, AdmmError> {
602        self.solve_warm(Array1::zeros(self.dim))
603    }
604
605    /// Solve with a user-provided warm-start for `z`
606    pub fn solve_warm(&self, z_init: Array1<f32>) -> Result<AdmmResult, AdmmError> {
607        if self.subproblems.is_empty() {
608            return Err(AdmmError::NoSubproblems);
609        }
610        if z_init.len() != self.dim {
611            return Err(AdmmError::DimensionMismatch {
612                expected: self.dim,
613                got: z_init.len(),
614            });
615        }
616
617        let n_agents = self.subproblems.len();
618        let mut z = z_init;
619        let mut us: Vec<Array1<f32>> = vec![Array1::zeros(self.dim); n_agents];
620
621        let mut primal_res = f32::INFINITY;
622        let mut dual_res = f32::INFINITY;
623        let mut iterations = 0usize;
624        let mut converged = false;
625
626        for iter in 0..self.config.max_iterations {
627            iterations = iter + 1;
628            let (new_xs, new_us, z_new, pr, dr) = admm_sweep(
629                &self.subproblems,
630                &z,
631                &us,
632                self.config.rho,
633                self.config.over_relaxation,
634                None,
635            )?;
636
637            let _ = new_xs; // local x values are internal; z carries the consensus
638            us = new_us;
639            z = z_new;
640            primal_res = pr;
641            dual_res = dr;
642
643            if self.config.verbose {
644                tracing::debug!(iter = iterations, primal_res, dual_res, "ADMM iteration");
645            }
646
647            if check_convergence(primal_res, dual_res, &self.config, n_agents, self.dim) {
648                converged = true;
649                break;
650            }
651        }
652
653        // Evaluate total objective at the consensus solution
654        let objective: f32 = self.subproblems.iter().map(|sp| sp.objective(&z)).sum();
655
656        Ok(AdmmResult {
657            solution: z,
658            iterations,
659            converged,
660            primal_residual: primal_res,
661            dual_residual: dual_res,
662            objective,
663        })
664    }
665}
666
667// ──────────────────────── ConsensusAdmm ──────────────────────────────────
668
669/// Consensus ADMM solver with optional per-agent weighting
670///
671/// Identical to [`DistributedAdmm`] but allows a user-supplied weight vector
672/// so that the z-update becomes a weighted average of local variables.
673/// The weights must be non-negative and sum to 1.
674pub struct ConsensusAdmm {
675    config: AdmmConfig,
676    subproblems: Vec<Box<dyn AdmmSubproblem>>,
677    dim: usize,
678    weights: Option<Array1<f32>>,
679}
680
681impl ConsensusAdmm {
682    /// Create a new consensus ADMM solver (uniform weights)
683    pub fn new(config: AdmmConfig, dim: usize) -> Self {
684        Self {
685            config,
686            subproblems: Vec::new(),
687            dim,
688            weights: None,
689        }
690    }
691
692    /// Create a weighted consensus ADMM solver.
693    ///
694    /// `weights` must have length equal to the number of sub-problems that
695    /// will be added, and must sum to approximately 1.0.
696    pub fn new_weighted(
697        config: AdmmConfig,
698        dim: usize,
699        weights: Array1<f32>,
700    ) -> Result<Self, AdmmError> {
701        let sum: f32 = weights.iter().sum();
702        if (sum - 1.0).abs() > 1e-4 {
703            return Err(AdmmError::NumericalError(format!(
704                "Weights must sum to 1.0, got {sum:.6}"
705            )));
706        }
707        Ok(Self {
708            config,
709            subproblems: Vec::new(),
710            dim,
711            weights: Some(weights),
712        })
713    }
714
715    /// Register a local sub-problem
716    pub fn add_subproblem(&mut self, subproblem: Box<dyn AdmmSubproblem>) -> Result<(), AdmmError> {
717        if subproblem.dim() != self.dim {
718            return Err(AdmmError::DimensionMismatch {
719                expected: self.dim,
720                got: subproblem.dim(),
721            });
722        }
723        self.subproblems.push(subproblem);
724        Ok(())
725    }
726
727    /// Run the consensus ADMM solve
728    pub fn solve(&self) -> Result<AdmmResult, AdmmError> {
729        if self.subproblems.is_empty() {
730            return Err(AdmmError::NoSubproblems);
731        }
732
733        let n_agents = self.subproblems.len();
734
735        // Validate weights length if provided
736        if let Some(w) = &self.weights {
737            if w.len() != n_agents {
738                return Err(AdmmError::DimensionMismatch {
739                    expected: n_agents,
740                    got: w.len(),
741                });
742            }
743        }
744
745        let mut z = Array1::<f32>::zeros(self.dim);
746        let mut us: Vec<Array1<f32>> = vec![Array1::zeros(self.dim); n_agents];
747
748        let mut primal_res = f32::INFINITY;
749        let mut dual_res = f32::INFINITY;
750        let mut iterations = 0usize;
751        let mut converged = false;
752
753        for iter in 0..self.config.max_iterations {
754            iterations = iter + 1;
755            let (new_xs, new_us, z_new, pr, dr) = admm_sweep(
756                &self.subproblems,
757                &z,
758                &us,
759                self.config.rho,
760                self.config.over_relaxation,
761                self.weights.as_ref(),
762            )?;
763
764            let _ = new_xs;
765            us = new_us;
766            z = z_new;
767            primal_res = pr;
768            dual_res = dr;
769
770            if self.config.verbose {
771                tracing::debug!(
772                    iter = iterations,
773                    primal_res,
774                    dual_res,
775                    "ConsensusADMM iteration"
776                );
777            }
778
779            if check_convergence(primal_res, dual_res, &self.config, n_agents, self.dim) {
780                converged = true;
781                break;
782            }
783        }
784
785        let objective: f32 = self.subproblems.iter().map(|sp| sp.objective(&z)).sum();
786
787        Ok(AdmmResult {
788            solution: z,
789            iterations,
790            converged,
791            primal_residual: primal_res,
792            dual_residual: dual_res,
793            objective,
794        })
795    }
796}
797
798// ─────────────────────────────── Tests ───────────────────────────────────
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803    use scirs2_core::ndarray::{Array1, Array2};
804
805    // ─── helpers ──────────────────────────────────────────────────────────
806
807    /// Build an n×n identity matrix scaled by `scale`
808    fn eye(n: usize, scale: f32) -> Array2<f32> {
809        let mut m = Array2::<f32>::zeros((n, n));
810        for i in 0..n {
811            m[[i, i]] = scale;
812        }
813        m
814    }
815
816    // ─── 1. Default config ───────────────────────────────────────────────
817
818    #[test]
819    fn test_admm_config_default() {
820        let cfg = AdmmConfig::default();
821        assert!(cfg.rho > 0.0, "rho must be positive");
822        assert!(cfg.max_iterations > 0, "max_iterations must be positive");
823        assert!(cfg.abs_tol > 0.0, "abs_tol must be positive");
824        assert!(cfg.rel_tol > 0.0, "rel_tol must be positive");
825        assert!(
826            (0.0..2.0).contains(&cfg.over_relaxation),
827            "over_relaxation must be in (0, 2)"
828        );
829    }
830
831    // ─── 2. ProjectionSubproblem ─────────────────────────────────────────
832
833    #[test]
834    fn test_projection_subproblem() {
835        let lb = Array1::from_vec(vec![0.0, -1.0, 2.0]);
836        let ub = Array1::from_vec(vec![1.0, 1.0, 5.0]);
837        let sp = ProjectionSubproblem::new(lb.clone(), ub.clone());
838
839        // z outside the box, u = 0 → result should be clipped
840        let z = Array1::from_vec(vec![-2.0, 3.0, 10.0]);
841        let u = Array1::zeros(3);
842        let x = sp.solve(&z, &u, 1.0).expect("projection should succeed");
843        assert!((x[0] - 0.0).abs() < 1e-6, "should clip to lb[0]");
844        assert!((x[1] - 1.0).abs() < 1e-6, "should clip to ub[1]");
845        assert!((x[2] - 5.0).abs() < 1e-6, "should clip to ub[2]");
846    }
847
848    // ─── 3. LassoSubproblem – soft threshold ────────────────────────────
849
850    #[test]
851    fn test_lasso_subproblem_soft_threshold() {
852        // A = I, b = 0, λ large → solution should shrink towards zero
853        let n = 3usize;
854        let a = eye(n, 1.0);
855        let b = Array1::zeros(n);
856        let lambda = 10.0f32;
857        let sp = LassoSubproblem::new(a, b, lambda);
858
859        let z = Array1::from_vec(vec![0.5, -0.5, 0.2]);
860        let u = Array1::zeros(n);
861        let x = sp.solve(&z, &u, 1.0).expect("lasso solve");
862
863        // With large lambda the solution should be very small in magnitude
864        for &xi in x.iter() {
865            assert!(xi.abs() < 0.6, "soft-threshold should shrink the solution");
866        }
867    }
868
869    // ─── 4. QuadraticSubproblem – unconstrained ─────────────────────────
870
871    #[test]
872    fn test_quadratic_subproblem_unconstrained() {
873        // Q = I, c = [-2, -2], minimum at x = (Q)^{-1} * 2 = [2, 2]
874        // but with ADMM coupling (rho=1, z=[2,2], u=[0,0]) we get
875        // (Q + I)x = ρ(z-u) - c = [2,2] - [-2,-2] = [4,4]
876        // x = [2, 2]
877        let n = 2usize;
878        let q = eye(n, 1.0);
879        let c = Array1::from_vec(vec![-2.0, -2.0]);
880        let sp = QuadraticSubproblem::new(q, c);
881
882        let z = Array1::from_vec(vec![2.0, 2.0]);
883        let u = Array1::zeros(n);
884        let x = sp.solve(&z, &u, 1.0).expect("quadratic solve");
885        assert!((x[0] - 2.0).abs() < 1e-3, "x[0] ≈ 2: got {}", x[0]);
886        assert!((x[1] - 2.0).abs() < 1e-3, "x[1] ≈ 2: got {}", x[1]);
887    }
888
889    // ─── 5. QuadraticSubproblem – box constrained ────────────────────────
890
891    #[test]
892    fn test_quadratic_subproblem_constrained() {
893        let n = 2usize;
894        let q = eye(n, 1.0);
895        let c = Array1::from_vec(vec![-5.0, -5.0]);
896        let lb = Array1::from_vec(vec![0.0, 0.0]);
897        let ub = Array1::from_vec(vec![1.0, 1.0]); // clip at 1
898        let sp = QuadraticSubproblem::new(q, c).with_bounds(lb, ub);
899
900        let z = Array1::from_vec(vec![3.0, 3.0]);
901        let u = Array1::zeros(n);
902        let x = sp.solve(&z, &u, 1.0).expect("constrained quadratic solve");
903        // Solution without bounds would be >> 1, so clipping should kick in
904        assert!(x[0] <= 1.0 + 1e-6, "x[0] must be ≤ ub");
905        assert!(x[1] <= 1.0 + 1e-6, "x[1] must be ≤ ub");
906        assert!(x[0] >= 0.0 - 1e-6, "x[0] must be ≥ lb");
907    }
908
909    // ─── 6. DistributedAdmm – basic consensus ───────────────────────────
910
911    #[test]
912    fn test_distributed_admm_consensus() {
913        // Three projection sub-problems with overlapping boxes.
914        // The intersection is [1, 1] so consensus should settle near [1, 1].
915        let dim = 2usize;
916        let cfg = AdmmConfig::default()
917            .with_rho(2.0)
918            .with_max_iterations(200)
919            .with_abs_tol(1e-4);
920
921        let mut solver = DistributedAdmm::new(cfg, dim);
922        // Agent 0: [1, 3] × [1, 3]
923        solver
924            .add_subproblem(Box::new(ProjectionSubproblem::new(
925                Array1::from_vec(vec![1.0, 1.0]),
926                Array1::from_vec(vec![3.0, 3.0]),
927            )))
928            .expect("add agent 0");
929        // Agent 1: [0, 1] × [0, 1]
930        solver
931            .add_subproblem(Box::new(ProjectionSubproblem::new(
932                Array1::from_vec(vec![0.0, 0.0]),
933                Array1::from_vec(vec![1.0, 1.0]),
934            )))
935            .expect("add agent 1");
936        // Agent 2: [1, 2] × [0, 2]
937        solver
938            .add_subproblem(Box::new(ProjectionSubproblem::new(
939                Array1::from_vec(vec![1.0, 0.0]),
940                Array1::from_vec(vec![2.0, 2.0]),
941            )))
942            .expect("add agent 2");
943
944        let result = solver.solve().expect("distributed ADMM solve");
945        // Consensus should be near the common point [1, 1]
946        assert!((result.solution[0] - 1.0).abs() < 0.1, "x[0] ≈ 1");
947        assert!((result.solution[1] - 1.0).abs() < 0.1, "x[1] ≈ 1");
948    }
949
950    // ─── 7. DistributedAdmm – convergence within max_iter ───────────────
951
952    #[test]
953    fn test_distributed_admm_convergence() {
954        let dim = 4usize;
955        let cfg = AdmmConfig::default()
956            .with_max_iterations(500)
957            .with_abs_tol(1e-3);
958
959        let mut solver = DistributedAdmm::new(cfg, dim);
960        for _ in 0..3 {
961            solver
962                .add_subproblem(Box::new(ProjectionSubproblem::new(
963                    Array1::from_vec(vec![0.0; dim]),
964                    Array1::from_vec(vec![1.0; dim]),
965                )))
966                .expect("add subproblem");
967        }
968
969        let result = solver.solve().expect("solve");
970        // Primal/dual residuals must be finite
971        assert!(result.primal_residual.is_finite());
972        assert!(result.dual_residual.is_finite());
973    }
974
975    // ─── 8. DistributedAdmm – converged flag ────────────────────────────
976
977    #[test]
978    fn test_distributed_admm_convergence_flag() {
979        let dim = 2usize;
980        let cfg = AdmmConfig::default()
981            .with_max_iterations(1000)
982            .with_abs_tol(1e-3)
983            .with_rel_tol(1e-2);
984
985        let mut solver = DistributedAdmm::new(cfg, dim);
986        // Two identical boxes → trivial consensus, should converge fast
987        for _ in 0..2 {
988            solver
989                .add_subproblem(Box::new(ProjectionSubproblem::new(
990                    Array1::from_vec(vec![0.0, 0.0]),
991                    Array1::from_vec(vec![1.0, 1.0]),
992                )))
993                .expect("add subproblem");
994        }
995
996        let result = solver.solve().expect("solve");
997        assert!(result.converged, "should have converged");
998    }
999
1000    // ─── 9. ConsensusAdmm – weighted average ────────────────────────────
1001
1002    #[test]
1003    fn test_consensus_admm_weighted() {
1004        let dim = 1usize;
1005        // Two projection sub-problems on disjoint intervals:
1006        //   Agent 0: [0, 0.4] (weight 0.3)
1007        //   Agent 1: [0.6, 1.0] (weight 0.7)
1008        // Weighted average of endpoints ≈ 0.3*0.4 + 0.7*0.6 = 0.54  (roughly)
1009        let weights = Array1::from_vec(vec![0.3f32, 0.7]);
1010        let cfg = AdmmConfig::default()
1011            .with_max_iterations(500)
1012            .with_abs_tol(1e-3);
1013
1014        let mut solver = ConsensusAdmm::new_weighted(cfg, dim, weights).expect("new_weighted");
1015
1016        solver
1017            .add_subproblem(Box::new(ProjectionSubproblem::new(
1018                Array1::from_vec(vec![0.0]),
1019                Array1::from_vec(vec![0.4]),
1020            )))
1021            .expect("add agent 0");
1022        solver
1023            .add_subproblem(Box::new(ProjectionSubproblem::new(
1024                Array1::from_vec(vec![0.6]),
1025                Array1::from_vec(vec![1.0]),
1026            )))
1027            .expect("add agent 1");
1028
1029        let result = solver.solve().expect("weighted consensus solve");
1030        // Solution should be between the two boxes
1031        assert!(result.solution[0] >= 0.0 - 1e-4);
1032        assert!(result.solution[0] <= 1.0 + 1e-4);
1033    }
1034
1035    // ─── 10. Warm start reduces iteration count ──────────────────────────
1036
1037    #[test]
1038    fn test_admm_warm_start() {
1039        let dim = 3usize;
1040        let cfg = AdmmConfig::default()
1041            .with_rho(2.0)
1042            .with_max_iterations(500)
1043            .with_abs_tol(1e-5);
1044
1045        let make_solver = |cfg: AdmmConfig| -> DistributedAdmm {
1046            let mut s = DistributedAdmm::new(cfg, dim);
1047            for _ in 0..2 {
1048                s.add_subproblem(Box::new(ProjectionSubproblem::new(
1049                    Array1::from_vec(vec![0.5, 0.5, 0.5]),
1050                    Array1::from_vec(vec![1.0, 1.0, 1.0]),
1051                )))
1052                .expect("add subproblem");
1053            }
1054            s
1055        };
1056
1057        let cold = make_solver(cfg.clone()).solve().expect("cold solve");
1058        // Warm start from the already-computed solution
1059        let warm = make_solver(cfg)
1060            .solve_warm(cold.solution.clone())
1061            .expect("warm solve");
1062
1063        // Warm start should converge in fewer or equal iterations
1064        // (with a perfect warm start it may take 1 iteration)
1065        assert!(
1066            warm.iterations <= cold.iterations,
1067            "warm ({}) should not exceed cold ({})",
1068            warm.iterations,
1069            cold.iterations
1070        );
1071    }
1072
1073    // ─── 11. Error: no subproblems ───────────────────────────────────────
1074
1075    #[test]
1076    fn test_admm_no_subproblems_error() {
1077        let cfg = AdmmConfig::default();
1078        let solver = DistributedAdmm::new(cfg, 3);
1079        let result = solver.solve();
1080        assert!(
1081            matches!(result, Err(AdmmError::NoSubproblems)),
1082            "expected NoSubproblems error"
1083        );
1084    }
1085
1086    // ─── 12. Error: dimension mismatch ──────────────────────────────────
1087
1088    #[test]
1089    fn test_admm_dimension_mismatch() {
1090        let cfg = AdmmConfig::default();
1091        let mut solver = DistributedAdmm::new(cfg, 3);
1092        // Add a sub-problem with dim = 5 (wrong)
1093        let result = solver.add_subproblem(Box::new(ProjectionSubproblem::new(
1094            Array1::from_vec(vec![0.0; 5]),
1095            Array1::from_vec(vec![1.0; 5]),
1096        )));
1097        assert!(
1098            matches!(
1099                result,
1100                Err(AdmmError::DimensionMismatch {
1101                    expected: 3,
1102                    got: 5
1103                })
1104            ),
1105            "expected DimensionMismatch error"
1106        );
1107    }
1108}