ferrolearn_preprocess/
label_encoder.rs1use ferrolearn_core::error::FerroError;
8use ferrolearn_core::traits::{Fit, FitTransform, Transform};
9use ndarray::Array1;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Default)]
37pub struct LabelEncoder;
38
39impl LabelEncoder {
40 #[must_use]
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47#[derive(Debug, Clone)]
55pub struct FittedLabelEncoder {
56 pub(crate) classes: Vec<String>,
58 pub(crate) label_to_index: HashMap<String, usize>,
60}
61
62impl FittedLabelEncoder {
63 #[must_use]
67 pub fn classes(&self) -> &[String] {
68 &self.classes
69 }
70
71 #[must_use]
73 pub fn n_classes(&self) -> usize {
74 self.classes.len()
75 }
76
77 pub fn inverse_transform(&self, y: &Array1<usize>) -> Result<Array1<String>, FerroError> {
83 let n_classes = self.classes.len();
84 let mut out = Vec::with_capacity(y.len());
85 for (i, &idx) in y.iter().enumerate() {
86 if idx >= n_classes {
87 return Err(FerroError::InvalidParameter {
88 name: format!("y[{i}]"),
89 reason: format!("index {idx} is out of range (n_classes = {n_classes})"),
90 });
91 }
92 out.push(self.classes[idx].clone());
93 }
94 Ok(Array1::from_vec(out))
95 }
96}
97
98impl Fit<Array1<String>, ()> for LabelEncoder {
103 type Fitted = FittedLabelEncoder;
104 type Error = FerroError;
105
106 fn fit(&self, x: &Array1<String>, _y: &()) -> Result<FittedLabelEncoder, FerroError> {
114 if x.is_empty() {
115 return Err(FerroError::InsufficientSamples {
116 required: 1,
117 actual: 0,
118 context: "LabelEncoder::fit".into(),
119 });
120 }
121
122 let mut unique: Vec<String> = x
123 .iter()
124 .cloned()
125 .collect::<std::collections::HashSet<_>>()
126 .into_iter()
127 .collect();
128 unique.sort();
129
130 let label_to_index: HashMap<String, usize> = unique
131 .iter()
132 .enumerate()
133 .map(|(i, label)| (label.clone(), i))
134 .collect();
135
136 Ok(FittedLabelEncoder {
137 classes: unique,
138 label_to_index,
139 })
140 }
141}
142
143impl Transform<Array1<String>> for FittedLabelEncoder {
144 type Output = Array1<usize>;
145 type Error = FerroError;
146
147 fn transform(&self, x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
153 let mut out = Vec::with_capacity(x.len());
154 for (i, label) in x.iter().enumerate() {
155 match self.label_to_index.get(label) {
156 Some(&idx) => out.push(idx),
157 None => {
158 return Err(FerroError::InvalidParameter {
159 name: format!("x[{i}]"),
160 reason: format!("unknown label \"{label}\""),
161 });
162 }
163 }
164 }
165 Ok(Array1::from_vec(out))
166 }
167}
168
169impl Transform<Array1<String>> for LabelEncoder {
172 type Output = Array1<usize>;
173 type Error = FerroError;
174
175 fn transform(&self, _x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
180 Err(FerroError::InvalidParameter {
181 name: "LabelEncoder".into(),
182 reason: "encoder must be fitted before calling transform; use fit() first".into(),
183 })
184 }
185}
186
187impl FitTransform<Array1<String>> for LabelEncoder {
188 type FitError = FerroError;
189
190 fn fit_transform(&self, x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
196 let fitted = self.fit(x, &())?;
197 fitted.transform(x)
198 }
199}
200
201#[cfg(test)]
206mod tests {
207 use super::*;
208 use ndarray::array;
209
210 fn str_arr(v: &[&str]) -> Array1<String> {
211 Array1::from_vec(v.iter().map(|s| s.to_string()).collect())
212 }
213
214 #[test]
215 fn test_label_encoder_basic() {
216 let enc = LabelEncoder::new();
217 let labels = str_arr(&["cat", "dog", "cat", "bird"]);
218 let fitted = enc.fit(&labels, &()).unwrap();
219
220 assert_eq!(fitted.classes(), &["bird", "cat", "dog"]);
222 assert_eq!(fitted.n_classes(), 3);
223
224 let encoded = fitted.transform(&labels).unwrap();
225 assert_eq!(encoded[0], 1); assert_eq!(encoded[1], 2); assert_eq!(encoded[2], 1); assert_eq!(encoded[3], 0); }
230
231 #[test]
232 fn test_inverse_transform_roundtrip() {
233 let enc = LabelEncoder::new();
234 let labels = str_arr(&["a", "b", "c", "a", "b"]);
235 let fitted = enc.fit(&labels, &()).unwrap();
236 let encoded = fitted.transform(&labels).unwrap();
237 let recovered = fitted.inverse_transform(&encoded).unwrap();
238 for (orig, rec) in labels.iter().zip(recovered.iter()) {
239 assert_eq!(orig, rec);
240 }
241 }
242
243 #[test]
244 fn test_unknown_label_error() {
245 let enc = LabelEncoder::new();
246 let labels = str_arr(&["a", "b"]);
247 let fitted = enc.fit(&labels, &()).unwrap();
248 let unknown = str_arr(&["c"]);
249 assert!(fitted.transform(&unknown).is_err());
250 }
251
252 #[test]
253 fn test_inverse_transform_out_of_range() {
254 let enc = LabelEncoder::new();
255 let labels = str_arr(&["x", "y"]);
256 let fitted = enc.fit(&labels, &()).unwrap();
257 let bad_indices = array![5usize];
258 assert!(fitted.inverse_transform(&bad_indices).is_err());
259 }
260
261 #[test]
262 fn test_fit_transform_equivalence() {
263 let enc = LabelEncoder::new();
264 let labels = str_arr(&["foo", "bar", "foo", "baz"]);
265 let via_fit_transform = enc.fit_transform(&labels).unwrap();
266 let fitted = enc.fit(&labels, &()).unwrap();
267 let via_separate = fitted.transform(&labels).unwrap();
268 assert_eq!(via_fit_transform, via_separate);
269 }
270
271 #[test]
272 fn test_empty_input_error() {
273 let enc = LabelEncoder::new();
274 let empty: Array1<String> = Array1::from_vec(vec![]);
275 assert!(enc.fit(&empty, &()).is_err());
276 }
277
278 #[test]
279 fn test_single_class() {
280 let enc = LabelEncoder::new();
281 let labels = str_arr(&["only", "only", "only"]);
282 let fitted = enc.fit(&labels, &()).unwrap();
283 assert_eq!(fitted.n_classes(), 1);
284 let encoded = fitted.transform(&labels).unwrap();
285 assert!(encoded.iter().all(|&v| v == 0));
286 }
287}