use std::collections::HashMap;
use clarabel::algebra::CscMatrix as ClarabelCsc;
use clarabel::solver::{
DefaultSettingsBuilder, DefaultSolver, IPSolver, SolverStatus, SupportedConeT,
};
use super::stuffing::{ConeDims, StuffedProblem, VariableMap};
use crate::expr::{Array, ExprId};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SolveStatus {
Optimal,
Infeasible,
Unbounded,
MaxIterations,
NumericalError,
Unknown,
}
impl From<SolverStatus> for SolveStatus {
fn from(status: SolverStatus) -> Self {
match status {
SolverStatus::Solved => SolveStatus::Optimal,
SolverStatus::PrimalInfeasible => SolveStatus::Infeasible,
SolverStatus::DualInfeasible => SolveStatus::Unbounded,
SolverStatus::MaxIterations => SolveStatus::MaxIterations,
SolverStatus::MaxTime => SolveStatus::MaxIterations,
_ => SolveStatus::Unknown,
}
}
}
#[derive(Debug, Clone)]
pub struct Settings {
pub verbose: bool,
pub max_iter: u32,
pub time_limit: f64,
pub tol_gap_abs: f64,
pub tol_gap_rel: f64,
}
impl Default for Settings {
fn default() -> Self {
Settings {
verbose: false,
max_iter: 100,
time_limit: f64::INFINITY,
tol_gap_abs: 1e-8,
tol_gap_rel: 1e-8,
}
}
}
#[derive(Debug, Clone)]
pub struct Solution {
pub status: SolveStatus,
pub value: Option<f64>,
pub primal: Option<HashMap<ExprId, Array>>,
pub dual: Option<Vec<f64>>,
pub solve_time: f64,
pub iterations: u32,
}
impl Solution {
pub fn get_value(&self, var_id: ExprId) -> Option<&Array> {
self.primal.as_ref().and_then(|p| p.get(&var_id))
}
pub fn value(&self, var: &crate::expr::Expr) -> f64 {
self.try_value(var).expect("failed to get scalar value")
}
pub fn try_value(&self, var: &crate::expr::Expr) -> crate::Result<f64> {
let var_id = var.variable_id().ok_or_else(|| {
crate::CvxError::InvalidProblem("Expression is not a variable".into())
})?;
let arr = self
.get_value(var_id)
.ok_or_else(|| crate::CvxError::InvalidProblem("Variable not in solution".into()))?;
match arr {
Array::Scalar(v) => Ok(*v),
Array::Dense(m) if m.nrows() == 1 && m.ncols() == 1 => Ok(m[(0, 0)]),
_ => Err(crate::CvxError::InvalidProblem(
"Variable is not scalar; use index operator for vectors/matrices".into(),
)),
}
}
pub fn duals(&self) -> Option<&[f64]> {
self.dual.as_deref()
}
pub fn constraint_dual(&self, idx: usize) -> Option<f64> {
self.dual.as_ref().and_then(|d| d.get(idx).copied())
}
pub fn has_duals(&self) -> bool {
self.dual.is_some()
}
}
impl std::ops::Index<&crate::expr::Expr> for Solution {
type Output = nalgebra::DMatrix<f64>;
fn index(&self, var: &crate::expr::Expr) -> &nalgebra::DMatrix<f64> {
let var_id = var.variable_id().expect("Expression is not a variable");
match self.get_value(var_id).expect("Variable not in solution") {
Array::Dense(m) => m,
Array::Scalar(_) => {
panic!("Variable is scalar, use .value() method instead of indexing")
}
Array::Sparse(_) => unreachable!("Solution values are never sparse"),
}
}
}
pub fn solve(problem: &StuffedProblem, settings: &Settings) -> Solution {
let p = to_clarabel_csc(&problem.p);
let a = to_clarabel_csc(&problem.a);
let cones = to_clarabel_cones(&problem.cone_dims);
let clarabel_settings = DefaultSettingsBuilder::default()
.verbose(settings.verbose)
.max_iter(settings.max_iter)
.time_limit(settings.time_limit)
.tol_gap_abs(settings.tol_gap_abs)
.tol_gap_rel(settings.tol_gap_rel)
.build()
.unwrap();
let mut solver = DefaultSolver::new(&p, &problem.q, &a, &problem.b, &cones, clarabel_settings);
solver.solve();
let status: SolveStatus = solver.solution.status.into();
let solve_time = solver.solution.solve_time;
let iterations = solver.info.iterations;
if status == SolveStatus::Optimal {
let primal = unpack_primal(&solver.solution.x, &problem.var_map);
let value = compute_objective(&solver.solution.x, &problem.p, &problem.q)
+ problem.objective_offset;
Solution {
status,
value: Some(value),
primal: Some(primal),
dual: Some(solver.solution.z.clone()),
solve_time,
iterations,
}
} else {
Solution {
status,
value: None,
primal: None,
dual: None,
solve_time,
iterations,
}
}
}
fn to_clarabel_csc(m: &nalgebra_sparse::CscMatrix<f64>) -> ClarabelCsc<f64> {
ClarabelCsc::new(
m.nrows(),
m.ncols(),
m.col_offsets().to_vec(),
m.row_indices().to_vec(),
m.values().to_vec(),
)
}
fn to_clarabel_cones(dims: &ConeDims) -> Vec<SupportedConeT<f64>> {
let mut cones = Vec::new();
if dims.zero > 0 {
cones.push(SupportedConeT::ZeroConeT(dims.zero));
}
if dims.nonneg > 0 {
cones.push(SupportedConeT::NonnegativeConeT(dims.nonneg));
}
for &soc_dim in &dims.soc {
cones.push(SupportedConeT::SecondOrderConeT(soc_dim));
}
for _ in 0..dims.exp {
cones.push(SupportedConeT::ExponentialConeT());
}
for &alpha in &dims.power {
cones.push(SupportedConeT::PowerConeT(alpha));
}
cones
}
fn unpack_primal(x: &[f64], var_map: &VariableMap) -> HashMap<ExprId, Array> {
let mut result = HashMap::new();
for (&var_id, &(start, size)) in &var_map.id_to_col {
let values: Vec<f64> = x[start..start + size].to_vec();
let arr = if size == 1 {
Array::Scalar(values[0])
} else {
Array::from_vec(values)
};
result.insert(var_id, arr);
}
result
}
fn compute_objective(x: &[f64], p: &nalgebra_sparse::CscMatrix<f64>, q: &[f64]) -> f64 {
let linear: f64 = q.iter().zip(x.iter()).map(|(qi, xi)| qi * xi).sum();
let mut quadratic = 0.0;
for (row, col, val) in p.triplet_iter() {
if row == col {
quadratic += 0.5 * *val * x[row] * x[col];
} else {
quadratic += *val * x[row] * x[col];
}
}
linear + quadratic
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_settings() {
let settings = Settings::default();
assert!(!settings.verbose);
assert_eq!(settings.max_iter, 100);
}
#[test]
fn test_to_clarabel_cones() {
let dims = ConeDims {
zero: 2,
nonneg: 3,
soc: vec![4],
exp: 0,
power: vec![],
};
let cones = to_clarabel_cones(&dims);
assert_eq!(cones.len(), 3);
}
}