use pyo3::prelude::*;
use pyo3::types::PyList;
use crate::OABuilder;
use crate::oa::OA;
#[pyclass(name = "OAParams")]
#[derive(Clone)]
pub struct PyOAParams {
#[pyo3(get)]
pub runs: usize,
#[pyo3(get)]
pub factors: usize,
#[pyo3(get)]
pub strength: u32,
}
#[pyclass(name = "OA")]
pub struct PyOA {
inner: OA,
}
#[pymethods]
impl PyOA {
#[getter]
fn runs(&self) -> usize {
self.inner.runs()
}
#[getter]
fn factors(&self) -> usize {
self.inner.factors()
}
#[getter]
fn strength(&self) -> u32 {
self.inner.strength()
}
fn data(&self, py: Python<'_>) -> PyResult<PyObject> {
let data = self.inner.data();
let rows = data.nrows();
let cols = data.ncols();
let list = PyList::empty(py);
for i in 0..rows {
let row_list = PyList::empty(py);
for j in 0..cols {
row_list.append(data[[i, j]])?;
}
list.append(row_list)?;
}
Ok(list.into())
}
fn is_balanced(&self) -> bool {
let report = self.inner.balance_report();
report.factor_balance.iter().all(|&b| b)
}
}
#[pyfunction]
#[pyo3(signature = (levels, factors, strength=2))]
fn construct(levels: u32, factors: usize, strength: u32) -> PyResult<PyOA> {
let oa = OABuilder::new()
.levels(levels)
.factors(factors)
.strength(strength)
.build()
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyOA { inner: oa })
}
#[pyfunction]
#[pyo3(signature = (levels, strength=2))]
fn construct_mixed(levels: Vec<u32>, strength: u32) -> PyResult<PyOA> {
let oa = OABuilder::new()
.mixed_levels(levels)
.strength(strength)
.build()
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
Ok(PyOA { inner: oa })
}
#[pymodule]
fn taguchi(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyOAParams>()?;
m.add_class::<PyOA>()?;
m.add_function(wrap_pyfunction!(construct, m)?)?;
m.add_function(wrap_pyfunction!(construct_mixed, m)?)?;
Ok(())
}