use core::cmp;
use core::fmt;
use core::iter::Map;
use core::mem;
use core::ops::Range;
use core::ops::{Index, IndexMut};
use core::slice;
use crate::local_prelude::*;
use crate::util::{div_rem, round_up_to_next};
pub struct BitSubMatrix<'a> {
slice: &'a [Block],
row_bits: usize,
}
pub struct BitSubMatrixMut<'a> {
slice: &'a mut [Block],
row_bits: usize,
}
impl<'a> BitSubMatrix<'a> {
pub fn new(slice: &[Block], row_bits: usize) -> BitSubMatrix {
BitSubMatrix {
slice: slice,
row_bits: row_bits,
}
}
#[inline]
pub unsafe fn from_raw_parts(ptr: *const Block, rows: usize, row_bits: usize) -> Self {
BitSubMatrix {
slice: slice::from_raw_parts(ptr, round_up_to_next(row_bits, BITS) / BITS * rows),
row_bits: row_bits,
}
}
pub fn iter(&self) -> Map<slice::Chunks<Block>, fn(&[Block]) -> &BitSlice> {
fn f(arg: &[Block]) -> &BitSlice {
unsafe { mem::transmute(arg) }
}
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
self.slice.chunks(row_size).map(f)
}
}
impl<'a> BitSubMatrixMut<'a> {
pub fn new(slice: &mut [Block], row_bits: usize) -> BitSubMatrixMut {
BitSubMatrixMut {
slice: slice,
row_bits: row_bits,
}
}
#[inline]
pub unsafe fn from_raw_parts(ptr: *mut Block, rows: usize, row_bits: usize) -> Self {
BitSubMatrixMut {
slice: slice::from_raw_parts_mut(ptr, round_up_to_next(row_bits, BITS) / BITS * rows),
row_bits: row_bits,
}
}
#[inline]
fn num_rows(&self) -> usize {
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
if row_size == 0 {
0
} else {
self.slice.len() / row_size
}
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, enabled: bool) {
let row_size_in_bits = round_up_to_next(self.row_bits, BITS);
let bit = row * row_size_in_bits + col;
let (block, i) = div_rem(bit, BITS);
assert!(block < self.slice.len() && col < self.row_bits);
unsafe {
let elt = self.slice.get_unchecked_mut(block);
if enabled {
*elt |= 1 << i;
} else {
*elt &= !(1 << i);
}
}
}
#[inline]
pub fn sub_matrix(&self, range: Range<usize>) -> BitSubMatrix {
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
BitSubMatrix {
slice: &self.slice[range.start * row_size..range.end * row_size],
row_bits: self.row_bits,
}
}
#[inline]
pub fn split_at(&self, row: usize) -> (BitSubMatrix, BitSubMatrix) {
(
self.sub_matrix(0..row),
self.sub_matrix(row..self.num_rows()),
)
}
#[inline]
pub fn split_at_mut(&mut self, row: usize) -> (BitSubMatrixMut, BitSubMatrixMut) {
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
let (first, second) = self.slice.split_at_mut(row * row_size);
(
BitSubMatrixMut::new(first, self.row_bits),
BitSubMatrixMut::new(second, self.row_bits),
)
}
pub fn transitive_closure(&mut self) {
assert_eq!(self.num_rows(), self.row_bits);
for pos in 0..self.row_bits {
let (mut rows0, mut rows1a) = self.split_at_mut(pos);
let (row, mut rows1b) = rows1a.split_at_mut(1);
for dst_row in rows0.iter_mut().chain(rows1b.iter_mut()) {
if dst_row[pos] {
for (dst, src) in dst_row.iter_blocks_mut().zip(row[0].iter_blocks()) {
*dst |= src;
}
}
}
}
}
pub fn reflexive_closure(&mut self) {
for i in 0..cmp::min(self.row_bits, self.num_rows()) {
self.set(i, i, true);
}
}
pub fn iter_mut(&mut self) -> Map<slice::ChunksMut<Block>, fn(&mut [Block]) -> &mut BitSlice> {
fn f(arg: &mut [Block]) -> &mut BitSlice {
unsafe { mem::transmute(arg) }
}
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
self.slice.chunks_mut(row_size).map(f)
}
}
impl<'a> Index<usize> for BitSubMatrixMut<'a> {
type Output = BitSlice;
#[inline]
fn index(&self, row: usize) -> &BitSlice {
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
unsafe { mem::transmute(&self.slice[row * row_size..(row + 1) * row_size]) }
}
}
impl<'a> IndexMut<usize> for BitSubMatrixMut<'a> {
#[inline]
fn index_mut(&mut self, row: usize) -> &mut BitSlice {
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
unsafe { mem::transmute(&mut self.slice[row * row_size..(row + 1) * row_size]) }
}
}
impl<'a> Index<usize> for BitSubMatrix<'a> {
type Output = BitSlice;
#[inline]
fn index(&self, row: usize) -> &BitSlice {
let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
unsafe { mem::transmute(&self.slice[row * row_size..(row + 1) * row_size]) }
}
}
impl<'a> fmt::Debug for BitSubMatrix<'a> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
for row in self.iter() {
for bit in row.iter_bits(self.row_bits) {
write!(fmt, "{}", if bit { 1 } else { 0 })?;
}
write!(fmt, "\n")?;
}
Ok(())
}
}