1use scirs2_core::ndarray::{Array1, Array2};
18use thiserror::Error;
19
20#[derive(Debug, Error)]
24pub enum AdmmError {
25 #[error("Dimension mismatch: expected {expected}, got {got}")]
27 DimensionMismatch { expected: usize, got: usize },
28
29 #[error("No subproblems added")]
31 NoSubproblems,
32
33 #[error("Subproblem solve failed: {0}")]
35 SubproblemFailed(String),
36
37 #[error("Maximum iterations reached without convergence")]
39 MaxIterationsReached,
40
41 #[error("Numerical error: {0}")]
43 NumericalError(String),
44}
45
46#[derive(Debug, Clone)]
50pub struct AdmmConfig {
51 pub rho: f32,
55 pub max_iterations: usize,
57 pub abs_tol: f32,
59 pub rel_tol: f32,
61 pub over_relaxation: f32,
65 pub verbose: bool,
67}
68
69impl Default for AdmmConfig {
70 fn default() -> Self {
71 Self {
72 rho: 1.0,
73 max_iterations: 100,
74 abs_tol: 1e-4,
75 rel_tol: 1e-3,
76 over_relaxation: 1.0,
77 verbose: false,
78 }
79 }
80}
81
82impl AdmmConfig {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub fn with_rho(mut self, rho: f32) -> Self {
90 self.rho = rho;
91 self
92 }
93
94 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
96 self.max_iterations = max_iterations;
97 self
98 }
99
100 pub fn with_abs_tol(mut self, abs_tol: f32) -> Self {
102 self.abs_tol = abs_tol;
103 self
104 }
105
106 pub fn with_rel_tol(mut self, rel_tol: f32) -> Self {
108 self.rel_tol = rel_tol;
109 self
110 }
111
112 pub fn with_over_relaxation(mut self, alpha: f32) -> Self {
114 self.over_relaxation = alpha;
115 self
116 }
117
118 pub fn with_verbose(mut self, verbose: bool) -> Self {
120 self.verbose = verbose;
121 self
122 }
123}
124
125#[derive(Debug, Clone)]
129pub struct AdmmResult {
130 pub solution: Array1<f32>,
132 pub iterations: usize,
134 pub converged: bool,
136 pub primal_residual: f32,
138 pub dual_residual: f32,
140 pub objective: f32,
142}
143
144pub trait AdmmSubproblem: Send + Sync {
155 fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, rho: f32) -> Result<Array1<f32>, AdmmError>;
162
163 fn objective(&self, x: &Array1<f32>) -> f32;
165
166 fn dim(&self) -> usize;
168}
169
170#[inline]
174fn l2_norm(v: &Array1<f32>) -> f32 {
175 v.iter().map(|&x| x * x).sum::<f32>().sqrt()
176}
177
178fn soft_threshold(v: &Array1<f32>, kappa: f32) -> Array1<f32> {
182 v.mapv(|x| {
183 if x > kappa {
184 x - kappa
185 } else if x < -kappa {
186 x + kappa
187 } else {
188 0.0
189 }
190 })
191}
192
193fn box_clip(v: &Array1<f32>, lb: &Array1<f32>, ub: &Array1<f32>) -> Array1<f32> {
195 v.iter()
196 .zip(lb.iter())
197 .zip(ub.iter())
198 .map(|((&vi, &li), &ui)| vi.clamp(li, ui))
199 .collect()
200}
201
202fn gauss_seidel_solve(
208 q: &Array2<f32>,
209 rho: f32,
210 b: &Array1<f32>,
211) -> Result<Array1<f32>, AdmmError> {
212 let n = b.len();
213 if q.nrows() != n || q.ncols() != n {
214 return Err(AdmmError::DimensionMismatch {
215 expected: n,
216 got: q.nrows(),
217 });
218 }
219
220 let mut x = Array1::<f32>::zeros(n);
221 let max_inner = 200usize;
224 for _iter in 0..max_inner {
225 let x_old = x.clone();
226 for i in 0..n {
227 let a_ii = q[[i, i]] + rho;
228 if a_ii.abs() < f32::EPSILON {
229 return Err(AdmmError::NumericalError(format!(
230 "Near-zero diagonal at index {i}"
231 )));
232 }
233 let mut sum = 0.0f32;
234 for j in 0..n {
235 if j != i {
236 sum += (q[[i, j]] + if i == j { rho } else { 0.0 }) * x[j];
237 }
238 }
239 x[i] = (b[i] - sum) / a_ii;
240 }
241 let diff: f32 = x
243 .iter()
244 .zip(x_old.iter())
245 .map(|(a, b)| (a - b) * (a - b))
246 .sum::<f32>()
247 .sqrt();
248 if diff < 1e-8 {
249 break;
250 }
251 }
252
253 for &xi in x.iter() {
255 if !xi.is_finite() {
256 return Err(AdmmError::NumericalError(
257 "Gauss-Seidel produced non-finite value".into(),
258 ));
259 }
260 }
261 Ok(x)
262}
263
264#[derive(Debug, Clone)]
275pub struct QuadraticSubproblem {
276 pub q: Array2<f32>,
278 pub c: Array1<f32>,
280 pub lb: Option<Array1<f32>>,
282 pub ub: Option<Array1<f32>>,
284}
285
286impl QuadraticSubproblem {
287 pub fn new(q: Array2<f32>, c: Array1<f32>) -> Self {
289 Self {
290 q,
291 c,
292 lb: None,
293 ub: None,
294 }
295 }
296
297 pub fn with_bounds(mut self, lb: Array1<f32>, ub: Array1<f32>) -> Self {
299 self.lb = Some(lb);
300 self.ub = Some(ub);
301 self
302 }
303}
304
305impl AdmmSubproblem for QuadraticSubproblem {
306 fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, rho: f32) -> Result<Array1<f32>, AdmmError> {
307 let n = self.dim();
308 if z.len() != n || u.len() != n {
309 return Err(AdmmError::DimensionMismatch {
310 expected: n,
311 got: z.len(),
312 });
313 }
314 let rhs: Array1<f32> = (z - u).mapv(|v| rho * v) - &self.c;
316
317 let x = gauss_seidel_solve(&self.q, rho, &rhs)?;
319
320 let x = match (&self.lb, &self.ub) {
322 (Some(lb), Some(ub)) => box_clip(&x, lb, ub),
323 (Some(lb), None) => x.mapv(|v| v.max(lb[0])),
324 (None, Some(ub)) => x.mapv(|v| v.min(ub[0])),
325 (None, None) => x,
326 };
327
328 Ok(x)
329 }
330
331 fn objective(&self, x: &Array1<f32>) -> f32 {
332 let qx: Array1<f32> = self.q.dot(x);
334 0.5 * x.iter().zip(qx.iter()).map(|(a, b)| a * b).sum::<f32>()
335 + x.iter().zip(self.c.iter()).map(|(a, b)| a * b).sum::<f32>()
336 }
337
338 fn dim(&self) -> usize {
339 self.c.len()
340 }
341}
342
343#[derive(Debug, Clone)]
353pub struct LassoSubproblem {
354 pub a: Array2<f32>,
356 pub b: Array1<f32>,
358 pub lambda: f32,
360}
361
362impl LassoSubproblem {
363 pub fn new(a: Array2<f32>, b: Array1<f32>, lambda: f32) -> Self {
365 Self { a, b, lambda }
366 }
367}
368
369impl AdmmSubproblem for LassoSubproblem {
370 fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, rho: f32) -> Result<Array1<f32>, AdmmError> {
371 let n = self.dim();
372 if z.len() != n || u.len() != n {
373 return Err(AdmmError::DimensionMismatch {
374 expected: n,
375 got: z.len(),
376 });
377 }
378
379 let at = self.a.t();
381 let ata: Array2<f32> = at.dot(&self.a);
382
383 let atb: Array1<f32> = at.dot(&self.b);
385 let rhs: Array1<f32> = atb + (z - u).mapv(|v| rho * v);
386
387 let v = gauss_seidel_solve(&ata, rho, &rhs)?;
389
390 let kappa = self.lambda / rho;
392 Ok(soft_threshold(&v, kappa))
393 }
394
395 fn objective(&self, x: &Array1<f32>) -> f32 {
396 let ax_minus_b: Array1<f32> = self.a.dot(x) - &self.b;
397 let l2_sq: f32 = ax_minus_b.iter().map(|&v| v * v).sum();
398 let l1: f32 = x.iter().map(|&v| v.abs()).sum();
399 l2_sq + self.lambda * l1
400 }
401
402 fn dim(&self) -> usize {
403 self.a.ncols()
404 }
405}
406
407#[derive(Debug, Clone)]
415pub struct ProjectionSubproblem {
416 pub lb: Array1<f32>,
418 pub ub: Array1<f32>,
420}
421
422impl ProjectionSubproblem {
423 pub fn new(lb: Array1<f32>, ub: Array1<f32>) -> Self {
425 Self { lb, ub }
426 }
427}
428
429impl AdmmSubproblem for ProjectionSubproblem {
430 fn solve(&self, z: &Array1<f32>, u: &Array1<f32>, _rho: f32) -> Result<Array1<f32>, AdmmError> {
431 let n = self.dim();
432 if z.len() != n || u.len() != n {
433 return Err(AdmmError::DimensionMismatch {
434 expected: n,
435 got: z.len(),
436 });
437 }
438 let v: Array1<f32> = z - u;
439 Ok(box_clip(&v, &self.lb, &self.ub))
440 }
441
442 fn objective(&self, _x: &Array1<f32>) -> f32 {
443 0.0
446 }
447
448 fn dim(&self) -> usize {
449 self.lb.len()
450 }
451}
452
453type AdmmSweepOutput = (Vec<Array1<f32>>, Vec<Array1<f32>>, Array1<f32>, f32, f32);
458
459fn admm_sweep(
463 subproblems: &[Box<dyn AdmmSubproblem>],
464 z: &Array1<f32>,
465 us: &[Array1<f32>],
466 rho: f32,
467 alpha: f32,
468 weights: Option<&Array1<f32>>,
469) -> Result<AdmmSweepOutput, AdmmError> {
470 let n_agents = subproblems.len();
471 let dim = z.len();
472
473 use rayon::prelude::*;
476
477 let x_results: Vec<Result<Array1<f32>, AdmmError>> = subproblems
478 .par_iter()
479 .enumerate()
480 .map(|(i, sp)| sp.solve(z, &us[i], rho))
481 .collect();
482
483 let mut new_xs: Vec<Array1<f32>> = Vec::with_capacity(n_agents);
484 for result in x_results {
485 new_xs.push(result?);
486 }
487
488 let z_old = z.clone();
492
493 let (x_avg, u_avg) = match weights {
494 Some(w) => {
495 let mut xa = Array1::<f32>::zeros(dim);
497 let mut ua = Array1::<f32>::zeros(dim);
498 for i in 0..n_agents {
499 xa = xa + new_xs[i].mapv(|v| v * w[i]);
500 ua = ua + us[i].mapv(|v| v * w[i]);
501 }
502 (xa, ua)
503 }
504 None => {
505 let mut xa = Array1::<f32>::zeros(dim);
507 let mut ua = Array1::<f32>::zeros(dim);
508 for i in 0..n_agents {
509 xa += &new_xs[i];
510 ua += &us[i];
511 }
512 let inv_n = 1.0 / n_agents as f32;
513 (xa.mapv(|v| v * inv_n), ua.mapv(|v| v * inv_n))
514 }
515 };
516
517 let z_new: Array1<f32> = x_avg.mapv(|v| alpha * v) + z_old.mapv(|v| (1.0 - alpha) * v) + &u_avg;
519
520 let mut new_us: Vec<Array1<f32>> = Vec::with_capacity(n_agents);
524 let mut primal_sq = 0.0f32;
525 for i in 0..n_agents {
526 let x_tilde: Array1<f32> =
527 new_xs[i].mapv(|v| alpha * v) + z_old.mapv(|v| (1.0 - alpha) * v);
528 let residual_i: Array1<f32> = &x_tilde - &z_new;
529 primal_sq += residual_i.iter().map(|&v| v * v).sum::<f32>();
530 let u_new: Array1<f32> = &us[i] + &residual_i;
531 new_us.push(u_new);
532 }
533 let primal_res = (primal_sq / n_agents as f32).sqrt();
534
535 let dual_res = rho * l2_norm(&(&z_new - &z_old));
537
538 Ok((new_xs, new_us, z_new, primal_res, dual_res))
539}
540
541fn check_convergence(
546 primal_res: f32,
547 dual_res: f32,
548 config: &AdmmConfig,
549 n_agents: usize,
550 dim: usize,
551) -> bool {
552 let scale = ((n_agents * dim) as f32).sqrt();
553 let eps_primal = config.abs_tol * scale + config.rel_tol;
554 let eps_dual = config.abs_tol * scale + config.rel_tol;
555 primal_res < eps_primal && dual_res < eps_dual
556}
557
558pub struct DistributedAdmm {
565 config: AdmmConfig,
566 subproblems: Vec<Box<dyn AdmmSubproblem>>,
567 dim: usize,
568}
569
570impl DistributedAdmm {
571 pub fn new(config: AdmmConfig, dim: usize) -> Self {
573 Self {
574 config,
575 subproblems: Vec::new(),
576 dim,
577 }
578 }
579
580 pub fn add_subproblem(&mut self, subproblem: Box<dyn AdmmSubproblem>) -> Result<(), AdmmError> {
585 if subproblem.dim() != self.dim {
586 return Err(AdmmError::DimensionMismatch {
587 expected: self.dim,
588 got: subproblem.dim(),
589 });
590 }
591 self.subproblems.push(subproblem);
592 Ok(())
593 }
594
595 pub fn num_subproblems(&self) -> usize {
597 self.subproblems.len()
598 }
599
600 pub fn solve(&self) -> Result<AdmmResult, AdmmError> {
602 self.solve_warm(Array1::zeros(self.dim))
603 }
604
605 pub fn solve_warm(&self, z_init: Array1<f32>) -> Result<AdmmResult, AdmmError> {
607 if self.subproblems.is_empty() {
608 return Err(AdmmError::NoSubproblems);
609 }
610 if z_init.len() != self.dim {
611 return Err(AdmmError::DimensionMismatch {
612 expected: self.dim,
613 got: z_init.len(),
614 });
615 }
616
617 let n_agents = self.subproblems.len();
618 let mut z = z_init;
619 let mut us: Vec<Array1<f32>> = vec![Array1::zeros(self.dim); n_agents];
620
621 let mut primal_res = f32::INFINITY;
622 let mut dual_res = f32::INFINITY;
623 let mut iterations = 0usize;
624 let mut converged = false;
625
626 for iter in 0..self.config.max_iterations {
627 iterations = iter + 1;
628 let (new_xs, new_us, z_new, pr, dr) = admm_sweep(
629 &self.subproblems,
630 &z,
631 &us,
632 self.config.rho,
633 self.config.over_relaxation,
634 None,
635 )?;
636
637 let _ = new_xs; us = new_us;
639 z = z_new;
640 primal_res = pr;
641 dual_res = dr;
642
643 if self.config.verbose {
644 tracing::debug!(iter = iterations, primal_res, dual_res, "ADMM iteration");
645 }
646
647 if check_convergence(primal_res, dual_res, &self.config, n_agents, self.dim) {
648 converged = true;
649 break;
650 }
651 }
652
653 let objective: f32 = self.subproblems.iter().map(|sp| sp.objective(&z)).sum();
655
656 Ok(AdmmResult {
657 solution: z,
658 iterations,
659 converged,
660 primal_residual: primal_res,
661 dual_residual: dual_res,
662 objective,
663 })
664 }
665}
666
667pub struct ConsensusAdmm {
675 config: AdmmConfig,
676 subproblems: Vec<Box<dyn AdmmSubproblem>>,
677 dim: usize,
678 weights: Option<Array1<f32>>,
679}
680
681impl ConsensusAdmm {
682 pub fn new(config: AdmmConfig, dim: usize) -> Self {
684 Self {
685 config,
686 subproblems: Vec::new(),
687 dim,
688 weights: None,
689 }
690 }
691
692 pub fn new_weighted(
697 config: AdmmConfig,
698 dim: usize,
699 weights: Array1<f32>,
700 ) -> Result<Self, AdmmError> {
701 let sum: f32 = weights.iter().sum();
702 if (sum - 1.0).abs() > 1e-4 {
703 return Err(AdmmError::NumericalError(format!(
704 "Weights must sum to 1.0, got {sum:.6}"
705 )));
706 }
707 Ok(Self {
708 config,
709 subproblems: Vec::new(),
710 dim,
711 weights: Some(weights),
712 })
713 }
714
715 pub fn add_subproblem(&mut self, subproblem: Box<dyn AdmmSubproblem>) -> Result<(), AdmmError> {
717 if subproblem.dim() != self.dim {
718 return Err(AdmmError::DimensionMismatch {
719 expected: self.dim,
720 got: subproblem.dim(),
721 });
722 }
723 self.subproblems.push(subproblem);
724 Ok(())
725 }
726
727 pub fn solve(&self) -> Result<AdmmResult, AdmmError> {
729 if self.subproblems.is_empty() {
730 return Err(AdmmError::NoSubproblems);
731 }
732
733 let n_agents = self.subproblems.len();
734
735 if let Some(w) = &self.weights {
737 if w.len() != n_agents {
738 return Err(AdmmError::DimensionMismatch {
739 expected: n_agents,
740 got: w.len(),
741 });
742 }
743 }
744
745 let mut z = Array1::<f32>::zeros(self.dim);
746 let mut us: Vec<Array1<f32>> = vec![Array1::zeros(self.dim); n_agents];
747
748 let mut primal_res = f32::INFINITY;
749 let mut dual_res = f32::INFINITY;
750 let mut iterations = 0usize;
751 let mut converged = false;
752
753 for iter in 0..self.config.max_iterations {
754 iterations = iter + 1;
755 let (new_xs, new_us, z_new, pr, dr) = admm_sweep(
756 &self.subproblems,
757 &z,
758 &us,
759 self.config.rho,
760 self.config.over_relaxation,
761 self.weights.as_ref(),
762 )?;
763
764 let _ = new_xs;
765 us = new_us;
766 z = z_new;
767 primal_res = pr;
768 dual_res = dr;
769
770 if self.config.verbose {
771 tracing::debug!(
772 iter = iterations,
773 primal_res,
774 dual_res,
775 "ConsensusADMM iteration"
776 );
777 }
778
779 if check_convergence(primal_res, dual_res, &self.config, n_agents, self.dim) {
780 converged = true;
781 break;
782 }
783 }
784
785 let objective: f32 = self.subproblems.iter().map(|sp| sp.objective(&z)).sum();
786
787 Ok(AdmmResult {
788 solution: z,
789 iterations,
790 converged,
791 primal_residual: primal_res,
792 dual_residual: dual_res,
793 objective,
794 })
795 }
796}
797
798#[cfg(test)]
801mod tests {
802 use super::*;
803 use scirs2_core::ndarray::{Array1, Array2};
804
805 fn eye(n: usize, scale: f32) -> Array2<f32> {
809 let mut m = Array2::<f32>::zeros((n, n));
810 for i in 0..n {
811 m[[i, i]] = scale;
812 }
813 m
814 }
815
816 #[test]
819 fn test_admm_config_default() {
820 let cfg = AdmmConfig::default();
821 assert!(cfg.rho > 0.0, "rho must be positive");
822 assert!(cfg.max_iterations > 0, "max_iterations must be positive");
823 assert!(cfg.abs_tol > 0.0, "abs_tol must be positive");
824 assert!(cfg.rel_tol > 0.0, "rel_tol must be positive");
825 assert!(
826 (0.0..2.0).contains(&cfg.over_relaxation),
827 "over_relaxation must be in (0, 2)"
828 );
829 }
830
831 #[test]
834 fn test_projection_subproblem() {
835 let lb = Array1::from_vec(vec![0.0, -1.0, 2.0]);
836 let ub = Array1::from_vec(vec![1.0, 1.0, 5.0]);
837 let sp = ProjectionSubproblem::new(lb.clone(), ub.clone());
838
839 let z = Array1::from_vec(vec![-2.0, 3.0, 10.0]);
841 let u = Array1::zeros(3);
842 let x = sp.solve(&z, &u, 1.0).expect("projection should succeed");
843 assert!((x[0] - 0.0).abs() < 1e-6, "should clip to lb[0]");
844 assert!((x[1] - 1.0).abs() < 1e-6, "should clip to ub[1]");
845 assert!((x[2] - 5.0).abs() < 1e-6, "should clip to ub[2]");
846 }
847
848 #[test]
851 fn test_lasso_subproblem_soft_threshold() {
852 let n = 3usize;
854 let a = eye(n, 1.0);
855 let b = Array1::zeros(n);
856 let lambda = 10.0f32;
857 let sp = LassoSubproblem::new(a, b, lambda);
858
859 let z = Array1::from_vec(vec![0.5, -0.5, 0.2]);
860 let u = Array1::zeros(n);
861 let x = sp.solve(&z, &u, 1.0).expect("lasso solve");
862
863 for &xi in x.iter() {
865 assert!(xi.abs() < 0.6, "soft-threshold should shrink the solution");
866 }
867 }
868
869 #[test]
872 fn test_quadratic_subproblem_unconstrained() {
873 let n = 2usize;
878 let q = eye(n, 1.0);
879 let c = Array1::from_vec(vec![-2.0, -2.0]);
880 let sp = QuadraticSubproblem::new(q, c);
881
882 let z = Array1::from_vec(vec![2.0, 2.0]);
883 let u = Array1::zeros(n);
884 let x = sp.solve(&z, &u, 1.0).expect("quadratic solve");
885 assert!((x[0] - 2.0).abs() < 1e-3, "x[0] ≈ 2: got {}", x[0]);
886 assert!((x[1] - 2.0).abs() < 1e-3, "x[1] ≈ 2: got {}", x[1]);
887 }
888
889 #[test]
892 fn test_quadratic_subproblem_constrained() {
893 let n = 2usize;
894 let q = eye(n, 1.0);
895 let c = Array1::from_vec(vec![-5.0, -5.0]);
896 let lb = Array1::from_vec(vec![0.0, 0.0]);
897 let ub = Array1::from_vec(vec![1.0, 1.0]); let sp = QuadraticSubproblem::new(q, c).with_bounds(lb, ub);
899
900 let z = Array1::from_vec(vec![3.0, 3.0]);
901 let u = Array1::zeros(n);
902 let x = sp.solve(&z, &u, 1.0).expect("constrained quadratic solve");
903 assert!(x[0] <= 1.0 + 1e-6, "x[0] must be ≤ ub");
905 assert!(x[1] <= 1.0 + 1e-6, "x[1] must be ≤ ub");
906 assert!(x[0] >= 0.0 - 1e-6, "x[0] must be ≥ lb");
907 }
908
909 #[test]
912 fn test_distributed_admm_consensus() {
913 let dim = 2usize;
916 let cfg = AdmmConfig::default()
917 .with_rho(2.0)
918 .with_max_iterations(200)
919 .with_abs_tol(1e-4);
920
921 let mut solver = DistributedAdmm::new(cfg, dim);
922 solver
924 .add_subproblem(Box::new(ProjectionSubproblem::new(
925 Array1::from_vec(vec![1.0, 1.0]),
926 Array1::from_vec(vec![3.0, 3.0]),
927 )))
928 .expect("add agent 0");
929 solver
931 .add_subproblem(Box::new(ProjectionSubproblem::new(
932 Array1::from_vec(vec![0.0, 0.0]),
933 Array1::from_vec(vec![1.0, 1.0]),
934 )))
935 .expect("add agent 1");
936 solver
938 .add_subproblem(Box::new(ProjectionSubproblem::new(
939 Array1::from_vec(vec![1.0, 0.0]),
940 Array1::from_vec(vec![2.0, 2.0]),
941 )))
942 .expect("add agent 2");
943
944 let result = solver.solve().expect("distributed ADMM solve");
945 assert!((result.solution[0] - 1.0).abs() < 0.1, "x[0] ≈ 1");
947 assert!((result.solution[1] - 1.0).abs() < 0.1, "x[1] ≈ 1");
948 }
949
950 #[test]
953 fn test_distributed_admm_convergence() {
954 let dim = 4usize;
955 let cfg = AdmmConfig::default()
956 .with_max_iterations(500)
957 .with_abs_tol(1e-3);
958
959 let mut solver = DistributedAdmm::new(cfg, dim);
960 for _ in 0..3 {
961 solver
962 .add_subproblem(Box::new(ProjectionSubproblem::new(
963 Array1::from_vec(vec![0.0; dim]),
964 Array1::from_vec(vec![1.0; dim]),
965 )))
966 .expect("add subproblem");
967 }
968
969 let result = solver.solve().expect("solve");
970 assert!(result.primal_residual.is_finite());
972 assert!(result.dual_residual.is_finite());
973 }
974
975 #[test]
978 fn test_distributed_admm_convergence_flag() {
979 let dim = 2usize;
980 let cfg = AdmmConfig::default()
981 .with_max_iterations(1000)
982 .with_abs_tol(1e-3)
983 .with_rel_tol(1e-2);
984
985 let mut solver = DistributedAdmm::new(cfg, dim);
986 for _ in 0..2 {
988 solver
989 .add_subproblem(Box::new(ProjectionSubproblem::new(
990 Array1::from_vec(vec![0.0, 0.0]),
991 Array1::from_vec(vec![1.0, 1.0]),
992 )))
993 .expect("add subproblem");
994 }
995
996 let result = solver.solve().expect("solve");
997 assert!(result.converged, "should have converged");
998 }
999
1000 #[test]
1003 fn test_consensus_admm_weighted() {
1004 let dim = 1usize;
1005 let weights = Array1::from_vec(vec![0.3f32, 0.7]);
1010 let cfg = AdmmConfig::default()
1011 .with_max_iterations(500)
1012 .with_abs_tol(1e-3);
1013
1014 let mut solver = ConsensusAdmm::new_weighted(cfg, dim, weights).expect("new_weighted");
1015
1016 solver
1017 .add_subproblem(Box::new(ProjectionSubproblem::new(
1018 Array1::from_vec(vec![0.0]),
1019 Array1::from_vec(vec![0.4]),
1020 )))
1021 .expect("add agent 0");
1022 solver
1023 .add_subproblem(Box::new(ProjectionSubproblem::new(
1024 Array1::from_vec(vec![0.6]),
1025 Array1::from_vec(vec![1.0]),
1026 )))
1027 .expect("add agent 1");
1028
1029 let result = solver.solve().expect("weighted consensus solve");
1030 assert!(result.solution[0] >= 0.0 - 1e-4);
1032 assert!(result.solution[0] <= 1.0 + 1e-4);
1033 }
1034
1035 #[test]
1038 fn test_admm_warm_start() {
1039 let dim = 3usize;
1040 let cfg = AdmmConfig::default()
1041 .with_rho(2.0)
1042 .with_max_iterations(500)
1043 .with_abs_tol(1e-5);
1044
1045 let make_solver = |cfg: AdmmConfig| -> DistributedAdmm {
1046 let mut s = DistributedAdmm::new(cfg, dim);
1047 for _ in 0..2 {
1048 s.add_subproblem(Box::new(ProjectionSubproblem::new(
1049 Array1::from_vec(vec![0.5, 0.5, 0.5]),
1050 Array1::from_vec(vec![1.0, 1.0, 1.0]),
1051 )))
1052 .expect("add subproblem");
1053 }
1054 s
1055 };
1056
1057 let cold = make_solver(cfg.clone()).solve().expect("cold solve");
1058 let warm = make_solver(cfg)
1060 .solve_warm(cold.solution.clone())
1061 .expect("warm solve");
1062
1063 assert!(
1066 warm.iterations <= cold.iterations,
1067 "warm ({}) should not exceed cold ({})",
1068 warm.iterations,
1069 cold.iterations
1070 );
1071 }
1072
1073 #[test]
1076 fn test_admm_no_subproblems_error() {
1077 let cfg = AdmmConfig::default();
1078 let solver = DistributedAdmm::new(cfg, 3);
1079 let result = solver.solve();
1080 assert!(
1081 matches!(result, Err(AdmmError::NoSubproblems)),
1082 "expected NoSubproblems error"
1083 );
1084 }
1085
1086 #[test]
1089 fn test_admm_dimension_mismatch() {
1090 let cfg = AdmmConfig::default();
1091 let mut solver = DistributedAdmm::new(cfg, 3);
1092 let result = solver.add_subproblem(Box::new(ProjectionSubproblem::new(
1094 Array1::from_vec(vec![0.0; 5]),
1095 Array1::from_vec(vec![1.0; 5]),
1096 )));
1097 assert!(
1098 matches!(
1099 result,
1100 Err(AdmmError::DimensionMismatch {
1101 expected: 3,
1102 got: 5
1103 })
1104 ),
1105 "expected DimensionMismatch error"
1106 );
1107 }
1108}