Skip to main content

anofox_ml_preprocessing/
polynomial_features.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4/// Generates polynomial and interaction features.
5///
6/// For a feature vector `[a, b]` and `degree=2`:
7/// - `interaction_only=false`: `[1, a, b, a^2, ab, b^2]`
8/// - `interaction_only=true`:  `[1, a, b, ab]`
9///
10/// Implements `FitUnsupervised` for pipeline compatibility. The fit step only
11/// records the number of input features; all computation happens in `transform`.
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct PolynomialFeatures {
14    /// Maximum degree of polynomial features.
15    pub degree: usize,
16    /// If true, only produce interaction features (no `x_i^k` for k > 1).
17    pub interaction_only: bool,
18}
19
20impl PolynomialFeatures {
21    /// Create a new `PolynomialFeatures` with default degree 2 and interaction_only = false.
22    pub fn new() -> Self {
23        Self {
24            degree: 2,
25            interaction_only: false,
26        }
27    }
28
29    /// Set the maximum polynomial degree.
30    pub fn with_degree(mut self, degree: usize) -> Self {
31        self.degree = degree;
32        self
33    }
34
35    /// Set whether to produce only interaction features.
36    pub fn with_interaction_only(mut self, interaction_only: bool) -> Self {
37        self.interaction_only = interaction_only;
38        self
39    }
40}
41
42impl Default for PolynomialFeatures {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48/// Fitted PolynomialFeatures — stores the number of input features.
49#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
50#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
51pub struct FittedPolynomialFeatures<F: Float> {
52    n_features: usize,
53    degree: usize,
54    interaction_only: bool,
55    /// Pre-computed list of (exponents) for each output column.
56    /// Each entry is a Vec of (feature_index, power) pairs.
57    combinations: Vec<Vec<(usize, usize)>>,
58    _marker: std::marker::PhantomData<F>,
59}
60
61/// Enumerate all combinations of features with total degree up to `max_degree`.
62///
63/// Each combination is represented as a vec of `(feature_index, power)` pairs
64/// where power > 0. The bias term (degree 0) is an empty vec.
65///
66/// Combinations are ordered by total degree, then lexicographically by feature
67/// index, matching scikit-learn's convention:
68/// - degree 0: `[1]`
69/// - degree 1: `[a, b, c, ...]`
70/// - degree 2: `[a^2, ab, ac, ..., b^2, bc, ..., c^2, ...]`
71fn enumerate_combinations(
72    n_features: usize,
73    max_degree: usize,
74    interaction_only: bool,
75) -> Vec<Vec<(usize, usize)>> {
76    let mut combos: Vec<Vec<(usize, usize)>> = Vec::new();
77    // Degree 0: bias term
78    combos.push(vec![]);
79
80    // Helper: enumerate all combinations with exactly `target_degree` total power,
81    // starting from feature `start_feature`.
82    fn recurse_exact(
83        start_feature: usize,
84        target_degree: usize,
85        n_features: usize,
86        interaction_only: bool,
87        current: &mut Vec<(usize, usize)>,
88        combos: &mut Vec<Vec<(usize, usize)>>,
89    ) {
90        if target_degree == 0 {
91            combos.push(current.clone());
92            return;
93        }
94        for feat in start_feature..n_features {
95            let max_power = if interaction_only { 1 } else { target_degree };
96            for power in (1..=max_power).rev() {
97                current.push((feat, power));
98                // Remaining degree allocated to features with index > feat
99                let remaining = target_degree - power;
100                if remaining == 0 {
101                    combos.push(current.clone());
102                } else {
103                    recurse_exact(
104                        feat + 1,
105                        remaining,
106                        n_features,
107                        interaction_only,
108                        current,
109                        combos,
110                    );
111                }
112                current.pop();
113            }
114        }
115    }
116
117    // Generate degree by degree to ensure correct ordering
118    for d in 1..=max_degree {
119        let mut current = Vec::new();
120        recurse_exact(
121            0,
122            d,
123            n_features,
124            interaction_only,
125            &mut current,
126            &mut combos,
127        );
128    }
129
130    combos
131}
132
133impl<F: Float> FitUnsupervised<F> for PolynomialFeatures {
134    type Fitted = FittedPolynomialFeatures<F>;
135
136    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
137        if x.is_empty() {
138            return Err(RustMlError::EmptyInput("input array is empty".into()));
139        }
140        if self.degree == 0 {
141            return Err(RustMlError::InvalidParameter(
142                "degree must be at least 1".into(),
143            ));
144        }
145
146        let n_features = x.ncols();
147        let combinations = enumerate_combinations(n_features, self.degree, self.interaction_only);
148
149        Ok(FittedPolynomialFeatures {
150            n_features,
151            degree: self.degree,
152            interaction_only: self.interaction_only,
153            combinations,
154            _marker: std::marker::PhantomData,
155        })
156    }
157}
158
159impl<F: Float> Transform<F> for FittedPolynomialFeatures<F> {
160    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
161        if x.ncols() != self.n_features {
162            return Err(RustMlError::ShapeMismatch(format!(
163                "expected {} features, got {}",
164                self.n_features,
165                x.ncols()
166            )));
167        }
168
169        let nrows = x.nrows();
170        let ncols_out = self.combinations.len();
171        let mut result = Array2::<F>::ones((nrows, ncols_out));
172
173        for (out_col, combo) in self.combinations.iter().enumerate() {
174            if combo.is_empty() {
175                // Bias term: already 1.0
176                continue;
177            }
178            for i in 0..nrows {
179                let mut val = F::one();
180                for &(feat, power) in combo {
181                    let base = x[[i, feat]];
182                    for _ in 0..power {
183                        val *= base;
184                    }
185                }
186                result[[i, out_col]] = val;
187            }
188        }
189
190        Ok(result)
191    }
192}
193
194impl<F: Float> FittedPolynomialFeatures<F> {
195    /// Return the number of input features.
196    pub fn n_input_features(&self) -> usize {
197        self.n_features
198    }
199
200    /// Return the number of output features.
201    pub fn n_output_features(&self) -> usize {
202        self.combinations.len()
203    }
204
205    /// Return the degree.
206    pub fn degree(&self) -> usize {
207        self.degree
208    }
209
210    /// Return whether only interactions are generated.
211    pub fn interaction_only(&self) -> bool {
212        self.interaction_only
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use approx::assert_abs_diff_eq;
220    use ndarray::array;
221
222    #[test]
223    fn test_degree2_two_features() {
224        // [a, b] -> [1, a, b, a^2, ab, b^2]
225        let x = array![[2.0, 3.0]];
226        let poly = PolynomialFeatures::new();
227        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
228        let out = fitted.transform(&x).unwrap();
229
230        assert_eq!(out.ncols(), 6);
231        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); // 1
232        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); // a
233        assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); // b
234        assert_abs_diff_eq!(out[[0, 3]], 4.0, epsilon = 1e-10); // a^2
235        assert_abs_diff_eq!(out[[0, 4]], 6.0, epsilon = 1e-10); // ab
236        assert_abs_diff_eq!(out[[0, 5]], 9.0, epsilon = 1e-10); // b^2
237    }
238
239    #[test]
240    fn test_interaction_only_degree2() {
241        // [a, b] -> [1, a, b, ab]
242        let x = array![[2.0, 3.0]];
243        let poly = PolynomialFeatures::new().with_interaction_only(true);
244        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
245        let out = fitted.transform(&x).unwrap();
246
247        assert_eq!(out.ncols(), 4);
248        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); // 1
249        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); // a
250        assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); // b
251        assert_abs_diff_eq!(out[[0, 3]], 6.0, epsilon = 1e-10); // ab
252    }
253
254    #[test]
255    fn test_degree3_single_feature() {
256        // [a] -> [1, a, a^2, a^3]
257        let x = array![[3.0]];
258        let poly = PolynomialFeatures::new().with_degree(3);
259        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
260        let out = fitted.transform(&x).unwrap();
261
262        assert_eq!(out.ncols(), 4);
263        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); // 1
264        assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-10); // a
265        assert_abs_diff_eq!(out[[0, 2]], 9.0, epsilon = 1e-10); // a^2
266        assert_abs_diff_eq!(out[[0, 3]], 27.0, epsilon = 1e-10); // a^3
267    }
268
269    #[test]
270    fn test_degree1() {
271        // degree=1: [a, b] -> [1, a, b]
272        let x = array![[2.0, 3.0]];
273        let poly = PolynomialFeatures::new().with_degree(1);
274        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
275        let out = fitted.transform(&x).unwrap();
276
277        assert_eq!(out.ncols(), 3);
278        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
279        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
280        assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10);
281    }
282
283    #[test]
284    fn test_degree0_error() {
285        let x = array![[1.0, 2.0]];
286        let poly = PolynomialFeatures::new().with_degree(0);
287        let result = FitUnsupervised::<f64>::fit(&poly, &x);
288        assert!(result.is_err());
289    }
290
291    #[test]
292    fn test_multiple_rows() {
293        let x = array![[1.0, 2.0], [3.0, 4.0]];
294        let poly = PolynomialFeatures::new();
295        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
296        let out = fitted.transform(&x).unwrap();
297
298        assert_eq!(out.nrows(), 2);
299        assert_eq!(out.ncols(), 6);
300
301        // Row 0: [1, 1, 2, 1, 2, 4]
302        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
303        assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10);
304        assert_abs_diff_eq!(out[[0, 2]], 2.0, epsilon = 1e-10);
305        assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10);
306        assert_abs_diff_eq!(out[[0, 4]], 2.0, epsilon = 1e-10);
307        assert_abs_diff_eq!(out[[0, 5]], 4.0, epsilon = 1e-10);
308
309        // Row 1: [1, 3, 4, 9, 12, 16]
310        assert_abs_diff_eq!(out[[1, 0]], 1.0, epsilon = 1e-10);
311        assert_abs_diff_eq!(out[[1, 1]], 3.0, epsilon = 1e-10);
312        assert_abs_diff_eq!(out[[1, 2]], 4.0, epsilon = 1e-10);
313        assert_abs_diff_eq!(out[[1, 3]], 9.0, epsilon = 1e-10);
314        assert_abs_diff_eq!(out[[1, 4]], 12.0, epsilon = 1e-10);
315        assert_abs_diff_eq!(out[[1, 5]], 16.0, epsilon = 1e-10);
316    }
317
318    #[test]
319    fn test_three_features_degree2() {
320        // [a, b, c] -> [1, a, b, c, a^2, ab, ac, b^2, bc, c^2]
321        let x = array![[1.0, 2.0, 3.0]];
322        let poly = PolynomialFeatures::new();
323        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
324        let out = fitted.transform(&x).unwrap();
325
326        assert_eq!(out.ncols(), 10);
327        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); // 1
328        assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10); // a
329        assert_abs_diff_eq!(out[[0, 2]], 2.0, epsilon = 1e-10); // b
330        assert_abs_diff_eq!(out[[0, 3]], 3.0, epsilon = 1e-10); // c
331        assert_abs_diff_eq!(out[[0, 4]], 1.0, epsilon = 1e-10); // a^2
332        assert_abs_diff_eq!(out[[0, 5]], 2.0, epsilon = 1e-10); // ab
333        assert_abs_diff_eq!(out[[0, 6]], 3.0, epsilon = 1e-10); // ac
334        assert_abs_diff_eq!(out[[0, 7]], 4.0, epsilon = 1e-10); // b^2
335        assert_abs_diff_eq!(out[[0, 8]], 6.0, epsilon = 1e-10); // bc
336        assert_abs_diff_eq!(out[[0, 9]], 9.0, epsilon = 1e-10); // c^2
337    }
338
339    #[test]
340    fn test_three_features_interaction_only() {
341        // [a, b, c] -> [1, a, b, c, ab, ac, bc]
342        let x = array![[2.0, 3.0, 5.0]];
343        let poly = PolynomialFeatures::new().with_interaction_only(true);
344        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
345        let out = fitted.transform(&x).unwrap();
346
347        assert_eq!(out.ncols(), 7);
348        assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); // 1
349        assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); // a
350        assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); // b
351        assert_abs_diff_eq!(out[[0, 3]], 5.0, epsilon = 1e-10); // c
352        assert_abs_diff_eq!(out[[0, 4]], 6.0, epsilon = 1e-10); // ab
353        assert_abs_diff_eq!(out[[0, 5]], 10.0, epsilon = 1e-10); // ac
354        assert_abs_diff_eq!(out[[0, 6]], 15.0, epsilon = 1e-10); // bc
355    }
356
357    #[test]
358    fn test_empty_input() {
359        let x: Array2<f64> = Array2::zeros((0, 0));
360        let poly = PolynomialFeatures::new();
361        assert!(FitUnsupervised::<f64>::fit(&poly, &x).is_err());
362    }
363
364    #[test]
365    fn test_shape_mismatch() {
366        let x = array![[1.0, 2.0]];
367        let poly = PolynomialFeatures::new();
368        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
369
370        let x_wrong = array![[1.0, 2.0, 3.0]];
371        assert!(fitted.transform(&x_wrong).is_err());
372    }
373
374    #[test]
375    fn test_bias_column_all_ones() {
376        let x = array![[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]];
377        let poly = PolynomialFeatures::new();
378        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
379        let out = fitted.transform(&x).unwrap();
380
381        // First column (bias) should be all 1.0
382        for i in 0..3 {
383            assert_abs_diff_eq!(out[[i, 0]], 1.0, epsilon = 1e-10);
384        }
385    }
386
387    #[test]
388    fn test_n_output_features() {
389        let x = array![[1.0, 2.0]];
390        let poly = PolynomialFeatures::new();
391        let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
392
393        assert_eq!(fitted.n_input_features(), 2);
394        assert_eq!(fitted.n_output_features(), 6);
395        assert_eq!(fitted.degree(), 2);
396        assert!(!fitted.interaction_only());
397    }
398
399    #[test]
400    fn test_f32() {
401        let x = array![[2.0f32, 3.0]];
402        let poly = PolynomialFeatures::new();
403        let fitted = FitUnsupervised::<f32>::fit(&poly, &x).unwrap();
404        let out = fitted.transform(&x).unwrap();
405
406        assert_eq!(out.ncols(), 6);
407        assert_abs_diff_eq!(out[[0, 3]], 4.0f32, epsilon = 1e-5); // a^2
408        assert_abs_diff_eq!(out[[0, 4]], 6.0f32, epsilon = 1e-5); // ab
409        assert_abs_diff_eq!(out[[0, 5]], 9.0f32, epsilon = 1e-5); // b^2
410    }
411
412    #[test]
413    fn test_default() {
414        let poly = PolynomialFeatures::default();
415        assert_eq!(poly.degree, 2);
416        assert!(!poly.interaction_only);
417    }
418
419    mod prop_tests {
420        use super::*;
421        use proptest::prelude::*;
422
423        fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
424            use std::collections::hash_map::DefaultHasher;
425            use std::hash::{Hash, Hasher};
426            let mut values = Vec::with_capacity(rows * cols);
427            for i in 0..(rows * cols) {
428                let mut h = DefaultHasher::new();
429                seed.hash(&mut h);
430                (i as u64).hash(&mut h);
431                let bits = h.finish();
432                let v = (bits as f64 / u64::MAX as f64) * 4.0 - 2.0;
433                values.push(v);
434            }
435            Array2::from_shape_vec((rows, cols), values).unwrap()
436        }
437
438        proptest! {
439            #[test]
440            fn poly_bias_column_all_ones(
441                rows in 1..20usize,
442                cols in 1..5usize,
443                seed in 0u64..10000,
444            ) {
445                let x = make_data(rows, cols, seed);
446                let poly = PolynomialFeatures::new();
447                let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
448                let out = fitted.transform(&x).unwrap();
449
450                for i in 0..rows {
451                    prop_assert!((out[[i, 0]] - 1.0).abs() < 1e-10,
452                        "bias column should be 1.0, got {}", out[[i, 0]]);
453                }
454            }
455
456            #[test]
457            fn poly_original_features_preserved(
458                rows in 1..20usize,
459                cols in 1..5usize,
460                seed in 0u64..10000,
461            ) {
462                let x = make_data(rows, cols, seed);
463                let poly = PolynomialFeatures::new();
464                let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
465                let out = fitted.transform(&x).unwrap();
466
467                // Columns 1..=cols should be the original features
468                for i in 0..rows {
469                    for j in 0..cols {
470                        prop_assert!((out[[i, 1 + j]] - x[[i, j]]).abs() < 1e-10,
471                            "original feature not preserved at ({}, {})", i, j);
472                    }
473                }
474            }
475        }
476    }
477}