use ndarray::Array2;
use super::Constructor;
use crate::error::{Error, Result};
use crate::gf::DynamicGf;
use crate::oa::{OA, OAParams};
use crate::utils::is_prime_power;
#[derive(Debug, Clone)]
pub struct DifferenceScheme {
pub data: Array2<u32>,
pub q: u32,
pub field: DynamicGf,
}
impl DifferenceScheme {
pub fn new(data: Array2<u32>, q: u32, field: DynamicGf) -> Self {
assert_eq!(field.order(), q);
Self { data, q, field }
}
pub fn linear(q: u32) -> Result<Self> {
if !is_prime_power(q) {
return Err(Error::LevelsNotPrimePower {
levels: q,
algorithm: "DifferenceScheme::linear",
});
}
let field = DynamicGf::new(q)?;
let mut data = Array2::zeros((q as usize, q as usize));
let tables = field.tables();
for i in 0..q {
for j in 0..q {
data[[i as usize, j as usize]] = tables.mul(i, j);
}
}
Ok(Self { data, q, field })
}
pub fn with_zero_column(&self) -> Self {
let rows = self.data.nrows();
let cols = self.data.ncols();
let mut new_data = Array2::zeros((rows, cols + 1));
for i in 0..rows {
for j in 0..cols {
new_data[[i, j + 1]] = self.data[[i, j]];
}
}
Self {
data: new_data,
q: self.q,
field: self.field.clone(),
}
}
pub fn verify(&self) -> bool {
let rows = self.data.nrows();
let cols = self.data.ncols();
let s = self.q as usize;
let expected_count = rows / s;
if rows % s != 0 {
return false;
}
let tables = self.field.tables();
for c1 in 0..cols {
for c2 in (c1 + 1)..cols {
let mut counts = vec![0; s];
for r in 0..rows {
let v1 = self.data[[r, c1]];
let v2 = self.data[[r, c2]];
let diff = tables.sub(v1, v2) as usize;
counts[diff] += 1;
}
if counts.iter().any(|&c| c != expected_count) {
return false;
}
}
}
true
}
pub fn kronecker_sum(&self, other: &Self) -> Result<Self> {
if self.q != other.q {
return Err(Error::invalid_params(format!(
"Cannot compute Kronecker sum: schemes have different levels {} and {}",
self.q, other.q
)));
}
let r1 = self.data.nrows();
let c1 = self.data.ncols();
let r2 = other.data.nrows();
let c2 = other.data.ncols();
let mut new_data = Array2::zeros((r1 * r2, c1 * c2));
let tables = self.field.tables();
for i1 in 0..r1 {
for j1 in 0..c1 {
let v1 = self.data[[i1, j1]];
for i2 in 0..r2 {
for j2 in 0..c2 {
let v2 = other.data[[i2, j2]];
let sum = tables.add(v1, v2);
let new_row = i1 * r2 + i2;
let new_col = j1 * c2 + j2;
new_data[[new_row, new_col]] = sum;
}
}
}
}
Ok(Self {
data: new_data,
q: self.q,
field: self.field.clone(),
})
}
pub fn to_oa(&self) -> Result<OA> {
let r = self.data.nrows();
let c = self.data.ncols();
let s = self.q;
let runs = r * s as usize;
let mut oa_data = Array2::zeros((runs, c));
let tables = self.field.tables();
for i in 0..r {
for g in 0..s {
let row_idx = i * s as usize + g as usize;
for j in 0..c {
let val = self.data[[i, j]];
oa_data[[row_idx, j]] = tables.add(val, g);
}
}
}
let strength = if c > 1 { 2 } else { 1 };
let params = OAParams::new(runs, c, s, strength)?;
Ok(OA::new(oa_data, params))
}
}
#[derive(Debug, Clone)]
pub struct LinearDifferenceScheme {
q: u32,
}
impl LinearDifferenceScheme {
pub fn new(q: u32) -> Result<Self> {
if !is_prime_power(q) {
return Err(Error::LevelsNotPrimePower {
levels: q,
algorithm: "LinearDifferenceScheme",
});
}
Ok(Self { q })
}
}
impl Constructor for LinearDifferenceScheme {
fn name(&self) -> &'static str {
"LinearDifferenceScheme"
}
fn family(&self) -> &'static str {
"OA(q^2, k, q, 2) via Difference Scheme D(q, q+1, q)"
}
fn levels(&self) -> u32 {
self.q
}
fn strength(&self) -> u32 {
2
}
fn runs(&self) -> usize {
(self.q * self.q) as usize
}
fn max_factors(&self) -> usize {
(self.q + 1) as usize
}
fn construct(&self, factors: usize) -> Result<OA> {
let ds = DifferenceScheme::linear(self.q)?;
let max_factors = self.max_factors();
if factors > max_factors {
return Err(Error::TooManyFactors {
factors,
max: max_factors,
algorithm: self.name(),
});
}
let r = ds.data.nrows(); let s = ds.q; let runs = r * s as usize;
let mut oa_data = Array2::zeros((runs, factors));
let tables = ds.field.tables();
for i in 0..r {
for g in 0..s {
let row_idx = i * s as usize + g as usize;
for j in 0..factors {
if j < (s as usize) {
let val = ds.data[[i, j]];
oa_data[[row_idx, j]] = tables.add(val, g);
} else {
oa_data[[row_idx, j]] = i as u32;
}
}
}
}
let strength = 2.min(factors as u32);
let params = OAParams::new(runs, factors, s, strength)?;
Ok(OA::new(oa_data, params))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oa::verify_strength;
#[test]
fn test_linear_ds_construction() {
let ds = DifferenceScheme::linear(3).unwrap();
assert_eq!(ds.data.nrows(), 3);
assert_eq!(ds.data.ncols(), 3);
assert_eq!(ds.data[[1, 1]], 1);
assert_eq!(ds.data[[2, 2]], 1);
}
#[test]
fn test_ds_expansion_basic() {
let ds = DifferenceScheme::linear(3).unwrap();
let oa = ds.to_oa().unwrap();
assert_eq!(oa.runs(), 9);
assert_eq!(oa.factors(), 3);
assert_eq!(oa.levels(), 3);
let result = verify_strength(&oa, 2).unwrap();
assert!(result.is_valid);
}
#[test]
fn test_constructor_interface_l9() {
let cons = LinearDifferenceScheme::new(3).unwrap();
let oa = cons.construct(4).unwrap();
assert_eq!(oa.runs(), 9);
assert_eq!(oa.factors(), 4);
assert_eq!(oa.levels(), 3);
let result = verify_strength(&oa, 2).unwrap();
assert!(result.is_valid, "L9 from DS should be valid");
}
#[test]
fn test_ds_verification() {
let ds = DifferenceScheme::linear(3).unwrap();
assert!(ds.verify());
let ds_aug = ds.with_zero_column();
assert!(!ds_aug.verify());
}
#[test]
fn test_kronecker_sum() {
let ds1 = DifferenceScheme::linear(2).unwrap(); let ds2 = DifferenceScheme::linear(2).unwrap();
let ds_kron = ds1.kronecker_sum(&ds2).unwrap();
assert_eq!(ds_kron.data.nrows(), 4);
assert_eq!(ds_kron.data.ncols(), 4);
assert_eq!(ds_kron.q, 2);
assert!(ds_kron.verify());
let oa = ds_kron.to_oa().unwrap();
assert_eq!(oa.runs(), 8);
assert_eq!(oa.factors(), 4);
let result = verify_strength(&oa, 2).unwrap();
assert!(result.is_valid);
}
}