#![allow(clippy::use_self)]
use crate::Uint;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Matrix(pub u64, pub u64, pub u64, pub u64, pub bool);
impl Matrix {
pub const IDENTITY: Self = Self(1, 0, 0, 1, true);
#[allow(clippy::suspicious_operation_groupings)]
#[must_use]
pub const fn compose(self, other: Self) -> Self {
Self(
self.0 * other.0 + self.1 * other.2,
self.0 * other.1 + self.1 * other.3,
self.2 * other.0 + self.3 * other.2,
self.2 * other.1 + self.3 * other.3,
self.4 ^ !other.4,
)
}
pub fn apply<const BITS: usize, const LIMBS: usize>(
&self,
a: &mut Uint<BITS, LIMBS>,
b: &mut Uint<BITS, LIMBS>,
) {
if BITS == 0 {
return;
}
let (c, d) = if self.4 {
(
Uint::from(self.0) * *a - Uint::from(self.1) * *b,
Uint::from(self.3) * *b - Uint::from(self.2) * *a,
)
} else {
(
Uint::from(self.1) * *b - Uint::from(self.0) * *a,
Uint::from(self.2) * *a - Uint::from(self.3) * *b,
)
};
*a = c;
*b = d;
}
#[must_use]
pub const fn apply_u128(&self, a: u128, b: u128) -> (u128, u128) {
if self.4 {
(
(self.0 as u128)
.wrapping_mul(a)
.wrapping_sub((self.1 as u128).wrapping_mul(b)),
(self.3 as u128)
.wrapping_mul(b)
.wrapping_sub((self.2 as u128).wrapping_mul(a)),
)
} else {
(
(self.1 as u128)
.wrapping_mul(b)
.wrapping_sub((self.0 as u128).wrapping_mul(a)),
(self.2 as u128)
.wrapping_mul(a)
.wrapping_sub((self.3 as u128).wrapping_mul(b)),
)
}
}
#[must_use]
pub fn from<const BITS: usize, const LIMBS: usize>(
a: Uint<BITS, LIMBS>,
b: Uint<BITS, LIMBS>,
) -> Self {
assert!(a >= b);
let s = a.bit_len();
if s <= 64 {
Self::from_u64(a.try_into().unwrap(), b.try_into().unwrap())
} else if s <= 128 {
Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
} else {
let a = a >> (s - 128);
let b = b >> (s - 128);
Self::from_u128_prefix(a.try_into().unwrap(), b.try_into().unwrap())
}
}
#[must_use]
pub fn from_u64(mut r0: u64, mut r1: u64) -> Self {
debug_assert!(r0 >= r1);
if r1 == 0_u64 {
return Matrix::IDENTITY;
}
let mut q00 = 1_u64;
let mut q01 = 0_u64;
let mut q10 = 0_u64;
let mut q11 = 1_u64;
loop {
let q = r0 / r1;
r0 -= q * r1;
q00 += q * q10;
q01 += q * q11;
if r0 == 0_u64 {
return Matrix(q10, q11, q00, q01, false);
}
let q = r1 / r0;
r1 -= q * r0;
q10 += q * q00;
q11 += q * q01;
if r1 == 0_u64 {
return Matrix(q00, q01, q10, q11, true);
}
}
}
#[must_use]
#[allow(clippy::redundant_else)]
#[allow(clippy::cognitive_complexity)] pub fn from_u64_prefix(a0: u64, mut a1: u64) -> Self {
const LIMIT: u64 = 1_u64 << 32;
debug_assert!(a0 >= 1_u64 << 63);
debug_assert!(a0 >= a1);
let mut k0 = 1_u64 << 32; let mut k1 = 1_u64; let mut even = true;
if a1 < LIMIT {
return Matrix::IDENTITY;
}
let q = a0 / a1;
let mut a2 = a0 - q * a1;
let mut k2 = k0 + q * k1;
if a2 < LIMIT {
let u2 = k2 >> 32;
let v2 = k2 % LIMIT;
if a2 >= v2 && a1 - a2 >= u2 {
return Matrix(0, 1, u2, v2, false);
} else {
return Matrix::IDENTITY;
}
}
let q = a1 / a2;
let mut a3 = a1 - q * a2;
let mut k3 = k1 + q * k2;
while a3 >= LIMIT {
a1 = a2;
a2 = a3;
a3 = a1;
k0 = k1;
k1 = k2;
k2 = k3;
k3 = k1;
debug_assert!(a2 < a3);
debug_assert!(a2 > 0);
let q = a3 / a2;
a3 -= q * a2;
k3 += q * k2;
if a3 < LIMIT {
even = false;
break;
}
a1 = a2;
a2 = a3;
a3 = a1;
k0 = k1;
k1 = k2;
k2 = k3;
k3 = k1;
debug_assert!(a2 < a3);
debug_assert!(a2 > 0);
let q = a3 / a2;
a3 -= q * a2;
k3 += q * k2;
}
let u0 = k0 >> 32;
let u1 = k1 >> 32;
let u2 = k2 >> 32;
let u3 = k3 >> 32;
let v0 = k0 % LIMIT;
let v1 = k1 % LIMIT;
let v2 = k2 % LIMIT;
let v3 = k3 % LIMIT;
debug_assert!(a2 >= LIMIT);
debug_assert!(a3 < LIMIT);
if even {
debug_assert!(a2 >= v2);
if a1 - a2 >= u2 + u1 {
if a3 >= u3 && a2 - a3 >= v3 + v2 {
Matrix(u2, v2, u3, v3, true)
} else {
Matrix(u1, v1, u2, v2, false)
}
} else {
Matrix(u0, v0, u1, v1, true)
}
} else {
debug_assert!(a2 >= u2);
if a1 - a2 >= v2 + v1 {
if a3 >= v3 && a2 - a3 >= u3 + u2 {
Matrix(u2, v2, u3, v3, false)
} else {
Matrix(u1, v1, u2, v2, true)
}
} else {
Matrix(u0, v0, u1, v1, false)
}
}
}
#[must_use]
pub fn from_u128_prefix(r0: u128, r1: u128) -> Self {
debug_assert!(r0 >= r1);
let s = r0.leading_zeros();
let r0s = r0 << s;
let r1s = r1 << s;
let q = Self::from_u64_prefix((r0s >> 64) as u64, (r1s >> 64) as u64);
if q == Matrix::IDENTITY {
return q;
}
q
}
}
#[cfg(test)]
#[allow(clippy::cast_lossless)]
#[allow(clippy::many_single_char_names)]
mod tests {
use super::*;
use crate::{const_for, nlimbs};
use core::{
cmp::{max, min},
mem::swap,
};
use proptest::{proptest, test_runner::Config};
use std::str::FromStr;
fn gcd(mut a: u128, mut b: u128) -> u128 {
while b != 0 {
a %= b;
swap(&mut a, &mut b);
}
a
}
fn gcd_uint<const BITS: usize, const LIMBS: usize>(
mut a: Uint<BITS, LIMBS>,
mut b: Uint<BITS, LIMBS>,
) -> Uint<BITS, LIMBS> {
while b != Uint::ZERO {
a %= b;
swap(&mut a, &mut b);
}
a
}
#[test]
fn test_from_u64_example() {
let (a, b) = (252, 105);
let m = Matrix::from_u64(a, b);
assert_eq!(m, Matrix(2, 5, 5, 12, false));
let (a, b) = m.apply_u128(a as u128, b as u128);
assert_eq!(a, 21);
assert_eq!(b, 0);
}
#[test]
fn test_from_u64() {
proptest!(|(a: u64, b: u64)| {
let (a, b) = (max(a,b), min(a,b));
let m = Matrix::from_u64(a, b);
let (c, d) = m.apply_u128(a as u128, b as u128);
assert!(c >= d);
assert_eq!(c, gcd(a as u128, b as u128));
assert_eq!(d, 0);
});
}
#[test]
fn test_from_u64_prefix() {
proptest!(|(a: u128, b: u128)| {
let (a, b) = (max(a,b), min(a,b));
let s = a.leading_zeros();
let (sa, sb) = (a << s, b << s);
let m = Matrix::from_u64_prefix((sa >> 64) as u64, (sb >> 64) as u64);
let (c, d) = m.apply_u128(a, b);
assert!(c >= d);
if m == Matrix::IDENTITY {
assert_eq!(c, a);
assert_eq!(d, b);
} else {
assert!(c <= a);
assert!(d < b);
assert_eq!(gcd(a, b), gcd(c, d));
}
});
}
fn test_form_uint_one<const BITS: usize, const LIMBS: usize>(
a: Uint<BITS, LIMBS>,
b: Uint<BITS, LIMBS>,
) {
let (a, b) = (max(a, b), min(a, b));
let m = Matrix::from(a, b);
let (mut c, mut d) = (a, b);
m.apply(&mut c, &mut d);
assert!(c >= d);
if m == Matrix::IDENTITY {
assert_eq!(c, a);
assert_eq!(d, b);
} else {
assert!(c <= a);
assert!(d < b);
assert_eq!(gcd_uint(a, b), gcd_uint(c, d));
}
}
#[test]
fn test_from_uint_cases() {
type U129 = Uint<129, 3>;
test_form_uint_one(
U129::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap(),
U129::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap(),
);
}
#[test]
#[allow(clippy::absurd_extreme_comparisons)] fn test_from_uint_proptest() {
const_for!(BITS in SIZES {
const LIMBS: usize = nlimbs(BITS);
type U = Uint<BITS, LIMBS>;
let mut config = Config::default();
config.cases = min(config.cases, if BITS > 500 { 12 } else { 40 });
proptest!(config, |(a: U, b: U)| {
test_form_uint_one(a, b);
});
});
}
}
#[cfg(feature = "bench")]
pub mod bench {
use super::*;
use crate::{const_for, nlimbs};
use ::proptest::{
arbitrary::Arbitrary,
strategy::{Strategy, ValueTree},
test_runner::TestRunner,
};
use core::cmp::{max, min};
use criterion::{black_box, BatchSize, Criterion};
pub fn group(criterion: &mut Criterion) {
bench_from_u64(criterion);
bench_from_u64_prefix(criterion);
const_for!(BITS in BENCH {
const LIMBS: usize = nlimbs(BITS);
bench_apply::<BITS, LIMBS>(criterion);
});
}
fn bench_from_u64(criterion: &mut Criterion) {
let input = (u64::arbitrary(), u64::arbitrary());
let mut runner = TestRunner::deterministic();
criterion.bench_function("algorithms/gcd/matrix/from_u64", move |bencher| {
bencher.iter_batched(
|| {
let (a, b) = input.new_tree(&mut runner).unwrap().current();
(max(a, b), min(a, b))
},
|(a, b)| black_box(Matrix::from_u64(black_box(a), black_box(b))),
BatchSize::SmallInput,
);
});
}
fn bench_from_u64_prefix(criterion: &mut Criterion) {
let input = (u64::arbitrary(), u64::arbitrary());
let mut runner = TestRunner::deterministic();
criterion.bench_function("algorithms/gcd/matrix/from_u64_prefix", move |bencher| {
bencher.iter_batched(
|| {
let (a, b) = input.new_tree(&mut runner).unwrap().current();
(max(a, b), min(a, b))
},
|(a, b)| black_box(Matrix::from_u64_prefix(black_box(a), black_box(b))),
BatchSize::SmallInput,
);
});
}
fn bench_apply<const BITS: usize, const LIMBS: usize>(criterion: &mut Criterion) {
let input = (
Uint::<BITS, LIMBS>::arbitrary(),
Uint::<BITS, LIMBS>::arbitrary(),
);
let mut runner = TestRunner::deterministic();
criterion.bench_function(
&format!("algorithms/gcd/matrix/apply/{BITS}"),
move |bencher| {
bencher.iter_batched(
|| {
let (a, b) = input.new_tree(&mut runner).unwrap().current();
let (a, b) = (max(a, b), min(a, b));
let m = Matrix::from(a, b);
(a, b, m)
},
|(a, b, m)| black_box(m).apply(&mut black_box(a), &mut black_box(b)),
BatchSize::SmallInput,
);
},
);
}
}