extern crate rayon;
extern crate rug;
#[macro_use]
extern crate log;
extern crate env_logger;
mod utils;
use rayon::prelude::*;
use rug::Integer;
use std::error::Error;
use std::fmt;
use std::fmt::Display;
use utils::*;
#[derive(Debug, PartialEq)]
pub enum ComputeError {
NotEnoughModuli,
}
impl Display for ComputeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ComputeError: {}", self.description())
}
}
impl Error for ComputeError {
fn description(&self) -> &str {
match self {
ComputeError::NotEnoughModuli => "Not enough moduli",
}
}
}
pub type ComputeResult = Result<Vec<Option<Integer>>, ComputeError>;
struct ProductTree {
levels: Vec<Vec<Integer>>,
}
fn compute_product_tree(moduli: Vec<Integer>) -> ProductTree {
if moduli.len() == 1 {
return ProductTree {
levels: vec![moduli],
};
}
let level = (0..(moduli.len() / 2))
.into_par_iter()
.map(|i| Integer::from(&moduli[i * 2] * &moduli[i * 2 + 1]))
.collect();
let mut res = compute_product_tree(level);
res.levels.push(moduli);
res
}
fn compute_remainders(tree: ProductTree) -> Option<Vec<Integer>> {
let level_count = tree.levels.len() - 1;
trace!("computing remainders for {} levels", level_count);
tree.levels
.into_iter()
.enumerate()
.fold(None, |maybe_parent, (level, current)| {
let parent = match maybe_parent {
None => {
return Some(current);
}
Some(parent) => parent,
};
trace!("computing remainder level {}/{}", level, level_count);
let remainders = current
.into_par_iter()
.enumerate()
.map(|(i, mut value)| {
value.square_mut();
&parent[i / 2] % value
})
.collect();
Some(remainders)
})
}
fn compute_gcds(remainders: &[Integer], moduli: &[Integer]) -> Vec<Integer> {
trace!("computing quotients and gcd");
remainders
.par_iter()
.zip(moduli.par_iter())
.map(|(remainder, modulo)| {
let quotient = Integer::from(remainder / modulo);
quotient.gcd(modulo)
})
.collect()
}
pub fn compute(moduli: &[Integer]) -> ComputeResult {
if moduli.len() < 2 {
return Err(ComputeError::NotEnoughModuli);
}
let (padded_moduli, pad_size) = pad_ints(moduli.to_vec());
trace!("added {} padding to moduli", pad_size);
trace!("computing product tree");
let tree = compute_product_tree(padded_moduli);
let remainders = compute_remainders(tree);
let gcds = compute_gcds(&unpad_ints(remainders.unwrap(), pad_size), moduli);
Ok(gcds
.into_iter()
.map(|gcd| if gcd == 1 { None } else { Some(gcd) })
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_should_fail_on_zero_moduli() {
assert!(compute(&[]).is_err());
}
#[test]
fn it_should_fail_on_single_moduli() {
assert!(compute(&[Integer::new()]).is_err());
}
#[test]
fn it_should_return_gcd_of_two_moduli() {
let moduli = [Integer::from(6), Integer::from(15)];
let result = compute(&moduli).unwrap();
assert_eq!(
result,
vec![Some(Integer::from(3)), Some(Integer::from(3)),]
);
}
#[test]
fn it_should_find_gcd_for_many_moduli() {
let moduli = vec![
Integer::from(31 * 41),
Integer::from(41),
Integer::from(61),
Integer::from(71 * 31),
Integer::from(101 * 131),
Integer::from(131 * 151),
];
let result = compute(&moduli).unwrap();
assert_eq!(
result,
vec![
Some(Integer::from(31 * 41)),
Some(Integer::from(41)),
None,
Some(Integer::from(31)),
Some(Integer::from(131)),
Some(Integer::from(131)),
]
);
}
}