use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
use ndarray::Array2;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PolynomialFeatures {
pub degree: usize,
pub interaction_only: bool,
}
impl PolynomialFeatures {
pub fn new() -> Self {
Self {
degree: 2,
interaction_only: false,
}
}
pub fn with_degree(mut self, degree: usize) -> Self {
self.degree = degree;
self
}
pub fn with_interaction_only(mut self, interaction_only: bool) -> Self {
self.interaction_only = interaction_only;
self
}
}
impl Default for PolynomialFeatures {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
pub struct FittedPolynomialFeatures<F: Float> {
n_features: usize,
degree: usize,
interaction_only: bool,
combinations: Vec<Vec<(usize, usize)>>,
_marker: std::marker::PhantomData<F>,
}
fn enumerate_combinations(
n_features: usize,
max_degree: usize,
interaction_only: bool,
) -> Vec<Vec<(usize, usize)>> {
let mut combos: Vec<Vec<(usize, usize)>> = Vec::new();
combos.push(vec![]);
fn recurse_exact(
start_feature: usize,
target_degree: usize,
n_features: usize,
interaction_only: bool,
current: &mut Vec<(usize, usize)>,
combos: &mut Vec<Vec<(usize, usize)>>,
) {
if target_degree == 0 {
combos.push(current.clone());
return;
}
for feat in start_feature..n_features {
let max_power = if interaction_only { 1 } else { target_degree };
for power in (1..=max_power).rev() {
current.push((feat, power));
let remaining = target_degree - power;
if remaining == 0 {
combos.push(current.clone());
} else {
recurse_exact(
feat + 1,
remaining,
n_features,
interaction_only,
current,
combos,
);
}
current.pop();
}
}
}
for d in 1..=max_degree {
let mut current = Vec::new();
recurse_exact(
0,
d,
n_features,
interaction_only,
&mut current,
&mut combos,
);
}
combos
}
impl<F: Float> FitUnsupervised<F> for PolynomialFeatures {
type Fitted = FittedPolynomialFeatures<F>;
fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
if x.is_empty() {
return Err(RustMlError::EmptyInput("input array is empty".into()));
}
if self.degree == 0 {
return Err(RustMlError::InvalidParameter(
"degree must be at least 1".into(),
));
}
let n_features = x.ncols();
let combinations = enumerate_combinations(n_features, self.degree, self.interaction_only);
Ok(FittedPolynomialFeatures {
n_features,
degree: self.degree,
interaction_only: self.interaction_only,
combinations,
_marker: std::marker::PhantomData,
})
}
}
impl<F: Float> Transform<F> for FittedPolynomialFeatures<F> {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
if x.ncols() != self.n_features {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.n_features,
x.ncols()
)));
}
let nrows = x.nrows();
let ncols_out = self.combinations.len();
let mut result = Array2::<F>::ones((nrows, ncols_out));
for (out_col, combo) in self.combinations.iter().enumerate() {
if combo.is_empty() {
continue;
}
for i in 0..nrows {
let mut val = F::one();
for &(feat, power) in combo {
let base = x[[i, feat]];
for _ in 0..power {
val *= base;
}
}
result[[i, out_col]] = val;
}
}
Ok(result)
}
}
impl<F: Float> FittedPolynomialFeatures<F> {
pub fn n_input_features(&self) -> usize {
self.n_features
}
pub fn n_output_features(&self) -> usize {
self.combinations.len()
}
pub fn degree(&self) -> usize {
self.degree
}
pub fn interaction_only(&self) -> bool {
self.interaction_only
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_degree2_two_features() {
let x = array![[2.0, 3.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 6);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 4.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 4]], 6.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 5]], 9.0, epsilon = 1e-10); }
#[test]
fn test_interaction_only_degree2() {
let x = array![[2.0, 3.0]];
let poly = PolynomialFeatures::new().with_interaction_only(true);
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 4);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 6.0, epsilon = 1e-10); }
#[test]
fn test_degree3_single_feature() {
let x = array![[3.0]];
let poly = PolynomialFeatures::new().with_degree(3);
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 4);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 9.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 27.0, epsilon = 1e-10); }
#[test]
fn test_degree1() {
let x = array![[2.0, 3.0]];
let poly = PolynomialFeatures::new().with_degree(1);
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 3);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10);
}
#[test]
fn test_degree0_error() {
let x = array![[1.0, 2.0]];
let poly = PolynomialFeatures::new().with_degree(0);
let result = FitUnsupervised::<f64>::fit(&poly, &x);
assert!(result.is_err());
}
#[test]
fn test_multiple_rows() {
let x = array![[1.0, 2.0], [3.0, 4.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.nrows(), 2);
assert_eq!(out.ncols(), 6);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 2]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 4]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[0, 5]], 4.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 1]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 2]], 4.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 3]], 9.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 4]], 12.0, epsilon = 1e-10);
assert_abs_diff_eq!(out[[1, 5]], 16.0, epsilon = 1e-10);
}
#[test]
fn test_three_features_degree2() {
let x = array![[1.0, 2.0, 3.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 10);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 4]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 5]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 6]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 7]], 4.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 8]], 6.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 9]], 9.0, epsilon = 1e-10); }
#[test]
fn test_three_features_interaction_only() {
let x = array![[2.0, 3.0, 5.0]];
let poly = PolynomialFeatures::new().with_interaction_only(true);
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 7);
assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 5.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 4]], 6.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 5]], 10.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 6]], 15.0, epsilon = 1e-10); }
#[test]
fn test_empty_input() {
let x: Array2<f64> = Array2::zeros((0, 0));
let poly = PolynomialFeatures::new();
assert!(FitUnsupervised::<f64>::fit(&poly, &x).is_err());
}
#[test]
fn test_shape_mismatch() {
let x = array![[1.0, 2.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let x_wrong = array![[1.0, 2.0, 3.0]];
assert!(fitted.transform(&x_wrong).is_err());
}
#[test]
fn test_bias_column_all_ones() {
let x = array![[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
for i in 0..3 {
assert_abs_diff_eq!(out[[i, 0]], 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_n_output_features() {
let x = array![[1.0, 2.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
assert_eq!(fitted.n_input_features(), 2);
assert_eq!(fitted.n_output_features(), 6);
assert_eq!(fitted.degree(), 2);
assert!(!fitted.interaction_only());
}
#[test]
fn test_f32() {
let x = array![[2.0f32, 3.0]];
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f32>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
assert_eq!(out.ncols(), 6);
assert_abs_diff_eq!(out[[0, 3]], 4.0f32, epsilon = 1e-5); assert_abs_diff_eq!(out[[0, 4]], 6.0f32, epsilon = 1e-5); assert_abs_diff_eq!(out[[0, 5]], 9.0f32, epsilon = 1e-5); }
#[test]
fn test_default() {
let poly = PolynomialFeatures::default();
assert_eq!(poly.degree, 2);
assert!(!poly.interaction_only);
}
mod prop_tests {
use super::*;
use proptest::prelude::*;
fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut values = Vec::with_capacity(rows * cols);
for i in 0..(rows * cols) {
let mut h = DefaultHasher::new();
seed.hash(&mut h);
(i as u64).hash(&mut h);
let bits = h.finish();
let v = (bits as f64 / u64::MAX as f64) * 4.0 - 2.0;
values.push(v);
}
Array2::from_shape_vec((rows, cols), values).unwrap()
}
proptest! {
#[test]
fn poly_bias_column_all_ones(
rows in 1..20usize,
cols in 1..5usize,
seed in 0u64..10000,
) {
let x = make_data(rows, cols, seed);
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
for i in 0..rows {
prop_assert!((out[[i, 0]] - 1.0).abs() < 1e-10,
"bias column should be 1.0, got {}", out[[i, 0]]);
}
}
#[test]
fn poly_original_features_preserved(
rows in 1..20usize,
cols in 1..5usize,
seed in 0u64..10000,
) {
let x = make_data(rows, cols, seed);
let poly = PolynomialFeatures::new();
let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
let out = fitted.transform(&x).unwrap();
for i in 0..rows {
for j in 0..cols {
prop_assert!((out[[i, 1 + j]] - x[[i, j]]).abs() < 1e-10,
"original feature not preserved at ({}, {})", i, j);
}
}
}
}
}
}