use crate::perm_vec::PermutationVector;
use crate::utils::{Float, Integer};
use crate::utils::{OarsError, OarsResult};
use itertools::Itertools;
use ndarray::Array2;
use num::{pow, ToPrimitive};
use rand::prelude::*;
use std::ops::Index;
#[cfg(feature = "serialize")]
use serde_derive::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct OA<T: Integer> {
pub levels: T,
pub strength: T,
pub factors: T,
pub index: T,
pub points: Array2<T>,
}
impl<T> fmt::Display for OA<T>
where
T: fmt::Display + Integer,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"OA:\n levels: {}\n strength: {}\n factors: {}\n index: {}\npoints:\n\n{}\n\n",
self.levels, self.strength, self.factors, self.index, self.points
)
}
}
impl<T: Integer> Index<[usize; 2]> for OA<T> {
type Output = T;
fn index(&self, idx: [usize; 2]) -> &Self::Output {
&self.points[idx]
}
}
pub type OAResult<T> = Result<OA<T>, OarsError>;
pub fn normalize<T: Integer, U: Float>(
oa: &OA<T>,
jitter: U,
randomize: bool,
) -> OarsResult<Array2<U>> {
if oa.points.ndim() != 2 {
return Err(OarsError::InvalidParams(
"The `points` array in `oa` must be two dimensional".to_owned(),
));
}
if jitter.to_f64().unwrap() < 0.0 || jitter.to_f64().unwrap() > 1.0 {
return Err(OarsError::InvalidParams(
"`jitter` must be between 0 and 1".to_owned(),
));
}
let dims = oa.points.shape();
let mut point_set = Array2::<U>::zeros((dims[0], dims[1]));
let mut perms: Vec<PermutationVector> = Vec::new();
let mut rng = rand::thread_rng();
for i in 0..dims[1] {
perms.push(PermutationVector::new(dims[0]));
if randomize {
perms[i].shuffle();
}
}
for i in 0..dims[0] {
for j in 0..dims[1] {
let shuffled_i = perms[j][i];
let jittered_point: U =
U::from(oa.points[[i, j]]).unwrap() + (jitter * U::from(rng.gen::<f64>()).unwrap());
point_set[[shuffled_i, j]] = jittered_point / U::from(oa.strength).unwrap();
}
}
Ok(point_set)
}
pub fn verify<T: Integer>(oa: &OA<T>) -> OarsResult<bool>
where
{
if oa.points.ndim() != 2 {
return Err(OarsError::InvalidParams(
"`oa.points` must be two-dimensional".to_owned(),
));
}
if oa.points.shape()[1] != oa.factors.to_usize().unwrap() {
return Ok(false);
}
let col_combos =
(0..oa.factors.to_u64().unwrap()).combinations(oa.strength.to_usize().unwrap());
for selection in col_combos {
let mut tuple_count: HashMap<u64, u64> = HashMap::new();
for i in 0..oa.points.shape()[0] {
let mut tuple_index = 0;
for (power, column) in selection.iter().enumerate() {
tuple_index += (oa.points[[i, column.to_usize().unwrap()]] * pow(oa.levels, power))
.to_u64()
.unwrap();
}
*tuple_count.entry(tuple_index).or_insert(0) += 1;
}
for i in 0..oa
.levels
.to_u64()
.unwrap()
.pow(oa.strength.to_u32().unwrap())
{
if *tuple_count.entry(i).or_insert(0) != oa.index.to_u64().unwrap() {
return Ok(false);
}
}
}
Ok(true)
}
pub trait OAConstructor<T: Integer> {
fn gen(&self) -> OAResult<T>;
}
pub trait ParOAConstructor<T: Integer> {
fn gen_par(&self) -> OAResult<T>;
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr2;
#[test]
fn test_verify_oa_bad_in() {
let points = arr2(&[
[0, 0, 0],
[0, 1, 1],
[0, 2, 2],
[1, 0, 0],
[1, 0, 0],
[1, 2, 2],
[2, 0, 0],
[2, 1, 1],
[2, 2, 2],
]);
let oa = OA {
strength: 3,
levels: 3,
index: 1,
factors: 3,
points,
};
assert!(!verify(&oa).unwrap());
}
#[test]
fn test_verify_oa_good_in() {
let points = arr2(&[
[0, 0, 0],
[0, 1, 1],
[0, 2, 2],
[1, 0, 1],
[1, 1, 2],
[1, 2, 0],
[2, 0, 2],
[2, 1, 0],
[2, 2, 1],
]);
let oa = OA {
strength: 2,
levels: 3,
index: 1,
factors: 3,
points,
};
assert!(verify(&oa).unwrap());
}
}