1use ferrolearn_core::error::FerroError;
9use ferrolearn_core::traits::{Fit, FitTransform, Transform};
10use ndarray::Array2;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Default)]
46pub struct OrdinalEncoder;
47
48impl OrdinalEncoder {
49 #[must_use]
51 pub fn new() -> Self {
52 Self
53 }
54}
55
56#[derive(Debug, Clone)]
64pub struct FittedOrdinalEncoder {
65 pub(crate) categories: Vec<Vec<String>>,
67 pub(crate) category_to_index: Vec<HashMap<String, usize>>,
69}
70
71impl FittedOrdinalEncoder {
72 #[must_use]
76 pub fn categories(&self) -> &[Vec<String>] {
77 &self.categories
78 }
79
80 #[must_use]
82 pub fn n_features(&self) -> usize {
83 self.categories.len()
84 }
85}
86
87impl Fit<Array2<String>, ()> for OrdinalEncoder {
92 type Fitted = FittedOrdinalEncoder;
93 type Error = FerroError;
94
95 fn fit(&self, x: &Array2<String>, _y: &()) -> Result<FittedOrdinalEncoder, FerroError> {
104 let n_samples = x.nrows();
105 if n_samples == 0 {
106 return Err(FerroError::InsufficientSamples {
107 required: 1,
108 actual: 0,
109 context: "OrdinalEncoder::fit".into(),
110 });
111 }
112
113 let n_features = x.ncols();
114 let mut categories = Vec::with_capacity(n_features);
115 let mut category_to_index = Vec::with_capacity(n_features);
116
117 for j in 0..n_features {
118 let mut unique: Vec<String> = Vec::new();
123 let mut seen_set: std::collections::HashSet<String> =
124 std::collections::HashSet::new();
125 for i in 0..n_samples {
126 let cat = &x[[i, j]];
127 if seen_set.insert(cat.clone()) {
128 unique.push(cat.clone());
129 }
130 }
131 unique.sort();
132
133 let map: HashMap<String, usize> = unique
134 .iter()
135 .enumerate()
136 .map(|(idx, s)| (s.clone(), idx))
137 .collect();
138
139 categories.push(unique);
140 category_to_index.push(map);
141 }
142
143 Ok(FittedOrdinalEncoder {
144 categories,
145 category_to_index,
146 })
147 }
148}
149
150impl Transform<Array2<String>> for FittedOrdinalEncoder {
151 type Output = Array2<usize>;
152 type Error = FerroError;
153
154 fn transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
164 let n_features = self.categories.len();
165 if x.ncols() != n_features {
166 return Err(FerroError::ShapeMismatch {
167 expected: vec![x.nrows(), n_features],
168 actual: vec![x.nrows(), x.ncols()],
169 context: "FittedOrdinalEncoder::transform".into(),
170 });
171 }
172
173 let n_samples = x.nrows();
174 let mut out = Array2::zeros((n_samples, n_features));
175
176 for j in 0..n_features {
177 let map = &self.category_to_index[j];
178 for i in 0..n_samples {
179 let cat = &x[[i, j]];
180 match map.get(cat) {
181 Some(&idx) => out[[i, j]] = idx,
182 None => {
183 return Err(FerroError::InvalidParameter {
184 name: format!("x[{i},{j}]"),
185 reason: format!("unknown category \"{cat}\" in column {j}"),
186 });
187 }
188 }
189 }
190 }
191
192 Ok(out)
193 }
194}
195
196impl Transform<Array2<String>> for OrdinalEncoder {
199 type Output = Array2<usize>;
200 type Error = FerroError;
201
202 fn transform(&self, _x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
204 Err(FerroError::InvalidParameter {
205 name: "OrdinalEncoder".into(),
206 reason: "encoder must be fitted before calling transform; use fit() first".into(),
207 })
208 }
209}
210
211impl FitTransform<Array2<String>> for OrdinalEncoder {
212 type FitError = FerroError;
213
214 fn fit_transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
220 let fitted = self.fit(x, &())?;
221 fitted.transform(x)
222 }
223}
224
225#[cfg(test)]
230mod tests {
231 use super::*;
232 use ndarray::Array2;
233
234 fn make_2col(rows: &[(&str, &str)]) -> Array2<String> {
235 let flat: Vec<String> = rows
236 .iter()
237 .flat_map(|(a, b)| [a.to_string(), b.to_string()])
238 .collect();
239 Array2::from_shape_vec((rows.len(), 2), flat).unwrap()
240 }
241
242 #[test]
243 fn test_ordinal_encoder_basic() {
244 let enc = OrdinalEncoder::new();
245 let x = make_2col(&[
246 ("cat", "small"),
247 ("dog", "large"),
248 ("cat", "medium"),
249 ("bird", "small"),
250 ]);
251 let fitted = enc.fit(&x, &()).unwrap();
252
253 assert_eq!(fitted.categories()[0], vec!["bird", "cat", "dog"]);
255 assert_eq!(fitted.categories()[1], vec!["large", "medium", "small"]);
256
257 let encoded = fitted.transform(&x).unwrap();
258 assert_eq!(encoded[[0, 0]], 1); assert_eq!(encoded[[1, 0]], 2); assert_eq!(encoded[[2, 0]], 1); assert_eq!(encoded[[3, 0]], 0); assert_eq!(encoded[[0, 1]], 2); assert_eq!(encoded[[1, 1]], 0); assert_eq!(encoded[[2, 1]], 1); assert_eq!(encoded[[3, 1]], 2); }
267
268 #[test]
269 fn test_fit_transform_equivalence() {
270 let enc = OrdinalEncoder::new();
271 let x = make_2col(&[("a", "x"), ("b", "y"), ("a", "z")]);
272 let via_ft = enc.fit_transform(&x).unwrap();
273 let fitted = enc.fit(&x, &()).unwrap();
274 let via_sep = fitted.transform(&x).unwrap();
275 assert_eq!(via_ft, via_sep);
276 }
277
278 #[test]
279 fn test_unknown_category_error() {
280 let enc = OrdinalEncoder::new();
281 let x_train = make_2col(&[("cat", "small"), ("dog", "large")]);
282 let fitted = enc.fit(&x_train, &()).unwrap();
283 let x_test = make_2col(&[("fish", "small")]);
284 assert!(fitted.transform(&x_test).is_err());
285 }
286
287 #[test]
288 fn test_shape_mismatch_error() {
289 let enc = OrdinalEncoder::new();
290 let x_train = make_2col(&[("a", "x")]);
291 let fitted = enc.fit(&x_train, &()).unwrap();
292 let x_bad = Array2::from_shape_vec((1, 1), vec!["a".to_string()]).unwrap();
294 assert!(fitted.transform(&x_bad).is_err());
295 }
296
297 #[test]
298 fn test_insufficient_samples_error() {
299 let enc = OrdinalEncoder::new();
300 let x: Array2<String> = Array2::from_shape_vec((0, 2), vec![]).unwrap();
301 assert!(enc.fit(&x, &()).is_err());
302 }
303
304 #[test]
305 fn test_unfitted_transform_error() {
306 let enc = OrdinalEncoder::new();
307 let x = make_2col(&[("a", "x")]);
308 assert!(enc.transform(&x).is_err());
309 }
310
311 #[test]
312 fn test_single_column() {
313 let enc = OrdinalEncoder::new();
314 let flat = vec![
315 "red".to_string(),
316 "green".to_string(),
317 "blue".to_string(),
318 "red".to_string(),
319 ];
320 let x = Array2::from_shape_vec((4, 1), flat).unwrap();
321 let fitted = enc.fit(&x, &()).unwrap();
322 assert_eq!(fitted.categories()[0], vec!["blue", "green", "red"]);
324 let encoded = fitted.transform(&x).unwrap();
325 assert_eq!(encoded[[0, 0]], 2); assert_eq!(encoded[[1, 0]], 1); assert_eq!(encoded[[2, 0]], 0); assert_eq!(encoded[[3, 0]], 2); }
330
331 #[test]
332 fn test_n_features() {
333 let enc = OrdinalEncoder::new();
334 let x = make_2col(&[("a", "x")]);
335 let fitted = enc.fit(&x, &()).unwrap();
336 assert_eq!(fitted.n_features(), 2);
337 }
338
339 #[test]
340 fn test_lexicographic_order() {
341 let enc = OrdinalEncoder::new();
343 let flat = vec!["zebra".to_string(), "ant".to_string(), "moose".to_string()];
344 let x = Array2::from_shape_vec((3, 1), flat).unwrap();
345 let fitted = enc.fit(&x, &()).unwrap();
346 assert_eq!(fitted.categories()[0][0], "ant");
348 assert_eq!(fitted.categories()[0][1], "moose");
349 assert_eq!(fitted.categories()[0][2], "zebra");
350 }
351}