1extern crate osqp_rust_sys;
70
71use osqp_rust_sys as ffi;
72use std::error::Error;
73use std::fmt;
74use std::ptr;
75
76mod csc;
77pub use csc::CscMatrix;
78
79mod settings;
80pub use settings::{LinsysSolver, Settings};
81
82mod status;
83pub use status::{
84 DualInfeasibilityCertificate, Failure, PolishStatus, PrimalInfeasibilityCertificate, Solution,
85 Status,
86};
87
88#[allow(non_camel_case_types)]
89type float = f64;
90
91macro_rules! check {
92 ($fun:ident, $ret:expr) => {
93 assert!(
94 $ret == 0,
95 "osqp_{} failed with exit code {}",
96 stringify!($fun),
97 $ret
98 );
99 };
100}
101
102pub struct Problem {
104 workspace: *mut ffi::src::src::osqp::OSQPWorkspace,
105 n: usize,
107 m: usize,
109}
110
111impl Problem {
112 #[allow(non_snake_case)]
119 pub fn new<'a, 'b, T: Into<CscMatrix<'a>>, U: Into<CscMatrix<'b>>>(
120 P: T,
121 q: &[float],
122 A: U,
123 l: &[float],
124 u: &[float],
125 settings: &Settings,
126 ) -> Result<Problem, SetupError> {
127 Problem::new_inner(P.into(), q, A.into(), l, u, settings)
129 }
130
131 #[allow(non_snake_case)]
132 fn new_inner(
133 P: CscMatrix,
134 q: &[float],
135 A: CscMatrix,
136 l: &[float],
137 u: &[float],
138 settings: &Settings,
139 ) -> Result<Problem, SetupError> {
140 let invalid_data = |msg| Err(SetupError::DataInvalid(msg));
141
142 unsafe {
143 let n = P.nrows;
149 if P.ncols != n {
150 return invalid_data("P must be a square matrix");
151 }
152 if q.len() != n {
153 return invalid_data("q must be the same number of rows as P");
154 }
155 if A.ncols != n {
156 return invalid_data("A must have the same number of columns as P");
157 }
158
159 let m = A.nrows;
161 if l.len() != m {
162 return invalid_data("l must have the same number of rows as A");
163 }
164 if u.len() != m {
165 return invalid_data("u must have the same number of rows as A");
166 }
167 if l.iter().zip(u.iter()).any(|(&l, &u)| !(l <= u)) {
168 return invalid_data("all elements of l must be less than or equal to the corresponding element of u");
169 }
170
171 if !P.is_valid() {
173 return invalid_data("P must be a valid CSC matrix");
174 }
175 if !A.is_valid() {
176 return invalid_data("A must be a valid CSC matrix");
177 }
178 if !P.is_structurally_upper_tri() {
179 return invalid_data("P must be structurally upper triangular");
180 }
181
182 let mut P_ffi = P.to_ffi();
184 let mut A_ffi = A.to_ffi();
185
186 let data = ffi::src::src::osqp::OSQPData {
187 n: n as ffi::src::src::osqp::c_int,
188 m: m as ffi::src::src::osqp::c_int,
189 P: &mut P_ffi,
190 A: &mut A_ffi,
191 q: q.as_ptr() as *mut float,
192 l: l.as_ptr() as *mut float,
193 u: u.as_ptr() as *mut float,
194 };
195
196 let settings = &settings.inner as *const ffi::src::src::osqp::OSQPSettings as *mut ffi::src::src::osqp::OSQPSettings;
197 let mut workspace: *mut ffi::src::src::osqp::OSQPWorkspace = ptr::null_mut();
198
199 let status = ffi::src::src::osqp::osqp_setup(&mut workspace, &data, settings);
200 let err = match status as ffi::src::src::osqp::osqp_error_type {
201 0 => return Ok(Problem { workspace, n, m }),
202 ffi::src::src::osqp::OSQP_DATA_VALIDATION_ERROR => SetupError::DataInvalid(""),
203 ffi::src::src::osqp::OSQP_SETTINGS_VALIDATION_ERROR => SetupError::SettingsInvalid,
204 ffi::src::src::osqp::OSQP_LINSYS_SOLVER_LOAD_ERROR => SetupError::LinsysSolverLoadFailed,
205 ffi::src::src::osqp::OSQP_LINSYS_SOLVER_INIT_ERROR => SetupError::LinsysSolverInitFailed,
206 ffi::src::src::osqp::OSQP_NONCVX_ERROR => SetupError::NonConvex,
207 ffi::src::src::osqp::OSQP_MEM_ALLOC_ERROR => SetupError::MemoryAllocationFailed,
208 _ => unreachable!(),
209 };
210
211 if !workspace.is_null() {
213 ffi::src::src::osqp::osqp_cleanup(workspace);
214 }
215 Err(err)
216 }
217 }
218
219 pub fn update_lin_cost(&mut self, q: &[float]) {
223 unsafe {
224 assert_eq!(self.n, q.len());
225 check!(
226 update_lin_cost,
227 ffi::src::src::osqp::osqp_update_lin_cost(self.workspace, q.as_ptr())
228 );
229 }
230 }
231
232 pub fn update_bounds(&mut self, l: &[float], u: &[float]) {
236 unsafe {
237 assert_eq!(self.m, l.len());
238 assert_eq!(self.m, u.len());
239 check!(
240 update_bounds,
241 ffi::src::src::osqp::osqp_update_bounds(self.workspace, l.as_ptr(), u.as_ptr())
242 );
243 }
244 }
245
246 pub fn update_lower_bound(&mut self, l: &[float]) {
250 unsafe {
251 assert_eq!(self.m, l.len());
252 check!(
253 update_lower_bound,
254 ffi::src::src::osqp::osqp_update_lower_bound(self.workspace, l.as_ptr())
255 );
256 }
257 }
258
259 pub fn update_upper_bound(&mut self, u: &[float]) {
263 unsafe {
264 assert_eq!(self.m, u.len());
265 check!(
266 update_upper_bound,
267 ffi::src::src::osqp::osqp_update_upper_bound(self.workspace, u.as_ptr())
268 );
269 }
270 }
271
272 pub fn warm_start(&mut self, x: &[float], y: &[float]) {
277 unsafe {
278 assert_eq!(self.n, x.len());
279 assert_eq!(self.m, y.len());
280 check!(
281 warm_start,
282 ffi::src::src::osqp::osqp_warm_start(self.workspace, x.as_ptr(), y.as_ptr())
283 );
284 }
285 }
286
287 pub fn warm_start_x(&mut self, x: &[float]) {
291 unsafe {
292 assert_eq!(self.n, x.len());
293 check!(
294 warm_start_x,
295 ffi::src::src::osqp::osqp_warm_start_x(self.workspace, x.as_ptr())
296 );
297 }
298 }
299
300 pub fn warm_start_y(&mut self, y: &[float]) {
304 unsafe {
305 assert_eq!(self.m, y.len());
306 check!(
307 warm_start_y,
308 ffi::src::src::osqp::osqp_warm_start_y(self.workspace, y.as_ptr())
309 );
310 }
311 }
312
313 #[allow(non_snake_case)]
318 pub fn update_P<'a, T: Into<CscMatrix<'a>>>(&mut self, P: T) {
319 self.update_P_inner(P.into());
320 }
321
322 #[allow(non_snake_case)]
323 fn update_P_inner(&mut self, P: CscMatrix) {
324 unsafe {
325 let P_ffi = CscMatrix::from_ffi((*(*self.workspace).data).P);
326 P.assert_same_sparsity_structure(&P_ffi);
327
328 check!(
329 update_P,
330 ffi::src::src::osqp::osqp_update_P(
331 self.workspace,
332 P.data.as_ptr(),
333 ptr::null(),
334 P.data.len() as ffi::src::src::osqp::c_int,
335 )
336 );
337 }
338 }
339
340 #[allow(non_snake_case)]
345 pub fn update_A<'a, T: Into<CscMatrix<'a>>>(&mut self, A: T) {
346 self.update_A_inner(A.into());
347 }
348
349 #[allow(non_snake_case)]
350 fn update_A_inner(&mut self, A: CscMatrix) {
351 unsafe {
352 let A_ffi = CscMatrix::from_ffi((*(*self.workspace).data).A);
353 A.assert_same_sparsity_structure(&A_ffi);
354
355 check!(
356 update_A,
357 ffi::src::src::osqp::osqp_update_A(
358 self.workspace,
359 A.data.as_ptr(),
360 ptr::null(),
361 A.data.len() as ffi::src::src::osqp::c_int,
362 )
363 );
364 }
365 }
366
367 #[allow(non_snake_case)]
372 pub fn update_P_A<'a, 'b, T: Into<CscMatrix<'a>>, U: Into<CscMatrix<'b>>>(
373 &mut self,
374 P: T,
375 A: U,
376 ) {
377 self.update_P_A_inner(P.into(), A.into());
378 }
379
380 #[allow(non_snake_case)]
381 fn update_P_A_inner(&mut self, P: CscMatrix, A: CscMatrix) {
382 unsafe {
383 let P_ffi = CscMatrix::from_ffi((*(*self.workspace).data).P);
384 P.assert_same_sparsity_structure(&P_ffi);
385
386 let A_ffi = CscMatrix::from_ffi((*(*self.workspace).data).A);
387 A.assert_same_sparsity_structure(&A_ffi);
388
389 check!(
390 update_P_A,
391 ffi::src::src::osqp::osqp_update_P_A(
392 self.workspace,
393 P.data.as_ptr(),
394 ptr::null(),
395 P.data.len() as ffi::src::src::osqp::c_int,
396 A.data.as_ptr(),
397 ptr::null(),
398 A.data.len() as ffi::src::src::osqp::c_int,
399 )
400 );
401 }
402 }
403
404 pub fn solve<'a>(&'a mut self) -> Status<'a> {
406 unsafe {
407 check!(solve, ffi::src::src::osqp::osqp_solve(self.workspace));
408 Status::from_problem(self)
409 }
410 }
411}
412
413impl Drop for Problem {
414 fn drop(&mut self) {
415 unsafe {
416 ffi::src::src::osqp::osqp_cleanup(self.workspace);
417 }
418 }
419}
420
421unsafe impl Send for Problem {}
422unsafe impl Sync for Problem {}
423
424#[derive(Debug)]
426pub enum SetupError {
427 DataInvalid(&'static str),
428 SettingsInvalid,
429 LinsysSolverLoadFailed,
430 LinsysSolverInitFailed,
431 NonConvex,
432 MemoryAllocationFailed,
433 #[doc(hidden)]
435 __Nonexhaustive,
436}
437
438impl fmt::Display for SetupError {
439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 match self {
441 SetupError::DataInvalid(msg) => {
442 "problem data invalid".fmt(f)?;
443 if !msg.is_empty() {
444 ": ".fmt(f)?;
445 msg.fmt(f)?;
446 }
447 Ok(())
448 }
449 SetupError::SettingsInvalid => "problem settings invalid".fmt(f),
450 SetupError::LinsysSolverLoadFailed => "linear system solver failed to load".fmt(f),
451 SetupError::LinsysSolverInitFailed => {
452 "linear system solver failed to initialise".fmt(f)
453 }
454 SetupError::NonConvex => "problem non-convex".fmt(f),
455 SetupError::MemoryAllocationFailed => "memory allocation failed".fmt(f),
456 SetupError::__Nonexhaustive => unreachable!(),
457 }
458 }
459}
460
461impl Error for SetupError {}
462
463#[cfg(test)]
464mod tests {
465 use std::iter;
466
467 use super::*;
468
469 #[test]
470 #[allow(non_snake_case)]
471 fn update_matrices() {
472 let P_wrong = CscMatrix::from(&[[2.0, 1.0], [1.0, 4.0]]).into_upper_tri();
474 let A_wrong = &[[2.0, 3.0], [1.0, 0.0], [0.0, 9.0]];
475
476 let P = CscMatrix::from(&[[4.0, 1.0], [1.0, 2.0]]).into_upper_tri();
477 let q = &[1.0, 1.0];
478 let A = &[[1.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
479 let l = &[1.0, 0.0, 0.0];
480 let u = &[1.0, 0.7, 0.7];
481
482 let settings = Settings::default().alpha(1.0).verbose(false);
484 let settings = settings.adaptive_rho(false);
485
486 let mut prob = Problem::new(&P_wrong, q, A_wrong, l, u, &settings).unwrap();
488 prob.update_P_A(&P, A);
489 let result = prob.solve();
490 let x = result.solution().unwrap().x();
491 let expected = &[0.2987710845986426, 0.701227995544065];
492 assert_eq!(expected.len(), x.len());
493 assert!(expected.iter().zip(x).all(|(&a, &b)| (a - b).abs() < 1e-9));
494
495 let mut prob = Problem::new(&P_wrong, q, A_wrong, l, u, &settings).unwrap();
497 prob.update_P(&P);
498 prob.update_A(A);
499 let result = prob.solve();
500 let x = result.solution().unwrap().x();
501 let expected = &[0.2987710845986426, 0.701227995544065];
502 assert_eq!(expected.len(), x.len());
503 assert!(expected.iter().zip(x).all(|(&a, &b)| (a - b).abs() < 1e-9));
504 }
505
506 #[test]
507 #[allow(non_snake_case)]
508 fn empty_A() {
509 let P = CscMatrix::from(&[[4.0, 1.0], [1.0, 2.0]]).into_upper_tri();
510 let q = &[1.0, 1.0];
511
512 let A = CscMatrix::from_column_iter_dense(0, 2, iter::empty());
513 let l = &[];
514 let u = &[];
515 let mut prob = Problem::new(&P, q, &A, l, u, &Settings::default()).unwrap();
516 prob.update_A(&A);
517
518 let A = CscMatrix::from(&[[0.0, 0.0], [0.0, 0.0]]);
519 assert_eq!(A.data.len(), 0);
520 let l = &[0.0, 0.0];
521 let u = &[1.0, 1.0];
522 let mut prob = Problem::new(&P, q, &A, l, u, &Settings::default()).unwrap();
523 prob.update_A(&A);
524 }
525}