use core::cmp::PartialEq;
use std::ops::{Add, Mul, Neg, Sub};
#[derive(Clone)]
struct CInt<const B: usize, const L: usize>(pub [u64; L]);
impl<const B: usize, const L: usize> CInt<B, L> {
pub const MASK: u64 = u64::MAX >> (64 - B);
pub const MINUS_ONE: Self = Self([Self::MASK; L]);
pub const ZERO: Self = Self([0; L]);
pub const ONE: Self = {
let mut data = [0; L];
data[0] = 1;
Self(data)
};
pub fn shift(&self) -> Self {
let mut data = [0; L];
if self.is_negative() {
data[L - 1] = Self::MASK;
}
data[..L - 1].copy_from_slice(&self.0[1..]);
Self(data)
}
pub fn lowest(&self) -> u64 {
self.0[0]
}
pub fn is_negative(&self) -> bool {
self.0[L - 1] > (Self::MASK >> 1)
}
}
impl<const B: usize, const L: usize> PartialEq for CInt<B, L> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const B: usize, const L: usize> Add for &CInt<B, L> {
type Output = CInt<B, L>;
fn add(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 0);
for (i, d) in data.iter_mut().enumerate().take(L) {
let sum = self.0[i] + other.0[i] + carry;
*d = sum & CInt::<B, L>::MASK;
carry = sum >> B;
}
CInt::<B, L>(data)
}
}
impl<const B: usize, const L: usize> Add<&CInt<B, L>> for CInt<B, L> {
type Output = CInt<B, L>;
fn add(self, other: &Self) -> Self::Output {
&self + other
}
}
impl<const B: usize, const L: usize> Add for CInt<B, L> {
type Output = CInt<B, L>;
fn add(self, other: Self) -> Self::Output {
&self + &other
}
}
impl<const B: usize, const L: usize> Sub for &CInt<B, L> {
type Output = CInt<B, L>;
fn sub(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 1);
for (i, d) in data.iter_mut().enumerate().take(L) {
let sum = self.0[i] + (other.0[i] ^ CInt::<B, L>::MASK) + carry;
*d = sum & CInt::<B, L>::MASK;
carry = sum >> B;
}
CInt::<B, L>(data)
}
}
impl<const B: usize, const L: usize> Sub<&CInt<B, L>> for CInt<B, L> {
type Output = CInt<B, L>;
fn sub(self, other: &Self) -> Self::Output {
&self - other
}
}
impl<const B: usize, const L: usize> Sub for CInt<B, L> {
type Output = CInt<B, L>;
fn sub(self, other: Self) -> Self::Output {
&self - &other
}
}
impl<const B: usize, const L: usize> Neg for &CInt<B, L> {
type Output = CInt<B, L>;
fn neg(self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 1);
for (i, d) in data.iter_mut().enumerate().take(L) {
let sum = (self.0[i] ^ CInt::<B, L>::MASK) + carry;
*d = sum & CInt::<B, L>::MASK;
carry = sum >> B;
}
CInt::<B, L>(data)
}
}
impl<const B: usize, const L: usize> Neg for CInt<B, L> {
type Output = CInt<B, L>;
fn neg(self) -> Self::Output {
-&self
}
}
impl<const B: usize, const L: usize> Mul for &CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: Self) -> Self::Output {
let mut data = [0; L];
for i in 0..L {
let mut carry = 0;
for k in 0..(L - i) {
let sum = (data[i + k] as u128)
+ (carry as u128)
+ (self.0[i] as u128) * (other.0[k] as u128);
data[i + k] = sum as u64 & CInt::<B, L>::MASK;
carry = (sum >> B) as u64;
}
}
CInt::<B, L>(data)
}
}
impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: &Self) -> Self::Output {
&self * other
}
}
impl<const B: usize, const L: usize> Mul for CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: Self) -> Self::Output {
&self * &other
}
}
impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: i64) -> Self::Output {
let mut data = [0; L];
let (other, mut carry, mask) = if other < 0 {
(-other, -other as u64, CInt::<B, L>::MASK)
} else {
(other, 0, 0)
};
for (i, d) in data.iter_mut().enumerate().take(L) {
let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
*d = sum as u64 & CInt::<B, L>::MASK;
carry = (sum >> B) as u64;
}
CInt::<B, L>(data)
}
}
impl<const B: usize, const L: usize> Mul<i64> for CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: i64) -> Self::Output {
&self * other
}
}
impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for i64 {
type Output = CInt<B, L>;
fn mul(self, other: &CInt<B, L>) -> Self::Output {
other * self
}
}
impl<const B: usize, const L: usize> Mul<CInt<B, L>> for i64 {
type Output = CInt<B, L>;
fn mul(self, other: CInt<B, L>) -> Self::Output {
other * self
}
}
pub struct BYInverter<const L: usize> {
modulus: CInt<62, L>,
adjuster: CInt<62, L>,
inverse: i64,
}
type Matrix = [[i64; 2]; 2];
impl<const L: usize> BYInverter<L> {
fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
let mut t: Matrix = [[1, 0], [0, 1]];
loop {
let zeros = steps.min(g.trailing_zeros() as i64);
(steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
t[0] = [t[0][0] << zeros, t[0][1] << zeros];
if steps == 0 {
break;
}
if delta > 0 {
(delta, f, g) = (-delta, g as i64, -f as i128);
(t[0], t[1]) = (t[1], [-t[0][0], -t[0][1]]);
}
let mask = (1 << steps.min(1 - delta).min(5)) - 1;
let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;
t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
g += w as i128 * f as i128;
}
(delta, t)
}
fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
(
(t[0][0] * &f + t[0][1] * &g).shift(),
(t[1][0] * &f + t[1][1] * &g).shift(),
)
}
fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
let mask = CInt::<62, L>::MASK as i64;
let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64;
let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64;
let cd = t[0][0]
.wrapping_mul(d.lowest() as i64)
.wrapping_add(t[0][1].wrapping_mul(e.lowest() as i64))
& mask;
let ce = t[1][0]
.wrapping_mul(d.lowest() as i64)
.wrapping_add(t[1][1].wrapping_mul(e.lowest() as i64))
& mask;
md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;
let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus;
let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus;
(cd.shift(), ce.shift())
}
fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
if value.is_negative() {
value = value + &self.modulus;
}
if negate {
value = -value;
}
if value.is_negative() {
value = value + &self.modulus;
}
value
}
const fn convert<const I: usize, const O: usize, const S: usize>(input: &[u64]) -> [u64; S] {
const fn min(a: usize, b: usize) -> usize {
if a > b {
b
} else {
a
}
}
let (total, mut output, mut bits) = (min(input.len() * I, S * O), [0; S], 0);
while bits < total {
let (i, o) = (bits % I, bits % O);
output[bits / O] |= (input[bits / I] >> i) << o;
bits += min(I - i, O - o);
}
let mask = u64::MAX >> (64 - O);
let mut filled = total / O + if total % O > 0 { 1 } else { 0 };
while filled > 0 {
filled -= 1;
output[filled] &= mask;
}
output
}
const fn inv(value: u64) -> i64 {
let x = value.wrapping_mul(3) ^ 2;
let y = 1u64.wrapping_sub(x.wrapping_mul(value));
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
(x.wrapping_mul(y.wrapping_add(1)) & CInt::<62, L>::MASK) as i64
}
pub const fn new(modulus: &[u64], adjuster: &[u64]) -> Self {
Self {
modulus: CInt::<62, L>(Self::convert::<64, 62, L>(modulus)),
adjuster: CInt::<62, L>(Self::convert::<64, 62, L>(adjuster)),
inverse: Self::inv(modulus[0]),
}
}
pub fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone());
let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value));
let (mut delta, mut f) = (1, self.modulus.clone());
let mut matrix;
while g != CInt::ZERO {
(delta, matrix) = Self::jump(&f, &g, delta);
(f, g) = Self::fg(f, g, matrix);
(d, e) = self.de(d, e, matrix);
}
let antiunit = f == CInt::MINUS_ONE;
if (f != CInt::ONE) && !antiunit {
return None;
}
Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0))
}
}