use osqp_sys as ffi;
use std::fmt;
use std::slice;
use std::time::Duration;
use {float, Problem};
use osqp_sys::osqp_status_type;
#[derive(Clone, Debug)]
pub enum Status<'a> {
Solved(Solution<'a>),
SolvedInaccurate(Solution<'a>),
MaxIterationsReached(Solution<'a>),
TimeLimitReached(Solution<'a>),
PrimalInfeasible(PrimalInfeasibilityCertificate<'a>),
PrimalInfeasibleInaccurate(PrimalInfeasibilityCertificate<'a>),
DualInfeasible(DualInfeasibilityCertificate<'a>),
DualInfeasibleInaccurate(DualInfeasibilityCertificate<'a>),
NonConvex(Failure<'a>),
#[doc(hidden)]
__Nonexhaustive,
}
#[derive(Clone)]
pub struct Solution<'a> {
prob: &'a Problem,
}
#[derive(Clone)]
pub struct PrimalInfeasibilityCertificate<'a> {
prob: &'a Problem,
}
#[derive(Clone)]
pub struct DualInfeasibilityCertificate<'a> {
prob: &'a Problem,
}
#[derive(Clone)]
pub struct Failure<'a> {
prob: &'a Problem,
}
#[derive(Copy, Clone, Debug, Hash, PartialEq)]
pub enum PolishStatus {
Successful,
Unsuccessful,
Unperformed,
#[doc(hidden)]
__Nonexhaustive,
}
impl<'a> Status<'a> {
pub(crate) fn from_problem(prob: &'a Problem) -> Status<'a> {
unsafe {
match (*(*prob.solver).info).status_val as osqp_status_type {
ffi::OSQP_SOLVED => Status::Solved(Solution { prob }),
ffi::OSQP_SOLVED_INACCURATE => Status::SolvedInaccurate(Solution { prob }),
ffi::OSQP_MAX_ITER_REACHED => Status::MaxIterationsReached(Solution { prob }),
ffi::OSQP_TIME_LIMIT_REACHED => Status::TimeLimitReached(Solution { prob }),
ffi::OSQP_PRIMAL_INFEASIBLE => {
Status::PrimalInfeasible(PrimalInfeasibilityCertificate { prob })
}
ffi::OSQP_PRIMAL_INFEASIBLE_INACCURATE => {
Status::PrimalInfeasibleInaccurate(PrimalInfeasibilityCertificate { prob })
}
ffi::OSQP_DUAL_INFEASIBLE => {
Status::DualInfeasible(DualInfeasibilityCertificate { prob })
}
ffi::OSQP_DUAL_INFEASIBLE_INACCURATE => {
Status::DualInfeasibleInaccurate(DualInfeasibilityCertificate { prob })
}
ffi::OSQP_NON_CVX => Status::NonConvex(Failure { prob }),
_ => unreachable!(),
}
}
}
pub fn x(&self) -> Option<&'a [float]> {
self.solution().map(|s| s.x())
}
pub fn solution(&self) -> Option<Solution<'a>> {
match *self {
Status::Solved(ref solution) => Some(solution.clone()),
_ => None,
}
}
pub fn iter(&self) -> u32 {
unsafe {
(*(*self.prob().solver).info).iter as u32
}
}
pub fn setup_time(&self) -> Duration {
unsafe { secs_to_duration((*(*self.prob().solver).info).setup_time) }
}
pub fn solve_time(&self) -> Duration {
unsafe { secs_to_duration((*(*self.prob().solver).info).solve_time) }
}
pub fn polish_time(&self) -> Duration {
unsafe { secs_to_duration((*(*self.prob().solver).info).polish_time) }
}
pub fn run_time(&self) -> Duration {
unsafe { secs_to_duration((*(*self.prob().solver).info).run_time) }
}
pub fn rho_updates(&self) -> u32 {
unsafe {
(*(*self.prob().solver).info).rho_updates as u32
}
}
pub fn rho_estimate(&self) -> float {
unsafe { (*(*self.prob().solver).info).rho_estimate }
}
fn prob(&self) -> &'a Problem {
match *self {
Status::Solved(ref solution)
| Status::SolvedInaccurate(ref solution)
| Status::MaxIterationsReached(ref solution)
| Status::TimeLimitReached(ref solution) => solution.prob,
Status::PrimalInfeasible(ref cert) | Status::PrimalInfeasibleInaccurate(ref cert) => {
cert.prob
}
Status::DualInfeasible(ref cert) | Status::DualInfeasibleInaccurate(ref cert) => {
cert.prob
}
Status::NonConvex(ref failure) => failure.prob,
Status::__Nonexhaustive => unreachable!(),
}
}
}
impl<'a> Solution<'a> {
pub fn x(&self) -> &'a [float] {
unsafe { slice::from_raw_parts((*(*self.prob.solver).solution).x, self.prob.n) }
}
pub fn y(&self) -> &'a [float] {
unsafe { slice::from_raw_parts((*(*self.prob.solver).solution).y, self.prob.m) }
}
pub fn polish_status(&self) -> PolishStatus {
unsafe {
match (*(*self.prob.solver).info).status_polish {
1 => PolishStatus::Successful,
-1 => PolishStatus::Unsuccessful,
0 => PolishStatus::Unperformed,
_ => unreachable!(),
}
}
}
pub fn obj_val(&self) -> float {
unsafe { (*(*self.prob.solver).info).obj_val }
}
pub fn pri_res(&self) -> float {
unsafe { (*(*self.prob.solver).info).prim_res }
}
pub fn dua_res(&self) -> float {
unsafe { (*(*self.prob.solver).info).dual_res }
}
}
impl<'a> fmt::Debug for Solution<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Solution")
.field("x", &self.x())
.field("y", &self.y())
.field("polish_status", &self.polish_status())
.field("obj_val", &self.obj_val())
.field("pri_res", &self.pri_res())
.field("dua_res", &self.dua_res())
.finish()
}
}
impl<'a> PrimalInfeasibilityCertificate<'a> {
pub fn delta_y(&self) -> &'a [float] {
unsafe { slice::from_raw_parts((*(*self.prob.solver).solution).prim_inf_cert, self.prob.m) }
}
}
impl<'a> fmt::Debug for PrimalInfeasibilityCertificate<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("PrimalInfeasibilityCertificate")
.field("delta_y", &self.delta_y())
.finish()
}
}
impl<'a> DualInfeasibilityCertificate<'a> {
pub fn delta_x(&self) -> &'a [float] {
unsafe { slice::from_raw_parts((*(*self.prob.solver).solution).dual_inf_cert, self.prob.n) }
}
}
impl<'a> fmt::Debug for DualInfeasibilityCertificate<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("DualInfeasibilityCertificate")
.field("delta_x", &self.delta_x())
.finish()
}
}
impl<'a> fmt::Debug for Failure<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Failure").finish()
}
}
fn secs_to_duration(secs: float) -> Duration {
let whole_secs = secs.floor() as u64;
let nanos = (secs.fract() * 1e9) as u32;
Duration::new(whole_secs, nanos)
}