1use clarabel::{algebra::*, solver::*};
15use cova_algebra::tensors::{DMatrix, DVector};
16
17use crate::{
18 SolverError, SolverResult,
19 traits::{OptimizationProblem, Solution},
20};
21
22#[derive(Debug, Clone)]
24pub struct AdmmParams {
25 pub rho: f64,
27 pub primal_tolerance: f64,
29 pub dual_tolerance: f64,
31 pub max_iterations: usize,
33 pub alpha: f64,
35}
36
37impl Default for AdmmParams {
38 fn default() -> Self {
39 Self {
40 rho: 1.0,
41 primal_tolerance: 1e-6,
42 dual_tolerance: 1e-6,
43 max_iterations: 1000,
44 alpha: 1.0,
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct AdmmSolver {
52 params: AdmmParams,
53}
54
55impl AdmmSolver {
56 pub fn new() -> Self { Self { params: AdmmParams::default() } }
58
59 pub fn with_params(params: AdmmParams) -> Self { Self { params } }
61
62 pub fn solve_qp_admm<F>(
69 &mut self,
70 p: &DMatrix<f64>,
71 q: &DVector<f64>,
72 a: &DMatrix<f64>,
73 b: &DMatrix<f64>,
74 c: &DVector<f64>,
75 z_update: F,
76 ) -> SolverResult<Solution>
77 where
78 F: Fn(&DVector<f64>, &DVector<f64>, f64) -> DVector<f64>,
79 {
80 let n = q.len();
81 let m = c.len();
82 let z_dim = b.ncols();
83
84 let mut x = DVector::zeros(n);
86 let mut z = DVector::zeros(z_dim);
87 let mut u = DVector::zeros(m); let ata = a.transpose() * a;
95 let p_aug = p + self.params.rho * &ata;
96
97 for iteration in 0..self.params.max_iterations {
98 let z_old = z.clone();
99
100 let rhs = b * &z - c + &u;
102 let q_aug = q + self.params.rho * (a.transpose() * &rhs);
103
104 x = self.solve_x_update(&p_aug, &q_aug)?;
105
106 let ax_plus_u = a * &x + &u;
108 let z_target = &ax_plus_u + c;
109 z = z_update(&z_target, &z_old, self.params.rho);
110
111 let residual = a * &x + b * &z - c;
113 u = &u + self.params.rho * &residual;
114
115 let primal_residual = residual.norm();
117 let dual_residual = self.params.rho * (a.transpose() * (&z - &z_old)).norm();
118
119 if primal_residual <= self.params.primal_tolerance
120 && dual_residual <= self.params.dual_tolerance
121 {
122 let objective_value = 0.5f64.mul_add(x.dot(&(p * &x)), q.dot(&x));
123 return Ok(Solution {
124 x,
125 objective_value,
126 iterations: iteration as u64 + 1,
127 converged: true,
128 termination: "Converged".to_string(),
129 });
130 }
131 }
132
133 let objective_value = 0.5f64.mul_add(x.dot(&(p * &x)), q.dot(&x));
134 Ok(Solution {
135 x,
136 objective_value,
137 iterations: self.params.max_iterations as u64,
138 converged: false,
139 termination: "MaxIterations".to_string(),
140 })
141 }
142
143 fn solve_x_update(&self, p: &DMatrix<f64>, q: &DVector<f64>) -> SolverResult<DVector<f64>> {
145 let n = q.len();
146
147 let (col_offsets, row_indices, values) = dense_to_csc(p);
150 let p_csc = CscMatrix::new(n, n, col_offsets, row_indices, values);
151
152 let q_vec: Vec<f64> = q.iter().cloned().collect();
153
154 let a_csc = CscMatrix::new(0, n, vec![0; n + 1], vec![], vec![]);
156 let b_vec: Vec<f64> = Vec::new();
157 let cones: Vec<SupportedConeT<f64>> = Vec::new();
158
159 let settings = DefaultSettingsBuilder::default().max_iter(1000).verbose(false).build().unwrap();
161
162 let mut solver =
164 DefaultSolver::new(&p_csc, &q_vec, &a_csc, &b_vec, &cones, settings).map_err(|e| {
165 SolverError::NumericalError { message: format!("Failed to create Clarabel solver: {e:?}") }
166 })?;
167
168 solver.solve();
169
170 let result = solver.solution;
171 match result.status {
172 SolverStatus::Solved => Ok(DVector::from_vec(result.x)),
173 _ => Err(SolverError::NumericalError {
174 message: format!("Clarabel failed with status: {:?}", result.status),
175 }),
176 }
177 }
178
179 pub fn solve_lasso(
181 &mut self,
182 a: &DMatrix<f64>,
183 b: &DVector<f64>,
184 lambda: f64,
185 ) -> SolverResult<Solution> {
186 let n = a.ncols();
187
188 let p = 2.0 * (a.transpose() * a);
193 let q = -2.0 * (a.transpose() * b);
194 let a_constraint = DMatrix::identity(n, n);
195 let b_constraint = -DMatrix::identity(n, n);
196 let c = DVector::zeros(n);
197
198 let z_update = move |target: &DVector<f64>, _z_old: &DVector<f64>, rho: f64| {
200 let threshold = lambda / rho;
201 target.map(|val| {
202 if val > threshold {
203 val - threshold
204 } else if val < -threshold {
205 val + threshold
206 } else {
207 0.0
208 }
209 })
210 };
211
212 self.solve_qp_admm(&p, &q, &a_constraint, &b_constraint, &c, z_update)
213 }
214
215 pub fn solve_basis_pursuit(
217 &mut self,
218 a: &DMatrix<f64>,
219 b: &DVector<f64>,
220 ) -> SolverResult<Solution> {
221 self.solve_lasso(a, b, 1e-8)
223 }
224}
225
226fn dense_to_csc(matrix: &DMatrix<f64>) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
228 let nrows = matrix.nrows();
229 let ncols = matrix.ncols();
230
231 let mut col_offsets = vec![0];
232 let mut row_indices = Vec::new();
233 let mut values = Vec::new();
234
235 for j in 0..ncols {
236 for i in 0..nrows {
237 let val = matrix[(i, j)];
238 if val.abs() > 1e-15 {
239 row_indices.push(i);
240 values.push(val);
241 }
242 }
243 col_offsets.push(row_indices.len());
244 }
245
246 (col_offsets, row_indices, values)
247}
248
249impl Default for AdmmSolver {
250 fn default() -> Self { Self::new() }
251}
252
253impl OptimizationProblem for AdmmSolver {
254 fn dimension(&self) -> usize {
255 0
257 }
258
259 fn solve(&self) -> SolverResult<Solution> {
260 Err(SolverError::InvalidProblem {
261 message: "ADMM solver requires specific problem setup via solve_lasso or solve_basis_pursuit"
262 .to_string(),
263 })
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_lasso_simple() {
273 let a = DMatrix::from_row_slice(2, 2, &[1.0, 0.0, 0.0, 1.0]);
274 let b = DVector::from_vec(vec![1.0, 2.0]);
275 let lambda = 0.1;
276
277 let mut solver = AdmmSolver::new();
278 let result = solver.solve_lasso(&a, &b, lambda).unwrap();
279
280 assert!(result.converged);
282 assert!(result.x.len() == 2);
283 }
284
285 #[test]
286 fn test_basis_pursuit() {
287 let a = DMatrix::from_row_slice(1, 2, &[1.0, 1.0]);
288 let b = DVector::from_vec(vec![1.0]);
289
290 let mut solver = AdmmSolver::new();
291 let result = solver.solve_basis_pursuit(&a, &b).unwrap();
292
293 assert!(result.converged);
295 let residual = (&a * &result.x - &b).norm();
296 assert!(residual < 1e-3);
297 }
298}