mop_structs/matrix/csr_matrix/
csr_matrix_rnd.rs1use crate::{dim::Dim, prelude::DynDenseStoMut};
2use mop_common_deps::rand::{
3 distributions::{Distribution, Uniform},
4 Rng,
5};
6
7#[derive(Debug)]
8pub struct CsrMatrixRnd<'a, DS, R, US> {
9 pub data: DS,
10 pub dim: &'a Dim<[usize; 2]>,
11 pub indcs: US,
12 pub nnz: usize,
13 pub ptrs: US,
14 pub rng: &'a mut R,
15}
16
17impl<'a, DS, R, US> CsrMatrixRnd<'a, DS, R, US>
18where
19 DS: DynDenseStoMut,
20 R: Rng,
21 US: DynDenseStoMut<Item = usize>,
22{
23 pub fn fill_data<F>(&mut self, mut cb: F)
24 where
25 F: FnMut(&mut R, [usize; 2]) -> DS::Item,
26 {
27 let (data, indcs, ptrs, rng) = (
28 &mut self.data,
29 self.indcs.as_slice(),
30 self.ptrs.as_slice(),
31 &mut self.rng,
32 );
33 ptrs.windows(2).enumerate().for_each(|(row_idx, range)| {
34 let cols = indcs[range[0]..range[1]].iter();
35 cols.for_each(|&col_idx| data.push(cb(rng, [row_idx, col_idx])))
36 });
37 }
38
39 pub fn fill_indcs(&mut self) {
40 let (dim, indcs, ptrs, rng) = (
41 self.dim,
42 &mut self.indcs,
43 self.ptrs.as_slice(),
44 &mut self.rng,
45 );
46 ptrs.windows(2).for_each(|range| {
47 let mut counter = 0;
48 let row_nnz = range[1] - range[0];
49 while counter < row_nnz {
50 let rnd = rng.gen_range(0, dim.cols());
51 if indcs.as_slice()[range[0]..].contains(&rnd) == false {
52 indcs.push(rnd);
53 counter += 1;
54 }
55 }
56 indcs.as_mut_slice()[range[0]..].sort_unstable();
57 });
58 }
59
60 pub fn fill_ptrs(&mut self) {
61 (0..=self.dim.rows()).for_each(|_| self.ptrs.push(0));
62 let ptrs_slice = self.ptrs.as_mut_slice();
63 loop {
64 let mut nnz_counter = 0;
65 for prev_idx in 0..self.dim.rows() {
66 let curr_idx = prev_idx + 1;
67 if nnz_counter < self.nnz {
68 let uniform = Uniform::from(ptrs_slice[curr_idx]..=self.dim.cols());
69 let row_nnz = uniform.sample(self.rng);
70 nnz_counter += row_nnz;
71 ptrs_slice[curr_idx] = row_nnz;
72 } else if nnz_counter > self.nnz {
73 ptrs_slice[prev_idx] = self.nnz - (nnz_counter - ptrs_slice[prev_idx]);
74 ptrs_slice[curr_idx] = 0;
75 nnz_counter = self.nnz;
76 } else {
77 ptrs_slice[curr_idx] = 0;
78 }
79 }
80 if nnz_counter > self.nnz {
81 let last = *ptrs_slice.last().unwrap();
82 *ptrs_slice.last_mut().unwrap() = self.nnz - (nnz_counter - last);
83 break;
84 } else if nnz_counter == self.nnz {
85 break;
86 }
87 }
88 (0..ptrs_slice.len() - 1).for_each(|x| ptrs_slice[x + 1] += ptrs_slice[x])
89 }
90}