sprs/sparse/
construct.rs

1//! High level construction of sparse matrices by stacking, by block, ...
2
3use crate::indexing::SpIndex;
4use crate::sparse::prelude::*;
5use std::cmp;
6use std::default::Default;
7
8/// Stack the given matrices into a new one, using the most efficient stacking
9/// direction (ie vertical stack for CSR matrices, horizontal stack for CSC)
10pub fn same_storage_fast_stack<'a, N, I, Iptr, MatArray>(
11    mats: &MatArray,
12) -> CsMatI<N, I, Iptr>
13where
14    N: 'a + Clone,
15    I: 'a + SpIndex,
16    Iptr: 'a + SpIndex,
17    MatArray: AsRef<[CsMatViewI<'a, N, I, Iptr>]>,
18{
19    let mats = mats.as_ref();
20    assert!(!mats.is_empty(), "Empty stacking list");
21    let inner_dim = mats[0].inner_dims();
22    assert!(
23        mats.iter().all(|x| x.inner_dims() == inner_dim),
24        "Dimension mismatch"
25    );
26    let storage_type = mats[0].storage();
27    assert!(
28        mats.iter().all(|x| x.storage() == storage_type),
29        "Storage mismatch"
30    );
31
32    let outer_dim = mats.iter().map(CsMatBase::outer_dims).sum::<usize>();
33    let nnz = mats.iter().map(CsMatBase::nnz).sum::<usize>();
34
35    let mut res = CsMatI::empty(storage_type, inner_dim);
36    res.reserve_outer_dim_exact(outer_dim);
37    res.reserve_nnz_exact(nnz);
38    for mat in mats {
39        for vec in mat.outer_iterator() {
40            res = res.append_outer_csvec(vec.view());
41        }
42    }
43
44    res
45}
46
47/// Construct a sparse matrix by vertically stacking other matrices
48pub fn vstack<'a, N, I, Iptr, MatArray>(mats: &MatArray) -> CsMatI<N, I, Iptr>
49where
50    N: 'a + Clone + Default,
51    I: 'a + SpIndex,
52    Iptr: 'a + SpIndex,
53    MatArray: AsRef<[CsMatViewI<'a, N, I, Iptr>]>,
54{
55    let mats = mats.as_ref();
56    if mats.iter().all(CsMatBase::is_csr) {
57        return same_storage_fast_stack(&mats);
58    }
59
60    let mats_csr: Vec<_> = mats.iter().map(CsMatBase::to_csr).collect();
61    let mats_csr_views: Vec<_> = mats_csr.iter().map(CsMatBase::view).collect();
62    same_storage_fast_stack(&mats_csr_views)
63}
64
65/// Construct a sparse matrix by horizontally stacking other matrices
66pub fn hstack<'a, N, I, Iptr, MatArray>(mats: &MatArray) -> CsMatI<N, I, Iptr>
67where
68    N: 'a + Clone + Default,
69    I: 'a + SpIndex,
70    Iptr: 'a + SpIndex,
71    MatArray: AsRef<[CsMatViewI<'a, N, I, Iptr>]>,
72{
73    let mats = mats.as_ref();
74    if mats.iter().all(CsMatBase::is_csc) {
75        return same_storage_fast_stack(&mats);
76    }
77
78    let mats_csc: Vec<_> = mats.iter().map(CsMatBase::to_csc).collect();
79    let mats_csc_views: Vec<_> = mats_csc.iter().map(CsMatBase::view).collect();
80    same_storage_fast_stack(&mats_csc_views)
81}
82
83/// Specify a sparse matrix by constructing it from blocks of other matrices
84///
85/// # Examples
86/// ```
87/// use sprs::CsMat;
88/// let a = CsMat::<f64>::eye(3);
89/// let b = CsMat::<f64>::eye(4);
90/// let c = sprs::bmat(&[[Some(a.view()), None],
91///                      [None, Some(b.view())]]);
92/// assert_eq!(c.rows(), 7);
93/// ```
94pub fn bmat<'a, N, I, Iptr, OuterArray, InnerArray>(
95    mats: &OuterArray,
96) -> CsMatI<N, I, Iptr>
97where
98    N: 'a + Clone + Default,
99    I: 'a + SpIndex,
100    Iptr: 'a + SpIndex,
101    OuterArray: 'a + AsRef<[InnerArray]>,
102    InnerArray: 'a + AsRef<[Option<CsMatViewI<'a, N, I, Iptr>>]>,
103{
104    let mats = mats.as_ref();
105    let super_rows = mats.len();
106    assert_ne!(super_rows, 0, "Empty stacking list");
107    let super_cols = mats[0].as_ref().len();
108    assert_ne!(super_cols, 0, "Empty stacking list");
109
110    // check input has matrix shape
111    assert!(
112        mats.iter().all(|x| x.as_ref().len() == super_cols),
113        "Dimension mismatch"
114    );
115
116    assert!(
117        !mats.iter().any(|x| x.as_ref().iter().all(Option::is_none)),
118        "Empty bmat row"
119    );
120    assert!(
121        !(0..super_cols).any(|j| mats.iter().all(|x| x.as_ref()[j].is_none())),
122        "Empty bmat col"
123    );
124
125    // find out the shapes of the None elements
126    let rows_per_row: Vec<_> = mats
127        .iter()
128        .map(|row| {
129            row.as_ref().iter().fold(0, |nrows, mopt| {
130                mopt.as_ref().map_or(nrows, |m| cmp::max(nrows, m.rows()))
131            })
132        })
133        .collect();
134    let cols_per_col: Vec<_> = (0..super_cols)
135        .map(|j| {
136            mats.iter().fold(0, |ncols, row| {
137                row.as_ref()[j]
138                    .as_ref()
139                    .map_or(ncols, |m| cmp::max(ncols, m.cols()))
140            })
141        })
142        .collect();
143    let mut to_vstack = Vec::with_capacity(super_rows);
144    for (i, row) in mats.iter().enumerate() {
145        let with_zeros: Vec<_> = row
146            .as_ref()
147            .iter()
148            .enumerate()
149            .map(|(j, m)| {
150                let shape = (rows_per_row[i], cols_per_col[j]);
151                m.as_ref().map_or(CsMatI::zero(shape), CsMatBase::to_owned)
152            })
153            .collect();
154        let borrows: Vec<_> = with_zeros.iter().map(CsMatBase::view).collect();
155        let stacked = hstack(&borrows);
156        to_vstack.push(stacked);
157    }
158    let borrows: Vec<_> = to_vstack.iter().map(CsMatBase::view).collect();
159    vstack(&borrows)
160}
161
162#[cfg(test)]
163mod test {
164    use crate::sparse::CsMat;
165    use crate::test_data::{mat1, mat2, mat3, mat4};
166
167    fn mat1_vstack_mat2() -> CsMat<f64> {
168        let indptr = vec![0, 2, 4, 5, 6, 7, 11, 13, 13, 15, 17];
169        let indices = vec![2, 3, 3, 4, 2, 1, 3, 0, 1, 2, 4, 0, 3, 2, 3, 1, 2];
170        let data = vec![
171            3., 4., 2., 5., 5., 8., 7., 6., 7., 3., 3., 8., 9., 2., 4., 4., 4.,
172        ];
173        CsMat::new((10, 5), indptr, indices, data)
174    }
175
176    #[test]
177    #[should_panic]
178    fn same_storage_fast_stack_fail_empty_stacking_list() {
179        let _: CsMat<f64> = super::same_storage_fast_stack(&[]);
180    }
181
182    #[test]
183    #[should_panic]
184    fn same_storage_fast_stack_fail_dim_mismatch() {
185        let a = mat1();
186        let c = mat3();
187        let _ = super::same_storage_fast_stack(&[a.view(), c.view()]);
188    }
189
190    #[test]
191    #[should_panic]
192    fn same_storage_fast_stack_fail_storage() {
193        let a = mat1();
194        let d = mat4();
195        let _ = super::same_storage_fast_stack(&[a.view(), d.view()]);
196    }
197
198    #[test]
199    fn same_storage_fast_stack_ok() {
200        let a = mat1();
201        let b = mat2();
202        let res = super::same_storage_fast_stack(&[a.view(), b.view()]);
203        let expected = mat1_vstack_mat2();
204        assert_eq!(res, expected);
205    }
206
207    #[test]
208    fn vstack_trivial() {
209        let a = mat1();
210        let b = mat2();
211        let res = super::vstack(&[a.view(), b.view()]);
212        let expected = mat1_vstack_mat2();
213        assert_eq!(res, expected);
214    }
215
216    #[test]
217    fn hstack_trivial() {
218        let a = mat1().transpose_into();
219        let b = mat2().transpose_into();
220        let res = super::hstack(&[a.view(), b.view()]);
221        let expected = mat1_vstack_mat2().transpose_into();
222        assert_eq!(res, expected);
223    }
224
225    #[test]
226    fn vstack_with_conversion() {
227        let a = mat1().to_csc();
228        let b = mat2();
229        let res = super::vstack(&[a.view(), b.view()]);
230        let expected = mat1_vstack_mat2();
231        assert_eq!(res, expected);
232    }
233
234    #[test]
235    #[should_panic]
236    fn bmat_fail_shapes() {
237        let _: CsMat<f64> = super::bmat(&vec![vec![None, None], vec![None]]);
238    }
239
240    #[test]
241    #[should_panic]
242    fn bmat_fail_empty_stacking_list() {
243        let _: CsMat<f64> = super::bmat(&[[]]);
244    }
245
246    #[test]
247    #[should_panic]
248    fn bmat_fail_empty_bmat_row() {
249        let a = mat1();
250        let c = mat3();
251        let _: CsMat<f64> =
252            super::bmat(&[[None, None], [Some(a.view()), Some(c.view())]]);
253    }
254
255    #[test]
256    #[should_panic]
257    fn bmat_fail_empty_bmat_col() {
258        let a = mat1();
259        let c = mat3();
260        let _: CsMat<f64> =
261            super::bmat(&[[Some(c.view()), None], [Some(a.view()), None]]);
262    }
263
264    #[test]
265    fn bmat_simple() {
266        let a = CsMat::<f64>::eye(5);
267        let b = CsMat::<f64>::eye(4);
268        let c = super::bmat(&[[Some(a.view()), None], [None, Some(b.view())]]);
269        let expected = CsMat::new(
270            (9, 9),
271            vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
272            vec![0, 1, 2, 3, 4, 5, 6, 7, 8],
273            vec![1.; 9],
274        );
275        assert_eq!(c, expected);
276    }
277
278    #[test]
279    fn bmat_complex() {
280        let a = mat1();
281        let b = mat2();
282        let c = super::bmat(&[
283            [Some(a.view()), Some(b.view())],
284            [Some(b.view()), None],
285        ]);
286        let expected = CsMat::new(
287            (10, 10),
288            vec![0, 6, 10, 11, 14, 17, 21, 23, 23, 25, 27],
289            vec![
290                2, 3, 5, 6, 7, 9, 3, 4, 5, 8, 2, 1, 7, 8, 3, 6, 7, 0, 1, 2, 4,
291                0, 3, 2, 3, 1, 2,
292            ],
293            vec![
294                3., 4., 6., 7., 3., 3., 2., 5., 8., 9., 5., 8., 2., 4., 7., 4.,
295                4., 6., 7., 3., 3., 8., 9., 2., 4., 4., 4.,
296            ],
297        );
298        assert_eq!(c, expected);
299
300        let d = mat3();
301        let e = mat4();
302        let f = super::bmat(&[
303            [Some(d.view()), Some(a.view())],
304            [None, Some(e.view())],
305        ]);
306        let expected = CsMat::new(
307            (10, 9),
308            vec![0, 4, 8, 10, 12, 14, 16, 18, 21, 23, 24],
309            vec![
310                2, 3, 6, 7, 2, 3, 7, 8, 2, 6, 1, 5, 3, 7, 4, 5, 4, 8, 4, 7, 8,
311                5, 7, 4,
312            ],
313            vec![
314                3., 4., 3., 4., 2., 5., 2., 5., 5., 5., 8., 8., 7., 7., 6., 8.,
315                7., 4., 3., 2., 4., 9., 4., 3.,
316            ],
317        );
318        assert_eq!(f, expected);
319    }
320}