use crate::*;
use core::ops::{Add, BitAnd, Shr};
use packed_simd::{u16x8, u32x4, u64x2, u8x16};
pub const LEFT_MASKS: [u128; 7] = [
0xFFFF_FFFF_FFFF_FFFF_0000_0000_0000_0000,
0xFFFF_FFFF_0000_0000_FFFF_FFFF_0000_0000,
0xFFFF_0000_FFFF_0000_FFFF_0000_FFFF_0000,
0xFF00_FF00_FF00_FF00_FF00_FF00_FF00_FF00,
0xF0F0_F0F0_F0F0_F0F0_F0F0_F0F0_F0F0_F0F0,
0xCCCC_CCCC_CCCC_CCCC_CCCC_CCCC_CCCC_CCCC,
0xAAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA_AAAA,
];
pub const RIGHT_MASKS: [u128; 7] = [
0x0000_0000_0000_0000_FFFF_FFFF_FFFF_FFFF,
0x0000_0000_FFFF_FFFF_0000_0000_FFFF_FFFF,
0x0000_FFFF_0000_FFFF_0000_FFFF_0000_FFFF,
0x00FF_00FF_00FF_00FF_00FF_00FF_00FF_00FF,
0x0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F_0F0F,
0x3333_3333_3333_3333_3333_3333_3333_3333,
0x5555_5555_5555_5555_5555_5555_5555_5555,
];
pub const ONES2: u128 = 0x5555_5555_5555_5555_5555_5555_5555_5555;
pub const ONES4: u128 = 0x1111_1111_1111_1111_1111_1111_1111_1111;
pub const ONES8: u128 = 0x0101_0101_0101_0101_0101_0101_0101_0101;
pub const ONES16: u128 = 0x0001_0001_0001_0001_0001_0001_0001_0001;
pub const ONES32: u128 = 0x0000_0001_0000_0001_0000_0001_0000_0001;
pub const ONES64: u128 = 0x0000_0000_0000_0001_0000_0000_0000_0001;
pub const SIGNS2: u128 = ONES2 << 1;
pub const SIGNS4: u128 = ONES4 << 3;
pub const SIGNS8: u128 = ONES8 << 7;
pub const SIGNS16: u128 = ONES16 << 15;
pub const SIGNS32: u128 = ONES32 << 31;
pub const SIGNS64: u128 = ONES64 << 63;
pub const WEIGHT_MASK2: u128 = 0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF;
pub const WEIGHT_MASK4: u128 = (ONES4 << 3) - ONES4;
pub const WEIGHT_MASK8: u128 = (ONES8 << 4) - ONES8;
pub const WEIGHT_MASK16: u128 = (ONES16 << 5) - ONES16;
pub const WEIGHT_MASK32: u128 = (ONES32 << 6) - ONES32;
pub const WEIGHT_MASK64: u128 = (ONES64 << 7) - ONES64;
pub const WEIGHT_MSB2: u128 = ONES2 << 1;
pub const WEIGHT_MSB4: u128 = ONES4 << 2;
pub const WEIGHT_MSB8: u128 = ONES8 << 3;
pub const WEIGHT_MSB16: u128 = ONES16 << 4;
pub const WEIGHT_MSB32: u128 = ONES32 << 5;
pub const WEIGHT_MSB64: u128 = ONES64 << 6;
impl Bits1<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n1 = e | e << 1;
let n2 = n1 | n1 << 2;
let n3 = n2 | n2 << 4;
let n4 = n3 | n3 << 8;
let n5 = n4 | n4 << 16;
let n6 = n5 | n5 << 32;
let n7 = n6 | n6 << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
self
}
#[inline]
pub fn union(left: Bits2<u128>, right: Bits2<u128>) -> Self {
let Bits2(left) = left;
let Bits2(right) = right;
let left = (left & LEFT_MASKS[5]) >> 1 | left & RIGHT_MASKS[5];
let left = (left & LEFT_MASKS[4]) >> 2 | left & RIGHT_MASKS[4];
let left = (left & LEFT_MASKS[3]) >> 4 | left & RIGHT_MASKS[3];
let left = (left & LEFT_MASKS[2]) >> 8 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[1]) >> 16 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[0]) >> 32 | left & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[5]) >> 1 | right & RIGHT_MASKS[5];
let right = (right & LEFT_MASKS[4]) >> 2 | right & RIGHT_MASKS[4];
let right = (right & LEFT_MASKS[3]) >> 4 | right & RIGHT_MASKS[3];
let right = (right & LEFT_MASKS[2]) >> 8 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[1]) >> 16 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[0]) >> 32 | right & RIGHT_MASKS[0];
Self(left << 64 | right)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits2<u128> {
let Self(x) = self;
let lower = (x & LEFT_MASKS[6]) >> 1 | x & RIGHT_MASKS[6];
let upper = x & (x & RIGHT_MASKS[6]) << 1;
Bits2(lower | upper)
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.0.count_ones().into()
}
#[inline]
pub fn sum_weight2(self) -> Bits2<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
Self(self.0 ^ other.0)
}
#[inline]
pub fn maxhwd(self, other: Self) -> Self {
Self(self.0 ^ other.0)
}
#[inline]
pub fn split(self) -> (Bits2<u128>, Bits2<u128>) {
let Self(n) = self;
(Bits2((n & LEFT_MASKS[6]) >> 1), Bits2(n & RIGHT_MASKS[6]))
}
#[inline]
pub fn halve(self) -> (Bits2<u128>, Bits2<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let left = (left & LEFT_MASKS[1]) << 32 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[2]) << 16 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[3]) << 8 | left & RIGHT_MASKS[3];
let left = (left & LEFT_MASKS[4]) << 4 | left & RIGHT_MASKS[4];
let left = (left & LEFT_MASKS[5]) << 2 | left & RIGHT_MASKS[5];
let left = (left & LEFT_MASKS[6]) << 1 | left & RIGHT_MASKS[6];
let right = n & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) << 32 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[2]) << 16 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[3]) << 8 | right & RIGHT_MASKS[3];
let right = (right & LEFT_MASKS[4]) << 4 | right & RIGHT_MASKS[4];
let right = (right & LEFT_MASKS[5]) << 2 | right & RIGHT_MASKS[5];
let right = (right & LEFT_MASKS[6]) << 1 | right & RIGHT_MASKS[6];
(Bits2(left), Bits2(right))
}
}
impl Add for Bits1<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits1<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits1<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits2<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n2 = e | e << 2;
let n3 = n2 | n2 << 4;
let n4 = n3 | n3 << 8;
let n5 = n4 | n4 << 16;
let n6 = n5 | n5 << 32;
let n7 = n6 | n6 << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
let Self(x) = self;
Self((x & LEFT_MASKS[6]) >> 1 | x & RIGHT_MASKS[6])
}
#[inline]
pub fn union(left: Bits4<u128>, right: Bits4<u128>) -> Self {
let Bits4(left) = left;
let Bits4(right) = right;
let left = (left & LEFT_MASKS[4]) >> 2 | left & RIGHT_MASKS[4];
let left = (left & LEFT_MASKS[3]) >> 4 | left & RIGHT_MASKS[3];
let left = (left & LEFT_MASKS[2]) >> 8 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[1]) >> 16 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[0]) >> 32 | left & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[4]) >> 2 | right & RIGHT_MASKS[4];
let right = (right & LEFT_MASKS[3]) >> 4 | right & RIGHT_MASKS[3];
let right = (right & LEFT_MASKS[2]) >> 8 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[1]) >> 16 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[0]) >> 32 | right & RIGHT_MASKS[0];
Self(left << 64 | right)
}
#[inline]
pub fn squash(self) -> Bits1<u128> {
Bits1::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits4<u128> {
let Self(x) = self;
let x: u8x16 = unsafe { core::mem::transmute(x) };
let left = x & 0xF0;
let right = x & 0x0F;
let left_count = left.count_ones();
let right_count = right.count_ones();
let one = u8x16::from([1; 16]);
let left_out = (one << left_count) - 1;
let right_out = (one << right_count) - 1;
let x = left_out << 4 | right_out;
Bits4(unsafe { core::mem::transmute(x) })
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.sum_weight2()
.sum_weight2()
.sum_weight2()
.sum_weight2()
.sum_weight2()
.sum_weight2()
.0
}
#[inline]
pub fn sum_weight2(self) -> Bits4<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let low = RIGHT_MASKS[6] & (a ^ b);
let high = LEFT_MASKS[6] & (b & !a & !a << 1 | a & !b & !b << 1);
Self(low | high)
}
#[inline]
pub fn maxhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let low = RIGHT_MASKS[6] & (a ^ b);
let high = LEFT_MASKS[6] & (b & !a & !a << 1 | a & !b & !b << 1 | a << 1 & b << 1);
Self(low | high)
}
#[inline]
pub fn split(self) -> (Bits4<u128>, Bits4<u128>) {
let Self(n) = self;
(Bits4((n & LEFT_MASKS[5]) >> 2), Bits4(n & RIGHT_MASKS[5]))
}
#[inline]
pub fn halve(self) -> (Bits4<u128>, Bits4<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let left = (left & LEFT_MASKS[1]) << 32 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[2]) << 16 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[3]) << 8 | left & RIGHT_MASKS[3];
let left = (left & LEFT_MASKS[4]) << 4 | left & RIGHT_MASKS[4];
let left = (left & LEFT_MASKS[5]) << 2 | left & RIGHT_MASKS[5];
let right = n & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) << 32 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[2]) << 16 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[3]) << 8 | right & RIGHT_MASKS[3];
let right = (right & LEFT_MASKS[4]) << 4 | right & RIGHT_MASKS[4];
let right = (right & LEFT_MASKS[5]) << 2 | right & RIGHT_MASKS[5];
(Bits4(left), Bits4(right))
}
}
impl Add for Bits2<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits2<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits2<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits4<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n3 = e | e << 4;
let n4 = n3 | n3 << 8;
let n5 = n4 | n4 << 16;
let n6 = n5 | n5 << 32;
let n7 = n6 | n6 << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
let Bits2(x) = Bits2(self.0).any();
Self((x & LEFT_MASKS[5]) >> 2 | x & RIGHT_MASKS[5])
}
#[inline]
pub fn union(left: Bits8<u128>, right: Bits8<u128>) -> Self {
let Bits8(left) = left;
let Bits8(right) = right;
let left = (left & LEFT_MASKS[3]) >> 4 | left & RIGHT_MASKS[3];
let left = (left & LEFT_MASKS[2]) >> 8 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[1]) >> 16 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[0]) >> 32 | left & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[3]) >> 4 | right & RIGHT_MASKS[3];
let right = (right & LEFT_MASKS[2]) >> 8 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[1]) >> 16 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[0]) >> 32 | right & RIGHT_MASKS[0];
Self(left << 64 | right)
}
#[inline]
pub fn squash(self) -> Bits2<u128> {
Bits2::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits8<u128> {
let Self(x) = self;
let x: u8x16 = unsafe { core::mem::transmute(x) };
let counted = x.count_ones();
let x = (u8x16::from([1; 16]) << (counted % 8) ^ counted >> 3) - 1;
Bits8(unsafe { core::mem::transmute(x) })
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.sum_weight2()
.sum_weight2()
.sum_weight2()
.sum_weight2()
.sum_weight2()
.0
}
#[inline]
pub fn sum_weight2(self) -> Bits8<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let m = a + (b ^ WEIGHT_MASK4);
let high = m & WEIGHT_MSB4;
let offset = (high ^ WEIGHT_MSB4) >> 2;
let flips = high | high >> 1 | high >> 2;
Self(((m ^ flips) + offset) & WEIGHT_MASK4)
}
#[inline]
pub fn split(self) -> (Bits8<u128>, Bits8<u128>) {
let Self(n) = self;
(Bits8((n & LEFT_MASKS[4]) >> 4), Bits8(n & RIGHT_MASKS[4]))
}
#[inline]
pub fn halve(self) -> (Bits8<u128>, Bits8<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let left = (left & LEFT_MASKS[1]) << 32 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[2]) << 16 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[3]) << 8 | left & RIGHT_MASKS[3];
let left = (left & LEFT_MASKS[4]) << 4 | left & RIGHT_MASKS[4];
let right = n & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) << 32 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[2]) << 16 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[3]) << 8 | right & RIGHT_MASKS[3];
let right = (right & LEFT_MASKS[4]) << 4 | right & RIGHT_MASKS[4];
(Bits8(left), Bits8(right))
}
}
impl Add for Bits4<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits4<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits4<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits8<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n4 = e | e << 8;
let n5 = n4 | n4 << 16;
let n6 = n5 | n5 << 32;
let n7 = n6 | n6 << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
let Bits4(x) = Bits4(self.0).any();
Self((x & LEFT_MASKS[4]) >> 4 | x & RIGHT_MASKS[4])
}
#[inline]
pub fn union(left: Bits16<u128>, right: Bits16<u128>) -> Self {
let Bits16(left) = left;
let Bits16(right) = right;
let left = (left & LEFT_MASKS[2]) >> 8 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[1]) >> 16 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[0]) >> 32 | left & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[2]) >> 8 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[1]) >> 16 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[0]) >> 32 | right & RIGHT_MASKS[0];
Self(left << 64 | right)
}
#[inline]
pub fn squash(self) -> Bits4<u128> {
Bits4::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits16<u128> {
let Self(x) = self;
let x: u16x8 = unsafe { core::mem::transmute(x) };
let counted = x.count_ones();
let x = (u16x8::from([1; 8]) << (counted % 16) ^ counted >> 4) - 1;
Bits16(unsafe { core::mem::transmute(x) })
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.sum_weight2()
.sum_weight2()
.sum_weight2()
.sum_weight2()
.0
}
#[inline]
pub fn sum_weight2(self) -> Bits16<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let m = a + (b ^ WEIGHT_MASK8);
let high = m & WEIGHT_MSB8;
let offset = (high ^ WEIGHT_MSB8) >> 3;
let flips = high | high >> 1;
let flips = flips | flips >> 2;
Self(((m ^ flips) + offset) & WEIGHT_MASK8)
}
#[inline]
pub fn split(self) -> (Bits16<u128>, Bits16<u128>) {
let Self(n) = self;
(Bits16((n & LEFT_MASKS[3]) >> 8), Bits16(n & RIGHT_MASKS[3]))
}
#[inline]
pub fn halve(self) -> (Bits16<u128>, Bits16<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let left = (left & LEFT_MASKS[1]) << 32 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[2]) << 16 | left & RIGHT_MASKS[2];
let left = (left & LEFT_MASKS[3]) << 8 | left & RIGHT_MASKS[3];
let right = n & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) << 32 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[2]) << 16 | right & RIGHT_MASKS[2];
let right = (right & LEFT_MASKS[3]) << 8 | right & RIGHT_MASKS[3];
(Bits16(left), Bits16(right))
}
}
impl Add for Bits8<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits8<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits8<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits16<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n5 = e | e << 16;
let n6 = n5 | n5 << 32;
let n7 = n6 | n6 << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
let Bits8(x) = Bits8(self.0).any();
Self((x & LEFT_MASKS[3]) >> 8 | x & RIGHT_MASKS[3])
}
#[inline]
pub fn union(left: Bits32<u128>, right: Bits32<u128>) -> Self {
let Bits32(left) = left;
let Bits32(right) = right;
let left = (left & LEFT_MASKS[1]) >> 16 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[0]) >> 32 | left & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) >> 16 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[0]) >> 32 | right & RIGHT_MASKS[0];
Self(left << 64 | right)
}
#[inline]
pub fn squash(self) -> Bits8<u128> {
Bits8::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits32<u128> {
let Self(x) = self;
let x: u32x4 = unsafe { core::mem::transmute(x) };
let counted = x.count_ones();
let x = (u32x4::from([1; 4]) << (counted % 32) ^ counted >> 5) - 1;
Bits32(unsafe { core::mem::transmute(x) })
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.sum_weight2().sum_weight2().sum_weight2().0
}
#[inline]
pub fn sum_weight2(self) -> Bits32<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let m = a + (b ^ WEIGHT_MASK16);
let high = m & WEIGHT_MSB16;
let offset = (high ^ WEIGHT_MSB16) >> 4;
let flips = high | high >> 1;
let flips = flips | flips >> 2 | high >> 4;
Self(((m ^ flips) + offset) & WEIGHT_MASK16)
}
#[inline]
pub fn split(self) -> (Bits32<u128>, Bits32<u128>) {
let Self(n) = self;
(
Bits32((n & LEFT_MASKS[2]) >> 16),
Bits32(n & RIGHT_MASKS[2]),
)
}
#[inline]
pub fn halve(self) -> (Bits32<u128>, Bits32<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let left = (left & LEFT_MASKS[1]) << 32 | left & RIGHT_MASKS[1];
let left = (left & LEFT_MASKS[2]) << 16 | left & RIGHT_MASKS[2];
let right = n & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) << 32 | right & RIGHT_MASKS[1];
let right = (right & LEFT_MASKS[2]) << 16 | right & RIGHT_MASKS[2];
(Bits32(left), Bits32(right))
}
}
impl Add for Bits16<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits16<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits16<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits32<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n6 = e | e << 32;
let n7 = n6 | n6 << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
let Bits16(x) = Bits16(self.0).any();
Self((x & LEFT_MASKS[2]) >> 16 | x & RIGHT_MASKS[2])
}
#[inline]
pub fn union(left: Bits64<u128>, right: Bits64<u128>) -> Self {
let Bits64(left) = left;
let Bits64(right) = right;
let left = (left & LEFT_MASKS[0]) >> 32 | left & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[0]) >> 32 | right & RIGHT_MASKS[0];
Self(left << 64 | right)
}
#[inline]
pub fn squash(self) -> Bits16<u128> {
Bits16::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits64<u128> {
let Self(x) = self;
let x: u64x2 = unsafe { core::mem::transmute(x) };
let counted = x.count_ones();
let x = (u64x2::from([1; 2]) << (counted % 64) ^ counted >> 6) - 1;
Bits64(unsafe { core::mem::transmute(x) })
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.sum_weight2().sum_weight2().0
}
#[inline]
pub fn sum_weight2(self) -> Bits64<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let m = a + (b ^ WEIGHT_MASK32);
let high = m & WEIGHT_MSB32;
let offset = (high ^ WEIGHT_MSB32) >> 5;
let flips = high | high >> 1;
let flips = flips | flips >> 2;
let flips = flips | flips >> 2;
Self(((m ^ flips) + offset) & WEIGHT_MASK32)
}
#[inline]
pub fn split(self) -> (Bits64<u128>, Bits64<u128>) {
let Self(n) = self;
(
Bits64((n & LEFT_MASKS[1]) >> 32),
Bits64(n & RIGHT_MASKS[1]),
)
}
#[inline]
pub fn halve(self) -> (Bits64<u128>, Bits64<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let left = (left & LEFT_MASKS[1]) << 32 | left & RIGHT_MASKS[1];
let right = n & RIGHT_MASKS[0];
let right = (right & LEFT_MASKS[1]) << 32 | right & RIGHT_MASKS[1];
(Bits64(left), Bits64(right))
}
}
impl Add for Bits32<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits32<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits32<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits64<u128> {
#[inline]
pub fn from_element(e: u128) -> Self {
let n7 = e | e << 64;
Self(n7)
}
#[inline]
pub fn any(self) -> Self {
let Bits32(x) = Bits32(self.0).any();
Self((x & LEFT_MASKS[1]) >> 32 | x & RIGHT_MASKS[1])
}
#[inline]
pub fn union(left: Bits128<u128>, right: Bits128<u128>) -> Self {
let Bits128(left) = left;
let Bits128(right) = right;
Self(left << 64 | right)
}
#[inline]
pub fn squash(self) -> Bits32<u128> {
Bits32::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn pack_ones(self) -> Bits128<u128> {
let Self(x) = self;
Bits128((1 << x.count_ones()) - 1)
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.sum_weight2().0
}
#[inline]
pub fn sum_weight2(self) -> Bits128<u128> {
let (left, right) = self.split();
left + right
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
let Self(a) = self;
let Self(b) = other;
let a_low = a as i32;
let a_high = (a >> 64) as i32;
let b_low = b as i32;
let b_high = (b >> 64) as i32;
Self(((a_high - b_high).abs() as u128) << 64 | (a_low - b_low).abs() as u128)
}
#[inline]
pub fn split(self) -> (Bits128<u128>, Bits128<u128>) {
let Self(n) = self;
(
Bits128((n & LEFT_MASKS[0]) >> 64),
Bits128(n & RIGHT_MASKS[0]),
)
}
#[inline]
pub fn halve(self) -> (Bits128<u128>, Bits128<u128>) {
let Self(n) = self;
let left = (n & LEFT_MASKS[0]) >> 64;
let right = n & RIGHT_MASKS[0];
(Bits128(left), Bits128(right))
}
}
impl Add for Bits64<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl BitAnd<u128> for Bits64<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits64<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Bits128<u128> {
#[inline]
pub fn any(self) -> Self {
let Bits64(x) = Bits64(self.0).any();
Self((x & LEFT_MASKS[0]) >> 64 | x & RIGHT_MASKS[0])
}
#[inline]
pub fn squash(self) -> Bits64<u128> {
Bits64::union(Self(0), self)
}
#[inline]
pub fn count_ones(self) -> u32 {
self.0.count_ones()
}
#[inline]
pub fn sum_weight(self) -> u128 {
self.0
}
#[inline]
pub fn minhwd(self, other: Self) -> Self {
Self((self.0 as i32 - other.0 as i32).abs() as u128)
}
}
impl From<Bits128<u128>> for u128 {
#[inline]
fn from(n: Bits128<u128>) -> u128 {
n.0
}
}
impl BitAnd<u128> for Bits128<u128> {
type Output = Self;
#[inline]
fn bitand(self, rhs: u128) -> Self {
Self(self.0 & rhs)
}
}
impl Shr<u32> for Bits128<u128> {
type Output = Self;
#[inline]
fn shr(self, rhs: u32) -> Self {
Self(self.0 >> rhs)
}
}
impl Add for Bits128<u128> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}