use crate::rand::{Rng, *};
use crate::sparse::{Node, SparseMatrix};
use crate::util::*;
use rand::seq::IteratorRandom;
use rayon::prelude::*;
use std::fmt;
use std::fmt::{Display, Formatter};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
NoAvailRows,
GirthTooSmall,
NoMoreBacktrack,
NoMoreTrials,
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Error::NoAvailRows => write!(f, "no rows available"),
Error::GirthTooSmall => write!(f, "girth is too small"),
Error::NoMoreBacktrack => write!(f, "exceeded backtrack trials"),
Error::NoMoreTrials => write!(f, "exceeded girth trials"),
}
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Config {
pub nrows: usize,
pub ncols: usize,
pub wr: usize,
pub wc: usize,
pub backtrack_cols: usize,
pub backtrack_trials: usize,
pub min_girth: Option<usize>,
pub girth_trials: usize,
pub fill_policy: FillPolicy,
}
impl Config {
pub fn run(&self, seed: u64) -> Result<SparseMatrix> {
MacKayNeal::new(self, seed).run()
}
pub fn search(&self, start_seed: u64, max_tries: u64) -> Option<(u64, SparseMatrix)> {
(start_seed..start_seed + max_tries)
.into_par_iter()
.filter_map(|s| self.run(s).ok().map(|x| (s, x)))
.find_any(|_| true)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FillPolicy {
Random,
Uniform,
}
struct MacKayNeal {
wr: usize,
wc: usize,
h: SparseMatrix,
rng: Rng,
backtrack_cols: usize,
backtrack_trials: usize,
min_girth: Option<usize>,
girth_trials: usize,
fill_policy: FillPolicy,
current_col: usize,
}
impl MacKayNeal {
fn new(conf: &Config, seed: u64) -> MacKayNeal {
MacKayNeal {
wr: conf.wr,
wc: conf.wc,
h: SparseMatrix::new(conf.nrows, conf.ncols),
rng: Rng::seed_from_u64(seed),
backtrack_cols: conf.backtrack_cols,
backtrack_trials: conf.backtrack_trials,
min_girth: conf.min_girth,
girth_trials: conf.girth_trials,
fill_policy: conf.fill_policy,
current_col: 0,
}
}
fn try_insert_column(&mut self) -> Result<()> {
let rows = self.select_rows()?;
self.h.insert_col(self.current_col, rows.into_iter());
if let Some(g) = self.min_girth
&& self
.h
.girth_at_node_with_max(Node::Col(self.current_col), g - 1)
.is_some()
{
self.h.clear_col(self.current_col);
return Err(Error::GirthTooSmall);
}
Ok(())
}
fn select_rows(&mut self) -> Result<Vec<usize>> {
match self.fill_policy {
FillPolicy::Random => {
let h = &self.h;
let wr = self.wr;
let avail_rows = (0..self.h.num_rows()).filter(|&r| h.row_weight(r) < wr);
let select_rows = avail_rows.choose_multiple(&mut self.rng, self.wc);
if select_rows.len() < self.wc {
return Err(Error::NoAvailRows);
}
Ok(select_rows)
}
FillPolicy::Uniform => {
let avail_rows: Vec<(usize, usize)> = (0..self.h.num_rows())
.filter_map(|r| {
let w = self.h.row_weight(r);
if w < self.wr { Some((r, w)) } else { None }
})
.collect();
avail_rows
.sort_by_random_sel(self.wc, |(_, x), (_, y)| x.cmp(y), &mut self.rng)
.map(|a| a.into_iter().map(|(x, _)| x).collect())
.ok_or(Error::NoAvailRows)
}
}
}
fn backtrack(&mut self) -> Result<()> {
if self.backtrack_trials == 0 {
return Err(Error::NoMoreBacktrack);
}
self.backtrack_trials -= 1;
let b = std::cmp::min(self.current_col, self.backtrack_cols);
let a = self.current_col - b;
for col in a..self.current_col {
self.h.clear_col(col);
}
self.current_col = a;
Ok(())
}
fn retry_girth(&mut self) -> Result<()> {
if self.girth_trials == 0 {
return Err(Error::NoMoreTrials);
}
self.girth_trials -= 1;
Ok(())
}
fn run(mut self) -> Result<SparseMatrix> {
while self.current_col < self.h.num_cols() {
match self.try_insert_column() {
Ok(_) => self.current_col += 1,
Err(Error::NoAvailRows) => self.backtrack()?,
Err(Error::GirthTooSmall) => self.retry_girth()?,
Err(e) => return Err(e),
};
}
Ok(self.h)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn small_matrix() {
let conf = Config {
nrows: 4,
ncols: 8,
wr: 4,
wc: 2,
backtrack_cols: 0,
backtrack_trials: 0,
min_girth: None,
girth_trials: 0,
fill_policy: FillPolicy::Random,
};
let h = conf.run(187).unwrap();
let alist = "8 4
2 4
2 2 2 2 2 2 2 2
4 4 4 4
1 3
2 4
2 3
1 4
1 4
1 4
2 3
2 3
1 4 5 6
2 3 7 8
1 3 7 8
2 4 5 6
";
assert_eq!(h.alist(), alist);
}
}