#ifndef _MINISKETCH_INT_UTILS_H_
#define _MINISKETCH_INT_UTILS_H_
#include <stdlib.h>
#include <limits>
#include <algorithm>
#include <type_traits>
template<int bits>
static constexpr inline uint64_t Rot(uint64_t x) { return (x << bits) | (x >> (64 - bits)); }
static inline void SipHashRound(uint64_t& v0, uint64_t& v1, uint64_t& v2, uint64_t& v3) {
v0 += v1; v1 = Rot<13>(v1); v1 ^= v0;
v0 = Rot<32>(v0);
v2 += v3; v3 = Rot<16>(v3); v3 ^= v2;
v0 += v3; v3 = Rot<21>(v3); v3 ^= v0;
v2 += v1; v1 = Rot<17>(v1); v1 ^= v2;
v2 = Rot<32>(v2);
}
inline uint64_t SipHash(uint64_t k0, uint64_t k1, uint64_t data) {
uint64_t v0 = 0x736f6d6570736575ULL ^ k0;
uint64_t v1 = 0x646f72616e646f6dULL ^ k1;
uint64_t v2 = 0x6c7967656e657261ULL ^ k0;
uint64_t v3 = 0x7465646279746573ULL ^ k1 ^ data;
SipHashRound(v0, v1, v2, v3);
SipHashRound(v0, v1, v2, v3);
v0 ^= data;
v3 ^= 0x800000000000000ULL;
SipHashRound(v0, v1, v2, v3);
SipHashRound(v0, v1, v2, v3);
v0 ^= 0x800000000000000ULL;
v2 ^= 0xFF;
SipHashRound(v0, v1, v2, v3);
SipHashRound(v0, v1, v2, v3);
SipHashRound(v0, v1, v2, v3);
SipHashRound(v0, v1, v2, v3);
return v0 ^ v1 ^ v2 ^ v3;
}
class BitWriter {
unsigned char state = 0;
int offset = 0;
unsigned char* out;
public:
BitWriter(unsigned char* output) : out(output) {}
template<int BITS, typename I>
inline void Write(I val) {
int bits = BITS;
if (bits + offset >= 8) {
state |= ((val & ((I(1) << (8 - offset)) - 1)) << offset);
*(out++) = state;
val >>= (8 - offset);
bits -= 8 - offset;
offset = 0;
state = 0;
}
while (bits >= 8) {
*(out++) = val & 255;
val >>= 8;
bits -= 8;
}
state |= ((val & ((I(1) << bits) - 1)) << offset);
offset += bits;
}
inline void Flush() {
if (offset) {
*(out++) = state;
state = 0;
offset = 0;
}
}
};
class BitReader {
unsigned char state = 0;
int offset = 0;
const unsigned char* in;
public:
BitReader(const unsigned char* input) : in(input) {}
template<int BITS, typename I>
inline I Read() {
int bits = BITS;
if (offset >= bits) {
I ret = state & ((1 << bits) - 1);
state >>= bits;
offset -= bits;
return ret;
}
I val = state;
int out = offset;
while (out + 8 <= bits) {
val |= ((I(*(in++))) << out);
out += 8;
}
if (out < bits) {
unsigned char c = *(in++);
val |= (c & ((I(1) << (bits - out)) - 1)) << out;
state = c >> (bits - out);
offset = 8 - (bits - out);
} else {
state = 0;
offset = 0;
}
return val;
}
};
template<int BITS, typename I>
constexpr inline I Mask() { return ((I((I(-1)) << (std::numeric_limits<I>::digits - BITS))) >> (std::numeric_limits<I>::digits - BITS)); }
template<typename I, int BITS>
class BitsInt {
private:
static_assert(std::is_integral<I>::value && std::is_unsigned<I>::value, "BitsInt requires an unsigned integer type");
static_assert(BITS > 0 && BITS <= std::numeric_limits<I>::digits, "BitsInt requires 1 <= Bits <= representation type size");
static constexpr I MASK = Mask<BITS, I>();
public:
typedef I Repr;
static constexpr int SIZE = BITS;
static void inline Swap(I& a, I& b) {
std::swap(a, b);
}
static constexpr inline bool IsZero(I a) { return a == 0; }
static constexpr inline I Mask(I val) { return val & MASK; }
static constexpr inline I Shift(I val, int bits) { return ((val << bits) & MASK); }
static constexpr inline I UnsafeShift(I val, int bits) { return (val << bits); }
template<int Offset, int Count>
static constexpr inline int MidBits(I val) {
static_assert(Count > 0, "BITSInt::MidBits needs Count > 0");
static_assert(Count + Offset <= BITS, "BitsInt::MidBits overflow of Count+Offset");
return (val >> Offset) & ((I(1) << Count) - 1);
}
template<int Count>
static constexpr inline int TopBits(I val) {
static_assert(Count > 0, "BitsInt::TopBits needs Count > 0");
static_assert(Count <= BITS, "BitsInt::TopBits needs Offset <= BITS");
return val >> (BITS - Count);
}
static inline constexpr I CondXorWith(I val, bool cond, I v) {
return val ^ (-I(cond) & v);
}
template<I MOD>
static inline constexpr I CondXorWith(I val, bool cond) {
return val ^ (-I(cond) & MOD);
}
static inline int Bits(I val, int max) {
#ifdef HAVE_CLZ
(void)max;
if (val == 0) return 0;
if (std::numeric_limits<unsigned>::digits >= std::numeric_limits<I>::digits) {
return std::numeric_limits<unsigned>::digits - __builtin_clz(val);
} else if (std::numeric_limits<unsigned long>::digits >= std::numeric_limits<I>::digits) {
return std::numeric_limits<unsigned long>::digits - __builtin_clzl(val);
} else {
return std::numeric_limits<unsigned long long>::digits - __builtin_clzll(val);
}
#else
while (max && (val >> (max - 1) == 0)) --max;
return max;
#endif
}
};
template<typename F, uint32_t MOD>
struct LFSR {
typedef typename F::Repr I;
static inline constexpr I Call(const I& a) {
return F::template CondXorWith<MOD>(F::Shift(a, 1), F::template TopBits<1>(a));
}
};
template<typename I, int N, typename L, typename F, int K> struct GFMulHelper;
template<typename I, int N, typename L, typename F> struct GFMulHelper<I, N, L, F, 0>
{
static inline constexpr I Run(const I& a, const I& b) { return I(0); }
};
template<typename I, int N, typename L, typename F, int K> struct GFMulHelper
{
static inline constexpr I Run(const I& a, const I& b) { return F::CondXorWith(GFMulHelper<I, N, L, F, K - 1>::Run(L::Call(a), b), F::template MidBits<N - K, 1>(b), a); }
};
template<typename I, int N, typename L, typename F> inline constexpr I GFMul(const I& a, const I& b) { return GFMulHelper<I, N, L, F, N>::Run(a, b); }
template<typename I, typename F, int BITS, uint32_t MOD>
inline I InvExtGCD(I x)
{
if (F::IsZero(x)) return x;
I t(0), newt(1);
I r(MOD), newr = x;
int rlen = BITS + 1, newrlen = F::Bits(newr, BITS);
while (newr) {
int q = rlen - newrlen;
r ^= F::Shift(newr, q);
t ^= F::UnsafeShift(newt, q);
rlen = F::Bits(r, rlen - 1);
if (r < newr) {
F::Swap(t, newt);
F::Swap(r, newr);
std::swap(rlen, newrlen);
}
}
return t;
}
template<typename I, typename F, int BITS, I (*MUL)(I, I), I (*SQR)(I), I (*SQR2)(I), I(*SQR4)(I), I(*SQR8)(I), I(*SQR16)(I)>
inline I InvLadder(I x1)
{
static constexpr int INV_EXP = BITS - 1;
I x2 = (INV_EXP >= 2) ? MUL(SQR(x1), x1) : I();
I x4 = (INV_EXP >= 4) ? MUL(SQR2(x2), x2) : I();
I x8 = (INV_EXP >= 8) ? MUL(SQR4(x4), x4) : I();
I x16 = (INV_EXP >= 16) ? MUL(SQR8(x8), x8) : I();
I x32 = (INV_EXP >= 32) ? MUL(SQR16(x16), x16) : I();
I r;
if (INV_EXP >= 32) {
r = x32;
} else if (INV_EXP >= 16) {
r = x16;
} else if (INV_EXP >= 8) {
r = x8;
} else if (INV_EXP >= 4) {
r = x4;
} else if (INV_EXP >= 2) {
r = x2;
} else {
r = x1;
}
if (INV_EXP >= 32 && (INV_EXP & 16)) r = MUL(SQR16(r), x16);
if (INV_EXP >= 16 && (INV_EXP & 8)) r = MUL(SQR8(r), x8);
if (INV_EXP >= 8 && (INV_EXP & 4)) r = MUL(SQR4(r), x4);
if (INV_EXP >= 4 && (INV_EXP & 2)) r = MUL(SQR2(r), x2);
if (INV_EXP >= 2 && (INV_EXP & 1)) r = MUL(SQR(r), x1);
return SQR(r);
}
#endif