use crate::pipeline::StreamingPreprocessor;
#[derive(Clone, Debug)]
pub struct PolynomialFeatures {
interaction_only: bool,
n_input_features: Option<usize>,
}
impl PolynomialFeatures {
pub fn new() -> Self {
Self {
interaction_only: false,
n_input_features: None,
}
}
pub fn interaction_only() -> Self {
Self {
interaction_only: true,
n_input_features: None,
}
}
pub fn is_interaction_only(&self) -> bool {
self.interaction_only
}
fn compute_output_dim(&self, d: usize) -> usize {
if self.interaction_only {
d + d * d.saturating_sub(1) / 2
} else {
d + d * (d + 1) / 2
}
}
fn generate(&self, features: &[f64]) -> Vec<f64> {
let d = features.len();
let capacity = self.compute_output_dim(d);
let mut out = Vec::with_capacity(capacity);
out.extend_from_slice(features);
for i in 0..d {
let start = if self.interaction_only { i + 1 } else { i };
for j in start..d {
out.push(features[i] * features[j]);
}
}
out
}
}
impl Default for PolynomialFeatures {
fn default() -> Self {
Self::new()
}
}
impl StreamingPreprocessor for PolynomialFeatures {
fn update_and_transform(&mut self, features: &[f64]) -> Vec<f64> {
self.n_input_features = Some(features.len());
self.generate(features)
}
fn transform(&self, features: &[f64]) -> Vec<f64> {
self.generate(features)
}
fn output_dim(&self) -> Option<usize> {
self.n_input_features.map(|d| self.compute_output_dim(d))
}
fn reset(&mut self) {
self.n_input_features = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn degree2_full_output() {
let poly = PolynomialFeatures::new();
let out = poly.transform(&[1.0, 2.0, 3.0]);
assert_eq!(out, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0]);
}
#[test]
fn interaction_only_output() {
let poly = PolynomialFeatures::interaction_only();
let out = poly.transform(&[1.0, 2.0, 3.0]);
assert_eq!(out, vec![1.0, 2.0, 3.0, 2.0, 3.0, 6.0]);
}
#[test]
fn single_feature_full() {
let poly = PolynomialFeatures::new();
let out = poly.transform(&[5.0]);
assert_eq!(out, vec![5.0, 25.0]);
}
#[test]
fn single_feature_interaction_only() {
let poly = PolynomialFeatures::interaction_only();
let out = poly.transform(&[5.0]);
assert_eq!(out, vec![5.0]);
}
#[test]
fn output_dim_correct() {
let mut poly_full = PolynomialFeatures::new();
for d in 1usize..=5 {
let input: Vec<f64> = (1..=d).map(|x| x as f64).collect();
let out = poly_full.update_and_transform(&input);
let expected = d + d * (d + 1) / 2;
assert_eq!(
poly_full.output_dim(),
Some(expected),
"full mode: d={d}, expected output_dim={expected}"
);
assert_eq!(out.len(), expected);
poly_full.reset();
}
let mut poly_int = PolynomialFeatures::interaction_only();
for d in 1usize..=5 {
let input: Vec<f64> = (1..=d).map(|x| x as f64).collect();
let out = poly_int.update_and_transform(&input);
let expected = d + d * d.saturating_sub(1) / 2;
assert_eq!(
poly_int.output_dim(),
Some(expected),
"interaction_only mode: d={d}, expected output_dim={expected}"
);
assert_eq!(out.len(), expected);
poly_int.reset();
}
}
#[test]
fn stateless_deterministic() {
let poly = PolynomialFeatures::new();
let input = [3.0, -1.0, 2.5, 0.0];
let out1 = poly.transform(&input);
let out2 = poly.transform(&input);
let out3 = poly.transform(&input);
assert_eq!(out1, out2);
assert_eq!(out2, out3);
}
#[test]
fn update_and_transform_equals_transform() {
let mut poly = PolynomialFeatures::new();
let input = [1.5, -2.0, 0.5];
let out_transform = poly.transform(&input);
let out_update = poly.update_and_transform(&input);
assert_eq!(
out_transform, out_update,
"transform and update_and_transform must produce identical output"
);
}
#[test]
fn reset_clears_dim() {
let mut poly = PolynomialFeatures::new();
assert_eq!(poly.output_dim(), None);
poly.update_and_transform(&[1.0, 2.0]);
assert_eq!(poly.output_dim(), Some(5));
poly.reset();
assert_eq!(poly.output_dim(), None);
}
#[test]
fn empty_input() {
let poly = PolynomialFeatures::new();
let out = poly.transform(&[]);
let empty: Vec<f64> = vec![];
assert_eq!(out, empty);
let poly_int = PolynomialFeatures::interaction_only();
let out_int = poly_int.transform(&[]);
assert_eq!(out_int, empty);
}
#[test]
fn is_interaction_only_accessor() {
let poly_full = PolynomialFeatures::new();
assert!(!poly_full.is_interaction_only());
let poly_int = PolynomialFeatures::interaction_only();
assert!(poly_int.is_interaction_only());
}
#[test]
fn two_features_full() {
let poly = PolynomialFeatures::new();
let out = poly.transform(&[2.0, 3.0]);
assert_eq!(out, vec![2.0, 3.0, 4.0, 6.0, 9.0]);
}
#[test]
fn two_features_interaction_only() {
let poly = PolynomialFeatures::interaction_only();
let out = poly.transform(&[2.0, 3.0]);
assert_eq!(out, vec![2.0, 3.0, 6.0]);
}
}