clarabel/algebra/csc/
block_concatenate.rs

1use crate::algebra::hvcat_dim_check;
2use crate::algebra::matrix_traits::ShapedMatrix;
3use crate::algebra::MatrixConcatenationError;
4use std::cmp::max;
5
6use crate::algebra::{BlockConcatenate, CscMatrix, FloatT, MatrixShape};
7
8impl<T> BlockConcatenate for CscMatrix<T>
9where
10    T: FloatT,
11{
12    fn hcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError> {
13        Self::hvcat(&[&[A, B]])
14    }
15
16    fn vcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError> {
17        Self::hvcat(&[&[A], &[B]])
18    }
19
20    //PJG: This might be modifiable to allow Adjoint and Symmetric
21    //inputs as well.
22    fn blockdiag(mats: &[&Self]) -> Result<Self, MatrixConcatenationError> {
23        if mats.is_empty() {
24            return Err(MatrixConcatenationError::IncompatibleDimension);
25        }
26
27        let mut nrows = 0;
28        let mut ncols = 0;
29        let mut nnzM = 0;
30        for mat in mats {
31            nrows += mat.nrows();
32            ncols += mat.ncols();
33            nnzM += mat.nnz();
34        }
35        let mut M = CscMatrix::<T>::spalloc((nrows, ncols), nnzM);
36
37        // assemble the column counts
38        M.colptr.fill(0);
39
40        let mut nextcol = 0;
41        for mat in mats {
42            M.colcount_block(mat, nextcol, MatrixShape::N);
43            nextcol += mat.ncols();
44        }
45
46        M.colcount_to_colptr();
47
48        //PJG: create fake data map showing where the
49        //entries go.   Probably this should be an Option
50        //instead, but that requires rewriting some of the
51        //KKT assembly code.
52
53        // unwrap is fine since this is unreachable for empty input
54        let dummylength = mats.iter().map(|m| m.nnz()).max().unwrap();
55        let mut dummymap = vec![0; dummylength];
56
57        // fill in data and rebuild colptr
58        let mut nextrow = 0;
59        let mut nextcol = 0;
60        for mat in mats {
61            M.fill_block(mat, &mut dummymap, nextrow, nextcol, MatrixShape::N);
62            nextrow += mat.nrows();
63            nextcol += mat.ncols();
64        }
65
66        M.backshift_colptrs();
67
68        Ok(M)
69    }
70
71    fn hvcat(mats: &[&[&Self]]) -> Result<Self, MatrixConcatenationError> {
72        // check for consistent block dimensions
73        hvcat_dim_check(mats)?;
74
75        // dimensions are consistent and nonzero, so count
76        // total rows and columns by counting along the border
77        let nrows = mats.iter().map(|blockrow| blockrow[0].nrows()).sum();
78        let ncols = mats[0].iter().map(|topblock| topblock.ncols()).sum();
79
80        let mut nnzM = 0;
81        let mut maxblocknnz = 0; // for dummy mapping below
82        for &blockrow in mats {
83            for mat in blockrow {
84                let blocknnz = mat.nnz();
85                maxblocknnz = max(maxblocknnz, blocknnz);
86                nnzM += blocknnz;
87            }
88        }
89
90        let mut M = CscMatrix::<T>::spalloc((nrows, ncols), nnzM);
91
92        // assemble the column counts
93        M.colptr.fill(0);
94        let mut currentcol = 0;
95        for i in 0..mats[0].len() {
96            for blockrow in mats {
97                M.colcount_block(blockrow[i], currentcol, MatrixShape::N);
98            }
99            currentcol += mats[0][i].ncols();
100        }
101
102        M.colcount_to_colptr();
103
104        //PJG: create fake data maps showing where the
105        //entries go.   Probably this should be an Option
106        //instead, but that requires rewriting some of the
107        //KKT assembly code
108        let mut dummymap = vec![0; maxblocknnz];
109
110        // fill in data and rebuild colptr
111        let mut currentcol = 0;
112        for i in 0..mats[0].len() {
113            let mut currentrow = 0;
114            for blockrow in mats {
115                M.fill_block(
116                    blockrow[i],
117                    &mut dummymap,
118                    currentrow,
119                    currentcol,
120                    MatrixShape::N,
121                );
122                currentrow += blockrow[i].nrows();
123            }
124            currentcol += mats[0][i].ncols();
125        }
126
127        M.backshift_colptrs();
128
129        Ok(M)
130    }
131}
132
133#[test]
134fn test_dense_concatenate() {
135    let A = CscMatrix::from(&[
136        [1., 3.], //
137        [2., 4.], //
138    ]);
139    let B = CscMatrix::from(&[
140        [5., 7.], //
141        [6., 8.], //
142    ]);
143
144    // horizontal
145    let C = CscMatrix::hcat(&A, &B).unwrap();
146    let Ctest = CscMatrix::from(&[
147        [1., 3., 5., 7.], //
148        [2., 4., 6., 8.], //
149    ]);
150
151    assert_eq!(C, Ctest);
152
153    // vertical
154    let C = CscMatrix::vcat(&A, &B).unwrap();
155    let Ctest = CscMatrix::from(&[
156        [1., 3.], //
157        [2., 4.], //
158        [5., 7.], //
159        [6., 8.], //
160    ]);
161    assert_eq!(C, Ctest);
162
163    // 2 x 2 block
164    let C = CscMatrix::hvcat(&[&[&A, &B], &[&B, &A]]).unwrap();
165    let Ctest = CscMatrix::from(&[
166        [1., 3., 5., 7.], //
167        [2., 4., 6., 8.], //
168        [5., 7., 1., 3.], //
169        [6., 8., 2., 4.], //
170    ]);
171    assert_eq!(C, Ctest);
172
173    // block diagonal
174    let C = CscMatrix::blockdiag(&[&A, &B]).unwrap();
175    let Ctest = CscMatrix::from(&[
176        [1., 3., 0., 0.], //
177        [2., 4., 0., 0.], //
178        [0., 0., 5., 7.], //
179        [0., 0., 6., 8.], //
180    ]);
181    assert_eq!(C, Ctest);
182}