Skip to main content

ferrolearn_preprocess/
one_hot_encoder.rs

1//! One-hot encoder for categorical integer features.
2//!
3//! Transforms a matrix of categorical integer indices into a dense binary
4//! array where each category is represented by a separate column.
5//!
6//! # Example
7//!
8//! ```text
9//! Input column with categories {0, 1, 2}:
10//!   [0, 1, 2, 1]  →  [[1,0,0],[0,1,0],[0,0,1],[0,1,0]]
11//! ```
12
13use ferrolearn_core::error::FerroError;
14use ferrolearn_core::traits::{Fit, FitTransform, Transform};
15use ndarray::Array2;
16use num_traits::Float;
17
18// ---------------------------------------------------------------------------
19// OneHotEncoder (unfitted)
20// ---------------------------------------------------------------------------
21
22/// An unfitted one-hot encoder for multi-column categorical data.
23///
24/// Input: `Array2<usize>` where each column contains non-negative integer
25/// category indices. Calling [`Fit::fit`] learns the set of categories per
26/// column and returns a [`FittedOneHotEncoder`].
27///
28/// # Examples
29///
30/// ```
31/// use ferrolearn_preprocess::OneHotEncoder;
32/// use ferrolearn_core::traits::{Fit, Transform};
33/// use ndarray::array;
34///
35/// let enc = OneHotEncoder::<f64>::new();
36/// let x = array![[0usize, 1], [1, 0], [2, 1]];
37/// let fitted = enc.fit(&x, &()).unwrap();
38/// let encoded = fitted.transform(&x).unwrap();
39/// ```
40#[derive(Debug, Clone)]
41pub struct OneHotEncoder<F> {
42    _marker: std::marker::PhantomData<F>,
43}
44
45impl<F: Float + Send + Sync + 'static> OneHotEncoder<F> {
46    /// Create a new `OneHotEncoder`.
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            _marker: std::marker::PhantomData,
51        }
52    }
53}
54
55impl<F: Float + Send + Sync + 'static> Default for OneHotEncoder<F> {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61// ---------------------------------------------------------------------------
62// FittedOneHotEncoder
63// ---------------------------------------------------------------------------
64
65/// A fitted one-hot encoder holding the number of categories per input column.
66///
67/// Created by calling [`Fit::fit`] on a [`OneHotEncoder`].
68#[derive(Debug, Clone)]
69pub struct FittedOneHotEncoder<F> {
70    /// Number of unique categories for each input column, in order.
71    pub(crate) n_categories: Vec<usize>,
72    _marker: std::marker::PhantomData<F>,
73}
74
75impl<F: Float + Send + Sync + 'static> FittedOneHotEncoder<F> {
76    /// Return the number of categories for each input feature column.
77    #[must_use]
78    pub fn n_categories(&self) -> &[usize] {
79        &self.n_categories
80    }
81
82    /// Return the total number of output columns.
83    #[must_use]
84    pub fn n_output_features(&self) -> usize {
85        self.n_categories.iter().sum()
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Trait implementations
91// ---------------------------------------------------------------------------
92
93impl<F: Float + Send + Sync + 'static> Fit<Array2<usize>, ()> for OneHotEncoder<F> {
94    type Fitted = FittedOneHotEncoder<F>;
95    type Error = FerroError;
96
97    /// Fit the encoder by determining the number of unique categories per column.
98    ///
99    /// The number of categories for column `j` is `max(x[:, j]) + 1`.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
104    fn fit(&self, x: &Array2<usize>, _y: &()) -> Result<FittedOneHotEncoder<F>, FerroError> {
105        let n_samples = x.nrows();
106        if n_samples == 0 {
107            return Err(FerroError::InsufficientSamples {
108                required: 1,
109                actual: 0,
110                context: "OneHotEncoder::fit".into(),
111            });
112        }
113
114        let n_features = x.ncols();
115        let mut n_categories = Vec::with_capacity(n_features);
116
117        for j in 0..n_features {
118            let col = x.column(j);
119            let max_cat = col.iter().copied().max().unwrap_or(0);
120            n_categories.push(max_cat + 1);
121        }
122
123        Ok(FittedOneHotEncoder {
124            n_categories,
125            _marker: std::marker::PhantomData,
126        })
127    }
128}
129
130impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for FittedOneHotEncoder<F> {
131    type Output = Array2<F>;
132    type Error = FerroError;
133
134    /// Transform categorical data into a dense one-hot encoded matrix.
135    ///
136    /// # Errors
137    ///
138    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
139    /// match the number of features seen during fitting.
140    ///
141    /// Returns [`FerroError::InvalidParameter`] if any category value exceeds
142    /// the maximum seen during fitting.
143    fn transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
144        let n_features = self.n_categories.len();
145        if x.ncols() != n_features {
146            return Err(FerroError::ShapeMismatch {
147                expected: vec![x.nrows(), n_features],
148                actual: vec![x.nrows(), x.ncols()],
149                context: "FittedOneHotEncoder::transform".into(),
150            });
151        }
152
153        let n_out_cols = self.n_output_features();
154        let n_samples = x.nrows();
155        let mut out = Array2::zeros((n_samples, n_out_cols));
156
157        let mut col_offset = 0;
158        for j in 0..n_features {
159            let n_cats = self.n_categories[j];
160            for i in 0..n_samples {
161                let cat = x[[i, j]];
162                if cat >= n_cats {
163                    return Err(FerroError::InvalidParameter {
164                        name: format!("x[{i},{j}]"),
165                        reason: format!(
166                            "category {cat} exceeds max seen during fitting ({})",
167                            n_cats - 1
168                        ),
169                    });
170                }
171                out[[i, col_offset + cat]] = F::one();
172            }
173            col_offset += n_cats;
174        }
175
176        Ok(out)
177    }
178}
179
180/// Implement `Transform` on the unfitted encoder to satisfy the `FitTransform: Transform`
181/// supertrait bound. Calling `transform` on an unfitted encoder always returns an error.
182impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for OneHotEncoder<F> {
183    type Output = Array2<F>;
184    type Error = FerroError;
185
186    /// Always returns an error — the encoder must be fitted first.
187    ///
188    /// Use [`Fit::fit`] to produce a [`FittedOneHotEncoder`], then call
189    /// [`Transform::transform`] on that.
190    fn transform(&self, _x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
191        Err(FerroError::InvalidParameter {
192            name: "OneHotEncoder".into(),
193            reason: "encoder must be fitted before calling transform; use fit() first".into(),
194        })
195    }
196}
197
198impl<F: Float + Send + Sync + 'static> FitTransform<Array2<usize>> for OneHotEncoder<F> {
199    type FitError = FerroError;
200
201    /// Fit the encoder on `x` and return the one-hot encoded output in one step.
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if fitting or transformation fails.
206    fn fit_transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
207        let fitted = self.fit(x, &())?;
208        fitted.transform(x)
209    }
210}
211
212/// Convenience: encode a 1-D array of categorical integers.
213///
214/// This wraps the input in a single-column `Array2<usize>` and returns the
215/// encoded result with one-hot columns for that single feature.
216impl<F: Float + Send + Sync + 'static> FittedOneHotEncoder<F> {
217    /// Transform a 1-D slice of category indices.
218    ///
219    /// # Errors
220    ///
221    /// Returns an error if any category value is out-of-range.
222    pub fn transform_1d(&self, x: &[usize]) -> Result<Array2<F>, FerroError> {
223        if self.n_categories.len() != 1 {
224            return Err(FerroError::InvalidParameter {
225                name: "transform_1d".into(),
226                reason: "encoder was fitted on more than one column; use transform instead".into(),
227            });
228        }
229        let col = Array2::from_shape_vec((x.len(), 1), x.to_vec()).map_err(|e| {
230            FerroError::InvalidParameter {
231                name: "x".into(),
232                reason: e.to_string(),
233            }
234        })?;
235        self.transform(&col)
236    }
237}
238
239// ---------------------------------------------------------------------------
240// Tests
241// ---------------------------------------------------------------------------
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use ndarray::array;
247
248    #[test]
249    fn test_one_hot_single_column() {
250        let enc = OneHotEncoder::<f64>::new();
251        let x = array![[0usize], [1], [2]];
252        let fitted = enc.fit(&x, &()).unwrap();
253        assert_eq!(fitted.n_categories(), &[3]);
254        assert_eq!(fitted.n_output_features(), 3);
255
256        let out = fitted.transform(&x).unwrap();
257        assert_eq!(out.shape(), &[3, 3]);
258        // Row 0: category 0 → [1, 0, 0]
259        assert_eq!(out[[0, 0]], 1.0);
260        assert_eq!(out[[0, 1]], 0.0);
261        assert_eq!(out[[0, 2]], 0.0);
262        // Row 1: category 1 → [0, 1, 0]
263        assert_eq!(out[[1, 0]], 0.0);
264        assert_eq!(out[[1, 1]], 1.0);
265        assert_eq!(out[[1, 2]], 0.0);
266        // Row 2: category 2 → [0, 0, 1]
267        assert_eq!(out[[2, 0]], 0.0);
268        assert_eq!(out[[2, 1]], 0.0);
269        assert_eq!(out[[2, 2]], 1.0);
270    }
271
272    #[test]
273    fn test_one_hot_multi_column() {
274        let enc = OneHotEncoder::<f64>::new();
275        // Two columns: col0 has 3 categories, col1 has 2 categories
276        let x = array![[0usize, 0], [1, 1], [2, 0]];
277        let fitted = enc.fit(&x, &()).unwrap();
278        assert_eq!(fitted.n_categories(), &[3, 2]);
279        assert_eq!(fitted.n_output_features(), 5);
280
281        let out = fitted.transform(&x).unwrap();
282        assert_eq!(out.shape(), &[3, 5]);
283        // Row 0: (0, 0) → [1,0,0, 1,0]
284        assert_eq!(out.row(0).to_vec(), vec![1.0, 0.0, 0.0, 1.0, 0.0]);
285        // Row 1: (1, 1) → [0,1,0, 0,1]
286        assert_eq!(out.row(1).to_vec(), vec![0.0, 1.0, 0.0, 0.0, 1.0]);
287        // Row 2: (2, 0) → [0,0,1, 1,0]
288        assert_eq!(out.row(2).to_vec(), vec![0.0, 0.0, 1.0, 1.0, 0.0]);
289    }
290
291    #[test]
292    fn test_out_of_range_category_error() {
293        let enc = OneHotEncoder::<f64>::new();
294        let x_train = array![[0usize], [1]];
295        let fitted = enc.fit(&x_train, &()).unwrap();
296        // Category 2 was not seen during fitting
297        let x_bad = array![[2usize]];
298        assert!(fitted.transform(&x_bad).is_err());
299    }
300
301    #[test]
302    fn test_fit_transform_equivalence() {
303        let enc = OneHotEncoder::<f64>::new();
304        let x = array![[0usize, 1], [1, 0], [2, 1]];
305        let via_fit_transform: Array2<f64> = enc.fit_transform(&x).unwrap();
306        let fitted = enc.fit(&x, &()).unwrap();
307        let via_separate = fitted.transform(&x).unwrap();
308        for (a, b) in via_fit_transform.iter().zip(via_separate.iter()) {
309            assert!((a - b).abs() < 1e-15);
310        }
311    }
312
313    #[test]
314    fn test_shape_mismatch_error() {
315        let enc = OneHotEncoder::<f64>::new();
316        let x_train = array![[0usize, 1], [1, 0]];
317        let fitted = enc.fit(&x_train, &()).unwrap();
318        let x_bad = array![[0usize]];
319        assert!(fitted.transform(&x_bad).is_err());
320    }
321}