1use ferrolearn_core::error::FerroError;
8use ferrolearn_core::traits::{Fit, FitTransform, Transform};
9use ndarray::Array2;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Default)]
45pub struct OrdinalEncoder;
46
47impl OrdinalEncoder {
48 #[must_use]
50 pub fn new() -> Self {
51 Self
52 }
53}
54
55#[derive(Debug, Clone)]
63pub struct FittedOrdinalEncoder {
64 pub(crate) categories: Vec<Vec<String>>,
66 pub(crate) category_to_index: Vec<HashMap<String, usize>>,
68}
69
70impl FittedOrdinalEncoder {
71 #[must_use]
75 pub fn categories(&self) -> &[Vec<String>] {
76 &self.categories
77 }
78
79 #[must_use]
81 pub fn n_features(&self) -> usize {
82 self.categories.len()
83 }
84}
85
86impl Fit<Array2<String>, ()> for OrdinalEncoder {
91 type Fitted = FittedOrdinalEncoder;
92 type Error = FerroError;
93
94 fn fit(&self, x: &Array2<String>, _y: &()) -> Result<FittedOrdinalEncoder, FerroError> {
102 let n_samples = x.nrows();
103 if n_samples == 0 {
104 return Err(FerroError::InsufficientSamples {
105 required: 1,
106 actual: 0,
107 context: "OrdinalEncoder::fit".into(),
108 });
109 }
110
111 let n_features = x.ncols();
112 let mut categories = Vec::with_capacity(n_features);
113 let mut category_to_index = Vec::with_capacity(n_features);
114
115 for j in 0..n_features {
116 let mut seen: Vec<String> = Vec::new();
117 let mut map: HashMap<String, usize> = HashMap::new();
118
119 for i in 0..n_samples {
120 let cat = x[[i, j]].clone();
121 if !map.contains_key(&cat) {
122 let idx = seen.len();
123 map.insert(cat.clone(), idx);
124 seen.push(cat);
125 }
126 }
127
128 categories.push(seen);
129 category_to_index.push(map);
130 }
131
132 Ok(FittedOrdinalEncoder {
133 categories,
134 category_to_index,
135 })
136 }
137}
138
139impl Transform<Array2<String>> for FittedOrdinalEncoder {
140 type Output = Array2<usize>;
141 type Error = FerroError;
142
143 fn transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
153 let n_features = self.categories.len();
154 if x.ncols() != n_features {
155 return Err(FerroError::ShapeMismatch {
156 expected: vec![x.nrows(), n_features],
157 actual: vec![x.nrows(), x.ncols()],
158 context: "FittedOrdinalEncoder::transform".into(),
159 });
160 }
161
162 let n_samples = x.nrows();
163 let mut out = Array2::zeros((n_samples, n_features));
164
165 for j in 0..n_features {
166 let map = &self.category_to_index[j];
167 for i in 0..n_samples {
168 let cat = &x[[i, j]];
169 match map.get(cat) {
170 Some(&idx) => out[[i, j]] = idx,
171 None => {
172 return Err(FerroError::InvalidParameter {
173 name: format!("x[{i},{j}]"),
174 reason: format!("unknown category \"{cat}\" in column {j}"),
175 });
176 }
177 }
178 }
179 }
180
181 Ok(out)
182 }
183}
184
185impl Transform<Array2<String>> for OrdinalEncoder {
188 type Output = Array2<usize>;
189 type Error = FerroError;
190
191 fn transform(&self, _x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
193 Err(FerroError::InvalidParameter {
194 name: "OrdinalEncoder".into(),
195 reason: "encoder must be fitted before calling transform; use fit() first".into(),
196 })
197 }
198}
199
200impl FitTransform<Array2<String>> for OrdinalEncoder {
201 type FitError = FerroError;
202
203 fn fit_transform(&self, x: &Array2<String>) -> Result<Array2<usize>, FerroError> {
209 let fitted = self.fit(x, &())?;
210 fitted.transform(x)
211 }
212}
213
214#[cfg(test)]
219mod tests {
220 use super::*;
221 use ndarray::Array2;
222
223 fn make_2col(rows: &[(&str, &str)]) -> Array2<String> {
224 let flat: Vec<String> = rows
225 .iter()
226 .flat_map(|(a, b)| [a.to_string(), b.to_string()])
227 .collect();
228 Array2::from_shape_vec((rows.len(), 2), flat).unwrap()
229 }
230
231 #[test]
232 fn test_ordinal_encoder_basic() {
233 let enc = OrdinalEncoder::new();
234 let x = make_2col(&[
235 ("cat", "small"),
236 ("dog", "large"),
237 ("cat", "medium"),
238 ("bird", "small"),
239 ]);
240 let fitted = enc.fit(&x, &()).unwrap();
241
242 assert_eq!(fitted.categories()[0], vec!["cat", "dog", "bird"]);
244 assert_eq!(fitted.categories()[1], vec!["small", "large", "medium"]);
245
246 let encoded = fitted.transform(&x).unwrap();
247 assert_eq!(encoded[[0, 0]], 0); assert_eq!(encoded[[1, 0]], 1); assert_eq!(encoded[[2, 0]], 0); assert_eq!(encoded[[3, 0]], 2); assert_eq!(encoded[[0, 1]], 0); assert_eq!(encoded[[1, 1]], 1); assert_eq!(encoded[[2, 1]], 2); assert_eq!(encoded[[3, 1]], 0); }
256
257 #[test]
258 fn test_fit_transform_equivalence() {
259 let enc = OrdinalEncoder::new();
260 let x = make_2col(&[("a", "x"), ("b", "y"), ("a", "z")]);
261 let via_ft = enc.fit_transform(&x).unwrap();
262 let fitted = enc.fit(&x, &()).unwrap();
263 let via_sep = fitted.transform(&x).unwrap();
264 assert_eq!(via_ft, via_sep);
265 }
266
267 #[test]
268 fn test_unknown_category_error() {
269 let enc = OrdinalEncoder::new();
270 let x_train = make_2col(&[("cat", "small"), ("dog", "large")]);
271 let fitted = enc.fit(&x_train, &()).unwrap();
272 let x_test = make_2col(&[("fish", "small")]);
273 assert!(fitted.transform(&x_test).is_err());
274 }
275
276 #[test]
277 fn test_shape_mismatch_error() {
278 let enc = OrdinalEncoder::new();
279 let x_train = make_2col(&[("a", "x")]);
280 let fitted = enc.fit(&x_train, &()).unwrap();
281 let x_bad = Array2::from_shape_vec((1, 1), vec!["a".to_string()]).unwrap();
283 assert!(fitted.transform(&x_bad).is_err());
284 }
285
286 #[test]
287 fn test_insufficient_samples_error() {
288 let enc = OrdinalEncoder::new();
289 let x: Array2<String> = Array2::from_shape_vec((0, 2), vec![]).unwrap();
290 assert!(enc.fit(&x, &()).is_err());
291 }
292
293 #[test]
294 fn test_unfitted_transform_error() {
295 let enc = OrdinalEncoder::new();
296 let x = make_2col(&[("a", "x")]);
297 assert!(enc.transform(&x).is_err());
298 }
299
300 #[test]
301 fn test_single_column() {
302 let enc = OrdinalEncoder::new();
303 let flat = vec![
304 "red".to_string(),
305 "green".to_string(),
306 "blue".to_string(),
307 "red".to_string(),
308 ];
309 let x = Array2::from_shape_vec((4, 1), flat).unwrap();
310 let fitted = enc.fit(&x, &()).unwrap();
311 assert_eq!(fitted.categories()[0], vec!["red", "green", "blue"]);
312 let encoded = fitted.transform(&x).unwrap();
313 assert_eq!(encoded[[0, 0]], 0);
314 assert_eq!(encoded[[1, 0]], 1);
315 assert_eq!(encoded[[2, 0]], 2);
316 assert_eq!(encoded[[3, 0]], 0);
317 }
318
319 #[test]
320 fn test_n_features() {
321 let enc = OrdinalEncoder::new();
322 let x = make_2col(&[("a", "x")]);
323 let fitted = enc.fit(&x, &()).unwrap();
324 assert_eq!(fitted.n_features(), 2);
325 }
326
327 #[test]
328 fn test_first_appearance_order() {
329 let enc = OrdinalEncoder::new();
331 let flat = vec!["zebra".to_string(), "ant".to_string(), "moose".to_string()];
332 let x = Array2::from_shape_vec((3, 1), flat).unwrap();
333 let fitted = enc.fit(&x, &()).unwrap();
334 assert_eq!(fitted.categories()[0][0], "zebra");
336 assert_eq!(fitted.categories()[0][1], "ant");
337 assert_eq!(fitted.categories()[0][2], "moose");
338 }
339}