use crate::Error;
use arrayvec::ArrayVec;
use hashx::HashX;
use std::{cmp, mem};
pub(crate) const EQUIHASH_N: usize = 60;
pub(crate) const EQUIHASH_K: usize = 3;
pub type SolutionItem = u16;
pub(crate) type HashValue = u64;
#[inline(always)]
pub(crate) fn item_hash(func: &HashX, item: SolutionItem) -> HashValue {
func.hash_to_u64(item.into())
}
pub type SolutionArray = ArrayVec<Solution, 8>;
pub type SolutionItemArray = [SolutionItem; Solution::NUM_ITEMS];
pub type SolutionByteArray = [u8; Solution::NUM_BYTES];
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Solution {
items: SolutionItemArray,
}
impl Solution {
pub const NUM_ITEMS: usize = 1 << EQUIHASH_K;
pub const NUM_BYTES: usize = Self::NUM_ITEMS * mem::size_of::<SolutionItem>();
const ITEM_SIZE: usize = mem::size_of::<SolutionItem>();
pub fn try_from_array(items: &SolutionItemArray) -> Result<Self, Error> {
if check_tree_order(items) {
Ok(Self { items: *items })
} else {
Err(Error::Order)
}
}
pub(crate) fn sort_from_array(mut items: SolutionItemArray) -> Self {
sort_into_tree_order(&mut items);
Self { items }
}
pub fn try_from_bytes(bytes: &SolutionByteArray) -> Result<Self, Error> {
let mut array: SolutionItemArray = Default::default();
for i in 0..Self::NUM_ITEMS {
array[i] = SolutionItem::from_le_bytes(
bytes[i * Self::ITEM_SIZE..(i + 1) * Self::ITEM_SIZE]
.try_into()
.expect("slice length matches"),
);
}
Self::try_from_array(&array)
}
pub fn to_bytes(&self) -> SolutionByteArray {
let mut result: SolutionByteArray = Default::default();
for i in 0..Self::NUM_ITEMS {
result[i * Self::ITEM_SIZE..(i + 1) * Self::ITEM_SIZE]
.copy_from_slice(&self.items[i].to_le_bytes());
}
result
}
}
impl AsRef<SolutionItemArray> for Solution {
fn as_ref(&self) -> &SolutionItemArray {
&self.items
}
}
impl From<Solution> for SolutionItemArray {
fn from(solution: Solution) -> SolutionItemArray {
solution.items
}
}
#[inline(always)]
fn branches_are_sorted(left: &[SolutionItem], right: &[SolutionItem]) -> bool {
matches!(
left.iter().rev().cmp(right.iter().rev()),
cmp::Ordering::Less | cmp::Ordering::Equal
)
}
#[inline(always)]
fn check_tree_order(items: &[SolutionItem]) -> bool {
let (left, right) = items.split_at(items.len() / 2);
let sorted = branches_are_sorted(left, right);
if items.len() == 2 {
sorted
} else {
sorted && check_tree_order(left) && check_tree_order(right)
}
}
#[inline(always)]
fn sort_into_tree_order(items: &mut [SolutionItem]) {
let len = items.len();
let (left, right) = items.split_at_mut(items.len() / 2);
if len > 2 {
sort_into_tree_order(left);
sort_into_tree_order(right);
}
if !branches_are_sorted(left, right) {
left.swap_with_slice(right);
}
}
#[inline(always)]
fn check_tree_sums(func: &HashX, items: &[SolutionItem], n_bits: usize) -> Result<HashValue, ()> {
let sum = if items.len() == 2 {
item_hash(func, items[0]).wrapping_add(item_hash(func, items[1]))
} else {
let (left, right) = items.split_at(items.len() / 2);
let left = check_tree_sums(func, left, n_bits / 2)?;
let right = check_tree_sums(func, right, n_bits / 2)?;
left.wrapping_add(right)
};
let mask = ((1 as HashValue) << n_bits) - 1;
if (sum & mask) == 0 { Ok(sum) } else { Err(()) }
}
pub(crate) fn check_all_tree_sums(func: &HashX, solution: &Solution) -> Result<(), Error> {
match check_tree_sums(func, solution.as_ref(), EQUIHASH_N) {
Ok(_unused_bits) => Ok(()),
Err(()) => Err(Error::HashSum),
}
}