Skip to main content

anofox_ml_preprocessing/
ordinal_encoder.rs

1use anofox_ml_core::{Result, RustMlError};
2use std::collections::HashMap;
3
4/// Encodes string categories as ordinal (integer) values per column.
5///
6/// Takes a list of columns (each column is a `Vec<String>`) and maps each
7/// unique category to a sorted integer index. This is useful as a
8/// preprocessing step before one-hot encoding or for ordinal features.
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct OrdinalEncoder;
11
12impl OrdinalEncoder {
13    /// Create a new `OrdinalEncoder`.
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Fit the encoder, learning the vocabulary for each column.
19    ///
20    /// `columns` is a list of columns where each column is a `Vec<String>`.
21    /// All columns must have the same length.
22    pub fn fit(&self, columns: &[Vec<String>]) -> Result<FittedOrdinalEncoder> {
23        if columns.is_empty() {
24            return Err(RustMlError::EmptyInput("columns slice is empty".into()));
25        }
26
27        let nrows = columns[0].len();
28        if nrows == 0 {
29            return Err(RustMlError::EmptyInput("columns contain no rows".into()));
30        }
31
32        for (j, col) in columns.iter().enumerate() {
33            if col.len() != nrows {
34                return Err(RustMlError::ShapeMismatch(format!(
35                    "column {} has {} rows, expected {}",
36                    j,
37                    col.len(),
38                    nrows
39                )));
40            }
41        }
42
43        let mut vocabularies = Vec::with_capacity(columns.len());
44        let mut mappings = Vec::with_capacity(columns.len());
45
46        for col in columns {
47            let mut vocab: Vec<String> = col.iter().cloned().collect();
48            vocab.sort();
49            vocab.dedup();
50
51            let mapping: HashMap<String, usize> = vocab
52                .iter()
53                .enumerate()
54                .map(|(i, s)| (s.clone(), i))
55                .collect();
56
57            vocabularies.push(vocab);
58            mappings.push(mapping);
59        }
60
61        Ok(FittedOrdinalEncoder {
62            vocabularies,
63            mappings,
64        })
65    }
66}
67
68impl Default for OrdinalEncoder {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74/// Fitted OrdinalEncoder — holds per-column vocabularies and mappings.
75#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
76pub struct FittedOrdinalEncoder {
77    vocabularies: Vec<Vec<String>>,
78    mappings: Vec<HashMap<String, usize>>,
79}
80
81impl FittedOrdinalEncoder {
82    /// Transform string columns into ordinal-encoded integer columns.
83    ///
84    /// Returns a `Vec<Vec<usize>>` where each inner vec is an encoded column.
85    pub fn transform(&self, columns: &[Vec<String>]) -> Result<Vec<Vec<usize>>> {
86        if columns.len() != self.vocabularies.len() {
87            return Err(RustMlError::ShapeMismatch(format!(
88                "expected {} columns, got {}",
89                self.vocabularies.len(),
90                columns.len()
91            )));
92        }
93
94        let mut result = Vec::with_capacity(columns.len());
95
96        for (j, col) in columns.iter().enumerate() {
97            let mapping = &self.mappings[j];
98            let mut encoded = Vec::with_capacity(col.len());
99            for val in col {
100                match mapping.get(val) {
101                    Some(&idx) => encoded.push(idx),
102                    None => {
103                        return Err(RustMlError::InvalidParameter(format!(
104                            "unknown category '{}' in column {}",
105                            val, j
106                        )));
107                    }
108                }
109            }
110            result.push(encoded);
111        }
112
113        Ok(result)
114    }
115
116    /// Inverse-transform ordinal-encoded columns back to string columns.
117    pub fn inverse_transform(&self, columns: &[Vec<usize>]) -> Result<Vec<Vec<String>>> {
118        if columns.len() != self.vocabularies.len() {
119            return Err(RustMlError::ShapeMismatch(format!(
120                "expected {} columns, got {}",
121                self.vocabularies.len(),
122                columns.len()
123            )));
124        }
125
126        let mut result = Vec::with_capacity(columns.len());
127
128        for (j, col) in columns.iter().enumerate() {
129            let vocab = &self.vocabularies[j];
130            let mut decoded = Vec::with_capacity(col.len());
131            for &idx in col {
132                if idx >= vocab.len() {
133                    return Err(RustMlError::InvalidParameter(format!(
134                        "encoded index {} is out of range for column {} (vocabulary size {})",
135                        idx,
136                        j,
137                        vocab.len()
138                    )));
139                }
140                decoded.push(vocab[idx].clone());
141            }
142            result.push(decoded);
143        }
144
145        Ok(result)
146    }
147
148    /// Return the vocabulary for a specific column.
149    pub fn vocabulary(&self, column: usize) -> Option<&[String]> {
150        self.vocabularies.get(column).map(|v| v.as_slice())
151    }
152
153    /// Return the number of columns.
154    pub fn n_columns(&self) -> usize {
155        self.vocabularies.len()
156    }
157
158    /// Return the number of categories per column.
159    pub fn n_categories(&self) -> Vec<usize> {
160        self.vocabularies.iter().map(|v| v.len()).collect()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    fn s(val: &str) -> String {
169        val.to_string()
170    }
171
172    #[test]
173    fn test_fit_transform_single_column() {
174        let columns = vec![vec![s("cat"), s("dog"), s("cat"), s("bird")]];
175        let encoder = OrdinalEncoder::new();
176        let fitted = encoder.fit(&columns).unwrap();
177        let encoded = fitted.transform(&columns).unwrap();
178
179        // Sorted vocab: ["bird", "cat", "dog"] -> [0, 1, 2]
180        assert_eq!(encoded, vec![vec![1, 2, 1, 0]]);
181    }
182
183    #[test]
184    fn test_fit_transform_multiple_columns() {
185        let columns = vec![
186            vec![s("red"), s("blue"), s("green")],
187            vec![s("small"), s("large"), s("small")],
188        ];
189        let encoder = OrdinalEncoder::new();
190        let fitted = encoder.fit(&columns).unwrap();
191        let encoded = fitted.transform(&columns).unwrap();
192
193        // Col 0 vocab: ["blue", "green", "red"] -> [0, 1, 2]
194        // Col 1 vocab: ["large", "small"] -> [0, 1]
195        assert_eq!(encoded[0], vec![2, 0, 1]);
196        assert_eq!(encoded[1], vec![1, 0, 1]);
197    }
198
199    #[test]
200    fn test_inverse_transform_roundtrip() {
201        let columns = vec![
202            vec![s("apple"), s("banana"), s("cherry")],
203            vec![s("x"), s("y"), s("z")],
204        ];
205        let encoder = OrdinalEncoder::new();
206        let fitted = encoder.fit(&columns).unwrap();
207        let encoded = fitted.transform(&columns).unwrap();
208        let recovered = fitted.inverse_transform(&encoded).unwrap();
209
210        assert_eq!(recovered, columns);
211    }
212
213    #[test]
214    fn test_unknown_category() {
215        let columns = vec![vec![s("cat"), s("dog")]];
216        let encoder = OrdinalEncoder::new();
217        let fitted = encoder.fit(&columns).unwrap();
218
219        let unknown = vec![vec![s("fish")]];
220        assert!(fitted.transform(&unknown).is_err());
221    }
222
223    #[test]
224    fn test_out_of_range_index() {
225        let columns = vec![vec![s("a"), s("b")]];
226        let encoder = OrdinalEncoder::new();
227        let fitted = encoder.fit(&columns).unwrap();
228
229        let bad = vec![vec![99]];
230        assert!(fitted.inverse_transform(&bad).is_err());
231    }
232
233    #[test]
234    fn test_empty_columns() {
235        let columns: Vec<Vec<String>> = vec![];
236        let encoder = OrdinalEncoder::new();
237        assert!(encoder.fit(&columns).is_err());
238    }
239
240    #[test]
241    fn test_empty_rows() {
242        let columns = vec![vec![]];
243        let encoder = OrdinalEncoder::new();
244        assert!(encoder.fit(&columns).is_err());
245    }
246
247    #[test]
248    fn test_column_length_mismatch() {
249        let columns = vec![vec![s("a"), s("b")], vec![s("x")]];
250        let encoder = OrdinalEncoder::new();
251        assert!(encoder.fit(&columns).is_err());
252    }
253
254    #[test]
255    fn test_shape_mismatch_transform() {
256        let columns = vec![vec![s("a"), s("b")]];
257        let encoder = OrdinalEncoder::new();
258        let fitted = encoder.fit(&columns).unwrap();
259
260        // Wrong number of columns
261        let wrong = vec![vec![s("a")], vec![s("b")]];
262        assert!(fitted.transform(&wrong).is_err());
263    }
264
265    #[test]
266    fn test_shape_mismatch_inverse() {
267        let columns = vec![vec![s("a"), s("b")]];
268        let encoder = OrdinalEncoder::new();
269        let fitted = encoder.fit(&columns).unwrap();
270
271        let wrong = vec![vec![0], vec![1]];
272        assert!(fitted.inverse_transform(&wrong).is_err());
273    }
274
275    #[test]
276    fn test_vocabulary_accessor() {
277        let columns = vec![
278            vec![s("z"), s("a"), s("m")],
279            vec![s("big"), s("small"), s("big")],
280        ];
281        let encoder = OrdinalEncoder::new();
282        let fitted = encoder.fit(&columns).unwrap();
283
284        assert_eq!(fitted.vocabulary(0).unwrap(), &[s("a"), s("m"), s("z")]);
285        assert_eq!(fitted.vocabulary(1).unwrap(), &[s("big"), s("small")]);
286        assert!(fitted.vocabulary(5).is_none());
287    }
288
289    #[test]
290    fn test_n_categories() {
291        let columns = vec![vec![s("a"), s("b"), s("c")], vec![s("x"), s("y"), s("x")]];
292        let encoder = OrdinalEncoder::new();
293        let fitted = encoder.fit(&columns).unwrap();
294
295        assert_eq!(fitted.n_columns(), 2);
296        assert_eq!(fitted.n_categories(), vec![3, 2]);
297    }
298
299    #[test]
300    fn test_default() {
301        let encoder = OrdinalEncoder::default();
302        let columns = vec![vec![s("a")]];
303        let fitted = encoder.fit(&columns).unwrap();
304        assert_eq!(fitted.n_columns(), 1);
305    }
306
307    #[test]
308    fn test_sorted_vocabulary() {
309        let columns = vec![vec![s("zebra"), s("apple"), s("mango")]];
310        let encoder = OrdinalEncoder::new();
311        let fitted = encoder.fit(&columns).unwrap();
312
313        assert_eq!(
314            fitted.vocabulary(0).unwrap(),
315            &[s("apple"), s("mango"), s("zebra")]
316        );
317    }
318
319    #[test]
320    fn test_duplicate_values() {
321        let columns = vec![vec![s("a"), s("a"), s("b"), s("b"), s("a")]];
322        let encoder = OrdinalEncoder::new();
323        let fitted = encoder.fit(&columns).unwrap();
324        let encoded = fitted.transform(&columns).unwrap();
325
326        assert_eq!(encoded[0], vec![0, 0, 1, 1, 0]);
327    }
328}