use crate::utils::OarsError;
use itertools::{zip, Itertools};
use ndarray::Array2;
use std::collections::{HashMap, HashSet};
pub type SOAResult = Result<SOA, OarsError>;
pub trait SOAConstructor {
fn gen(&self) -> SOAResult;
}
#[derive(Debug)]
pub struct SOA {
pub strength: u32,
pub base: u32,
pub points: Array2<u32>,
}
type Vec2D<T> = Vec<Vec<T>>;
fn sum_perms_helper(sum: u32, reduced_num: u32, arr: &[u32], res: &mut Vec2D<u32>) {
if reduced_num == 0 {
res.push(arr.to_vec());
}
let prev = *arr.last().unwrap_or(&1);
for k in prev..=sum {
let mut next_arr = arr.to_owned();
next_arr.push(k);
if k <= reduced_num {
sum_perms_helper(sum, reduced_num - k, &next_arr, res);
}
}
}
fn sum_perms(sum: u32) -> Vec2D<u32> {
let mut res = Vec::new();
let arr = Vec::new();
sum_perms_helper(sum, sum, &arr, &mut res);
res
}
pub fn verify(soa: &SOA) -> bool {
let strata_exp = sum_perms(soa.strength);
for curr_strata in strata_exp {
let strata_perms = curr_strata.iter().combinations(curr_strata.len());
for strata_perm in strata_perms {
let expected_combos = strata_perm
.iter()
.map(|x| 0..soa.base.pow(**x))
.multi_cartesian_product();
let mut combo_counter: HashMap<Vec<u32>, u32> =
expected_combos.map(|x| (x, 0)).collect();
let column_combos = (0..soa.points.shape()[1]).combinations(strata_perm.len());
for col_combo in column_combos {
combo_counter = combo_counter.iter().map(|(k, _)| (k.clone(), 0)).collect();
for row in soa.points.genrows() {
let mut point = Vec::new();
for (strata_pow, col) in zip(strata_perm.iter(), col_combo.iter()) {
point.push(row[[*col]] / soa.base.pow(soa.strength - **strata_pow));
}
if !combo_counter.contains_key(&point) {
return false;
}
*combo_counter.entry(point).or_default() += 1;
}
if combo_counter.values().any(|&x| x < 1) {
return false;
}
let uniq: HashSet<u32> = combo_counter.values().cloned().collect();
if uniq.len() > 1 {
return false;
}
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
use rand::prelude::*;
use std::collections::HashSet;
#[test]
fn test_sum_perms_ground_truth() {
let res = sum_perms(5);
let res_set: HashSet<Vec<u32>> = res.iter().cloned().collect();
let ground_truth = vec![
vec![1, 1, 1, 1, 1],
vec![1, 1, 1, 2],
vec![1, 2, 2],
vec![1, 1, 3],
vec![2, 3],
vec![1, 4],
vec![5],
];
for array in ground_truth {
assert!(res_set.contains(&array));
}
}
#[test]
fn test_sum_perms_random() {
let mut rng = thread_rng();
let mut targets: Vec<u32> = Vec::new();
for _ in 0..10 {
targets.push(rng.gen_range(1..25));
}
for target in targets {
let res: Vec2D<u32> = sum_perms(target);
for array in res {
assert!(array.into_iter().sum::<u32>() == target);
}
}
}
#[test]
fn test_verify_valid_soa() {
let ground_truth = array![
[4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[7, 6, 3, 6, 2, 2, 3, 7, 7, 6, 3],
[5, 5, 4, 1, 4, 0, 0, 1, 5, 5, 5],
[6, 3, 7, 6, 3, 6, 2, 2, 3, 7, 7],
[7, 6, 3, 7, 6, 3, 6, 2, 2, 3, 7],
[7, 7, 6, 3, 7, 6, 3, 6, 2, 2, 3],
[5, 5, 5, 4, 1, 5, 4, 1, 4, 0, 1],
[4, 1, 5, 5, 4, 1, 5, 4, 1, 4, 1],
[4, 0, 1, 5, 5, 4, 1, 5, 4, 1, 5],
[6, 2, 2, 3, 7, 7, 6, 3, 7, 6, 3],
[5, 4, 0, 0, 1, 5, 5, 4, 1, 5, 5],
[6, 3, 6, 2, 2, 3, 7, 7, 6, 3, 7],
[3, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6],
[0, 1, 4, 1, 5, 5, 4, 0, 0, 1, 4],
[2, 2, 3, 6, 3, 7, 7, 6, 2, 2, 2],
[1, 4, 0, 1, 4, 1, 5, 5, 4, 0, 0],
[0, 1, 4, 0, 1, 4, 1, 5, 5, 4, 0],
[0, 0, 1, 4, 0, 1, 4, 1, 5, 5, 4],
[2, 2, 2, 3, 6, 2, 3, 6, 3, 7, 6],
[3, 6, 2, 2, 3, 6, 2, 3, 6, 3, 6],
[3, 7, 6, 2, 2, 3, 6, 2, 3, 6, 2],
[1, 5, 5, 4, 0, 0, 1, 4, 0, 1, 4],
[2, 3, 7, 7, 6, 2, 2, 3, 6, 2, 2],
[1, 4, 1, 5, 5, 4, 0, 0, 1, 4, 0],
];
let soa = SOA {
strength: 3,
base: 2,
points: ground_truth,
};
assert!(verify(&soa));
let ground_truth = array![
[0, 0, 0],
[2, 3, 6],
[3, 6, 2],
[1, 5, 4],
[6, 2, 3],
[4, 1, 5],
[5, 4, 1],
[7, 7, 7],
];
let soa = SOA {
strength: 3,
base: 2,
points: ground_truth,
};
assert!(verify(&soa));
}
#[test]
fn test_verify_invalid_soa() {
let ground_truth = array![
[4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[7, 6, 3, 6, 2, 2, 3, 7, 7, 6, 3],
[5, 5, 4, 1, 4, 0, 0, 1, 5, 5, 5],
[6, 3, 7, 6, 3, 6, 2, 2, 3, 7, 7],
[7, 6, 3, 7, 6, 3, 6, 2, 2, 3, 7],
[7, 7, 6, 3, 7, 6, 3, 6, 2, 2, 3],
[5, 5, 5, 4, 1, 5, 4, 1, 4, 0, 1],
[4, 1, 5, 5, 4, 1, 5, 4, 1, 4, 1],
[4, 0, 1, 5, 5, 4, 1, 5, 4, 1, 5],
[6, 2, 2, 3, 7, 7, 6, 3, 7, 6, 3],
[5, 4, 0, 0, 1, 5, 5, 4, 1, 5, 5],
[6, 3, 6, 2, 2, 3, 7, 7, 6, 3, 7],
[3, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6],
[0, 1, 4, 1, 5, 5, 4, 0, 0, 1, 4],
[2, 2, 3, 6, 3, 7, 7, 6, 2, 2, 2],
[1, 4, 0, 1, 4, 1, 5, 5, 4, 0, 0],
[0, 1, 4, 0, 1, 4, 1, 5, 5, 4, 0],
[0, 0, 1, 4, 0, 1, 4, 1, 5, 5, 4],
[2, 2, 2, 3, 6, 2, 3, 6, 3, 7, 6],
[3, 6, 2, 2, 3, 6, 2, 3, 6, 3, 6],
[3, 7, 6, 2, 2, 3, 6, 2, 3, 6, 2],
[3, 7, 6, 2, 2, 3, 6, 2, 3, 6, 2],
[1, 5, 5, 4, 0, 0, 1, 4, 0, 1, 4],
[2, 3, 7, 7, 6, 2, 2, 3, 6, 2, 2],
[1, 4, 1, 5, 5, 4, 0, 0, 1, 4, 0],
];
let soa = SOA {
strength: 3,
base: 2,
points: ground_truth,
};
assert!(!verify(&soa));
let ground_truth = array![
[4, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
[7, 6, 3, 6, 2, 2, 3, 7, 7, 6, 3],
[5, 5, 4, 1, 4, 0, 0, 1, 5, 5, 5],
[6, 3, 7, 6, 3, 6, 2, 2, 3, 7, 7],
[7, 6, 3, 7, 6, 3, 6, 2, 2, 3, 7],
[7, 7, 6, 3, 7, 6, 3, 6, 2, 2, 3],
[5, 5, 5, 4, 1, 5, 4, 1, 4, 0, 1],
[4, 1, 5, 5, 4, 1, 5, 4, 1, 4, 1],
[4, 0, 1, 5, 5, 4, 1, 5, 4, 1, 5],
[6, 2, 2, 3, 7, 7, 6, 3, 7, 6, 3],
[5, 4, 0, 0, 1, 5, 5, 4, 1, 5, 5],
[6, 3, 6, 2, 2, 3, 7, 7, 6, 3, 7],
[3, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6],
[0, 1, 4, 1, 5, 5, 4, 0, 0, 1, 4],
[2, 2, 3, 6, 3, 7, 7, 6, 2, 2, 2],
[1, 4, 0, 1, 4, 1, 5, 5, 4, 0, 0],
[0, 1, 4, 0, 1, 4, 1, 5, 5, 4, 0],
[0, 0, 1, 4, 0, 1, 4, 1, 5, 5, 4],
[2, 2, 2, 3, 6, 2, 3, 6, 3, 7, 6],
[3, 6, 2, 2, 3, 6, 2, 3, 6, 3, 6],
[3, 7, 6, 2, 2, 3, 6, 2, 3, 6, 2],
[1, 5, 5, 4, 0, 0, 1, 4, 0, 1, 4],
[2, 3, 7, 7, 6, 2, 2, 3, 6, 2, 2],
[1, 4, 1, 5, 5, 4, 0, 0, 1, 4, 0],
];
let soa = SOA {
strength: 3,
base: 2,
points: ground_truth,
};
assert!(!verify(&soa));
}
}