Skip to main content

anofox_ml_preprocessing/
one_hot_encoder.rs

1use anofox_ml_core::{Result, RustMlError};
2use ndarray::Array2;
3
4/// One-hot encoder for integer-encoded categorical features.
5///
6/// Transforms integer-encoded columns into binary indicator columns.
7/// For example, a column with values [0, 1, 2] becomes three binary columns:
8/// ```text
9/// [0] -> [1, 0, 0]
10/// [1] -> [0, 1, 0]
11/// [2] -> [0, 0, 1]
12/// ```
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct OneHotEncoder;
15
16impl OneHotEncoder {
17    /// Create a new `OneHotEncoder`.
18    pub fn new() -> Self {
19        Self
20    }
21
22    /// Fit the encoder on integer-encoded data.
23    ///
24    /// Learns the number of unique categories per column.
25    pub fn fit(&self, x: &Array2<usize>) -> Result<FittedOneHotEncoder> {
26        if x.is_empty() {
27            return Err(RustMlError::EmptyInput("input array is empty".into()));
28        }
29
30        let ncols = x.ncols();
31        let mut categories = Vec::with_capacity(ncols);
32
33        for j in 0..ncols {
34            let col = x.column(j);
35            let max_val = col.iter().copied().max().unwrap_or(0);
36            categories.push(max_val + 1);
37        }
38
39        Ok(FittedOneHotEncoder { categories })
40    }
41}
42
43impl Default for OneHotEncoder {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49/// Fitted OneHotEncoder — holds the number of unique categories per column.
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct FittedOneHotEncoder {
52    categories: Vec<usize>,
53}
54
55impl FittedOneHotEncoder {
56    /// Transform integer-encoded data into one-hot encoded data.
57    ///
58    /// Each input column of `k` categories is expanded into `k` binary columns.
59    pub fn transform(&self, x: &Array2<usize>) -> Result<Array2<f64>> {
60        if x.ncols() != self.categories.len() {
61            return Err(RustMlError::ShapeMismatch(format!(
62                "expected {} columns, got {}",
63                self.categories.len(),
64                x.ncols()
65            )));
66        }
67
68        let total_out_cols: usize = self.categories.iter().sum();
69        let nrows = x.nrows();
70        let mut result = Array2::<f64>::zeros((nrows, total_out_cols));
71
72        for i in 0..nrows {
73            let mut col_offset = 0;
74            for j in 0..x.ncols() {
75                let val = x[[i, j]];
76                if val >= self.categories[j] {
77                    return Err(RustMlError::InvalidParameter(format!(
78                        "value {} in column {} exceeds number of categories {}",
79                        val, j, self.categories[j]
80                    )));
81                }
82                result[[i, col_offset + val]] = 1.0;
83                col_offset += self.categories[j];
84            }
85        }
86
87        Ok(result)
88    }
89
90    /// Return the number of categories per original column.
91    pub fn categories(&self) -> &[usize] {
92        &self.categories
93    }
94
95    /// Return the total number of output columns after one-hot encoding.
96    pub fn n_output_features(&self) -> usize {
97        self.categories.iter().sum()
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use approx::assert_abs_diff_eq;
105    use ndarray::array;
106
107    #[test]
108    fn test_single_column() {
109        let x = array![[0usize], [1], [2]];
110        let encoder = OneHotEncoder::new();
111        let fitted = encoder.fit(&x).unwrap();
112        let encoded = fitted.transform(&x).unwrap();
113
114        assert_eq!(encoded.shape(), &[3, 3]);
115        // Row 0: [1, 0, 0]
116        assert_abs_diff_eq!(encoded[[0, 0]], 1.0);
117        assert_abs_diff_eq!(encoded[[0, 1]], 0.0);
118        assert_abs_diff_eq!(encoded[[0, 2]], 0.0);
119        // Row 1: [0, 1, 0]
120        assert_abs_diff_eq!(encoded[[1, 0]], 0.0);
121        assert_abs_diff_eq!(encoded[[1, 1]], 1.0);
122        assert_abs_diff_eq!(encoded[[1, 2]], 0.0);
123        // Row 2: [0, 0, 1]
124        assert_abs_diff_eq!(encoded[[2, 0]], 0.0);
125        assert_abs_diff_eq!(encoded[[2, 1]], 0.0);
126        assert_abs_diff_eq!(encoded[[2, 2]], 1.0);
127    }
128
129    #[test]
130    fn test_multiple_columns() {
131        // Column 0 has 2 categories (0, 1), column 1 has 3 categories (0, 1, 2)
132        let x = array![[0usize, 2], [1, 0], [0, 1]];
133        let encoder = OneHotEncoder::new();
134        let fitted = encoder.fit(&x).unwrap();
135        let encoded = fitted.transform(&x).unwrap();
136
137        assert_eq!(encoded.shape(), &[3, 5]); // 2 + 3 = 5 output columns
138        assert_eq!(fitted.n_output_features(), 5);
139
140        // Row 0: col0=0 -> [1,0], col1=2 -> [0,0,1] => [1,0,0,0,1]
141        assert_abs_diff_eq!(encoded[[0, 0]], 1.0);
142        assert_abs_diff_eq!(encoded[[0, 1]], 0.0);
143        assert_abs_diff_eq!(encoded[[0, 2]], 0.0);
144        assert_abs_diff_eq!(encoded[[0, 3]], 0.0);
145        assert_abs_diff_eq!(encoded[[0, 4]], 1.0);
146
147        // Row 1: col0=1 -> [0,1], col1=0 -> [1,0,0] => [0,1,1,0,0]
148        assert_abs_diff_eq!(encoded[[1, 0]], 0.0);
149        assert_abs_diff_eq!(encoded[[1, 1]], 1.0);
150        assert_abs_diff_eq!(encoded[[1, 2]], 1.0);
151        assert_abs_diff_eq!(encoded[[1, 3]], 0.0);
152        assert_abs_diff_eq!(encoded[[1, 4]], 0.0);
153    }
154
155    #[test]
156    fn test_binary_column() {
157        let x = array![[0usize], [1], [1], [0]];
158        let encoder = OneHotEncoder::new();
159        let fitted = encoder.fit(&x).unwrap();
160        let encoded = fitted.transform(&x).unwrap();
161
162        assert_eq!(encoded.shape(), &[4, 2]);
163        assert_eq!(fitted.categories(), &[2]);
164    }
165
166    #[test]
167    fn test_empty_input() {
168        let x: Array2<usize> = Array2::zeros((0, 0));
169        let encoder = OneHotEncoder::new();
170        assert!(encoder.fit(&x).is_err());
171    }
172
173    #[test]
174    fn test_shape_mismatch() {
175        let x_train = array![[0usize, 1], [1, 0]];
176        let encoder = OneHotEncoder::new();
177        let fitted = encoder.fit(&x_train).unwrap();
178
179        let x_wrong = array![[0usize, 1, 2]];
180        assert!(fitted.transform(&x_wrong).is_err());
181    }
182
183    #[test]
184    fn test_unknown_category_in_transform() {
185        let x_train = array![[0usize], [1]];
186        let encoder = OneHotEncoder::new();
187        let fitted = encoder.fit(&x_train).unwrap();
188
189        // Value 5 was never seen during fit (max was 1, so categories = 2)
190        let x_test = array![[5usize]];
191        assert!(fitted.transform(&x_test).is_err());
192    }
193
194    #[test]
195    fn test_all_zeros() {
196        let x = array![[0usize, 0], [0, 0], [0, 0]];
197        let encoder = OneHotEncoder::new();
198        let fitted = encoder.fit(&x).unwrap();
199        let encoded = fitted.transform(&x).unwrap();
200
201        // 1 category per column -> 2 output columns
202        assert_eq!(encoded.shape(), &[3, 2]);
203        // Every row: [1, 1]
204        for i in 0..3 {
205            assert_abs_diff_eq!(encoded[[i, 0]], 1.0);
206            assert_abs_diff_eq!(encoded[[i, 1]], 1.0);
207        }
208    }
209
210    #[test]
211    fn test_row_sums() {
212        // Each one-hot block should have exactly one 1 per row per original column
213        let x = array![[0usize, 2, 1], [2, 0, 0], [1, 1, 2]];
214        let encoder = OneHotEncoder::new();
215        let fitted = encoder.fit(&x).unwrap();
216        let encoded = fitted.transform(&x).unwrap();
217
218        // Total output columns = 3 + 3 + 3 = 9
219        assert_eq!(encoded.shape(), &[3, 9]);
220
221        // Each row should sum to number of original columns (3)
222        for i in 0..3 {
223            let row_sum: f64 = encoded.row(i).sum();
224            assert_abs_diff_eq!(row_sum, 3.0, epsilon = 1e-10);
225        }
226    }
227
228    #[test]
229    fn test_default() {
230        let encoder = OneHotEncoder::default();
231        let x = array![[0usize], [1]];
232        let fitted = encoder.fit(&x).unwrap();
233        assert_eq!(fitted.categories(), &[2]);
234    }
235}