mop-structs 0.0.10

Low-level structures for MOP
Documentation
use crate::{dim::Dim, prelude::DynDenseStoMut};
use mop_common_deps::rand::{
  distributions::{Distribution, Uniform},
  Rng,
};

#[derive(Debug)]
pub struct CsrMatrixRnd<'a, DS, R, US> {
  pub data: DS,
  pub dim: &'a Dim<[usize; 2]>,
  pub indcs: US,
  pub nnz: usize,
  pub ptrs: US,
  pub rng: &'a mut R,
}

impl<'a, DS, R, US> CsrMatrixRnd<'a, DS, R, US>
where
  DS: DynDenseStoMut,
  R: Rng,
  US: DynDenseStoMut<Item = usize>,
{
  pub fn fill_data<F>(&mut self, mut cb: F)
  where
    F: FnMut(&mut R, [usize; 2]) -> DS::Item,
  {
    let (data, indcs, ptrs, rng) = (
      &mut self.data,
      self.indcs.as_slice(),
      self.ptrs.as_slice(),
      &mut self.rng,
    );
    ptrs.windows(2).enumerate().for_each(|(row_idx, range)| {
      let cols = indcs[range[0]..range[1]].iter();
      cols.for_each(|&col_idx| data.push(cb(rng, [row_idx, col_idx])))
    });
  }

  pub fn fill_indcs(&mut self) {
    let (dim, indcs, ptrs, rng) = (
      self.dim,
      &mut self.indcs,
      self.ptrs.as_slice(),
      &mut self.rng,
    );
    ptrs.windows(2).for_each(|range| {
      let mut counter = 0;
      let row_nnz = range[1] - range[0];
      while counter < row_nnz {
        let rnd = rng.gen_range(0, dim.cols());
        if indcs.as_slice()[range[0]..].contains(&rnd) == false {
          indcs.push(rnd);
          counter += 1;
        }
      }
      indcs.as_mut_slice()[range[0]..].sort_unstable();
    });
  }

  pub fn fill_ptrs(&mut self) {
    (0..=self.dim.rows()).for_each(|_| self.ptrs.push(0));
    let ptrs_slice = self.ptrs.as_mut_slice();
    loop {
      let mut nnz_counter = 0;
      for prev_idx in 0..self.dim.rows() {
        let curr_idx = prev_idx + 1;
        if nnz_counter < self.nnz {
          let uniform = Uniform::from(ptrs_slice[curr_idx]..=self.dim.cols());
          let row_nnz = uniform.sample(self.rng);
          nnz_counter += row_nnz;
          ptrs_slice[curr_idx] = row_nnz;
        } else if nnz_counter > self.nnz {
          ptrs_slice[prev_idx] = self.nnz - (nnz_counter - ptrs_slice[prev_idx]);
          ptrs_slice[curr_idx] = 0;
          nnz_counter = self.nnz;
        } else {
          ptrs_slice[curr_idx] = 0;
        }
      }
      if nnz_counter > self.nnz {
        let last = *ptrs_slice.last().unwrap();
        *ptrs_slice.last_mut().unwrap() = self.nnz - (nnz_counter - last);
        break;
      } else if nnz_counter == self.nnz {
        break;
      }
    }
    (0..ptrs_slice.len() - 1).for_each(|x| ptrs_slice[x + 1] += ptrs_slice[x])
  }
}