Skip to main content

cova_solver/
admm.rs

1//! ADMM (Alternating Direction Method of Multipliers) using Clarabel
2//!
3//! This module implements ADMM by using Clarabel to solve the convex subproblems.
4//! ADMM solves problems of the form:
5//!
6//! minimize    f(x) + g(z)
7//! subject to  Ax + Bz = c
8//!
9//! The algorithm alternates between:
10//! 1. x-update: minimize f(x) + (ρ/2)||Ax + Bz - c + u||²
11//! 2. z-update: minimize g(z) + (ρ/2)||Ax + Bz - c + u||²
12//! 3. u-update: u := u + ρ(Ax + Bz - c)
13
14use clarabel::{algebra::*, solver::*};
15use cova_algebra::tensors::{DMatrix, DVector};
16
17use crate::{
18  SolverError, SolverResult,
19  traits::{OptimizationProblem, Solution},
20};
21
22/// ADMM parameters
23#[derive(Debug, Clone)]
24pub struct AdmmParams {
25  /// Penalty parameter (ρ)
26  pub rho:              f64,
27  /// Primal tolerance
28  pub primal_tolerance: f64,
29  /// Dual tolerance
30  pub dual_tolerance:   f64,
31  /// Maximum iterations
32  pub max_iterations:   usize,
33  /// Over-relaxation parameter (α)
34  pub alpha:            f64,
35}
36
37impl Default for AdmmParams {
38  fn default() -> Self {
39    Self {
40      rho:              1.0,
41      primal_tolerance: 1e-6,
42      dual_tolerance:   1e-6,
43      max_iterations:   1000,
44      alpha:            1.0,
45    }
46  }
47}
48
49/// ADMM solver using Clarabel for subproblems
50#[derive(Debug)]
51pub struct AdmmSolver {
52  params: AdmmParams,
53}
54
55impl AdmmSolver {
56  /// Create a new ADMM solver
57  pub fn new() -> Self { Self { params: AdmmParams::default() } }
58
59  /// Create ADMM solver with custom parameters
60  pub fn with_params(params: AdmmParams) -> Self { Self { params } }
61
62  /// Solve quadratic programming with equality constraints using ADMM
63  ///
64  /// minimize    (1/2) x^T P x + q^T x + g(z)
65  /// subject to  Ax + Bz = c
66  ///
67  /// where g(z) is handled by the z-update (e.g., indicator functions, norms)
68  pub fn solve_qp_admm<F>(
69    &mut self,
70    p: &DMatrix<f64>,
71    q: &DVector<f64>,
72    a: &DMatrix<f64>,
73    b: &DMatrix<f64>,
74    c: &DVector<f64>,
75    z_update: F,
76  ) -> SolverResult<Solution>
77  where
78    F: Fn(&DVector<f64>, &DVector<f64>, f64) -> DVector<f64>,
79  {
80    let n = q.len();
81    let m = c.len();
82    let z_dim = b.ncols();
83
84    // Initialize variables
85    let mut x = DVector::zeros(n);
86    let mut z = DVector::zeros(z_dim);
87    let mut u = DVector::zeros(m); // Dual variable
88
89    // Set up Clarabel solver for x-update
90    // The x-update subproblem is:
91    // minimize (1/2) x^T P x + q^T x + (ρ/2)||Ax + Bz - c + u||²
92    // This becomes: minimize (1/2) x^T (P + ρA^TA) x + (q + ρA^T(Bz - c + u))^T x
93
94    let ata = a.transpose() * a;
95    let p_aug = p + self.params.rho * &ata;
96
97    for iteration in 0..self.params.max_iterations {
98      let z_old = z.clone();
99
100      // x-update: solve QP using Clarabel
101      let rhs = b * &z - c + &u;
102      let q_aug = q + self.params.rho * (a.transpose() * &rhs);
103
104      x = self.solve_x_update(&p_aug, &q_aug)?;
105
106      // z-update: apply proximal operator (problem-specific)
107      let ax_plus_u = a * &x + &u;
108      let z_target = &ax_plus_u + c;
109      z = z_update(&z_target, &z_old, self.params.rho);
110
111      // u-update: dual variable update
112      let residual = a * &x + b * &z - c;
113      u = &u + self.params.rho * &residual;
114
115      // Check convergence
116      let primal_residual = residual.norm();
117      let dual_residual = self.params.rho * (a.transpose() * (&z - &z_old)).norm();
118
119      if primal_residual <= self.params.primal_tolerance
120        && dual_residual <= self.params.dual_tolerance
121      {
122        let objective_value = 0.5f64.mul_add(x.dot(&(p * &x)), q.dot(&x));
123        return Ok(Solution {
124          x,
125          objective_value,
126          iterations: iteration as u64 + 1,
127          converged: true,
128          termination: "Converged".to_string(),
129        });
130      }
131    }
132
133    let objective_value = 0.5f64.mul_add(x.dot(&(p * &x)), q.dot(&x));
134    Ok(Solution {
135      x,
136      objective_value,
137      iterations: self.params.max_iterations as u64,
138      converged: false,
139      termination: "MaxIterations".to_string(),
140    })
141  }
142
143  /// Solve the x-update subproblem using Clarabel
144  fn solve_x_update(&self, p: &DMatrix<f64>, q: &DVector<f64>) -> SolverResult<DVector<f64>> {
145    let n = q.len();
146
147    // Convert dense P matrix to sparse CSC format for Clarabel
148    // For now, use a simple dense-to-sparse conversion
149    let (col_offsets, row_indices, values) = dense_to_csc(p);
150    let p_csc = CscMatrix::new(n, n, col_offsets, row_indices, values);
151
152    let q_vec: Vec<f64> = q.iter().cloned().collect();
153
154    // No constraints for the x-update (it's unconstrained QP)
155    let a_csc = CscMatrix::new(0, n, vec![0; n + 1], vec![], vec![]);
156    let b_vec: Vec<f64> = Vec::new();
157    let cones: Vec<SupportedConeT<f64>> = Vec::new();
158
159    // Set up Clarabel settings
160    let settings = DefaultSettingsBuilder::default().max_iter(1000).verbose(false).build().unwrap();
161
162    // Create and solve the problem
163    let mut solver =
164      DefaultSolver::new(&p_csc, &q_vec, &a_csc, &b_vec, &cones, settings).map_err(|e| {
165        SolverError::NumericalError { message: format!("Failed to create Clarabel solver: {e:?}") }
166      })?;
167
168    solver.solve();
169
170    let result = solver.solution;
171    match result.status {
172      SolverStatus::Solved => Ok(DVector::from_vec(result.x)),
173      _ => Err(SolverError::NumericalError {
174        message: format!("Clarabel failed with status: {:?}", result.status),
175      }),
176    }
177  }
178
179  /// Solve LASSO problem: minimize ||Ax - b||² + λ||x||₁
180  pub fn solve_lasso(
181    &mut self,
182    a: &DMatrix<f64>,
183    b: &DVector<f64>,
184    lambda: f64,
185  ) -> SolverResult<Solution> {
186    let n = a.ncols();
187
188    // LASSO as ADMM:
189    // minimize ||Ax - b||² + λ||z||₁
190    // subject to x = z
191
192    let p = 2.0 * (a.transpose() * a);
193    let q = -2.0 * (a.transpose() * b);
194    let a_constraint = DMatrix::identity(n, n);
195    let b_constraint = -DMatrix::identity(n, n);
196    let c = DVector::zeros(n);
197
198    // z-update for LASSO: soft thresholding
199    let z_update = move |target: &DVector<f64>, _z_old: &DVector<f64>, rho: f64| {
200      let threshold = lambda / rho;
201      target.map(|val| {
202        if val > threshold {
203          val - threshold
204        } else if val < -threshold {
205          val + threshold
206        } else {
207          0.0
208        }
209      })
210    };
211
212    self.solve_qp_admm(&p, &q, &a_constraint, &b_constraint, &c, z_update)
213  }
214
215  /// Solve basis pursuit: minimize ||x||₁ subject to Ax = b
216  pub fn solve_basis_pursuit(
217    &mut self,
218    a: &DMatrix<f64>,
219    b: &DVector<f64>,
220  ) -> SolverResult<Solution> {
221    // Use LASSO with very small λ to approximate basis pursuit
222    self.solve_lasso(a, b, 1e-8)
223  }
224}
225
226/// Convert dense matrix to CSC format (col_offsets, row_indices, values)
227fn dense_to_csc(matrix: &DMatrix<f64>) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
228  let nrows = matrix.nrows();
229  let ncols = matrix.ncols();
230
231  let mut col_offsets = vec![0];
232  let mut row_indices = Vec::new();
233  let mut values = Vec::new();
234
235  for j in 0..ncols {
236    for i in 0..nrows {
237      let val = matrix[(i, j)];
238      if val.abs() > 1e-15 {
239        row_indices.push(i);
240        values.push(val);
241      }
242    }
243    col_offsets.push(row_indices.len());
244  }
245
246  (col_offsets, row_indices, values)
247}
248
249impl Default for AdmmSolver {
250  fn default() -> Self { Self::new() }
251}
252
253impl OptimizationProblem for AdmmSolver {
254  fn dimension(&self) -> usize {
255    // This is problem-dependent, so we'll return 0 as placeholder
256    0
257  }
258
259  fn solve(&self) -> SolverResult<Solution> {
260    Err(SolverError::InvalidProblem {
261      message: "ADMM solver requires specific problem setup via solve_lasso or solve_basis_pursuit"
262        .to_string(),
263    })
264  }
265}
266
267#[cfg(test)]
268mod tests {
269  use super::*;
270
271  #[test]
272  fn test_lasso_simple() {
273    let a = DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]);
274    let b = DVector::from_vec(vec![1.0, 2.0]);
275    let lambda = 0.1;
276
277    let mut solver = AdmmSolver::new();
278    let result = solver.solve_lasso(&a, &b, lambda).unwrap();
279
280    // Should find sparse solution close to [1.0, 2.0] with some shrinkage
281    assert!(result.converged);
282    assert!(result.x.len() == 2);
283  }
284
285  #[test]
286  fn test_basis_pursuit() {
287    let a = DMatrix::from_row_slice(1, 2, &[1.0, 1.0]);
288    let b = DVector::from_vec(vec![1.0]);
289
290    let mut solver = AdmmSolver::new();
291    let result = solver.solve_basis_pursuit(&a, &b).unwrap();
292
293    // Should find sparse solution that satisfies Ax = b
294    assert!(result.converged);
295    let residual = (&a * &result.x - &b).norm();
296    assert!(residual < 1e-3);
297  }
298}