use core::cmp::PartialEq;
use std::ops::{Add, Mul, Neg, Shr, Sub};
#[derive(Clone)]
pub struct LInt<const L: usize>([u64; L]);
impl<const L: usize> LInt<L> {
pub const MINUS_ONE: Self = Self([u64::MAX; L]);
pub const ZERO: Self = Self([0; L]);
pub const ONE: Self = {
let mut data = [0; L];
data[0] = 1;
Self(data)
};
pub fn new(data: &[u64]) -> Self {
let mut number = Self::ZERO;
number.0[..data.len()].copy_from_slice(data);
number
}
#[inline]
pub fn is_negative(&self) -> bool {
self.0[L - 1] > (u64::MAX >> 1)
}
#[inline]
fn sum(first: u64, second: u64, carry: bool) -> (u64, bool) {
let (second, carry) = second.overflowing_add(carry as u64);
let (first, high) = first.overflowing_add(second);
(first, carry || high)
}
#[inline]
fn prodsum(first: u64, second: u64, summand: u64, carry: u64) -> (u64, u64) {
let all = (first as u128) * (second as u128) + (carry as u128) + (summand as u128);
(all as u64, (all >> u64::BITS) as u64)
}
}
impl<const L: usize> PartialEq for LInt<L> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const L: usize> Shr<u32> for &LInt<L> {
type Output = LInt<L>;
fn shr(self, bits: u32) -> Self::Output {
debug_assert!(
(bits > 0) && (bits < 64),
"Cannot shift by 0 or more than 63 bits!"
);
let (mut data, right) = ([0; L], u64::BITS - bits);
for (i, d) in data.iter_mut().enumerate().take(L - 1) {
*d = (self.0[i] >> bits) | (self.0[i + 1] << right);
}
data[L - 1] = self.0[L - 1] >> bits;
if self.is_negative() {
data[L - 1] |= u64::MAX << right;
}
LInt::<L>(data)
}
}
impl<const L: usize> Shr<u32> for LInt<L> {
type Output = LInt<L>;
fn shr(self, bits: u32) -> Self::Output {
&self >> bits
}
}
impl<const L: usize> Add for &LInt<L> {
type Output = LInt<L>;
fn add(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], false);
for (i, d) in data.iter_mut().enumerate().take(L) {
(*d, carry) = Self::Output::sum(self.0[i], other.0[i], carry);
}
LInt::<L>(data)
}
}
impl<const L: usize> Add<&LInt<L>> for LInt<L> {
type Output = LInt<L>;
fn add(self, other: &Self) -> Self::Output {
&self + other
}
}
impl<const L: usize> Add for LInt<L> {
type Output = LInt<L>;
fn add(self, other: Self) -> Self::Output {
&self + &other
}
}
impl<const L: usize> Sub for &LInt<L> {
type Output = LInt<L>;
fn sub(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], true);
for (i, d) in data.iter_mut().enumerate().take(L) {
(*d, carry) = Self::Output::sum(self.0[i], !other.0[i], carry);
}
LInt::<L>(data)
}
}
impl<const L: usize> Sub<&LInt<L>> for LInt<L> {
type Output = LInt<L>;
fn sub(self, other: &Self) -> Self::Output {
&self - other
}
}
impl<const L: usize> Sub for LInt<L> {
type Output = LInt<L>;
fn sub(self, other: Self) -> Self::Output {
&self - &other
}
}
impl<const L: usize> Neg for &LInt<L> {
type Output = LInt<L>;
fn neg(self) -> Self::Output {
let (mut data, mut carry) = ([0; L], true);
for (i, d) in data.iter_mut().enumerate().take(L) {
(*d, carry) = (!self.0[i]).overflowing_add(carry as u64);
}
LInt::<L>(data)
}
}
impl<const L: usize> Neg for LInt<L> {
type Output = LInt<L>;
fn neg(self) -> Self::Output {
-&self
}
}
impl<const L: usize> Mul for &LInt<L> {
type Output = LInt<L>;
fn mul(self, other: Self) -> Self::Output {
let mut data = [0; L];
for i in 0..L {
let mut carry = 0;
for k in 0..(L - i) {
(data[i + k], carry) =
Self::Output::prodsum(self.0[i], other.0[k], data[i + k], carry);
}
}
LInt::<L>(data)
}
}
impl<const L: usize> Mul<&LInt<L>> for LInt<L> {
type Output = LInt<L>;
fn mul(self, other: &Self) -> Self::Output {
&self * other
}
}
impl<const L: usize> Mul for LInt<L> {
type Output = LInt<L>;
fn mul(self, other: Self) -> Self::Output {
&self * &other
}
}
impl<const L: usize> Mul<i64> for &LInt<L> {
type Output = LInt<L>;
fn mul(self, other: i64) -> Self::Output {
let mut data = [0; L];
let (other, mut carry, mask) = if other < 0 {
(-other as u64, -other as u64, u64::MAX)
} else {
(other as u64, 0, 0)
};
for (i, d) in data.iter_mut().enumerate().take(L) {
(*d, carry) = Self::Output::prodsum(self.0[i] ^ mask, other, 0, carry);
}
LInt::<L>(data)
}
}
impl<const L: usize> Mul<i64> for LInt<L> {
type Output = LInt<L>;
fn mul(self, other: i64) -> Self::Output {
&self * other
}
}
impl<const L: usize> Mul<&LInt<L>> for i64 {
type Output = LInt<L>;
fn mul(self, other: &LInt<L>) -> Self::Output {
other * self
}
}
impl<const L: usize> Mul<LInt<L>> for i64 {
type Output = LInt<L>;
fn mul(self, other: LInt<L>) -> Self::Output {
other * self
}
}
fn approximate<const L: usize>(x: &LInt<L>, y: &LInt<L>) -> (u64, u64, bool) {
debug_assert!(
!(x.is_negative() || y.is_negative()),
"Both the arguments must be non-negative!"
);
debug_assert!(
(*x != LInt::ZERO) || (*y != LInt::ZERO),
"At least one argument must be non-zero!"
);
let mut i = L - 1;
while (x.0[i] == 0) && (y.0[i] == 0) {
i -= 1;
}
if i == 0 {
return (x.0[0], y.0[0], true);
}
let mut h = (x.0[i], y.0[i]);
let z = h.0.leading_zeros().min(h.1.leading_zeros());
h = (h.0 << z, h.1 << z);
if z > 32 {
h.0 |= x.0[i - 1] >> z;
h.1 |= y.0[i - 1] >> z;
}
let h = (h.0 & (u64::MAX << 32), h.1 & (u64::MAX << 32));
let l = (x.0[0] & (u64::MAX >> 32), y.0[0] & (u64::MAX >> 32));
(h.0 | l.0, h.1 | l.1, false)
}
fn jacobinary(mut n: u64, mut d: u64, mut t: u64) -> i64 {
debug_assert!(d & 1 > 0, "The second argument must be odd!");
while n != 0 {
if n & 1 > 0 {
if n < d {
(n, d) = (d, n);
t ^= n & d;
}
n = (n - d) >> 1;
t ^= d ^ (d >> 1);
} else {
let z = n.trailing_zeros();
t ^= (d ^ (d >> 1)) & (z << 1) as u64;
n >>= z;
}
}
(d == 1) as i64 * (1 - (t & 2) as i64)
}
pub fn jacobi<const L: usize>(n: &[u64], d: &[u64]) -> i64 {
let (mut n, mut d, mut t) = (LInt::<L>::new(n), LInt::<L>::new(d), 0u64);
debug_assert!(d.0[0] & 1 > 0, "The second argument must be odd!");
debug_assert!(
n.0[L - 1].leading_zeros().min(d.0[L - 1].leading_zeros()) >= 31,
"Both the arguments must be less than 2 ^ (64 * L - 31)!"
);
loop {
let (mut u, mut v, mut i) = ((1i64, 0i64), (0i64, 1i64), 30);
let (mut a, mut b, precise) = approximate(&n, &d);
if precise {
return jacobinary(a, b, t);
}
while i > 0 {
if a & 1 > 0 {
if a < b {
(a, b, u, v) = (b, a, v, u);
t ^= a & b;
}
a = (a - b) >> 1;
u = (u.0 - v.0, u.1 - v.1);
v = (v.0 << 1, v.1 << 1);
t ^= b ^ (b >> 1);
i -= 1;
} else {
let z = i.min(a.trailing_zeros());
t ^= (b ^ (b >> 1)) & (z << 1) as u64;
v = (v.0 << z, v.1 << z);
a >>= z;
i -= z;
}
}
(n, d) = ((&n * u.0 + &d * u.1) >> 30, (&n * v.0 + &d * v.1) >> 30);
if n == LInt::ZERO {
return (d == LInt::ONE) as i64 * (1 - (t & 2) as i64);
}
if n.is_negative() {
t ^= d.0[0];
n = -n;
} else if d.is_negative() {
d = -d;
}
}
}