mop_structs/matrix/csr_matrix/
csr_matrix_rnd.rs

1use crate::{dim::Dim, prelude::DynDenseStoMut};
2use mop_common_deps::rand::{
3  distributions::{Distribution, Uniform},
4  Rng,
5};
6
7#[derive(Debug)]
8pub struct CsrMatrixRnd<'a, DS, R, US> {
9  pub data: DS,
10  pub dim: &'a Dim<[usize; 2]>,
11  pub indcs: US,
12  pub nnz: usize,
13  pub ptrs: US,
14  pub rng: &'a mut R,
15}
16
17impl<'a, DS, R, US> CsrMatrixRnd<'a, DS, R, US>
18where
19  DS: DynDenseStoMut,
20  R: Rng,
21  US: DynDenseStoMut<Item = usize>,
22{
23  pub fn fill_data<F>(&mut self, mut cb: F)
24  where
25    F: FnMut(&mut R, [usize; 2]) -> DS::Item,
26  {
27    let (data, indcs, ptrs, rng) = (
28      &mut self.data,
29      self.indcs.as_slice(),
30      self.ptrs.as_slice(),
31      &mut self.rng,
32    );
33    ptrs.windows(2).enumerate().for_each(|(row_idx, range)| {
34      let cols = indcs[range[0]..range[1]].iter();
35      cols.for_each(|&col_idx| data.push(cb(rng, [row_idx, col_idx])))
36    });
37  }
38
39  pub fn fill_indcs(&mut self) {
40    let (dim, indcs, ptrs, rng) = (
41      self.dim,
42      &mut self.indcs,
43      self.ptrs.as_slice(),
44      &mut self.rng,
45    );
46    ptrs.windows(2).for_each(|range| {
47      let mut counter = 0;
48      let row_nnz = range[1] - range[0];
49      while counter < row_nnz {
50        let rnd = rng.gen_range(0, dim.cols());
51        if indcs.as_slice()[range[0]..].contains(&rnd) == false {
52          indcs.push(rnd);
53          counter += 1;
54        }
55      }
56      indcs.as_mut_slice()[range[0]..].sort_unstable();
57    });
58  }
59
60  pub fn fill_ptrs(&mut self) {
61    (0..=self.dim.rows()).for_each(|_| self.ptrs.push(0));
62    let ptrs_slice = self.ptrs.as_mut_slice();
63    loop {
64      let mut nnz_counter = 0;
65      for prev_idx in 0..self.dim.rows() {
66        let curr_idx = prev_idx + 1;
67        if nnz_counter < self.nnz {
68          let uniform = Uniform::from(ptrs_slice[curr_idx]..=self.dim.cols());
69          let row_nnz = uniform.sample(self.rng);
70          nnz_counter += row_nnz;
71          ptrs_slice[curr_idx] = row_nnz;
72        } else if nnz_counter > self.nnz {
73          ptrs_slice[prev_idx] = self.nnz - (nnz_counter - ptrs_slice[prev_idx]);
74          ptrs_slice[curr_idx] = 0;
75          nnz_counter = self.nnz;
76        } else {
77          ptrs_slice[curr_idx] = 0;
78        }
79      }
80      if nnz_counter > self.nnz {
81        let last = *ptrs_slice.last().unwrap();
82        *ptrs_slice.last_mut().unwrap() = self.nnz - (nnz_counter - last);
83        break;
84      } else if nnz_counter == self.nnz {
85        break;
86      }
87    }
88    (0..ptrs_slice.len() - 1).for_each(|x| ptrs_slice[x + 1] += ptrs_slice[x])
89  }
90}