use crate::oa::{OAConstructor, OAResult, OA};
use crate::utils::{poly_eval, to_base_fixed, Integer, OarsError, OarsResult};
use ndarray::Array2;
use num::pow::pow;
use oars_proc_macro::Checked;
use primes::is_prime;
use std::cmp::min;
#[cfg(feature = "parallel")]
use crate::oa::ParOAConstructor;
#[cfg(feature = "parallel")]
use ndarray::{concatenate, parallel::prelude::*, Axis};
#[cfg(feature = "parallel")]
use rayon::iter::IntoParallelIterator;
impl<T: Integer> BushChecked<T> {
pub fn verify(self) -> OarsResult<Bush<T>> {
if !is_prime(self.prime_base.to_u64().unwrap()) {
return Err(OarsError::InvalidParams("Base is not prime".to_owned()));
}
if self.dimensions < T::from(2).unwrap()
|| self.dimensions > self.prime_base + T::from(1).unwrap()
{
return Err(OarsError::InvalidParams(
"Dimensions must be less than `prime_base` + 1".to_owned(),
));
}
if self.strength < T::from(1).unwrap() || self.strength > self.prime_base {
return Err(OarsError::InvalidParams(
"`strength` must be between 1 and `prime_base` (inclusive)".to_owned(),
));
}
Ok(Bush {
strength: self.strength,
prime_base: self.prime_base,
dimensions: self.dimensions,
})
}
}
#[derive(Checked)]
pub struct Bush<T: Integer> {
pub prime_base: T,
pub strength: T,
pub dimensions: T,
}
impl<T: Integer> OAConstructor<T> for Bush<T> {
fn gen(&self) -> OAResult<T> {
let n = pow(self.prime_base, self.strength.to_usize().unwrap());
let mut points =
Array2::<T>::zeros((n.to_usize().unwrap(), self.dimensions.to_usize().unwrap()));
for i in 0..n.to_usize().unwrap() {
let coeffs = to_base_fixed(T::from(i).unwrap(), self.prime_base, self.strength);
let poly_dims = min(self.dimensions, self.prime_base);
for j in 0..poly_dims.to_usize().unwrap() {
points[[i as usize, j as usize]] =
poly_eval(&coeffs, T::from(j).unwrap()) % self.prime_base;
}
if self.dimensions == self.prime_base + T::from(1).unwrap() {
points[[i, self.prime_base.to_usize().unwrap()]] =
T::from(i - 1).unwrap() % self.prime_base;
}
}
Ok(OA {
strength: self.strength,
levels: self.prime_base,
index: T::from(1).unwrap(),
factors: self.dimensions,
points,
})
}
}
#[cfg(feature = "parallel")]
impl<T: Integer> ParOAConstructor<T> for Bush<T> {
fn gen_par(&self) -> OAResult<T> {
let n = pow(self.prime_base, self.strength.to_usize().unwrap());
let mut initial_points = Array2::<T>::zeros((
n.to_usize().unwrap(),
min(
self.dimensions.to_usize().unwrap(),
self.prime_base.to_usize().unwrap(),
),
));
initial_points
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(row_idx, mut row)| {
let coeffs =
to_base_fixed(T::from(row_idx).unwrap(), self.prime_base, self.strength);
row.axis_iter_mut(Axis(0))
.into_iter()
.enumerate()
.for_each(|(col_idx, mut col)| {
col[[col_idx; 0]] =
poly_eval(&coeffs, T::from(col_idx).unwrap()) % self.prime_base;
})
});
if self.dimensions == self.prime_base + T::from(1).unwrap() {
let mut last_col = Array2::<T>::zeros((n.to_usize().unwrap(), 1));
last_col
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(row_idx, mut row)| {
row.axis_iter_mut(Axis(0))
.into_iter()
.enumerate()
.for_each(|(_, mut col)| {
col[[0 as usize; 0]] = T::from(row_idx - 1).unwrap() % self.prime_base;
})
});
let points = concatenate(Axis(1), &[initial_points.view(), last_col.view()])?;
return Ok(OA {
strength: self.strength,
levels: self.prime_base,
index: T::from(1).unwrap(),
factors: self.dimensions,
points,
});
}
Ok(OA {
strength: self.strength,
levels: self.prime_base,
index: T::from(1).unwrap(),
factors: self.dimensions,
points: initial_points,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bush_non_prime() {
let bush = BushChecked {
strength: 2,
prime_base: 4,
dimensions: 3,
};
assert!(bush.verify().is_err());
let bush = BushChecked {
strength: 2,
prime_base: 9,
dimensions: 3,
};
assert!(bush.verify().is_err());
let bush = BushChecked {
strength: 2,
prime_base: 100,
dimensions: 3,
};
assert!(bush.verify().is_err());
}
#[test]
fn bush_bad_dims() {
let bush = BushChecked {
strength: 2,
prime_base: 5,
dimensions: 7,
};
assert!(bush.verify().is_err());
let bush = BushChecked {
strength: 2,
prime_base: 7,
dimensions: 11,
};
assert!(bush.verify().is_err());
let bush = BushChecked {
strength: 2,
prime_base: 13,
dimensions: 17,
};
assert!(bush.verify().is_err());
}
}