1use osqp_rust_sys as ffi;
2use std::borrow::Cow;
3use std::iter;
4use std::slice;
5
6use crate::float;
7
8macro_rules! check {
9 ($check:expr) => {
10 if !{ $check } {
11 return false;
12 }
13 };
14}
15
16#[derive(Clone, Debug, PartialEq)]
18pub struct CscMatrix<'a> {
19 pub nrows: usize,
21 pub ncols: usize,
23 pub indptr: Cow<'a, [usize]>,
27 pub indices: Cow<'a, [usize]>,
31 pub data: Cow<'a, [float]>,
35}
36
37impl<'a> CscMatrix<'a> {
38 pub fn from_column_iter_dense<I: IntoIterator<Item = float>>(
43 nrows: usize,
44 ncols: usize,
45 iter: I,
46 ) -> CscMatrix<'static> {
47 CscMatrix::from_iter_dense_inner(nrows, ncols, |size| {
48 let mut data = Vec::with_capacity(size);
49 data.extend(iter.into_iter().take(size));
50 assert_eq!(size, data.len(), "not enough elements in iterator");
51 data
52 })
53 }
54
55 pub fn from_row_iter_dense<I: IntoIterator<Item = float>>(
63 nrows: usize,
64 ncols: usize,
65 iter: I,
66 ) -> CscMatrix<'static> {
67 CscMatrix::from_iter_dense_inner(nrows, ncols, |size| {
68 let mut iter = iter.into_iter();
69 let mut data = vec![0.0; size];
70 for r in 0..nrows {
71 for c in 0..ncols {
72 data[c * nrows + r] = iter.next().expect("not enough elements in iterator");
73 }
74 }
75 data
76 })
77 }
78
79 fn from_iter_dense_inner<F: FnOnce(usize) -> Vec<float>>(
80 nrows: usize,
81 ncols: usize,
82 f: F,
83 ) -> CscMatrix<'static> {
84 let size = nrows
85 .checked_mul(ncols)
86 .expect("overflow calculating matrix size");
87
88 let data = f(size);
89
90 CscMatrix {
91 nrows,
92 ncols,
93 indptr: Cow::Owned((0..ncols + 1).map(|i| i * nrows).collect()),
94 indices: Cow::Owned(iter::repeat(0..nrows).take(ncols).flat_map(|i| i).collect()),
95 data: Cow::Owned(data),
96 }
97 }
98
99 pub fn is_structurally_upper_tri(&self) -> bool {
107 for col in 0..self.indptr.len().saturating_sub(1) {
108 let col_data_start_idx = self.indptr[col];
109 let col_data_end_idx = self.indptr[col + 1];
110
111 for &row in &self.indices[col_data_start_idx..col_data_end_idx] {
112 if row > col {
113 return false;
114 }
115 }
116 }
117
118 true
119 }
120
121 pub fn into_upper_tri(self) -> CscMatrix<'a> {
128 if self.is_structurally_upper_tri() {
129 return self;
130 }
131
132 let mut indptr = self.indptr.into_owned();
133 let mut indices = self.indices.into_owned();
134 let mut data = self.data.into_owned();
135
136 let mut next_data_idx = 0;
137
138 for col in 0..indptr.len().saturating_sub(1) {
139 let col_start_idx = indptr[col];
140 let next_col_start_idx = indptr[col + 1];
141
142 indptr[col] = next_data_idx;
143
144 for data_idx in col_start_idx..next_col_start_idx {
145 let row = indices[data_idx];
146
147 if row <= col {
148 data[next_data_idx] = data[data_idx];
149 indices[next_data_idx] = row;
150 next_data_idx += 1;
151 }
152 }
153 }
154
155 if let Some(data_len) = indptr.last_mut() {
156 *data_len = next_data_idx
157 }
158 indices.truncate(next_data_idx);
159 data.truncate(next_data_idx);
160
161 CscMatrix {
162 indptr: Cow::Owned(indptr),
163 indices: Cow::Owned(indices),
164 data: Cow::Owned(data),
165 ..self
166 }
167 }
168
169 pub(crate) unsafe fn to_ffi(&self) -> ffi::src::src::osqp::csc {
170 ffi::src::src::osqp::csc {
173 nzmax: self.data.len() as ffi::src::src::osqp::c_int,
174 m: self.nrows as ffi::src::src::osqp::c_int,
175 n: self.ncols as ffi::src::src::osqp::c_int,
176 p: self.indptr.as_ptr() as *mut usize as *mut ffi::src::src::osqp::c_int,
177 i: self.indices.as_ptr() as *mut usize as *mut ffi::src::src::osqp::c_int,
178 x: self.data.as_ptr() as *mut float,
179 nz: -1,
180 }
181 }
182
183 pub(crate) unsafe fn from_ffi<'b>(csc: *const ffi::src::src::osqp::csc) -> CscMatrix<'b> {
184 let nrows = (*csc).m as usize;
185 let ncols = (*csc).n as usize;
186 let indptr = Cow::Borrowed(slice::from_raw_parts((*csc).p as *const usize, ncols + 1));
187 let nnz = if indptr[ncols] == 0 {
189 0
190 } else {
191 (*csc).nzmax as usize
192 };
193
194 CscMatrix {
195 nrows,
196 ncols,
197 indptr,
198 indices: Cow::Borrowed(slice::from_raw_parts((*csc).i as *const usize, nnz)),
199 data: Cow::Borrowed(slice::from_raw_parts((*csc).x as *const float, nnz)),
200 }
201 }
202
203 pub(crate) fn assert_same_sparsity_structure(&self, other: &CscMatrix) {
204 assert_eq!(self.nrows, other.nrows);
205 assert_eq!(self.ncols, other.ncols);
206 assert_eq!(&*self.indptr, &*other.indptr);
207 assert_eq!(&*self.indices, &*other.indices);
208 assert_eq!(self.data.len(), other.data.len());
209 }
210
211 pub(crate) fn is_valid(&self) -> bool {
212 let max_idx = isize::max_value() as usize;
213 check!(self.nrows <= max_idx);
214 check!(self.ncols <= max_idx);
215 check!(self.indptr.len() <= max_idx);
216 check!(self.indices.len() <= max_idx);
217 check!(self.data.len() <= max_idx);
218
219 check!(self.indptr.len() == self.ncols + 1);
221 check!(self.indptr[self.ncols] == self.data.len());
222 let mut prev_row_idx = 0;
223 for &row_idx in self.indptr.iter() {
224 check!(row_idx >= prev_row_idx);
226 prev_row_idx = row_idx;
227 }
228
229 check!(self.data.len() == self.indices.len());
231 check!(self.indices.iter().all(|r| *r < self.nrows));
232 for i in 0..self.ncols {
233 let row_indices = &self.indices[self.indptr[i] as usize..self.indptr[i + 1] as usize];
234 let mut row_indices = row_indices.iter();
235 if let Some(&first_row) = row_indices.next() {
236 let mut prev_row = first_row;
237 for &row in row_indices {
238 check!(row > prev_row);
240 prev_row = row;
241 }
242 check!(prev_row < self.nrows);
243 }
244 }
245
246 true
247 }
248}
249
250impl<'a, 'b: 'a> From<&'a CscMatrix<'b>> for CscMatrix<'a> {
252 fn from(mat: &'a CscMatrix<'b>) -> CscMatrix<'a> {
253 CscMatrix {
254 nrows: mat.nrows,
255 ncols: mat.ncols,
256 indptr: (*mat.indptr).into(),
257 indices: (*mat.indices).into(),
258 data: (*mat.data).into(),
259 }
260 }
261}
262
263impl<'a, I: 'a, J: 'a> From<I> for CscMatrix<'static>
270where
271 I: IntoIterator<Item = J>,
272 J: IntoIterator<Item = &'a float>,
273{
274 fn from(rows: I) -> CscMatrix<'static> {
275 let rows: Vec<Vec<float>> = rows
276 .into_iter()
277 .map(|r| r.into_iter().map(|&v| v).collect())
278 .collect();
279
280 let nrows = rows.len();
281 let ncols = rows.iter().map(|r| r.len()).next().unwrap_or(0);
282 assert!(rows.iter().all(|r| r.len() == ncols));
283 let nnz = rows.iter().flat_map(|r| r).filter(|&&v| v != 0.0).count();
284
285 let mut indptr = Vec::with_capacity(ncols + 1);
286 let mut indices = Vec::with_capacity(nnz);
287 let mut data = Vec::with_capacity(nnz);
288
289 indptr.push(0);
290 for c in 0..ncols {
291 for r in 0..nrows {
292 let value = rows[r][c];
293 if value != 0.0 {
294 indices.push(r);
295 data.push(value);
296 }
297 }
298 indptr.push(data.len());
299 }
300
301 CscMatrix {
302 nrows,
303 ncols,
304 indptr: indptr.into(),
305 indices: indices.into(),
306 data: data.into(),
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use std::borrow::Cow;
314
315 use super::*;
316
317 #[test]
318 fn csc_from_array() {
319 let mat = &[[1.0, 2.0], [3.0, 0.0], [0.0, 4.0]];
320 let csc: CscMatrix = mat.into();
321
322 assert_eq!(3, csc.nrows);
323 assert_eq!(2, csc.ncols);
324 assert_eq!(&[0, 2, 4], &*csc.indptr);
325 assert_eq!(&[0, 1, 0, 2], &*csc.indices);
326 assert_eq!(&[1.0, 3.0, 2.0, 4.0], &*csc.data);
327 }
328
329 #[test]
330 fn csc_from_ref() {
331 let mat = &[[1.0, 2.0], [3.0, 0.0], [0.0, 4.0]];
332 let csc: CscMatrix = mat.into();
333 let csc_ref: CscMatrix = (&csc).into();
334
335 if let Cow::Owned(_) = csc_ref.indptr {
337 panic!();
338 }
339 if let Cow::Owned(_) = csc_ref.indices {
340 panic!();
341 }
342 if let Cow::Owned(_) = csc_ref.data {
343 panic!();
344 }
345
346 assert_eq!(csc.nrows, csc_ref.nrows);
347 assert_eq!(csc.ncols, csc_ref.ncols);
348 assert_eq!(csc.indptr, csc_ref.indptr);
349 assert_eq!(csc.indices, csc_ref.indices);
350 assert_eq!(csc.data, csc_ref.data);
351 }
352
353 #[test]
354 fn csc_from_iter_dense() {
355 let mat1 = CscMatrix::from_column_iter_dense(
356 4,
357 3,
358 [
359 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
360 ]
361 .iter()
362 .cloned(),
363 );
364 let mat2 = CscMatrix::from_row_iter_dense(
365 4,
366 3,
367 [
368 1.0, 5.0, 9.0, 2.0, 6.0, 10.0, 3.0, 7.0, 11.0, 4.0, 8.0, 12.0,
369 ]
370 .iter()
371 .cloned(),
372 );
373 let mat3: CscMatrix = (&[
374 [1.0, 5.0, 9.0],
375 [2.0, 6.0, 10.0],
376 [3.0, 7.0, 11.0],
377 [4.0, 8.0, 12.0],
378 ])
379 .into();
380
381 assert_eq!(mat1, mat3);
382 assert_eq!(mat2, mat3);
383 }
384
385 #[test]
386 fn same_sparsity_structure_ok() {
387 let mat1: CscMatrix = (&[[1.0, 2.0, 0.0], [3.0, 0.0, 0.0], [0.0, 5.0, 0.0]]).into();
388 let mat2: CscMatrix = (&[[7.0, 8.0, 0.0], [9.0, 0.0, 0.0], [0.0, 10.0, 0.0]]).into();
389 mat1.assert_same_sparsity_structure(&mat2);
390 }
391
392 #[test]
393 #[should_panic]
394 fn different_sparsity_structure_panics() {
395 let mat1: CscMatrix = (&[[1.0, 2.0, 0.0], [3.0, 0.0, 0.0], [0.0, 5.0, 6.0]]).into();
396 let mat2: CscMatrix = (&[[7.0, 8.0, 0.0], [9.0, 0.0, 0.0], [0.0, 10.0, 0.0]]).into();
397 mat1.assert_same_sparsity_structure(&mat2);
398 }
399
400 #[test]
401 fn is_structurally_upper_tri() {
402 let structurally_upper_tri: CscMatrix =
403 (&[[1.0, 0.0, 5.0], [0.0, 3.0, 4.0], [0.0, 0.0, 2.0]]).into();
404 let numerically_upper_tri: CscMatrix = CscMatrix::from_row_iter_dense(
405 3,
406 3,
407 [1.0, 0.0, 5.0, 0.0, 3.0, 4.0, 0.0, 0.0, 2.0]
408 .iter()
409 .cloned(),
410 );
411 let not_upper_tri: CscMatrix =
412 (&[[7.0, 2.0, 0.0], [9.0, 0.0, 5.0], [7.0, 10.0, 0.0]]).into();
413 assert!(structurally_upper_tri.is_structurally_upper_tri());
414 assert!(!numerically_upper_tri.is_structurally_upper_tri());
415 assert!(!not_upper_tri.is_structurally_upper_tri());
416 }
417
418 #[test]
419 fn into_upper_tri() {
420 let mat: CscMatrix = (&[[1.0, 0.0, 5.0], [7.0, 3.0, 4.0], [6.0, 0.0, 2.0]]).into();
421 let mat_upper_tri: CscMatrix =
422 (&[[1.0, 0.0, 5.0], [0.0, 3.0, 4.0], [0.0, 0.0, 2.0]]).into();
423 assert_eq!(mat.into_upper_tri(), mat_upper_tri);
424 }
425}