// https://github.com/andrewmilson/ministark/blob/main/gpu-poly/src/metal/u128.h.metal
#ifndef u128_h
#define u128_h
#include <metal_stdlib>
class u128
{
public:
u128() = default;
constexpr u128(int l) : low(l), high(0) {}
constexpr u128(unsigned long l) : low(l), high(0) {}
constexpr u128(bool b) : low(b), high(0) {}
constexpr u128(unsigned long h, unsigned long l) : low(l), high(h) {}
constexpr u128 operator+(const u128 rhs) const
{
return u128(high + rhs.high + ((low + rhs.low) < low), low + rhs.low);
}
constexpr u128 operator+=(const u128 rhs)
{
*this = *this + rhs;
return *this;
}
constexpr inline u128 operator-(const u128 rhs) const
{
return u128(high - rhs.high - ((low - rhs.low) > low), low - rhs.low);
}
constexpr u128 operator-=(const u128 rhs)
{
*this = *this - rhs;
return *this;
}
constexpr bool operator==(const u128 rhs) const
{
return high == rhs.high && low == rhs.low;
}
constexpr bool operator!=(const u128 rhs) const
{
return !(*this == rhs);
}
constexpr bool operator<(const u128 rhs) const
{
return ((high == rhs.high) && (low < rhs.low)) || (high < rhs.high);
}
constexpr u128 operator&(const u128 rhs) const
{
return u128(high & rhs.high, low & rhs.low);
}
constexpr u128 operator|(const u128 rhs) const
{
return u128(high | rhs.high, low | rhs.low);
}
constexpr bool operator>(const u128 rhs) const
{
return ((high == rhs.high) && (low > rhs.low)) || (high > rhs.high);
}
constexpr bool operator>=(const u128 rhs) const
{
return !(*this < rhs);
}
constexpr bool operator<=(const u128 rhs) const
{
return !(*this > rhs);
}
constexpr inline u128 operator>>(unsigned shift) const
{
uint64_t new_low = (shift == 0) * low
| (shift == 64) * high
| ((shift < 64) ^ (shift == 0)) * ((high << (64 - shift)) | (low >> shift))
| ((shift > 64) & (shift < 128)) * (high >> (shift - 64));
uint64_t new_high = (shift == 0) * high
| ((shift < 64) ^ (shift == 0)) * (high >> shift);
return u128(new_high, new_low);
// Unoptimized form:
// if (shift >= 128)
// return u128(0);
// else if (shift == 64)
// return u128(0, high);
// else if (shift == 0)
// return *this;
// else if (shift < 64)
// return u128(high >> shift, (high << (64 - shift)) | (low >> shift));
// else if ((128 > shift) && (shift > 64))
// return u128(0, (high >> (shift - 64)));
// else
// return u128(0);
}
constexpr inline u128 operator<<(unsigned shift) const
{
unsigned long new_low = (shift == 0) * low
| ((shift < 64) ^ (shift == 0)) * (low << shift);
unsigned long new_high = (shift == 0) * high
| (shift == 64) * low
| ((shift < 64) ^ (shift == 0)) * (high << shift) | (low >> (64 - shift))
| ((shift > 64) & (shift < 128)) * (low >> (shift - 64));
return u128(new_high, new_low);
// Unoptimized form:
// if (shift >= 128)
// return u128(0);
// else if (shift == 64)
// return u128(low, 0);
// else if (shift == 0)
// return *this;
// else if (shift < 64)
// return u128((high << shift) | (low >> (64 - shift)), low << shift);
// else if ((128 > shift) && (shift > 64))
// return u128((low >> (shift - 64)), 0);
// else
// return u128(0);
}
constexpr u128 operator>>=(unsigned rhs)
{
*this = *this >> rhs;
return *this;
}
u128 operator*(const bool rhs) const
{
return u128(high * rhs, low * rhs);
}
u128 operator*(const u128 rhs) const
{
unsigned long t_low_high = low * rhs.high;
unsigned long t_high = metal::mulhi(low, rhs.low);
unsigned long t_high_low = high * rhs.low;
unsigned long t_low = low * rhs.low;
return u128(t_low_high + t_high_low + t_high, t_low);
// // // split values into 4 32-bit parts
// // unsigned long top[4] = {high >> 32, high & 0xffffffff, low >> 32, low & 0xffffffff};
// // unsigned long bottom[4] = {rhs.high >> 32, rhs.high & 0xffffffff, rhs.low >> 32, rhs.low & 0xffffffff};
// // unsigned long products[4][4];
// // // multiply each component of the values
// // Alternative:
// // for(int y = 3; y > -1; y--){
// // for(int x = 3; x > -1; x--){
// // products[3 - x][y] = top[x] * bottom[y];
// // }
// // }
// products[0][3] = top[3] * bottom[3];
// products[1][3] = top[2] * bottom[3];
// products[2][3] = top[1] * bottom[3];
// products[3][3] = top[0] * bottom[3];
// products[0][2] = top[3] * bottom[2];
// products[1][2] = top[2] * bottom[2];
// products[2][2] = top[1] * bottom[2];
// // products[3][2] = top[0] * bottom[2];
// products[0][1] = top[3] * bottom[1];
// products[1][1] = top[2] * bottom[1];
// // products[2][1] = top[1] * bottom[1];
// products[3][1] = top[0] * bottom[1];
// products[0][0] = top[3] * bottom[0];
// // products[1][0] = top[2] * bottom[0];
// // products[2][0] = top[1] * bottom[0];
// // products[3][0] = top[0] * bottom[0];
// // first row
// unsigned long fourth32 = products[0][3] & 0xffffffff;
// unsigned long third32 = (products[0][2] & 0xffffffff) + (products[0][3] >> 32);
// unsigned long second32 = (products[0][1] & 0xffffffff) + (products[0][2] >> 32);
// unsigned long first32 = (products[0][0] & 0xffffffff) + (products[0][1] >> 32);
// // second row
// third32 += products[1][3] & 0xffffffff;
// second32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32);
// first32 += (products[1][1] & 0xffffffff) + (products[1][2] >> 32);
// // third row
// second32 += products[2][3] & 0xffffffff;
// first32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32);
// // fourth row
// first32 += products[3][3] & 0xffffffff;
// // move carry to next digit
// // third32 += fourth32 >> 32; // TODO: figure out if this is a nop
// second32 += third32 >> 32;
// first32 += second32 >> 32;
// // remove carry from current digit
// // fourth32 &= 0xffffffff; // TODO: figure out if this is a nop
// // third32 &= 0xffffffff;
// second32 &= 0xffffffff;
// // first32 &= 0xffffffff;
// // combine components
// // return u128((first32 << 32) | second32, (third32 << 32) | fourth32);
// return u128((first32 << 32) | second32, (third32 << 32) | fourth32);
}
u128 operator*=(const u128 rhs)
{
*this = *this * rhs;
return *this;
}
unsigned long high;
unsigned long low;
// TODO: Could get better performance with smaller limb size
// Not sure what word size is for M1 GPU
// #ifdef __LITTLE_ENDIAN__
// unsigned long low;
// unsigned long high;
// #endif
// #ifdef __BIG_ENDIAN__
// #endif
};
#endif /* u128_h */