1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::Array2;
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct PolynomialFeatures {
14 pub degree: usize,
16 pub interaction_only: bool,
18}
19
20impl PolynomialFeatures {
21 pub fn new() -> Self {
23 Self {
24 degree: 2,
25 interaction_only: false,
26 }
27 }
28
29 pub fn with_degree(mut self, degree: usize) -> Self {
31 self.degree = degree;
32 self
33 }
34
35 pub fn with_interaction_only(mut self, interaction_only: bool) -> Self {
37 self.interaction_only = interaction_only;
38 self
39 }
40}
41
42impl Default for PolynomialFeatures {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
50#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
51pub struct FittedPolynomialFeatures<F: Float> {
52 n_features: usize,
53 degree: usize,
54 interaction_only: bool,
55 combinations: Vec<Vec<(usize, usize)>>,
58 _marker: std::marker::PhantomData<F>,
59}
60
61fn enumerate_combinations(
72 n_features: usize,
73 max_degree: usize,
74 interaction_only: bool,
75) -> Vec<Vec<(usize, usize)>> {
76 let mut combos: Vec<Vec<(usize, usize)>> = Vec::new();
77 combos.push(vec![]);
79
80 fn recurse_exact(
83 start_feature: usize,
84 target_degree: usize,
85 n_features: usize,
86 interaction_only: bool,
87 current: &mut Vec<(usize, usize)>,
88 combos: &mut Vec<Vec<(usize, usize)>>,
89 ) {
90 if target_degree == 0 {
91 combos.push(current.clone());
92 return;
93 }
94 for feat in start_feature..n_features {
95 let max_power = if interaction_only { 1 } else { target_degree };
96 for power in (1..=max_power).rev() {
97 current.push((feat, power));
98 let remaining = target_degree - power;
100 if remaining == 0 {
101 combos.push(current.clone());
102 } else {
103 recurse_exact(
104 feat + 1,
105 remaining,
106 n_features,
107 interaction_only,
108 current,
109 combos,
110 );
111 }
112 current.pop();
113 }
114 }
115 }
116
117 for d in 1..=max_degree {
119 let mut current = Vec::new();
120 recurse_exact(
121 0,
122 d,
123 n_features,
124 interaction_only,
125 &mut current,
126 &mut combos,
127 );
128 }
129
130 combos
131}
132
133impl<F: Float> FitUnsupervised<F> for PolynomialFeatures {
134 type Fitted = FittedPolynomialFeatures<F>;
135
136 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
137 if x.is_empty() {
138 return Err(RustMlError::EmptyInput("input array is empty".into()));
139 }
140 if self.degree == 0 {
141 return Err(RustMlError::InvalidParameter(
142 "degree must be at least 1".into(),
143 ));
144 }
145
146 let n_features = x.ncols();
147 let combinations = enumerate_combinations(n_features, self.degree, self.interaction_only);
148
149 Ok(FittedPolynomialFeatures {
150 n_features,
151 degree: self.degree,
152 interaction_only: self.interaction_only,
153 combinations,
154 _marker: std::marker::PhantomData,
155 })
156 }
157}
158
159impl<F: Float> Transform<F> for FittedPolynomialFeatures<F> {
160 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
161 if x.ncols() != self.n_features {
162 return Err(RustMlError::ShapeMismatch(format!(
163 "expected {} features, got {}",
164 self.n_features,
165 x.ncols()
166 )));
167 }
168
169 let nrows = x.nrows();
170 let ncols_out = self.combinations.len();
171 let mut result = Array2::<F>::ones((nrows, ncols_out));
172
173 for (out_col, combo) in self.combinations.iter().enumerate() {
174 if combo.is_empty() {
175 continue;
177 }
178 for i in 0..nrows {
179 let mut val = F::one();
180 for &(feat, power) in combo {
181 let base = x[[i, feat]];
182 for _ in 0..power {
183 val *= base;
184 }
185 }
186 result[[i, out_col]] = val;
187 }
188 }
189
190 Ok(result)
191 }
192}
193
194impl<F: Float> FittedPolynomialFeatures<F> {
195 pub fn n_input_features(&self) -> usize {
197 self.n_features
198 }
199
200 pub fn n_output_features(&self) -> usize {
202 self.combinations.len()
203 }
204
205 pub fn degree(&self) -> usize {
207 self.degree
208 }
209
210 pub fn interaction_only(&self) -> bool {
212 self.interaction_only
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use approx::assert_abs_diff_eq;
220 use ndarray::array;
221
222 #[test]
223 fn test_degree2_two_features() {
224 let x = array![[2.0, 3.0]];
226 let poly = PolynomialFeatures::new();
227 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
228 let out = fitted.transform(&x).unwrap();
229
230 assert_eq!(out.ncols(), 6);
231 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 4.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 4]], 6.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 5]], 9.0, epsilon = 1e-10); }
238
239 #[test]
240 fn test_interaction_only_degree2() {
241 let x = array![[2.0, 3.0]];
243 let poly = PolynomialFeatures::new().with_interaction_only(true);
244 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
245 let out = fitted.transform(&x).unwrap();
246
247 assert_eq!(out.ncols(), 4);
248 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 6.0, epsilon = 1e-10); }
253
254 #[test]
255 fn test_degree3_single_feature() {
256 let x = array![[3.0]];
258 let poly = PolynomialFeatures::new().with_degree(3);
259 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
260 let out = fitted.transform(&x).unwrap();
261
262 assert_eq!(out.ncols(), 4);
263 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 9.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 27.0, epsilon = 1e-10); }
268
269 #[test]
270 fn test_degree1() {
271 let x = array![[2.0, 3.0]];
273 let poly = PolynomialFeatures::new().with_degree(1);
274 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
275 let out = fitted.transform(&x).unwrap();
276
277 assert_eq!(out.ncols(), 3);
278 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
279 assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10);
280 assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10);
281 }
282
283 #[test]
284 fn test_degree0_error() {
285 let x = array![[1.0, 2.0]];
286 let poly = PolynomialFeatures::new().with_degree(0);
287 let result = FitUnsupervised::<f64>::fit(&poly, &x);
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_multiple_rows() {
293 let x = array![[1.0, 2.0], [3.0, 4.0]];
294 let poly = PolynomialFeatures::new();
295 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
296 let out = fitted.transform(&x).unwrap();
297
298 assert_eq!(out.nrows(), 2);
299 assert_eq!(out.ncols(), 6);
300
301 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10);
303 assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10);
304 assert_abs_diff_eq!(out[[0, 2]], 2.0, epsilon = 1e-10);
305 assert_abs_diff_eq!(out[[0, 3]], 1.0, epsilon = 1e-10);
306 assert_abs_diff_eq!(out[[0, 4]], 2.0, epsilon = 1e-10);
307 assert_abs_diff_eq!(out[[0, 5]], 4.0, epsilon = 1e-10);
308
309 assert_abs_diff_eq!(out[[1, 0]], 1.0, epsilon = 1e-10);
311 assert_abs_diff_eq!(out[[1, 1]], 3.0, epsilon = 1e-10);
312 assert_abs_diff_eq!(out[[1, 2]], 4.0, epsilon = 1e-10);
313 assert_abs_diff_eq!(out[[1, 3]], 9.0, epsilon = 1e-10);
314 assert_abs_diff_eq!(out[[1, 4]], 12.0, epsilon = 1e-10);
315 assert_abs_diff_eq!(out[[1, 5]], 16.0, epsilon = 1e-10);
316 }
317
318 #[test]
319 fn test_three_features_degree2() {
320 let x = array![[1.0, 2.0, 3.0]];
322 let poly = PolynomialFeatures::new();
323 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
324 let out = fitted.transform(&x).unwrap();
325
326 assert_eq!(out.ncols(), 10);
327 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 4]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 5]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 6]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 7]], 4.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 8]], 6.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 9]], 9.0, epsilon = 1e-10); }
338
339 #[test]
340 fn test_three_features_interaction_only() {
341 let x = array![[2.0, 3.0, 5.0]];
343 let poly = PolynomialFeatures::new().with_interaction_only(true);
344 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
345 let out = fitted.transform(&x).unwrap();
346
347 assert_eq!(out.ncols(), 7);
348 assert_abs_diff_eq!(out[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 1]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 2]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 3]], 5.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 4]], 6.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 5]], 10.0, epsilon = 1e-10); assert_abs_diff_eq!(out[[0, 6]], 15.0, epsilon = 1e-10); }
356
357 #[test]
358 fn test_empty_input() {
359 let x: Array2<f64> = Array2::zeros((0, 0));
360 let poly = PolynomialFeatures::new();
361 assert!(FitUnsupervised::<f64>::fit(&poly, &x).is_err());
362 }
363
364 #[test]
365 fn test_shape_mismatch() {
366 let x = array![[1.0, 2.0]];
367 let poly = PolynomialFeatures::new();
368 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
369
370 let x_wrong = array![[1.0, 2.0, 3.0]];
371 assert!(fitted.transform(&x_wrong).is_err());
372 }
373
374 #[test]
375 fn test_bias_column_all_ones() {
376 let x = array![[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]];
377 let poly = PolynomialFeatures::new();
378 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
379 let out = fitted.transform(&x).unwrap();
380
381 for i in 0..3 {
383 assert_abs_diff_eq!(out[[i, 0]], 1.0, epsilon = 1e-10);
384 }
385 }
386
387 #[test]
388 fn test_n_output_features() {
389 let x = array![[1.0, 2.0]];
390 let poly = PolynomialFeatures::new();
391 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
392
393 assert_eq!(fitted.n_input_features(), 2);
394 assert_eq!(fitted.n_output_features(), 6);
395 assert_eq!(fitted.degree(), 2);
396 assert!(!fitted.interaction_only());
397 }
398
399 #[test]
400 fn test_f32() {
401 let x = array![[2.0f32, 3.0]];
402 let poly = PolynomialFeatures::new();
403 let fitted = FitUnsupervised::<f32>::fit(&poly, &x).unwrap();
404 let out = fitted.transform(&x).unwrap();
405
406 assert_eq!(out.ncols(), 6);
407 assert_abs_diff_eq!(out[[0, 3]], 4.0f32, epsilon = 1e-5); assert_abs_diff_eq!(out[[0, 4]], 6.0f32, epsilon = 1e-5); assert_abs_diff_eq!(out[[0, 5]], 9.0f32, epsilon = 1e-5); }
411
412 #[test]
413 fn test_default() {
414 let poly = PolynomialFeatures::default();
415 assert_eq!(poly.degree, 2);
416 assert!(!poly.interaction_only);
417 }
418
419 mod prop_tests {
420 use super::*;
421 use proptest::prelude::*;
422
423 fn make_data(rows: usize, cols: usize, seed: u64) -> Array2<f64> {
424 use std::collections::hash_map::DefaultHasher;
425 use std::hash::{Hash, Hasher};
426 let mut values = Vec::with_capacity(rows * cols);
427 for i in 0..(rows * cols) {
428 let mut h = DefaultHasher::new();
429 seed.hash(&mut h);
430 (i as u64).hash(&mut h);
431 let bits = h.finish();
432 let v = (bits as f64 / u64::MAX as f64) * 4.0 - 2.0;
433 values.push(v);
434 }
435 Array2::from_shape_vec((rows, cols), values).unwrap()
436 }
437
438 proptest! {
439 #[test]
440 fn poly_bias_column_all_ones(
441 rows in 1..20usize,
442 cols in 1..5usize,
443 seed in 0u64..10000,
444 ) {
445 let x = make_data(rows, cols, seed);
446 let poly = PolynomialFeatures::new();
447 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
448 let out = fitted.transform(&x).unwrap();
449
450 for i in 0..rows {
451 prop_assert!((out[[i, 0]] - 1.0).abs() < 1e-10,
452 "bias column should be 1.0, got {}", out[[i, 0]]);
453 }
454 }
455
456 #[test]
457 fn poly_original_features_preserved(
458 rows in 1..20usize,
459 cols in 1..5usize,
460 seed in 0u64..10000,
461 ) {
462 let x = make_data(rows, cols, seed);
463 let poly = PolynomialFeatures::new();
464 let fitted = FitUnsupervised::<f64>::fit(&poly, &x).unwrap();
465 let out = fitted.transform(&x).unwrap();
466
467 for i in 0..rows {
469 for j in 0..cols {
470 prop_assert!((out[[i, 1 + j]] - x[[i, j]]).abs() < 1e-10,
471 "original feature not preserved at ({}, {})", i, j);
472 }
473 }
474 }
475 }
476 }
477}