1use anofox_ml_core::{Result, RustMlError};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct OrdinalEncoder;
11
12impl OrdinalEncoder {
13 pub fn new() -> Self {
15 Self
16 }
17
18 pub fn fit(&self, columns: &[Vec<String>]) -> Result<FittedOrdinalEncoder> {
23 if columns.is_empty() {
24 return Err(RustMlError::EmptyInput("columns slice is empty".into()));
25 }
26
27 let nrows = columns[0].len();
28 if nrows == 0 {
29 return Err(RustMlError::EmptyInput("columns contain no rows".into()));
30 }
31
32 for (j, col) in columns.iter().enumerate() {
33 if col.len() != nrows {
34 return Err(RustMlError::ShapeMismatch(format!(
35 "column {} has {} rows, expected {}",
36 j,
37 col.len(),
38 nrows
39 )));
40 }
41 }
42
43 let mut vocabularies = Vec::with_capacity(columns.len());
44 let mut mappings = Vec::with_capacity(columns.len());
45
46 for col in columns {
47 let mut vocab: Vec<String> = col.iter().cloned().collect();
48 vocab.sort();
49 vocab.dedup();
50
51 let mapping: HashMap<String, usize> = vocab
52 .iter()
53 .enumerate()
54 .map(|(i, s)| (s.clone(), i))
55 .collect();
56
57 vocabularies.push(vocab);
58 mappings.push(mapping);
59 }
60
61 Ok(FittedOrdinalEncoder {
62 vocabularies,
63 mappings,
64 })
65 }
66}
67
68impl Default for OrdinalEncoder {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
76pub struct FittedOrdinalEncoder {
77 vocabularies: Vec<Vec<String>>,
78 mappings: Vec<HashMap<String, usize>>,
79}
80
81impl FittedOrdinalEncoder {
82 pub fn transform(&self, columns: &[Vec<String>]) -> Result<Vec<Vec<usize>>> {
86 if columns.len() != self.vocabularies.len() {
87 return Err(RustMlError::ShapeMismatch(format!(
88 "expected {} columns, got {}",
89 self.vocabularies.len(),
90 columns.len()
91 )));
92 }
93
94 let mut result = Vec::with_capacity(columns.len());
95
96 for (j, col) in columns.iter().enumerate() {
97 let mapping = &self.mappings[j];
98 let mut encoded = Vec::with_capacity(col.len());
99 for val in col {
100 match mapping.get(val) {
101 Some(&idx) => encoded.push(idx),
102 None => {
103 return Err(RustMlError::InvalidParameter(format!(
104 "unknown category '{}' in column {}",
105 val, j
106 )));
107 }
108 }
109 }
110 result.push(encoded);
111 }
112
113 Ok(result)
114 }
115
116 pub fn inverse_transform(&self, columns: &[Vec<usize>]) -> Result<Vec<Vec<String>>> {
118 if columns.len() != self.vocabularies.len() {
119 return Err(RustMlError::ShapeMismatch(format!(
120 "expected {} columns, got {}",
121 self.vocabularies.len(),
122 columns.len()
123 )));
124 }
125
126 let mut result = Vec::with_capacity(columns.len());
127
128 for (j, col) in columns.iter().enumerate() {
129 let vocab = &self.vocabularies[j];
130 let mut decoded = Vec::with_capacity(col.len());
131 for &idx in col {
132 if idx >= vocab.len() {
133 return Err(RustMlError::InvalidParameter(format!(
134 "encoded index {} is out of range for column {} (vocabulary size {})",
135 idx,
136 j,
137 vocab.len()
138 )));
139 }
140 decoded.push(vocab[idx].clone());
141 }
142 result.push(decoded);
143 }
144
145 Ok(result)
146 }
147
148 pub fn vocabulary(&self, column: usize) -> Option<&[String]> {
150 self.vocabularies.get(column).map(|v| v.as_slice())
151 }
152
153 pub fn n_columns(&self) -> usize {
155 self.vocabularies.len()
156 }
157
158 pub fn n_categories(&self) -> Vec<usize> {
160 self.vocabularies.iter().map(|v| v.len()).collect()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 fn s(val: &str) -> String {
169 val.to_string()
170 }
171
172 #[test]
173 fn test_fit_transform_single_column() {
174 let columns = vec![vec![s("cat"), s("dog"), s("cat"), s("bird")]];
175 let encoder = OrdinalEncoder::new();
176 let fitted = encoder.fit(&columns).unwrap();
177 let encoded = fitted.transform(&columns).unwrap();
178
179 assert_eq!(encoded, vec![vec![1, 2, 1, 0]]);
181 }
182
183 #[test]
184 fn test_fit_transform_multiple_columns() {
185 let columns = vec![
186 vec![s("red"), s("blue"), s("green")],
187 vec![s("small"), s("large"), s("small")],
188 ];
189 let encoder = OrdinalEncoder::new();
190 let fitted = encoder.fit(&columns).unwrap();
191 let encoded = fitted.transform(&columns).unwrap();
192
193 assert_eq!(encoded[0], vec![2, 0, 1]);
196 assert_eq!(encoded[1], vec![1, 0, 1]);
197 }
198
199 #[test]
200 fn test_inverse_transform_roundtrip() {
201 let columns = vec![
202 vec![s("apple"), s("banana"), s("cherry")],
203 vec![s("x"), s("y"), s("z")],
204 ];
205 let encoder = OrdinalEncoder::new();
206 let fitted = encoder.fit(&columns).unwrap();
207 let encoded = fitted.transform(&columns).unwrap();
208 let recovered = fitted.inverse_transform(&encoded).unwrap();
209
210 assert_eq!(recovered, columns);
211 }
212
213 #[test]
214 fn test_unknown_category() {
215 let columns = vec![vec![s("cat"), s("dog")]];
216 let encoder = OrdinalEncoder::new();
217 let fitted = encoder.fit(&columns).unwrap();
218
219 let unknown = vec![vec![s("fish")]];
220 assert!(fitted.transform(&unknown).is_err());
221 }
222
223 #[test]
224 fn test_out_of_range_index() {
225 let columns = vec![vec![s("a"), s("b")]];
226 let encoder = OrdinalEncoder::new();
227 let fitted = encoder.fit(&columns).unwrap();
228
229 let bad = vec![vec![99]];
230 assert!(fitted.inverse_transform(&bad).is_err());
231 }
232
233 #[test]
234 fn test_empty_columns() {
235 let columns: Vec<Vec<String>> = vec![];
236 let encoder = OrdinalEncoder::new();
237 assert!(encoder.fit(&columns).is_err());
238 }
239
240 #[test]
241 fn test_empty_rows() {
242 let columns = vec![vec![]];
243 let encoder = OrdinalEncoder::new();
244 assert!(encoder.fit(&columns).is_err());
245 }
246
247 #[test]
248 fn test_column_length_mismatch() {
249 let columns = vec![vec![s("a"), s("b")], vec![s("x")]];
250 let encoder = OrdinalEncoder::new();
251 assert!(encoder.fit(&columns).is_err());
252 }
253
254 #[test]
255 fn test_shape_mismatch_transform() {
256 let columns = vec![vec![s("a"), s("b")]];
257 let encoder = OrdinalEncoder::new();
258 let fitted = encoder.fit(&columns).unwrap();
259
260 let wrong = vec![vec![s("a")], vec![s("b")]];
262 assert!(fitted.transform(&wrong).is_err());
263 }
264
265 #[test]
266 fn test_shape_mismatch_inverse() {
267 let columns = vec![vec![s("a"), s("b")]];
268 let encoder = OrdinalEncoder::new();
269 let fitted = encoder.fit(&columns).unwrap();
270
271 let wrong = vec![vec![0], vec![1]];
272 assert!(fitted.inverse_transform(&wrong).is_err());
273 }
274
275 #[test]
276 fn test_vocabulary_accessor() {
277 let columns = vec![
278 vec![s("z"), s("a"), s("m")],
279 vec![s("big"), s("small"), s("big")],
280 ];
281 let encoder = OrdinalEncoder::new();
282 let fitted = encoder.fit(&columns).unwrap();
283
284 assert_eq!(fitted.vocabulary(0).unwrap(), &[s("a"), s("m"), s("z")]);
285 assert_eq!(fitted.vocabulary(1).unwrap(), &[s("big"), s("small")]);
286 assert!(fitted.vocabulary(5).is_none());
287 }
288
289 #[test]
290 fn test_n_categories() {
291 let columns = vec![vec![s("a"), s("b"), s("c")], vec![s("x"), s("y"), s("x")]];
292 let encoder = OrdinalEncoder::new();
293 let fitted = encoder.fit(&columns).unwrap();
294
295 assert_eq!(fitted.n_columns(), 2);
296 assert_eq!(fitted.n_categories(), vec![3, 2]);
297 }
298
299 #[test]
300 fn test_default() {
301 let encoder = OrdinalEncoder::default();
302 let columns = vec![vec![s("a")]];
303 let fitted = encoder.fit(&columns).unwrap();
304 assert_eq!(fitted.n_columns(), 1);
305 }
306
307 #[test]
308 fn test_sorted_vocabulary() {
309 let columns = vec![vec![s("zebra"), s("apple"), s("mango")]];
310 let encoder = OrdinalEncoder::new();
311 let fitted = encoder.fit(&columns).unwrap();
312
313 assert_eq!(
314 fitted.vocabulary(0).unwrap(),
315 &[s("apple"), s("mango"), s("zebra")]
316 );
317 }
318
319 #[test]
320 fn test_duplicate_values() {
321 let columns = vec![vec![s("a"), s("a"), s("b"), s("b"), s("a")]];
322 let encoder = OrdinalEncoder::new();
323 let fitted = encoder.fit(&columns).unwrap();
324 let encoded = fitted.transform(&columns).unwrap();
325
326 assert_eq!(encoded[0], vec![0, 0, 1, 1, 0]);
327 }
328}