Skip to main content

sprs_rand/
lib.rs

1//! Random sparse matrix generation
2
3use crate::rand::distr::Distribution;
4use crate::rand::Rng;
5use crate::rand::RngExt;
6use crate::rand::SeedableRng;
7use sprs::indexing::SpIndex;
8use sprs::{CsMat, CsMatI};
9
10/// Re-export [`rand`](https://docs.rs/rand/0.7.3/rand/)
11/// for version compatibility
12pub mod rand {
13    pub use rand::*;
14}
15
16/// Re-export [`rand_distr`](https://docs.rs/rand_distr/0.2.2/rand_distr)
17/// for version compatibility
18pub mod rand_distr {
19    pub use rand_distr::*;
20}
21
22/// Generate a random sparse matrix matching the given density and sampling
23/// the values of its non-zero elements from the provided distribution.
24pub fn rand_csr<R, N, D, I>(
25    rng: &mut R,
26    dist: D,
27    shape: (usize, usize),
28    density: f64,
29) -> CsMatI<N, I>
30where
31    R: Rng + ?Sized,
32    D: Distribution<N>,
33    N: Copy,
34    I: SpIndex,
35{
36    assert!((0.0..=1.0).contains(&density));
37    let exp_nnz =
38        (density * (shape.0 as f64) * (shape.1 as f64)).ceil() as usize;
39    let mut indptr = Vec::with_capacity(shape.0 + 1);
40    let mut indices = Vec::with_capacity(exp_nnz);
41    let mut data = Vec::with_capacity(exp_nnz);
42    // sample row indices
43    for _ in 0..exp_nnz {
44        indices.push(I::from_usize(rng.random_range(0..shape.0)));
45        // Note: there won't be any correspondence between the data
46        // sampled here and the row sampled before, but this does not matter
47        // as we are sampling.
48        data.push(dist.sample(rng));
49    }
50    indices.sort_unstable();
51    indptr.push(I::from_usize(0));
52    let mut count = 0;
53    for &row in &indices {
54        while indptr.len() != row.index() + 1 {
55            indptr.push(I::from_usize(count));
56        }
57        count += 1;
58    }
59    while indptr.len() != shape.0 + 1 {
60        indptr.push(I::from_usize(count));
61    }
62    assert_eq!(indptr.last().unwrap().index(), exp_nnz);
63    indices.clear();
64    for row in 0..shape.0 {
65        let start = indptr[row].index();
66        let end = indptr[row + 1].index();
67        for _ in start..end {
68            loop {
69                let col = I::from_usize(rng.random_range(0..shape.1));
70                let loc = indices[start..].binary_search(&col);
71                if let Err(loc) = loc {
72                    indices.insert(start + loc, col);
73                    break;
74                }
75            }
76        }
77        indices[start..end].sort_unstable();
78    }
79
80    CsMatI::new(shape, indptr, indices, data)
81}
82
83/// Convenient wrapper for the common case of sampling a matrix with standard
84/// normal distribution of the nnz values, using a lightweight rng.
85pub fn rand_csr_std(shape: (usize, usize), density: f64) -> CsMat<f64> {
86    let mut rng = rand_pcg::Pcg64Mcg::from_rng(&mut rand::rng());
87    rand_csr(&mut rng, crate::rand_distr::StandardNormal, shape, density)
88}
89
90#[cfg(test)]
91mod tests {
92    use rand::distr::StandardUniform;
93    use rand::SeedableRng;
94    use sprs::CsMat;
95
96    #[test]
97    fn empty_random_mat() {
98        let mut rng = rand::rng();
99        let empty: CsMat<f64> =
100            super::rand_csr(&mut rng, StandardUniform, (0, 0), 0.3);
101        assert_eq!(empty.nnz(), 0);
102    }
103
104    #[test]
105    fn random_csr() {
106        let mut rng = rand::rngs::StdRng::seed_from_u64(1234);
107        let mat: CsMat<f64> =
108            super::rand_csr(&mut rng, StandardUniform, (100, 70), 0.3);
109        assert!(mat.density() > 0.25);
110        assert!(mat.density() < 0.35);
111
112        let mat: CsMat<f64> =
113            super::rand_csr(&mut rng, StandardUniform, (1, 10000), 0.3);
114        assert!(mat.density() > 0.28);
115        assert!(mat.density() < 0.32);
116    }
117
118    #[test]
119    fn random_csr_std() {
120        let mat = super::rand_csr_std((100, 1000), 0.2);
121        assert_eq!(mat.shape(), (100, 1000));
122        // Not checking the density as I have no control over the seed
123        // Checking the mean nnz value should be safe though
124        assert!(
125            mat.data().iter().sum::<f64>().abs() / (mat.data().len() as f64)
126                < 0.05
127        );
128    }
129}