1use crate::indexing::SpIndex;
4use crate::sparse::prelude::*;
5use std::cmp;
6use std::default::Default;
7
8pub fn same_storage_fast_stack<'a, N, I, Iptr, MatArray>(
11 mats: &MatArray,
12) -> CsMatI<N, I, Iptr>
13where
14 N: 'a + Clone,
15 I: 'a + SpIndex,
16 Iptr: 'a + SpIndex,
17 MatArray: AsRef<[CsMatViewI<'a, N, I, Iptr>]>,
18{
19 let mats = mats.as_ref();
20 assert!(!mats.is_empty(), "Empty stacking list");
21 let inner_dim = mats[0].inner_dims();
22 assert!(
23 mats.iter().all(|x| x.inner_dims() == inner_dim),
24 "Dimension mismatch"
25 );
26 let storage_type = mats[0].storage();
27 assert!(
28 mats.iter().all(|x| x.storage() == storage_type),
29 "Storage mismatch"
30 );
31
32 let outer_dim = mats.iter().map(CsMatBase::outer_dims).sum::<usize>();
33 let nnz = mats.iter().map(CsMatBase::nnz).sum::<usize>();
34
35 let mut res = CsMatI::empty(storage_type, inner_dim);
36 res.reserve_outer_dim_exact(outer_dim);
37 res.reserve_nnz_exact(nnz);
38 for mat in mats {
39 for vec in mat.outer_iterator() {
40 res = res.append_outer_csvec(vec.view());
41 }
42 }
43
44 res
45}
46
47pub fn vstack<'a, N, I, Iptr, MatArray>(mats: &MatArray) -> CsMatI<N, I, Iptr>
49where
50 N: 'a + Clone + Default,
51 I: 'a + SpIndex,
52 Iptr: 'a + SpIndex,
53 MatArray: AsRef<[CsMatViewI<'a, N, I, Iptr>]>,
54{
55 let mats = mats.as_ref();
56 if mats.iter().all(CsMatBase::is_csr) {
57 return same_storage_fast_stack(&mats);
58 }
59
60 let mats_csr: Vec<_> = mats.iter().map(CsMatBase::to_csr).collect();
61 let mats_csr_views: Vec<_> = mats_csr.iter().map(CsMatBase::view).collect();
62 same_storage_fast_stack(&mats_csr_views)
63}
64
65pub fn hstack<'a, N, I, Iptr, MatArray>(mats: &MatArray) -> CsMatI<N, I, Iptr>
67where
68 N: 'a + Clone + Default,
69 I: 'a + SpIndex,
70 Iptr: 'a + SpIndex,
71 MatArray: AsRef<[CsMatViewI<'a, N, I, Iptr>]>,
72{
73 let mats = mats.as_ref();
74 if mats.iter().all(CsMatBase::is_csc) {
75 return same_storage_fast_stack(&mats);
76 }
77
78 let mats_csc: Vec<_> = mats.iter().map(CsMatBase::to_csc).collect();
79 let mats_csc_views: Vec<_> = mats_csc.iter().map(CsMatBase::view).collect();
80 same_storage_fast_stack(&mats_csc_views)
81}
82
83pub fn bmat<'a, N, I, Iptr, OuterArray, InnerArray>(
95 mats: &OuterArray,
96) -> CsMatI<N, I, Iptr>
97where
98 N: 'a + Clone + Default,
99 I: 'a + SpIndex,
100 Iptr: 'a + SpIndex,
101 OuterArray: 'a + AsRef<[InnerArray]>,
102 InnerArray: 'a + AsRef<[Option<CsMatViewI<'a, N, I, Iptr>>]>,
103{
104 let mats = mats.as_ref();
105 let super_rows = mats.len();
106 assert_ne!(super_rows, 0, "Empty stacking list");
107 let super_cols = mats[0].as_ref().len();
108 assert_ne!(super_cols, 0, "Empty stacking list");
109
110 assert!(
112 mats.iter().all(|x| x.as_ref().len() == super_cols),
113 "Dimension mismatch"
114 );
115
116 assert!(
117 !mats.iter().any(|x| x.as_ref().iter().all(Option::is_none)),
118 "Empty bmat row"
119 );
120 assert!(
121 !(0..super_cols).any(|j| mats.iter().all(|x| x.as_ref()[j].is_none())),
122 "Empty bmat col"
123 );
124
125 let rows_per_row: Vec<_> = mats
127 .iter()
128 .map(|row| {
129 row.as_ref().iter().fold(0, |nrows, mopt| {
130 mopt.as_ref().map_or(nrows, |m| cmp::max(nrows, m.rows()))
131 })
132 })
133 .collect();
134 let cols_per_col: Vec<_> = (0..super_cols)
135 .map(|j| {
136 mats.iter().fold(0, |ncols, row| {
137 row.as_ref()[j]
138 .as_ref()
139 .map_or(ncols, |m| cmp::max(ncols, m.cols()))
140 })
141 })
142 .collect();
143 let mut to_vstack = Vec::with_capacity(super_rows);
144 for (i, row) in mats.iter().enumerate() {
145 let with_zeros: Vec<_> = row
146 .as_ref()
147 .iter()
148 .enumerate()
149 .map(|(j, m)| {
150 let shape = (rows_per_row[i], cols_per_col[j]);
151 m.as_ref().map_or(CsMatI::zero(shape), CsMatBase::to_owned)
152 })
153 .collect();
154 let borrows: Vec<_> = with_zeros.iter().map(CsMatBase::view).collect();
155 let stacked = hstack(&borrows);
156 to_vstack.push(stacked);
157 }
158 let borrows: Vec<_> = to_vstack.iter().map(CsMatBase::view).collect();
159 vstack(&borrows)
160}
161
162#[cfg(test)]
163mod test {
164 use crate::sparse::CsMat;
165 use crate::test_data::{mat1, mat2, mat3, mat4};
166
167 fn mat1_vstack_mat2() -> CsMat<f64> {
168 let indptr = vec![0, 2, 4, 5, 6, 7, 11, 13, 13, 15, 17];
169 let indices = vec![2, 3, 3, 4, 2, 1, 3, 0, 1, 2, 4, 0, 3, 2, 3, 1, 2];
170 let data = vec![
171 3., 4., 2., 5., 5., 8., 7., 6., 7., 3., 3., 8., 9., 2., 4., 4., 4.,
172 ];
173 CsMat::new((10, 5), indptr, indices, data)
174 }
175
176 #[test]
177 #[should_panic]
178 fn same_storage_fast_stack_fail_empty_stacking_list() {
179 let _: CsMat<f64> = super::same_storage_fast_stack(&[]);
180 }
181
182 #[test]
183 #[should_panic]
184 fn same_storage_fast_stack_fail_dim_mismatch() {
185 let a = mat1();
186 let c = mat3();
187 let _ = super::same_storage_fast_stack(&[a.view(), c.view()]);
188 }
189
190 #[test]
191 #[should_panic]
192 fn same_storage_fast_stack_fail_storage() {
193 let a = mat1();
194 let d = mat4();
195 let _ = super::same_storage_fast_stack(&[a.view(), d.view()]);
196 }
197
198 #[test]
199 fn same_storage_fast_stack_ok() {
200 let a = mat1();
201 let b = mat2();
202 let res = super::same_storage_fast_stack(&[a.view(), b.view()]);
203 let expected = mat1_vstack_mat2();
204 assert_eq!(res, expected);
205 }
206
207 #[test]
208 fn vstack_trivial() {
209 let a = mat1();
210 let b = mat2();
211 let res = super::vstack(&[a.view(), b.view()]);
212 let expected = mat1_vstack_mat2();
213 assert_eq!(res, expected);
214 }
215
216 #[test]
217 fn hstack_trivial() {
218 let a = mat1().transpose_into();
219 let b = mat2().transpose_into();
220 let res = super::hstack(&[a.view(), b.view()]);
221 let expected = mat1_vstack_mat2().transpose_into();
222 assert_eq!(res, expected);
223 }
224
225 #[test]
226 fn vstack_with_conversion() {
227 let a = mat1().to_csc();
228 let b = mat2();
229 let res = super::vstack(&[a.view(), b.view()]);
230 let expected = mat1_vstack_mat2();
231 assert_eq!(res, expected);
232 }
233
234 #[test]
235 #[should_panic]
236 fn bmat_fail_shapes() {
237 let _: CsMat<f64> = super::bmat(&vec![vec![None, None], vec![None]]);
238 }
239
240 #[test]
241 #[should_panic]
242 fn bmat_fail_empty_stacking_list() {
243 let _: CsMat<f64> = super::bmat(&[[]]);
244 }
245
246 #[test]
247 #[should_panic]
248 fn bmat_fail_empty_bmat_row() {
249 let a = mat1();
250 let c = mat3();
251 let _: CsMat<f64> =
252 super::bmat(&[[None, None], [Some(a.view()), Some(c.view())]]);
253 }
254
255 #[test]
256 #[should_panic]
257 fn bmat_fail_empty_bmat_col() {
258 let a = mat1();
259 let c = mat3();
260 let _: CsMat<f64> =
261 super::bmat(&[[Some(c.view()), None], [Some(a.view()), None]]);
262 }
263
264 #[test]
265 fn bmat_simple() {
266 let a = CsMat::<f64>::eye(5);
267 let b = CsMat::<f64>::eye(4);
268 let c = super::bmat(&[[Some(a.view()), None], [None, Some(b.view())]]);
269 let expected = CsMat::new(
270 (9, 9),
271 vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
272 vec![0, 1, 2, 3, 4, 5, 6, 7, 8],
273 vec![1.; 9],
274 );
275 assert_eq!(c, expected);
276 }
277
278 #[test]
279 fn bmat_complex() {
280 let a = mat1();
281 let b = mat2();
282 let c = super::bmat(&[
283 [Some(a.view()), Some(b.view())],
284 [Some(b.view()), None],
285 ]);
286 let expected = CsMat::new(
287 (10, 10),
288 vec![0, 6, 10, 11, 14, 17, 21, 23, 23, 25, 27],
289 vec![
290 2, 3, 5, 6, 7, 9, 3, 4, 5, 8, 2, 1, 7, 8, 3, 6, 7, 0, 1, 2, 4,
291 0, 3, 2, 3, 1, 2,
292 ],
293 vec![
294 3., 4., 6., 7., 3., 3., 2., 5., 8., 9., 5., 8., 2., 4., 7., 4.,
295 4., 6., 7., 3., 3., 8., 9., 2., 4., 4., 4.,
296 ],
297 );
298 assert_eq!(c, expected);
299
300 let d = mat3();
301 let e = mat4();
302 let f = super::bmat(&[
303 [Some(d.view()), Some(a.view())],
304 [None, Some(e.view())],
305 ]);
306 let expected = CsMat::new(
307 (10, 9),
308 vec![0, 4, 8, 10, 12, 14, 16, 18, 21, 23, 24],
309 vec![
310 2, 3, 6, 7, 2, 3, 7, 8, 2, 6, 1, 5, 3, 7, 4, 5, 4, 8, 4, 7, 8,
311 5, 7, 4,
312 ],
313 vec![
314 3., 4., 3., 4., 2., 5., 2., 5., 5., 5., 8., 8., 7., 7., 6., 8.,
315 7., 4., 3., 2., 4., 9., 4., 3.,
316 ],
317 );
318 assert_eq!(f, expected);
319 }
320}