use std::{
cmp::Ordering,
collections::{HashSet, VecDeque},
fmt::{Debug, Display, Formatter},
ops::{Index, Mul, MulAssign, Not},
rc::Rc,
sync::Arc,
};
use itertools::Itertools;
use crate::{Bell, IncompatibleStages, Parity, RowBuf, SameStageVec, Stage};
use super::{MulIntoError, RowAccumulator};
#[allow(unused_imports)]
use crate::Block;
pub type BellIter<'a> = std::iter::Cloned<std::slice::Iter<'a, Bell>>;
#[derive(Eq, PartialEq, PartialOrd, Ord, Hash)]
#[repr(transparent)] pub struct Row {
bell_slice: [Bell],
}
impl Row {
#[inline]
pub fn stage(&self) -> Stage {
Stage::new(self.bell_slice.len() as u8)
}
pub fn bells(&self) -> &[Bell] {
&self.bell_slice
}
#[inline]
pub fn bell_iter(&self) -> BellIter {
self.bell_slice.iter().cloned()
}
#[inline]
pub fn slice(&self) -> &[Bell] {
&self.bell_slice
}
#[inline]
pub fn place_of(&self, bell: Bell) -> Option<usize> {
self.bell_iter().position(|b| b == bell)
}
#[inline]
pub unsafe fn get_bell_unchecked(&self, place: usize) -> Bell {
*self.bell_slice.get_unchecked(place)
}
pub fn parity(&mut self) -> Parity {
let mut num_swaps = 0;
let mut first_non_rounds_bell = 0;
while first_non_rounds_bell < self.stage().num_bells() {
let cur_bell = self.bell_slice[first_non_rounds_bell];
if cur_bell == Bell::from_index(first_non_rounds_bell as u8) {
first_non_rounds_bell += 1;
} else {
self.swap(first_non_rounds_bell, cur_bell.index());
num_swaps += 1;
}
}
Parity::from_number(num_swaps)
}
pub fn is_rounds(&self) -> bool {
self.bell_iter().enumerate().all(|(i, b)| b.index() == i)
}
pub fn is_backrounds(&self) -> bool {
self.bell_iter()
.rev() .enumerate()
.all(|(i, b)| b.index() == i)
}
pub fn effective_stage(&self) -> Stage {
for (i, b) in self.bell_slice.iter().enumerate().rev() {
if b.index() != i {
return Stage::new(i as u8 + 1);
}
}
Stage::ONE
}
#[inline]
pub fn swap(&mut self, a: usize, b: usize) {
self.bell_slice.swap(a, b);
}
#[track_caller]
pub fn copy_from(&mut self, other: &Row) {
self.check_stage(other);
self.bell_slice.copy_from_slice(&other.bell_slice);
}
pub fn try_mul(&self, rhs: &Self) -> Result<RowBuf, IncompatibleStages> {
IncompatibleStages::test_err(self.stage(), rhs.stage())?;
Ok(unsafe { self.mul_unchecked(rhs) })
}
pub unsafe fn mul_unchecked(&self, rhs: &Row) -> RowBuf {
RowBuf::from_bell_iter_unchecked(rhs.bell_iter().map(|b| self.bell_slice[b.index()]))
}
#[track_caller]
pub fn mul_into(&self, rhs: &Row, out: &mut RowBuf) {
self.check_stage(rhs);
unsafe { self.mul_into_unchecked(rhs, out) }
}
pub unsafe fn mul_into_unchecked(&self, rhs: &Row, out: &mut RowBuf) {
out.bell_vec.clear();
out.bell_vec
.extend(rhs.bell_iter().map(|b| self.bell_slice[b.index()]));
}
pub fn mul_into_row(&self, rhs: &Row, out: &mut Row) -> Result<(), MulIntoError> {
IncompatibleStages::test_err(self.stage(), rhs.stage()).map_err(MulIntoError::RhsStage)?;
IncompatibleStages::test_err(self.stage(), out.stage()).map_err(MulIntoError::IntoStage)?;
unsafe { self.mul_into_row_unchecked(rhs, out) }
Ok(())
}
pub unsafe fn mul_into_row_unchecked(&self, rhs: &Row, out: &mut Row) {
for (out_bell, rhs_bell) in out.bell_slice.iter_mut().zip_eq(rhs.bell_slice.iter()) {
*out_bell = self.bell_slice[rhs_bell.index()];
}
}
pub fn inv(&self) -> RowBuf {
let mut inv_bells = vec![Bell::TREBLE; self.stage().num_bells()];
for (i, b) in self.bell_slice.iter().enumerate() {
inv_bells[b.index()] = Bell::from_index(i as u8);
}
unsafe { RowBuf::from_vec_unchecked(inv_bells) }
}
#[track_caller]
pub fn inv_into(&self, out: &mut Row) {
self.check_stage(out);
for (i, b) in self.bell_iter().enumerate() {
out.bell_slice[b.index()] = Bell::from_index(i as u8);
}
}
pub fn inv_into_buf(&self, out: &mut RowBuf) {
match out.stage().cmp(&self.stage()) {
Ordering::Less => {
out.bell_vec.extend(
std::iter::repeat(Bell::TREBLE)
.take(self.bell_slice.len() - out.bell_vec.len()),
);
}
Ordering::Greater => {
out.bell_vec.drain(self.bell_slice.len()..);
}
Ordering::Equal => {}
}
debug_assert_eq!(out.stage(), self.stage());
for (i, b) in self.bell_iter().enumerate() {
out.bell_vec[b.index()] = Bell::from_index(i as u8);
}
}
pub fn pow_i(&self, exponent: isize) -> RowBuf {
if exponent < 0 {
self.inv().pow_u((-exponent) as usize)
} else {
self.pow_u(exponent as usize)
}
}
pub fn pow_u(&self, exponent: usize) -> RowBuf {
let mut accumulator = RowAccumulator::rounds(self.stage());
for _ in 0..exponent {
accumulator.post_accumulate(self);
}
accumulator.into_total()
}
#[inline]
pub fn solve_ax_equals_b(a: &Self, b: &Self) -> RowBuf {
!a * b
}
#[inline]
pub fn solve_xa_equals_b(a: &Self, b: &Self) -> RowBuf {
b * !a
}
pub fn is_fixed(&self, bell: Bell) -> bool {
self.bell_slice.get(bell.index()) == Some(&bell)
}
pub fn fixed_bells(&self) -> impl Iterator<Item = Bell> + '_ {
self.bell_slice
.iter()
.enumerate()
.filter(|&(idx, bell)| idx == bell.index())
.map(|(_idx, bell)| *bell)
}
#[inline]
pub fn copy_into(&self, other: &mut RowBuf) {
other.bell_vec.clear();
other.bell_vec.extend_from_slice(&self.bell_slice);
}
#[inline]
pub fn to_rc(&self) -> Rc<Row> {
let rc_of_bells: Rc<[Bell]> = self.bell_slice.to_vec().into();
let ptr_to_bells = Rc::into_raw(rc_of_bells);
let ptr_to_row = ptr_to_bells as *const Row;
unsafe { Rc::from_raw(ptr_to_row) }
}
#[inline]
pub fn to_arc(&self) -> Arc<Row> {
let arc_of_bells: Arc<[Bell]> = self.bell_slice.to_vec().into();
let ptr_to_bells = Arc::into_raw(arc_of_bells);
let ptr_to_row = ptr_to_bells as *const Row;
unsafe { Arc::from_raw(ptr_to_row) }
}
#[inline]
pub fn order(&self) -> usize {
let mut accum = RowAccumulator::new(self.to_owned());
let mut count = 1;
while !accum.total().is_rounds() {
unsafe { accum.post_accumulate_unchecked(self) };
count += 1;
}
count
}
pub fn closure(&self) -> Vec<RowBuf> {
let mut closure = vec![self.to_owned()];
loop {
let last_row = closure.last().unwrap();
if last_row.is_rounds() {
return closure;
}
let next_row = unsafe { last_row.mul_unchecked(self) };
closure.push(next_row);
}
}
pub fn closure_from_rounds(&self) -> Vec<RowBuf> {
let mut closure = vec![RowBuf::rounds(self.stage())];
loop {
let last_row = closure.last().unwrap();
let next_row = unsafe { last_row.mul_unchecked(self) };
if next_row.is_rounds() {
return closure;
}
closure.push(next_row);
}
}
pub fn multi_cartesian_product(
row_sets: impl IntoIterator<Item = impl IntoIterator<Item = impl AsRef<Self>>>,
stage: Stage,
) -> SameStageVec {
let mut set_iter = row_sets.into_iter();
let mut transpose_from = SameStageVec::new(stage);
let mut transpose_to = SameStageVec::new(stage);
match set_iter.next() {
None => return SameStageVec::new(stage),
Some(set) => {
for r in set.into_iter() {
let r = r.as_ref();
transpose_from.push(r);
}
}
}
for set in set_iter {
if transpose_from.is_empty() {
return SameStageVec::new(stage);
}
transpose_to.clear();
for r2 in set {
let r2 = r2.as_ref();
for r1 in &transpose_from {
transpose_to.push(&(r1 * r2));
}
}
std::mem::swap(&mut transpose_to, &mut transpose_from);
}
transpose_from
}
pub fn least_group_containing<'a>(
rows: impl IntoIterator<Item = &'a Self> + Clone,
) -> HashSet<RowBuf>
where
Self: 'a,
{
for (r1, r2) in rows.clone().into_iter().tuple_windows() {
assert_eq!(r1.stage(), r2.stage());
}
let mut set = HashSet::<RowBuf>::new();
let mut frontier = VecDeque::<RowBuf>::new();
for r in rows.clone().into_iter() {
if set.insert(r.to_owned()) {
frontier.push_back(r.to_owned());
}
}
while let Some(r) = frontier.pop_front() {
for r2 in rows.clone().into_iter() {
let new_row = unsafe { r.mul_unchecked(r2) };
if !set.contains(&new_row) {
frontier.push_back(new_row.clone());
set.insert(new_row);
}
}
}
set
}
pub fn is_group<'a>(rows: impl IntoIterator<Item = &'a Row>) -> bool
where
Self: 'a,
{
let row_set: HashSet<&Row> = rows.into_iter().collect();
for (r1, r2) in row_set.iter().tuple_windows() {
assert_eq!(r1.stage(), r2.stage());
}
if row_set.is_empty() {
return false;
}
let mut b_inv = RowBuf::rounds(Stage::ONE);
let mut a_mul_b_inv = RowBuf::rounds(Stage::ONE);
for &b in &row_set {
b.inv_into_buf(&mut b_inv);
for &a in &row_set {
unsafe { a.mul_into_unchecked(&b_inv, &mut a_mul_b_inv) }
if !row_set.contains(&*a_mul_b_inv) {
return false;
}
}
}
true
}
pub fn fast_hash(&self) -> usize {
let mut accum = 0;
let mut multiplier = 1;
for b in self.bell_iter() {
accum += b.index() * multiplier;
multiplier *= self.stage().num_bells();
}
accum
}
#[inline]
pub unsafe fn from_slice_unchecked(slice: &[Bell]) -> &Row {
&*(slice as *const [Bell] as *const Row)
}
#[inline]
pub unsafe fn from_mut_slice_unchecked(slice: &mut [Bell]) -> &mut Row {
&mut *(slice as *mut [Bell] as *mut Row)
}
#[track_caller]
fn check_stage(&self, row: &Row) {
assert_eq!(
self.stage(),
row.stage(),
"Stage mismatch: LHS has stage {:?} but RHS has stage {:?}",
self.stage(),
row.stage(),
);
}
}
impl Index<usize> for Row {
type Output = Bell;
fn index(&self, index: usize) -> &Bell {
&self.bell_slice[index]
}
}
impl Not for &RowBuf {
type Output = RowBuf;
fn not(self) -> Self::Output {
self.inv()
}
}
impl Not for &Row {
type Output = RowBuf;
fn not(self) -> Self::Output {
self.inv()
}
}
impl Not for RowBuf {
type Output = RowBuf;
fn not(self) -> Self::Output {
self.inv()
}
}
impl Mul for &Row {
type Output = RowBuf;
#[inline]
fn mul(self, rhs: &Row) -> Self::Output {
self.try_mul(rhs).unwrap()
}
}
macro_rules! mul_impl {
($lhs: ty, $rhs: ty) => {
impl Mul<$rhs> for $lhs {
type Output = RowBuf;
#[inline]
fn mul(self, rhs: $rhs) -> Self::Output {
self.try_mul(&rhs).unwrap()
}
}
};
}
mul_impl!(RowBuf, RowBuf);
mul_impl!(RowBuf, &RowBuf);
mul_impl!(RowBuf, &Row);
mul_impl!(&RowBuf, RowBuf);
mul_impl!(&RowBuf, &RowBuf);
mul_impl!(&RowBuf, &Row);
mul_impl!(&Row, RowBuf);
mul_impl!(&Row, &RowBuf);
impl MulAssign<&Row> for RowBuf {
fn mul_assign(&mut self, rhs: &Row) {
*self = &*self * rhs;
}
}
impl MulAssign<&RowBuf> for RowBuf {
fn mul_assign(&mut self, rhs: &RowBuf) {
*self *= rhs.as_row();
}
}
impl MulAssign<RowBuf> for RowBuf {
fn mul_assign(&mut self, rhs: RowBuf) {
*self *= rhs.as_row();
}
}
impl<'row> IntoIterator for &'row Row {
type Item = Bell;
type IntoIter = BellIter<'row>;
fn into_iter(self) -> Self::IntoIter {
self.bell_iter()
}
}
impl Debug for Row {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Row({})", self)
}
}
impl Display for Row {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
for b in self.bell_iter() {
write!(f, "{}", b)?;
}
Ok(())
}
}
pub struct ShortRow<'r>(pub &'r Row);
impl Display for ShortRow<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let length = self.0.effective_stage().num_bells();
for b in self.0.bell_iter().take(length) {
write!(f, "{}", b)?;
}
Ok(())
}
}
#[derive(Clone)]
pub struct DbgRow<'r>(pub &'r Row);
impl Debug for DbgRow<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[cfg(test)]
mod tests {
use std::ops::Deref;
use crate::{Row, RowBuf};
#[test]
fn is_group() {
#[track_caller]
fn check(rows: &[&str]) {
let rows: Vec<RowBuf> = rows.iter().map(|s| RowBuf::parse(s).unwrap()).collect();
println!("Is {:?} a group?", rows);
assert!(Row::is_group(rows.iter().map(|r| r.deref())));
}
check(&["1234", "1342", "1423"]);
check(&["1"]);
check(&["1234", "1324"]);
check(&["1234", "1234", "1234", "1324"]);
check(&["1234", "4123", "3412", "2341"]);
check(&["123456", "134256", "142356", "132456", "124356", "143256"]);
#[rustfmt::skip]
check(&[
"123456", "134562", "145623", "156234", "162345",
"165432", "126543", "132654", "143265", "154326",
]);
check(&["123456", "234561", "345612", "456123", "561234", "612345"]);
#[rustfmt::skip]
check(&[
"123456", "234561", "345612", "456123", "561234", "612345",
"654321", "165432", "216543", "321654", "432165", "543216",
]);
}
#[test]
fn is_non_group() {
#[track_caller]
fn check(groups: &[&str]) {
let rows: Vec<RowBuf> = groups.iter().map(|s| RowBuf::parse(s).unwrap()).collect();
println!("Is {:?} not a group?", groups);
assert!(!Row::is_group(rows.iter().map(|r| r.deref())));
}
check(&["21"]);
check(&["123456", "134256", "142356", "132456", "124356"]); check(&[]); check(&[
"123456", "134256", "142356", "132456", "124356", "143256", "213456",
]);
}
#[test]
fn order() {
#[track_caller]
fn check(row: &str, exp_order: usize) {
assert_eq!(RowBuf::parse(row).unwrap().order(), exp_order);
}
check("1", 1);
check("1234", 1);
check("123456789", 1);
check("21", 2);
check("2134", 2);
check("2143", 2);
check("23145", 3);
check("23451", 5);
check("23154", 6);
check("231564", 3);
check("1452367890", 2);
}
}