Skip to main content

scry_learn/preprocess/
polynomial.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Polynomial and interaction feature expansion.
3
4use crate::dataset::Dataset;
5use crate::error::{Result, ScryLearnError};
6use crate::preprocess::Transformer;
7
8/// Generate polynomial and interaction features.
9///
10/// Transforms an input feature set `[x1, x2, …]` into all polynomial
11/// combinations up to a given `degree`. For example, with `degree=2` and
12/// 2 features:
13///
14/// - `include_bias=true`:  `[1, x1, x2, x1², x1·x2, x2²]`
15/// - `interaction_only=true`: `[1, x1, x2, x1·x2]` (no self-powers)
16///
17/// # Example
18///
19/// ```ignore
20/// let mut poly = PolynomialFeatures::new().degree(2);
21/// poly.fit_transform(&mut ds)?;
22/// ```
23#[derive(Clone, Debug)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[non_exhaustive]
26pub struct PolynomialFeatures {
27    degree: usize,
28    interaction_only: bool,
29    include_bias: bool,
30    /// Stored combo descriptors: each is a Vec of (original_col_idx, power) pairs.
31    combos: Vec<Vec<(usize, usize)>>,
32    fitted: bool,
33    #[cfg_attr(feature = "serde", serde(default))]
34    _schema_version: u32,
35}
36
37impl PolynomialFeatures {
38    /// Create a new `PolynomialFeatures` with default settings (degree=2).
39    pub fn new() -> Self {
40        Self {
41            degree: 2,
42            interaction_only: false,
43            include_bias: true,
44            combos: Vec::new(),
45            fitted: false,
46            _schema_version: crate::version::SCHEMA_VERSION,
47        }
48    }
49
50    /// Set the maximum polynomial degree.
51    pub fn degree(mut self, degree: usize) -> Self {
52        self.degree = degree;
53        self
54    }
55
56    /// If true, only interaction features are produced (no self-powers like x²).
57    pub fn interaction_only(mut self, v: bool) -> Self {
58        self.interaction_only = v;
59        self
60    }
61
62    /// If true, include a bias (all-ones) column.
63    pub fn include_bias(mut self, v: bool) -> Self {
64        self.include_bias = v;
65        self
66    }
67
68    /// Number of output features after transform.
69    pub fn n_output_features(&self) -> usize {
70        self.combos.len()
71    }
72}
73
74impl Default for PolynomialFeatures {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80/// Recursively enumerate all monomial combinations of `n_features` variables
81/// up to `degree`, in graded lexicographic order.
82/// Generate combos of exactly `target_deg` total degree,
83/// using feature indices >= `start`.
84fn gen_combos(
85    n_features: usize,
86    remaining_deg: usize,
87    start: usize,
88    interaction_only: bool,
89    current: &mut Vec<(usize, usize)>,
90    out: &mut Vec<Vec<(usize, usize)>>,
91) {
92    if remaining_deg == 0 {
93        out.push(current.clone());
94        return;
95    }
96    for col in start..n_features {
97        let max_power = if interaction_only { 1 } else { remaining_deg };
98        // Try powers from highest to lowest so x1^2 comes before x1*x2.
99        for power in (1..=max_power).rev() {
100            if power > remaining_deg {
101                continue;
102            }
103            current.push((col, power));
104            gen_combos(
105                n_features,
106                remaining_deg - power,
107                col + 1,
108                interaction_only,
109                current,
110                out,
111            );
112            current.pop();
113        }
114    }
115}
116
117fn enumerate_combos(
118    n_features: usize,
119    degree: usize,
120    interaction_only: bool,
121    include_bias: bool,
122) -> Vec<Vec<(usize, usize)>> {
123    let mut result = Vec::new();
124
125    for deg in 0..=degree {
126        if deg == 0 {
127            if include_bias {
128                result.push(Vec::new()); // bias term
129            }
130        } else if deg == 1 {
131            for col in 0..n_features {
132                result.push(vec![(col, 1)]);
133            }
134        } else {
135            let mut current = Vec::new();
136            gen_combos(
137                n_features,
138                deg,
139                0,
140                interaction_only,
141                &mut current,
142                &mut result,
143            );
144        }
145    }
146
147    result
148}
149
150impl Transformer for PolynomialFeatures {
151    fn fit(&mut self, data: &Dataset) -> Result<()> {
152        data.validate_finite()?;
153        if data.n_samples() == 0 {
154            return Err(ScryLearnError::EmptyDataset);
155        }
156        self.combos = enumerate_combos(
157            data.n_features(),
158            self.degree,
159            self.interaction_only,
160            self.include_bias,
161        );
162        self.fitted = true;
163        Ok(())
164    }
165
166    fn transform(&self, data: &mut Dataset) -> Result<()> {
167        crate::version::check_schema_version(self._schema_version)?;
168        if !self.fitted {
169            return Err(ScryLearnError::NotFitted);
170        }
171        let n = data.n_samples();
172        let old_features = data.features.clone();
173
174        let mut new_features: Vec<Vec<f64>> = Vec::with_capacity(self.combos.len());
175        let mut new_names: Vec<String> = Vec::with_capacity(self.combos.len());
176
177        for combo in &self.combos {
178            let mut col = vec![1.0; n];
179            let mut name_parts = Vec::new();
180
181            for &(feat_idx, power) in combo {
182                #[allow(clippy::cast_possible_wrap)]
183                let exp = power as i32;
184                for (i, val) in col.iter_mut().enumerate() {
185                    *val *= old_features[feat_idx][i].powi(exp);
186                }
187                let fname = data
188                    .feature_names
189                    .get(feat_idx)
190                    .cloned()
191                    .unwrap_or_else(|| format!("x{feat_idx}"));
192                if power == 1 {
193                    name_parts.push(fname);
194                } else {
195                    name_parts.push(format!("{fname}^{power}"));
196                }
197            }
198
199            if name_parts.is_empty() {
200                new_names.push("1".into());
201            } else {
202                new_names.push(name_parts.join("*"));
203            }
204            new_features.push(col);
205        }
206
207        data.features = new_features;
208        data.feature_names = new_names;
209        data.sync_matrix();
210        Ok(())
211    }
212
213    fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
214        Err(ScryLearnError::InvalidParameter(
215            "PolynomialFeatures is not invertible".into(),
216        ))
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_poly_degree2_basic() {
226        // Input: 2 features, 2 samples: [[1,2],[3,4]]
227        let mut ds = Dataset::new(
228            vec![vec![1.0, 3.0], vec![2.0, 4.0]],
229            vec![0.0, 1.0],
230            vec!["x1".into(), "x2".into()],
231            "y",
232        );
233        let mut poly = PolynomialFeatures::new().degree(2).include_bias(true);
234        poly.fit_transform(&mut ds).unwrap();
235
236        // Expected columns: [1, x1, x2, x1^2, x1*x2, x2^2]
237        assert_eq!(ds.n_features(), 6);
238
239        // Row 0: [1, 1, 2, 1, 2, 4]
240        let row0: Vec<f64> = ds.features.iter().map(|c| c[0]).collect();
241        assert_eq!(row0, vec![1.0, 1.0, 2.0, 1.0, 2.0, 4.0]);
242
243        // Row 1: [1, 3, 4, 9, 12, 16]
244        let row1: Vec<f64> = ds.features.iter().map(|c| c[1]).collect();
245        assert_eq!(row1, vec![1.0, 3.0, 4.0, 9.0, 12.0, 16.0]);
246    }
247
248    #[test]
249    fn test_poly_interaction_only() {
250        let mut ds = Dataset::new(
251            vec![vec![1.0, 3.0], vec![2.0, 4.0]],
252            vec![0.0, 1.0],
253            vec!["x1".into(), "x2".into()],
254            "y",
255        );
256        let mut poly = PolynomialFeatures::new()
257            .degree(2)
258            .interaction_only(true)
259            .include_bias(true);
260        poly.fit_transform(&mut ds).unwrap();
261
262        // Expected: [1, x1, x2, x1*x2] — no x1^2, x2^2
263        assert_eq!(ds.n_features(), 4);
264
265        let row0: Vec<f64> = ds.features.iter().map(|c| c[0]).collect();
266        assert_eq!(row0, vec![1.0, 1.0, 2.0, 2.0]);
267    }
268
269    #[test]
270    fn test_poly_no_bias() {
271        let mut ds = Dataset::new(
272            vec![vec![2.0], vec![3.0]],
273            vec![0.0],
274            vec!["a".into(), "b".into()],
275            "y",
276        );
277        let mut poly = PolynomialFeatures::new().degree(2).include_bias(false);
278        poly.fit_transform(&mut ds).unwrap();
279
280        // No bias column, so first col should be a feature, not 1.
281        let first_vals = &ds.features[0];
282        assert!((first_vals[0] - 2.0).abs() < 1e-10);
283    }
284
285    #[test]
286    fn test_poly_degree3() {
287        let mut ds = Dataset::new(vec![vec![2.0]], vec![0.0], vec!["x".into()], "y");
288        let mut poly = PolynomialFeatures::new().degree(3).include_bias(true);
289        poly.fit_transform(&mut ds).unwrap();
290
291        // [1, x, x^2, x^3] → [1, 2, 4, 8]
292        assert_eq!(ds.n_features(), 4);
293        let row: Vec<f64> = ds.features.iter().map(|c| c[0]).collect();
294        assert_eq!(row, vec![1.0, 2.0, 4.0, 8.0]);
295    }
296
297    #[test]
298    fn test_poly_not_fitted() {
299        let poly = PolynomialFeatures::new();
300        let mut ds = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "y");
301        assert!(poly.transform(&mut ds).is_err());
302    }
303}