Skip to main content

ferrolearn_preprocess/
ordinal_encoder.rs

1//! Ordinal encoder: map string categories to integer indices.
2//!
3//! Each column's categories are mapped to integers `0, 1, 2, ...` in
4//! **lexicographic order** (matching scikit-learn's `OrdinalEncoder`).
5//! Unknown categories seen during
6//! `transform` produce an error.
7
8use ferrolearn_core::error::FerroError;
9use ferrolearn_core::traits::{Fit, FitTransform, Transform};
10use ndarray::Array2;
11use std::collections::HashMap;
12
13// ---------------------------------------------------------------------------
14// OrdinalEncoder (unfitted)
15// ---------------------------------------------------------------------------
16
17/// An unfitted ordinal encoder.
18///
19/// Calling [`Fit::fit`] on an `Array2<String>` learns, for each column, a
20/// mapping from the unique string categories (sorted lexicographically)
21/// to consecutive integers `0, 1, 2, ...`, and returns a
22/// [`FittedOrdinalEncoder`].
23///
24/// # Examples
25///
26/// ```
27/// use ferrolearn_preprocess::ordinal_encoder::OrdinalEncoder;
28/// use ferrolearn_core::traits::{Fit, Transform};
29/// use ndarray::Array2;
30///
31/// let enc = OrdinalEncoder::new();
32/// let data = Array2::from_shape_vec(
33///     (3, 2),
34///     vec![
35///         "cat".to_string(), "small".to_string(),
36///         "dog".to_string(), "large".to_string(),
37///         "cat".to_string(), "small".to_string(),
38///     ],
39/// ).unwrap();
40/// let fitted = enc.fit(&data, &()).unwrap();
41/// let encoded = fitted.transform(&data).unwrap();
42/// assert_eq!(encoded[[0, 0]], 0); // "cat" is index 0 in col 0
43/// assert_eq!(encoded[[1, 0]], 1); // "dog" is index 1 in col 0
44/// ```
45#[derive(Debug, Clone, Default)]
46pub struct OrdinalEncoder;
47
48impl OrdinalEncoder {
49    /// Create a new `OrdinalEncoder`.
50    #[must_use]
51    pub fn new() -> Self {
52        Self
53    }
54}
55
56// ---------------------------------------------------------------------------
57// FittedOrdinalEncoder
58// ---------------------------------------------------------------------------
59
60/// A fitted ordinal encoder holding per-column category-to-index mappings.
61///
62/// Created by calling [`Fit::fit`] on an [`OrdinalEncoder`].
63#[derive(Debug, Clone)]
64pub struct FittedOrdinalEncoder {
65    /// Per-column ordered category lists (index = integer value).
66    pub(crate) categories: Vec<Vec<String>>,
67    /// Per-column category-to-index maps.
68    pub(crate) category_to_index: Vec<HashMap<String, usize>>,
69}
70
71impl FittedOrdinalEncoder {
72    /// Return the ordered category list for each column.
73    ///
74    /// `categories()[j][i]` is the category that maps to integer `i` in column `j`.
75    #[must_use]
76    pub fn categories(&self) -> &[Vec<String>] {
77        &self.categories
78    }
79
80    /// Return the number of input columns (features).
81    #[must_use]
82    pub fn n_features(&self) -> usize {
83        self.categories.len()
84    }
85}
86
87// ---------------------------------------------------------------------------
88// Trait implementations
89// ---------------------------------------------------------------------------
90
91impl Fit<Array2<String>, ()> for OrdinalEncoder {
92    type Fitted = FittedOrdinalEncoder;
93    type Error = FerroError;
94
95    /// Fit the encoder by building per-column category-to-index mappings.
96    ///
97    /// Categories are recorded in **lexicographic order** in each column,
98    /// matching scikit-learn's `OrdinalEncoder.categories_`.
99    ///
100    /// # Errors
101    ///
102    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
103    fn fit(&self, x: &Array2<String>, _y: &()) -> Result<FittedOrdinalEncoder, FerroError> {
104        let n_samples = x.nrows();
105        if n_samples == 0 {
106            return Err(FerroError::InsufficientSamples {
107                required: 1,
108                actual: 0,
109                context: "OrdinalEncoder::fit".into(),
110            });
111        }
112
113        let n_features = x.ncols();
114        let mut categories = Vec::with_capacity(n_features);
115        let mut category_to_index = Vec::with_capacity(n_features);
116
117        for j in 0..n_features {
118            // Collect unique categories then sort lexicographically so the
119            // assigned indices match sklearn's `OrdinalEncoder`, which
120            // documents `categories_ = sorted(unique(X[:, j]))`. (Older
121            // ferrolearn versions used first-seen order — #344.)
122            let mut unique: Vec<String> = Vec::new();
123            let mut seen_set: std::collections::HashSet<String> =
124                std::collections::HashSet::new();
125            for i in 0..n_samples {
126                let cat = &x[[i, j]];
127                if seen_set.insert(cat.clone()) {
128                    unique.push(cat.clone());
129                }
130            }
131            unique.sort();
132
133            let map: HashMap<String, usize> = unique
134                .iter()
135                .enumerate()
136                .map(|(idx, s)| (s.clone(), idx))
137                .collect();
138
139            categories.push(unique);
140            category_to_index.push(map);
141        }
142
143        Ok(FittedOrdinalEncoder {
144            categories,
145            category_to_index,
146        })
147    }
148}
149
150impl Transform<Array2<String>> for FittedOrdinalEncoder {
151    type Output = Array2<usize>;
152    type Error = FerroError;
153
154    /// Transform string categories to integer indices.
155    ///
156    /// # Errors
157    ///
158    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
159    /// match the number of features seen during fitting.
160    ///
161    /// Returns [`FerroError::InvalidParameter`] if any category was not seen
162    /// during fitting.
163    fn transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
164        let n_features = self.categories.len();
165        if x.ncols() != n_features {
166            return Err(FerroError::ShapeMismatch {
167                expected: vec![x.nrows(), n_features],
168                actual: vec![x.nrows(), x.ncols()],
169                context: "FittedOrdinalEncoder::transform".into(),
170            });
171        }
172
173        let n_samples = x.nrows();
174        let mut out = Array2::zeros((n_samples, n_features));
175
176        for j in 0..n_features {
177            let map = &self.category_to_index[j];
178            for i in 0..n_samples {
179                let cat = &x[[i, j]];
180                match map.get(cat) {
181                    Some(&idx) => out[[i, j]] = idx,
182                    None => {
183                        return Err(FerroError::InvalidParameter {
184                            name: format!("x[{i},{j}]"),
185                            reason: format!("unknown category \"{cat}\" in column {j}"),
186                        });
187                    }
188                }
189            }
190        }
191
192        Ok(out)
193    }
194}
195
196/// Implement `Transform` on the unfitted encoder to satisfy the
197/// `FitTransform: Transform` supertrait bound.
198impl Transform<Array2<String>> for OrdinalEncoder {
199    type Output = Array2<usize>;
200    type Error = FerroError;
201
202    /// Always returns an error — the encoder must be fitted first.
203    fn transform(&self, _x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
204        Err(FerroError::InvalidParameter {
205            name: "OrdinalEncoder".into(),
206            reason: "encoder must be fitted before calling transform; use fit() first".into(),
207        })
208    }
209}
210
211impl FitTransform<Array2<String>> for OrdinalEncoder {
212    type FitError = FerroError;
213
214    /// Fit the encoder on `x` and return the encoded output in one step.
215    ///
216    /// # Errors
217    ///
218    /// Returns an error if fitting or transformation fails.
219    fn fit_transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
220        let fitted = self.fit(x, &())?;
221        fitted.transform(x)
222    }
223}
224
225// ---------------------------------------------------------------------------
226// Tests
227// ---------------------------------------------------------------------------
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use ndarray::Array2;
233
234    fn make_2col(rows: &[(&str, &str)]) -> Array2<String> {
235        let flat: Vec<String> = rows
236            .iter()
237            .flat_map(|(a, b)| [a.to_string(), b.to_string()])
238            .collect();
239        Array2::from_shape_vec((rows.len(), 2), flat).unwrap()
240    }
241
242    #[test]
243    fn test_ordinal_encoder_basic() {
244        let enc = OrdinalEncoder::new();
245        let x = make_2col(&[
246            ("cat", "small"),
247            ("dog", "large"),
248            ("cat", "medium"),
249            ("bird", "small"),
250        ]);
251        let fitted = enc.fit(&x, &()).unwrap();
252
253        // Categories are sorted lexicographically (sklearn convention).
254        assert_eq!(fitted.categories()[0], vec!["bird", "cat", "dog"]);
255        assert_eq!(fitted.categories()[1], vec!["large", "medium", "small"]);
256
257        let encoded = fitted.transform(&x).unwrap();
258        assert_eq!(encoded[[0, 0]], 1); // "cat"  -> 1 (lex pos)
259        assert_eq!(encoded[[1, 0]], 2); // "dog"  -> 2
260        assert_eq!(encoded[[2, 0]], 1); // "cat"  -> 1
261        assert_eq!(encoded[[3, 0]], 0); // "bird" -> 0
262        assert_eq!(encoded[[0, 1]], 2); // "small"  -> 2
263        assert_eq!(encoded[[1, 1]], 0); // "large"  -> 0
264        assert_eq!(encoded[[2, 1]], 1); // "medium" -> 1
265        assert_eq!(encoded[[3, 1]], 2); // "small"  -> 2
266    }
267
268    #[test]
269    fn test_fit_transform_equivalence() {
270        let enc = OrdinalEncoder::new();
271        let x = make_2col(&[("a", "x"), ("b", "y"), ("a", "z")]);
272        let via_ft = enc.fit_transform(&x).unwrap();
273        let fitted = enc.fit(&x, &()).unwrap();
274        let via_sep = fitted.transform(&x).unwrap();
275        assert_eq!(via_ft, via_sep);
276    }
277
278    #[test]
279    fn test_unknown_category_error() {
280        let enc = OrdinalEncoder::new();
281        let x_train = make_2col(&[("cat", "small"), ("dog", "large")]);
282        let fitted = enc.fit(&x_train, &()).unwrap();
283        let x_test = make_2col(&[("fish", "small")]);
284        assert!(fitted.transform(&x_test).is_err());
285    }
286
287    #[test]
288    fn test_shape_mismatch_error() {
289        let enc = OrdinalEncoder::new();
290        let x_train = make_2col(&[("a", "x")]);
291        let fitted = enc.fit(&x_train, &()).unwrap();
292        // Single-column input when 2 cols expected
293        let x_bad = Array2::from_shape_vec((1, 1), vec!["a".to_string()]).unwrap();
294        assert!(fitted.transform(&x_bad).is_err());
295    }
296
297    #[test]
298    fn test_insufficient_samples_error() {
299        let enc = OrdinalEncoder::new();
300        let x: Array2<String> = Array2::from_shape_vec((0, 2), vec![]).unwrap();
301        assert!(enc.fit(&x, &()).is_err());
302    }
303
304    #[test]
305    fn test_unfitted_transform_error() {
306        let enc = OrdinalEncoder::new();
307        let x = make_2col(&[("a", "x")]);
308        assert!(enc.transform(&x).is_err());
309    }
310
311    #[test]
312    fn test_single_column() {
313        let enc = OrdinalEncoder::new();
314        let flat = vec![
315            "red".to_string(),
316            "green".to_string(),
317            "blue".to_string(),
318            "red".to_string(),
319        ];
320        let x = Array2::from_shape_vec((4, 1), flat).unwrap();
321        let fitted = enc.fit(&x, &()).unwrap();
322        // Lex order: blue (0), green (1), red (2)
323        assert_eq!(fitted.categories()[0], vec!["blue", "green", "red"]);
324        let encoded = fitted.transform(&x).unwrap();
325        assert_eq!(encoded[[0, 0]], 2); // red
326        assert_eq!(encoded[[1, 0]], 1); // green
327        assert_eq!(encoded[[2, 0]], 0); // blue
328        assert_eq!(encoded[[3, 0]], 2); // red
329    }
330
331    #[test]
332    fn test_n_features() {
333        let enc = OrdinalEncoder::new();
334        let x = make_2col(&[("a", "x")]);
335        let fitted = enc.fit(&x, &()).unwrap();
336        assert_eq!(fitted.n_features(), 2);
337    }
338
339    #[test]
340    fn test_lexicographic_order() {
341        // Categories are sorted lexicographically to match sklearn (#344).
342        let enc = OrdinalEncoder::new();
343        let flat = vec!["zebra".to_string(), "ant".to_string(), "moose".to_string()];
344        let x = Array2::from_shape_vec((3, 1), flat).unwrap();
345        let fitted = enc.fit(&x, &()).unwrap();
346        // ant < moose < zebra
347        assert_eq!(fitted.categories()[0][0], "ant");
348        assert_eq!(fitted.categories()[0][1], "moose");
349        assert_eq!(fitted.categories()[0][2], "zebra");
350    }
351}