use crate::{dim::Dim, prelude::DynDenseStoMut};
use mop_common_deps::rand::{
distributions::{Distribution, Uniform},
Rng,
};
#[derive(Debug)]
pub struct CsrMatrixRnd<'a, DS, R, US> {
pub data: DS,
pub dim: &'a Dim<[usize; 2]>,
pub indcs: US,
pub nnz: usize,
pub ptrs: US,
pub rng: &'a mut R,
}
impl<'a, DS, R, US> CsrMatrixRnd<'a, DS, R, US>
where
DS: DynDenseStoMut,
R: Rng,
US: DynDenseStoMut<Item = usize>,
{
pub fn fill_data<F>(&mut self, mut cb: F)
where
F: FnMut(&mut R, [usize; 2]) -> DS::Item,
{
let (data, indcs, ptrs, rng) = (
&mut self.data,
self.indcs.as_slice(),
self.ptrs.as_slice(),
&mut self.rng,
);
ptrs.windows(2).enumerate().for_each(|(row_idx, range)| {
let cols = indcs[range[0]..range[1]].iter();
cols.for_each(|&col_idx| data.push(cb(rng, [row_idx, col_idx])))
});
}
pub fn fill_indcs(&mut self) {
let (dim, indcs, ptrs, rng) = (
self.dim,
&mut self.indcs,
self.ptrs.as_slice(),
&mut self.rng,
);
ptrs.windows(2).for_each(|range| {
let mut counter = 0;
let row_nnz = range[1] - range[0];
while counter < row_nnz {
let rnd = rng.gen_range(0, dim.cols());
if indcs.as_slice()[range[0]..].contains(&rnd) == false {
indcs.push(rnd);
counter += 1;
}
}
indcs.as_mut_slice()[range[0]..].sort_unstable();
});
}
pub fn fill_ptrs(&mut self) {
(0..=self.dim.rows()).for_each(|_| self.ptrs.push(0));
let ptrs_slice = self.ptrs.as_mut_slice();
loop {
let mut nnz_counter = 0;
for prev_idx in 0..self.dim.rows() {
let curr_idx = prev_idx + 1;
if nnz_counter < self.nnz {
let uniform = Uniform::from(ptrs_slice[curr_idx]..=self.dim.cols());
let row_nnz = uniform.sample(self.rng);
nnz_counter += row_nnz;
ptrs_slice[curr_idx] = row_nnz;
} else if nnz_counter > self.nnz {
ptrs_slice[prev_idx] = self.nnz - (nnz_counter - ptrs_slice[prev_idx]);
ptrs_slice[curr_idx] = 0;
nnz_counter = self.nnz;
} else {
ptrs_slice[curr_idx] = 0;
}
}
if nnz_counter > self.nnz {
let last = *ptrs_slice.last().unwrap();
*ptrs_slice.last_mut().unwrap() = self.nnz - (nnz_counter - last);
break;
} else if nnz_counter == self.nnz {
break;
}
}
(0..ptrs_slice.len() - 1).for_each(|x| ptrs_slice[x + 1] += ptrs_slice[x])
}
}