Skip to main content

ferrolearn_preprocess/
ordinal_encoder.rs

1//! Ordinal encoder: map string categories to integer indices.
2//!
3//! Each column's categories are mapped to integers `0, 1, 2, ...` in **order
4//! of first appearance** in the training data. Unknown categories seen during
5//! `transform` produce an error.
6
7use ferrolearn_core::error::FerroError;
8use ferrolearn_core::traits::{Fit, FitTransform, Transform};
9use ndarray::Array2;
10use std::collections::HashMap;
11
12// ---------------------------------------------------------------------------
13// OrdinalEncoder (unfitted)
14// ---------------------------------------------------------------------------
15
16/// An unfitted ordinal encoder.
17///
18/// Calling [`Fit::fit`] on an `Array2<String>` learns, for each column, a
19/// mapping from the unique string categories (in order of first appearance)
20/// to consecutive integers `0, 1, 2, ...`, and returns a
21/// [`FittedOrdinalEncoder`].
22///
23/// # Examples
24///
25/// ```
26/// use ferrolearn_preprocess::ordinal_encoder::OrdinalEncoder;
27/// use ferrolearn_core::traits::{Fit, Transform};
28/// use ndarray::Array2;
29///
30/// let enc = OrdinalEncoder::new();
31/// let data = Array2::from_shape_vec(
32///     (3, 2),
33///     vec![
34///         "cat".to_string(), "small".to_string(),
35///         "dog".to_string(), "large".to_string(),
36///         "cat".to_string(), "small".to_string(),
37///     ],
38/// ).unwrap();
39/// let fitted = enc.fit(&data, &()).unwrap();
40/// let encoded = fitted.transform(&data).unwrap();
41/// assert_eq!(encoded[[0, 0]], 0); // "cat" is index 0 in col 0
42/// assert_eq!(encoded[[1, 0]], 1); // "dog" is index 1 in col 0
43/// ```
44#[derive(Debug, Clone, Default)]
45pub struct OrdinalEncoder;
46
47impl OrdinalEncoder {
48    /// Create a new `OrdinalEncoder`.
49    #[must_use]
50    pub fn new() -> Self {
51        Self
52    }
53}
54
55// ---------------------------------------------------------------------------
56// FittedOrdinalEncoder
57// ---------------------------------------------------------------------------
58
59/// A fitted ordinal encoder holding per-column category-to-index mappings.
60///
61/// Created by calling [`Fit::fit`] on an [`OrdinalEncoder`].
62#[derive(Debug, Clone)]
63pub struct FittedOrdinalEncoder {
64    /// Per-column ordered category lists (index = integer value).
65    pub(crate) categories: Vec<Vec<String>>,
66    /// Per-column category-to-index maps.
67    pub(crate) category_to_index: Vec<HashMap<String, usize>>,
68}
69
70impl FittedOrdinalEncoder {
71    /// Return the ordered category list for each column.
72    ///
73    /// `categories()[j][i]` is the category that maps to integer `i` in column `j`.
74    #[must_use]
75    pub fn categories(&self) -> &[Vec<String>] {
76        &self.categories
77    }
78
79    /// Return the number of input columns (features).
80    #[must_use]
81    pub fn n_features(&self) -> usize {
82        self.categories.len()
83    }
84}
85
86// ---------------------------------------------------------------------------
87// Trait implementations
88// ---------------------------------------------------------------------------
89
90impl Fit<Array2<String>, ()> for OrdinalEncoder {
91    type Fitted = FittedOrdinalEncoder;
92    type Error = FerroError;
93
94    /// Fit the encoder by building per-column category-to-index mappings.
95    ///
96    /// Categories are recorded in **order of first appearance** in each column.
97    ///
98    /// # Errors
99    ///
100    /// Returns [`FerroError::InsufficientSamples`] if the input has zero rows.
101    fn fit(&self, x: &Array2<String>, _y: &()) -> Result<FittedOrdinalEncoder, FerroError> {
102        let n_samples = x.nrows();
103        if n_samples == 0 {
104            return Err(FerroError::InsufficientSamples {
105                required: 1,
106                actual: 0,
107                context: "OrdinalEncoder::fit".into(),
108            });
109        }
110
111        let n_features = x.ncols();
112        let mut categories = Vec::with_capacity(n_features);
113        let mut category_to_index = Vec::with_capacity(n_features);
114
115        for j in 0..n_features {
116            let mut seen: Vec<String> = Vec::new();
117            let mut map: HashMap<String, usize> = HashMap::new();
118
119            for i in 0..n_samples {
120                let cat = x[[i, j]].clone();
121                if !map.contains_key(&cat) {
122                    let idx = seen.len();
123                    map.insert(cat.clone(), idx);
124                    seen.push(cat);
125                }
126            }
127
128            categories.push(seen);
129            category_to_index.push(map);
130        }
131
132        Ok(FittedOrdinalEncoder {
133            categories,
134            category_to_index,
135        })
136    }
137}
138
139impl Transform<Array2<String>> for FittedOrdinalEncoder {
140    type Output = Array2<usize>;
141    type Error = FerroError;
142
143    /// Transform string categories to integer indices.
144    ///
145    /// # Errors
146    ///
147    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
148    /// match the number of features seen during fitting.
149    ///
150    /// Returns [`FerroError::InvalidParameter`] if any category was not seen
151    /// during fitting.
152    fn transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
153        let n_features = self.categories.len();
154        if x.ncols() != n_features {
155            return Err(FerroError::ShapeMismatch {
156                expected: vec![x.nrows(), n_features],
157                actual: vec![x.nrows(), x.ncols()],
158                context: "FittedOrdinalEncoder::transform".into(),
159            });
160        }
161
162        let n_samples = x.nrows();
163        let mut out = Array2::zeros((n_samples, n_features));
164
165        for j in 0..n_features {
166            let map = &self.category_to_index[j];
167            for i in 0..n_samples {
168                let cat = &x[[i, j]];
169                match map.get(cat) {
170                    Some(&idx) => out[[i, j]] = idx,
171                    None => {
172                        return Err(FerroError::InvalidParameter {
173                            name: format!("x[{i},{j}]"),
174                            reason: format!("unknown category \"{cat}\" in column {j}"),
175                        });
176                    }
177                }
178            }
179        }
180
181        Ok(out)
182    }
183}
184
185/// Implement `Transform` on the unfitted encoder to satisfy the
186/// `FitTransform: Transform` supertrait bound.
187impl Transform<Array2<String>> for OrdinalEncoder {
188    type Output = Array2<usize>;
189    type Error = FerroError;
190
191    /// Always returns an error — the encoder must be fitted first.
192    fn transform(&self, _x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
193        Err(FerroError::InvalidParameter {
194            name: "OrdinalEncoder".into(),
195            reason: "encoder must be fitted before calling transform; use fit() first".into(),
196        })
197    }
198}
199
200impl FitTransform<Array2<String>> for OrdinalEncoder {
201    type FitError = FerroError;
202
203    /// Fit the encoder on `x` and return the encoded output in one step.
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if fitting or transformation fails.
208    fn fit_transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
209        let fitted = self.fit(x, &())?;
210        fitted.transform(x)
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Tests
216// ---------------------------------------------------------------------------
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use ndarray::Array2;
222
223    fn make_2col(rows: &[(&str, &str)]) -> Array2<String> {
224        let flat: Vec<String> = rows
225            .iter()
226            .flat_map(|(a, b)| [a.to_string(), b.to_string()])
227            .collect();
228        Array2::from_shape_vec((rows.len(), 2), flat).unwrap()
229    }
230
231    #[test]
232    fn test_ordinal_encoder_basic() {
233        let enc = OrdinalEncoder::new();
234        let x = make_2col(&[
235            ("cat", "small"),
236            ("dog", "large"),
237            ("cat", "medium"),
238            ("bird", "small"),
239        ]);
240        let fitted = enc.fit(&x, &()).unwrap();
241
242        // Categories should be in order of first appearance
243        assert_eq!(fitted.categories()[0], vec!["cat", "dog", "bird"]);
244        assert_eq!(fitted.categories()[1], vec!["small", "large", "medium"]);
245
246        let encoded = fitted.transform(&x).unwrap();
247        assert_eq!(encoded[[0, 0]], 0); // "cat" -> 0
248        assert_eq!(encoded[[1, 0]], 1); // "dog" -> 1
249        assert_eq!(encoded[[2, 0]], 0); // "cat" -> 0
250        assert_eq!(encoded[[3, 0]], 2); // "bird" -> 2
251        assert_eq!(encoded[[0, 1]], 0); // "small" -> 0
252        assert_eq!(encoded[[1, 1]], 1); // "large" -> 1
253        assert_eq!(encoded[[2, 1]], 2); // "medium" -> 2
254        assert_eq!(encoded[[3, 1]], 0); // "small" -> 0
255    }
256
257    #[test]
258    fn test_fit_transform_equivalence() {
259        let enc = OrdinalEncoder::new();
260        let x = make_2col(&[("a", "x"), ("b", "y"), ("a", "z")]);
261        let via_ft = enc.fit_transform(&x).unwrap();
262        let fitted = enc.fit(&x, &()).unwrap();
263        let via_sep = fitted.transform(&x).unwrap();
264        assert_eq!(via_ft, via_sep);
265    }
266
267    #[test]
268    fn test_unknown_category_error() {
269        let enc = OrdinalEncoder::new();
270        let x_train = make_2col(&[("cat", "small"), ("dog", "large")]);
271        let fitted = enc.fit(&x_train, &()).unwrap();
272        let x_test = make_2col(&[("fish", "small")]);
273        assert!(fitted.transform(&x_test).is_err());
274    }
275
276    #[test]
277    fn test_shape_mismatch_error() {
278        let enc = OrdinalEncoder::new();
279        let x_train = make_2col(&[("a", "x")]);
280        let fitted = enc.fit(&x_train, &()).unwrap();
281        // Single-column input when 2 cols expected
282        let x_bad = Array2::from_shape_vec((1, 1), vec!["a".to_string()]).unwrap();
283        assert!(fitted.transform(&x_bad).is_err());
284    }
285
286    #[test]
287    fn test_insufficient_samples_error() {
288        let enc = OrdinalEncoder::new();
289        let x: Array2<String> = Array2::from_shape_vec((0, 2), vec![]).unwrap();
290        assert!(enc.fit(&x, &()).is_err());
291    }
292
293    #[test]
294    fn test_unfitted_transform_error() {
295        let enc = OrdinalEncoder::new();
296        let x = make_2col(&[("a", "x")]);
297        assert!(enc.transform(&x).is_err());
298    }
299
300    #[test]
301    fn test_single_column() {
302        let enc = OrdinalEncoder::new();
303        let flat = vec![
304            "red".to_string(),
305            "green".to_string(),
306            "blue".to_string(),
307            "red".to_string(),
308        ];
309        let x = Array2::from_shape_vec((4, 1), flat).unwrap();
310        let fitted = enc.fit(&x, &()).unwrap();
311        assert_eq!(fitted.categories()[0], vec!["red", "green", "blue"]);
312        let encoded = fitted.transform(&x).unwrap();
313        assert_eq!(encoded[[0, 0]], 0);
314        assert_eq!(encoded[[1, 0]], 1);
315        assert_eq!(encoded[[2, 0]], 2);
316        assert_eq!(encoded[[3, 0]], 0);
317    }
318
319    #[test]
320    fn test_n_features() {
321        let enc = OrdinalEncoder::new();
322        let x = make_2col(&[("a", "x")]);
323        let fitted = enc.fit(&x, &()).unwrap();
324        assert_eq!(fitted.n_features(), 2);
325    }
326
327    #[test]
328    fn test_first_appearance_order() {
329        // Categories must be in first-appearance order, NOT alphabetical
330        let enc = OrdinalEncoder::new();
331        let flat = vec!["zebra".to_string(), "ant".to_string(), "moose".to_string()];
332        let x = Array2::from_shape_vec((3, 1), flat).unwrap();
333        let fitted = enc.fit(&x, &()).unwrap();
334        // zebra first, then ant, then moose
335        assert_eq!(fitted.categories()[0][0], "zebra");
336        assert_eq!(fitted.categories()[0][1], "ant");
337        assert_eq!(fitted.categories()[0][2], "moose");
338    }
339}