extern crate osqp_sys;
use osqp_sys as ffi;
use std::error::Error;
use std::fmt;
use std::ptr;
mod csc;
pub use csc::CscMatrix;
mod settings;
pub use settings::{LinsysSolver, Settings};
mod status;
pub use status::{
DualInfeasibilityCertificate, Failure, PolishStatus, PrimalInfeasibilityCertificate, Solution,
Status,
};
#[allow(non_camel_case_types)]
type float = f64;
#[allow(dead_code)]
fn assert_osqp_int_size() {
let _osqp_int_must_be_usize = ::std::mem::transmute::<ffi::osqp_int, usize>;
}
macro_rules! check {
($fun:ident, $ret:expr) => {
assert!(
$ret == 0,
"osqp_{} failed with exit code {}",
stringify!($fun),
$ret
);
};
}
pub struct Problem {
solver: *mut ffi::OSQPSolver,
n: usize,
m: usize,
}
impl Problem {
#[allow(non_snake_case)]
pub fn new<'a, 'b, T: Into<CscMatrix<'a>>, U: Into<CscMatrix<'b>>>(
P: T,
q: &[float],
A: U,
l: &[float],
u: &[float],
settings: &Settings,
) -> Result<Problem, SetupError> {
Problem::new_inner(P.into(), q, A.into(), l, u, settings)
}
#[allow(non_snake_case)]
fn new_inner(
P: CscMatrix,
q: &[float],
A: CscMatrix,
l: &[float],
u: &[float],
settings: &Settings,
) -> Result<Problem, SetupError> {
let invalid_data = |msg| Err(SetupError::DataInvalid(msg));
unsafe {
let n = P.nrows;
if P.ncols != n {
return invalid_data("P must be a square matrix");
}
if q.len() != n {
return invalid_data("q must be the same number of rows as P");
}
if A.ncols != n {
return invalid_data("A must have the same number of columns as P");
}
let m = A.nrows;
if l.len() != m {
return invalid_data("l must have the same number of rows as A");
}
if u.len() != m {
return invalid_data("u must have the same number of rows as A");
}
if l.iter().zip(u.iter()).any(|(&l, &u)| !(l <= u)) {
return invalid_data("all elements of l must be less than or equal to the corresponding element of u");
}
if !P.is_valid() {
return invalid_data("P must be a valid CSC matrix");
}
if !A.is_valid() {
return invalid_data("A must be a valid CSC matrix");
}
if !P.is_structurally_upper_tri() {
return invalid_data("P must be structurally upper triangular");
}
let P_ffi = P.to_ffi();
let A_ffi = A.to_ffi();
let settings = &settings.inner as *const ffi::OSQPSettings as *mut ffi::OSQPSettings;
let mut solver: *mut ffi::OSQPSolver = ptr::null_mut();
let status = ffi::osqp_setup(&mut solver, P_ffi, q.as_ptr(), A_ffi, l.as_ptr(), u.as_ptr(), m as ffi::osqp_int, n as ffi::osqp_int, settings);
let err = match status as ffi::osqp_error_type {
0 => return Ok(Problem { solver, n, m }),
ffi::OSQP_DATA_VALIDATION_ERROR => SetupError::DataInvalid(""),
ffi::OSQP_SETTINGS_VALIDATION_ERROR => SetupError::SettingsInvalid,
ffi::OSQP_ALGEBRA_LOAD_ERROR => SetupError::LinsysSolverLoadFailed,
ffi::OSQP_LINSYS_SOLVER_INIT_ERROR => SetupError::LinsysSolverInitFailed,
ffi::OSQP_NONCVX_ERROR => SetupError::NonConvex,
ffi::OSQP_MEM_ALLOC_ERROR => SetupError::MemoryAllocationFailed,
_ => unreachable!(),
};
if !solver.is_null() {
ffi::osqp_cleanup(solver);
}
Err(err)
}
}
pub fn update_lin_cost(&mut self, q: &[float]) {
unsafe {
assert_eq!(self.n, q.len());
check!(
update_lin_cost,
ffi::osqp_update_data_vec(self.solver, q.as_ptr(), ptr::null(), ptr::null())
);
}
}
pub fn update_bounds(&mut self, l: &[float], u: &[float]) {
unsafe {
assert_eq!(self.m, l.len());
assert_eq!(self.m, u.len());
check!(
update_bounds,
ffi::osqp_update_data_vec(self.solver, ptr::null(), l.as_ptr(), u.as_ptr())
);
}
}
pub fn update_lower_bound(&mut self, l: &[float]) {
unsafe {
assert_eq!(self.m, l.len());
check!(
update_lower_bound,
ffi::osqp_update_data_vec(self.solver, ptr::null(), l.as_ptr(), ptr::null())
);
}
}
pub fn update_upper_bound(&mut self, u: &[float]) {
unsafe {
assert_eq!(self.m, u.len());
check!(
update_upper_bound,
ffi::osqp_update_data_vec(self.solver, ptr::null(), ptr::null(), u.as_ptr())
);
}
}
pub fn warm_start(&mut self, x: &[float], y: &[float]) {
unsafe {
assert_eq!(self.n, x.len());
assert_eq!(self.m, y.len());
check!(
warm_start,
ffi::osqp_warm_start(self.solver, x.as_ptr(), y.as_ptr())
);
}
}
pub fn warm_start_x(&mut self, x: &[float]) {
unsafe {
assert_eq!(self.n, x.len());
check!(
warm_start_x,
ffi::osqp_warm_start(self.solver, x.as_ptr(), ptr::null())
);
}
}
pub fn warm_start_y(&mut self, y: &[float]) {
unsafe {
assert_eq!(self.m, y.len());
check!(
warm_start_y,
ffi::osqp_warm_start(self.solver, ptr::null(), y.as_ptr())
);
}
}
#[allow(non_snake_case)]
pub fn update_P<'a, T: Into<CscMatrix<'a>>>(&mut self, P: T) {
self.update_P_inner(P.into());
}
#[allow(non_snake_case)]
fn update_P_inner(&mut self, P: CscMatrix) {
unsafe {
check!(
update_P,
ffi::osqp_update_data_mat(
self.solver,
P.data.as_ptr(),
ptr::null(),
P.data.len() as ffi::osqp_int,
ptr::null(),
ptr::null(),
0
)
);
}
}
#[allow(non_snake_case)]
pub fn update_A<'a, T: Into<CscMatrix<'a>>>(&mut self, A: T) {
self.update_A_inner(A.into());
}
#[allow(non_snake_case)]
fn update_A_inner(&mut self, A: CscMatrix) {
unsafe {
check!(
update_A,
ffi::osqp_update_data_mat(
self.solver,
ptr::null(),
ptr::null(),
0,
A.data.as_ptr(),
ptr::null(),
A.data.len() as ffi::osqp_int,
)
);
}
}
#[allow(non_snake_case)]
pub fn update_settings(&mut self, settings: &Settings) {
let settings = &settings.inner as *const ffi::OSQPSettings;
unsafe {
check!(
update_settings,
ffi::osqp_update_settings(self.solver, settings)
);
}
}
pub fn solve<'a>(&'a mut self) -> Status<'a> {
unsafe {
check!(solve, ffi::osqp_solve(self.solver));
Status::from_problem(self)
}
}
}
impl Drop for Problem {
fn drop(&mut self) {
unsafe {
ffi::osqp_cleanup(self.solver);
}
}
}
unsafe impl Send for Problem {}
unsafe impl Sync for Problem {}
#[derive(Debug)]
pub enum SetupError {
DataInvalid(&'static str),
SettingsInvalid,
LinsysSolverLoadFailed,
LinsysSolverInitFailed,
NonConvex,
MemoryAllocationFailed,
#[doc(hidden)]
__Nonexhaustive,
}
impl fmt::Display for SetupError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SetupError::DataInvalid(msg) => {
"problem data invalid".fmt(f)?;
if !msg.is_empty() {
": ".fmt(f)?;
msg.fmt(f)?;
}
Ok(())
}
SetupError::SettingsInvalid => "problem settings invalid".fmt(f),
SetupError::LinsysSolverLoadFailed => "linear system solver failed to load".fmt(f),
SetupError::LinsysSolverInitFailed => {
"linear system solver failed to initialise".fmt(f)
}
SetupError::NonConvex => "problem non-convex".fmt(f),
SetupError::MemoryAllocationFailed => "memory allocation failed".fmt(f),
SetupError::__Nonexhaustive => unreachable!(),
}
}
}
impl Error for SetupError {}
#[cfg(test)]
mod tests {
use std::iter;
use super::*;
#[test]
#[allow(non_snake_case)]
fn update_settings() {
let settings = Settings::default().alpha(0.7).verbose(false);
let settings = settings.adaptive_rho(true);
assert_eq!(settings.inner.alpha, 0.7);
assert_eq!(settings.inner.verbose, 0);
assert_eq!(settings.inner.adaptive_rho, 1);
}
#[test]
#[allow(non_snake_case)]
fn update_rho() {
let P = CscMatrix::from(&[[4.0, 1.0], [1.0, 2.0]]).into_upper_tri();
let q = &[1.0, 1.0];
let A = &[[1.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let l = &[1.0, 0.0, 0.0];
let u = &[1.0, 0.7, 0.7];
let settings = Settings::default().verbose(false);
let mut prob = Problem::new(&P, q, A, l, u, &settings).unwrap();
prob.update_rho(0.7);
unsafe {
let solver_ref = prob.solver.as_ref().unwrap();
let settings_ref = solver_ref.settings.as_ref().unwrap();
assert!((*settings_ref).rho == 0.7, "Expected rho to be 0.7 but found {}", (*settings_ref).rho);
}
}
#[test]
#[allow(non_snake_case)]
fn update_prob_settings() {
let P = CscMatrix::from(&[[4.0, 1.0], [1.0, 2.0]]).into_upper_tri();
let q = &[1.0, 1.0];
let A = &[[1.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let l = &[1.0, 0.0, 0.0];
let u = &[1.0, 0.7, 0.7];
let settings = Settings::default().verbose(false);
let mut prob = Problem::new(&P, q, A, l, u, &settings).unwrap();
let settings = settings.max_iter(1_000_000);
prob.update_settings(&settings);
unsafe {
let solver_ref = prob.solver.as_ref().unwrap();
let settings_ref = solver_ref.settings.as_ref().unwrap();
assert_eq!((*settings_ref).max_iter, 1_000_000);
}
}
#[test]
#[allow(non_snake_case)]
fn update_matrices() {
let P_wrong = CscMatrix::from(&[[2.0, 1.0], [1.0, 4.0]]).into_upper_tri();
let A_wrong = &[[2.0, 3.0], [1.0, 0.0], [0.0, 9.0]];
let P = CscMatrix::from(&[[4.0, 1.0], [1.0, 2.0]]).into_upper_tri();
let q = &[1.0, 1.0];
let A = &[[1.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
let l = &[1.0, 0.0, 0.0];
let u = &[1.0, 0.7, 0.7];
let settings = Settings::default().alpha(1.0).verbose(false);
let settings = settings.adaptive_rho(false);
let mut prob = Problem::new(&P_wrong, q, A_wrong, l, u, &settings).unwrap();
prob.update_P(&P);
prob.update_A(A);
let result = prob.solve();
let x = result.solution().unwrap().x();
let expected = &[0.2987710845986426, 0.701227995544065];
assert_eq!(expected.len(), x.len());
assert!(expected.iter().zip(x).all(|(&a, &b)| (a - b).abs() < 1e-9));
}
#[test]
#[allow(non_snake_case)]
fn empty_A() {
let P = CscMatrix::from(&[[4.0, 1.0], [1.0, 2.0]]).into_upper_tri();
let q = &[1.0, 1.0];
let A = CscMatrix::from_column_iter_dense(0, 2, iter::empty());
let l = &[];
let u = &[];
let mut prob = Problem::new(&P, q, &A, l, u, &Settings::default()).unwrap();
prob.update_A(&A);
let A = CscMatrix::from(&[[0.0, 0.0], [0.0, 0.0]]);
assert_eq!(A.data.len(), 0);
let l = &[0.0, 0.0];
let u = &[1.0, 1.0];
let mut prob = Problem::new(&P, q, &A, l, u, &Settings::default()).unwrap();
prob.update_A(&A);
}
}