use std::sync::Arc;
use num_bigint::{BigInt, BigUint, Sign};
use num_traits::Zero;
use crate::primes::{first_n_primes, gcd, mod_inverse};
use crate::RAYON_CHANNEL_THRESHOLD;
#[derive(Clone, Debug)]
pub struct Channels(pub Arc<Vec<u64>>);
impl Channels {
pub fn new(moduli: Vec<u64>) -> Self {
debug_assert!(
Self::pairwise_coprime(&moduli),
"RNS channels must be pairwise coprime"
);
Channels(Arc::new(moduli))
}
pub fn standard(n: usize) -> Self {
Channels(Arc::new(first_n_primes(n)))
}
fn pairwise_coprime(moduli: &[u64]) -> bool {
for i in 0..moduli.len() {
for j in (i + 1)..moduli.len() {
if gcd(moduli[i], moduli[j]) != 1 {
return false;
}
}
}
true
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn modulus(&self, c: usize) -> u64 {
self.0[c]
}
#[inline]
pub fn moduli(&self) -> &[u64] {
&self.0
}
pub fn capacity(&self) -> BigUint {
self.0.iter().map(|&m| BigUint::from(m)).product()
}
pub fn signed_capacity(&self) -> BigInt {
BigInt::from(self.capacity() / BigUint::from(2u8))
}
}
impl PartialEq for Channels {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0) || self.0 == other.0
}
}
impl Eq for Channels {}
#[derive(Clone, Debug)]
pub struct RnsInt {
pub residues: Vec<u64>,
pub channels: Channels,
pub negative: bool,
}
impl RnsInt {
pub fn from_bigint(n: &BigInt, channels: Channels) -> Self {
let negative = n.sign() == Sign::Minus;
let residues = channels
.moduli()
.iter()
.map(|&m| {
let mm = BigInt::from(m);
let r = ((n % &mm) + &mm) % &mm;
r.to_biguint().unwrap().try_into().unwrap()
})
.collect();
RnsInt {
residues,
channels,
negative,
}
}
pub fn from_i64(n: i64, channels: Channels) -> Self {
Self::from_bigint(&BigInt::from(n), channels)
}
pub fn zero(channels: Channels) -> Self {
RnsInt {
residues: vec![0; channels.len()],
channels,
negative: false,
}
}
pub fn from_residues(residues: Vec<u64>, channels: Channels) -> Self {
Self::finish(residues, channels)
}
pub fn to_bigint(&self) -> BigInt {
let u = garner_crt(&self.residues, self.channels.moduli());
let m = self.channels.capacity();
if &u * 2u8 > m {
BigInt::from_biguint(Sign::Plus, u) - BigInt::from_biguint(Sign::Plus, m)
} else {
BigInt::from_biguint(Sign::Plus, u)
}
}
pub fn is_zero(&self) -> bool {
self.residues.iter().all(|&r| r == 0)
}
pub fn add(&self, other: &Self) -> Self {
let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
(a + b) % m
});
Self::finish(out, self.channels.clone())
}
pub fn sub(&self, other: &Self) -> Self {
let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
(a + m - b) % m
});
Self::finish(out, self.channels.clone())
}
pub fn mul(&self, other: &Self) -> Self {
let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
gpu_mul_channel(a, b, m)
});
Self::finish(out, self.channels.clone())
}
pub fn neg(&self) -> Self {
Self::zero(self.channels.clone()).sub(self)
}
fn finish(residues: Vec<u64>, channels: Channels) -> Self {
let mut v = RnsInt {
residues,
channels,
negative: false,
};
v.negative = v.to_bigint().sign() == Sign::Minus;
v
}
}
fn channel_map(
a: &[u64],
b: &[u64],
moduli: &[u64],
f: impl Fn(u64, u64, u64) -> u64 + Sync + Send,
) -> Vec<u64> {
use rayon::prelude::*;
if a.len() >= RAYON_CHANNEL_THRESHOLD {
a.par_iter()
.zip(b.par_iter())
.zip(moduli.par_iter())
.map(|((&av, &bv), &m)| f(av, bv, m))
.collect()
} else {
a.iter()
.zip(b.iter())
.zip(moduli.iter())
.map(|((&av, &bv), &m)| f(av, bv, m))
.collect()
}
}
pub fn garner_crt(residues: &[u64], moduli: &[u64]) -> BigUint {
let k = residues.len();
assert_eq!(k, moduli.len(), "residue/moduli length mismatch");
if k == 0 {
return BigUint::zero();
}
let mut c: Vec<u64> = residues.to_vec();
for i in 0..k {
for j in 0..i {
let mi = moduli[i];
let inv = mod_inverse(moduli[j] % mi, mi)
.expect("channels must be pairwise coprime for CRT");
let diff = (c[i] + mi - (c[j] % mi)) % mi;
c[i] = ((diff as u128 * inv as u128) % mi as u128) as u64;
}
}
let mut result = BigUint::from(c[k - 1]);
for i in (0..k - 1).rev() {
result = result * BigUint::from(moduli[i]) + BigUint::from(c[i]);
}
result
}
#[inline]
pub fn gpu_add_channel(a: u64, b: u64, m: u64) -> u64 {
(a + b) % m
}
#[inline]
pub fn gpu_mul_channel(a: u64, b: u64, m: u64) -> u64 {
((a as u128 * b as u128) % m as u128) as u64
}
pub const MAX_SAFE_MODULUS: u64 = 65535;
#[cfg(test)]
mod tests {
use super::*;
fn ch() -> Channels {
Channels::standard(16)
}
#[test]
fn roundtrip_positive() {
let a = RnsInt::from_i64(123_456_789, ch());
assert_eq!(a.to_bigint(), BigInt::from(123_456_789));
}
#[test]
fn roundtrip_negative() {
let a = RnsInt::from_i64(-42, ch());
assert!(a.negative);
assert_eq!(a.to_bigint(), BigInt::from(-42));
}
#[test]
fn add_sub_mul() {
let a = RnsInt::from_i64(1000, ch());
let b = RnsInt::from_i64(337, ch());
assert_eq!(a.add(&b).to_bigint(), BigInt::from(1337));
assert_eq!(a.sub(&b).to_bigint(), BigInt::from(663));
assert_eq!(b.sub(&a).to_bigint(), BigInt::from(-663));
assert_eq!(a.mul(&b).to_bigint(), BigInt::from(337_000));
}
#[test]
fn garner_classic() {
assert_eq!(garner_crt(&[2, 3, 2], &[3, 5, 7]), BigUint::from(23u8));
assert_eq!(garner_crt(&[0, 1, 0], &[2, 3, 5]), BigUint::from(10u8));
}
#[test]
fn is_zero_works() {
assert!(RnsInt::zero(ch()).is_zero());
assert!(!RnsInt::from_i64(1, ch()).is_zero());
}
}