1use crate::algebra::hvcat_dim_check;
2use crate::algebra::matrix_traits::ShapedMatrix;
3use crate::algebra::MatrixConcatenationError;
4use std::cmp::max;
5
6use crate::algebra::{BlockConcatenate, CscMatrix, FloatT, MatrixShape};
7
8impl<T> BlockConcatenate for CscMatrix<T>
9where
10 T: FloatT,
11{
12 fn hcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError> {
13 Self::hvcat(&[&[A, B]])
14 }
15
16 fn vcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError> {
17 Self::hvcat(&[&[A], &[B]])
18 }
19
20 fn blockdiag(mats: &[&Self]) -> Result<Self, MatrixConcatenationError> {
23 if mats.is_empty() {
24 return Err(MatrixConcatenationError::IncompatibleDimension);
25 }
26
27 let mut nrows = 0;
28 let mut ncols = 0;
29 let mut nnzM = 0;
30 for mat in mats {
31 nrows += mat.nrows();
32 ncols += mat.ncols();
33 nnzM += mat.nnz();
34 }
35 let mut M = CscMatrix::<T>::spalloc((nrows, ncols), nnzM);
36
37 M.colptr.fill(0);
39
40 let mut nextcol = 0;
41 for mat in mats {
42 M.colcount_block(mat, nextcol, MatrixShape::N);
43 nextcol += mat.ncols();
44 }
45
46 M.colcount_to_colptr();
47
48 let dummylength = mats.iter().map(|m| m.nnz()).max().unwrap();
55 let mut dummymap = vec![0; dummylength];
56
57 let mut nextrow = 0;
59 let mut nextcol = 0;
60 for mat in mats {
61 M.fill_block(mat, &mut dummymap, nextrow, nextcol, MatrixShape::N);
62 nextrow += mat.nrows();
63 nextcol += mat.ncols();
64 }
65
66 M.backshift_colptrs();
67
68 Ok(M)
69 }
70
71 fn hvcat(mats: &[&[&Self]]) -> Result<Self, MatrixConcatenationError> {
72 hvcat_dim_check(mats)?;
74
75 let nrows = mats.iter().map(|blockrow| blockrow[0].nrows()).sum();
78 let ncols = mats[0].iter().map(|topblock| topblock.ncols()).sum();
79
80 let mut nnzM = 0;
81 let mut maxblocknnz = 0; for &blockrow in mats {
83 for mat in blockrow {
84 let blocknnz = mat.nnz();
85 maxblocknnz = max(maxblocknnz, blocknnz);
86 nnzM += blocknnz;
87 }
88 }
89
90 let mut M = CscMatrix::<T>::spalloc((nrows, ncols), nnzM);
91
92 M.colptr.fill(0);
94 let mut currentcol = 0;
95 for i in 0..mats[0].len() {
96 for blockrow in mats {
97 M.colcount_block(blockrow[i], currentcol, MatrixShape::N);
98 }
99 currentcol += mats[0][i].ncols();
100 }
101
102 M.colcount_to_colptr();
103
104 let mut dummymap = vec![0; maxblocknnz];
109
110 let mut currentcol = 0;
112 for i in 0..mats[0].len() {
113 let mut currentrow = 0;
114 for blockrow in mats {
115 M.fill_block(
116 blockrow[i],
117 &mut dummymap,
118 currentrow,
119 currentcol,
120 MatrixShape::N,
121 );
122 currentrow += blockrow[i].nrows();
123 }
124 currentcol += mats[0][i].ncols();
125 }
126
127 M.backshift_colptrs();
128
129 Ok(M)
130 }
131}
132
133#[test]
134fn test_dense_concatenate() {
135 let A = CscMatrix::from(&[
136 [1., 3.], [2., 4.], ]);
139 let B = CscMatrix::from(&[
140 [5., 7.], [6., 8.], ]);
143
144 let C = CscMatrix::hcat(&A, &B).unwrap();
146 let Ctest = CscMatrix::from(&[
147 [1., 3., 5., 7.], [2., 4., 6., 8.], ]);
150
151 assert_eq!(C, Ctest);
152
153 let C = CscMatrix::vcat(&A, &B).unwrap();
155 let Ctest = CscMatrix::from(&[
156 [1., 3.], [2., 4.], [5., 7.], [6., 8.], ]);
161 assert_eq!(C, Ctest);
162
163 let C = CscMatrix::hvcat(&[&[&A, &B], &[&B, &A]]).unwrap();
165 let Ctest = CscMatrix::from(&[
166 [1., 3., 5., 7.], [2., 4., 6., 8.], [5., 7., 1., 3.], [6., 8., 2., 4.], ]);
171 assert_eq!(C, Ctest);
172
173 let C = CscMatrix::blockdiag(&[&A, &B]).unwrap();
175 let Ctest = CscMatrix::from(&[
176 [1., 3., 0., 0.], [2., 4., 0., 0.], [0., 0., 5., 7.], [0., 0., 6., 8.], ]);
181 assert_eq!(C, Ctest);
182}