anofox_ml_preprocessing/
label_encoder.rs1use anofox_ml_core::{Result, RustMlError};
2use ndarray::Array1;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct LabelEncoder;
14
15impl LabelEncoder {
16 pub fn new() -> Self {
18 Self
19 }
20
21 pub fn fit(&self, labels: &[String]) -> Result<FittedLabelEncoder> {
23 if labels.is_empty() {
24 return Err(RustMlError::EmptyInput("labels slice is empty".into()));
25 }
26
27 let mut vocab: Vec<String> = labels.iter().cloned().collect();
28 vocab.sort();
29 vocab.dedup();
30
31 let label_to_index: HashMap<String, usize> = vocab
32 .iter()
33 .enumerate()
34 .map(|(i, s)| (s.clone(), i))
35 .collect();
36
37 Ok(FittedLabelEncoder {
38 vocab,
39 label_to_index,
40 })
41 }
42}
43
44impl Default for LabelEncoder {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct FittedLabelEncoder {
53 vocab: Vec<String>,
54 label_to_index: HashMap<String, usize>,
55}
56
57impl FittedLabelEncoder {
58 pub fn transform(&self, labels: &[String]) -> Result<Array1<usize>> {
60 let mut encoded = Vec::with_capacity(labels.len());
61 for label in labels {
62 match self.label_to_index.get(label) {
63 Some(&idx) => encoded.push(idx),
64 None => {
65 return Err(RustMlError::InvalidParameter(format!(
66 "unknown label: '{}'",
67 label
68 )));
69 }
70 }
71 }
72 Ok(Array1::from_vec(encoded))
73 }
74
75 pub fn inverse_transform(&self, encoded: &Array1<usize>) -> Result<Vec<String>> {
77 let mut labels = Vec::with_capacity(encoded.len());
78 for &idx in encoded.iter() {
79 if idx >= self.vocab.len() {
80 return Err(RustMlError::InvalidParameter(format!(
81 "encoded index {} is out of range (vocabulary size {})",
82 idx,
83 self.vocab.len()
84 )));
85 }
86 labels.push(self.vocab[idx].clone());
87 }
88 Ok(labels)
89 }
90
91 pub fn vocab(&self) -> &[String] {
93 &self.vocab
94 }
95
96 pub fn num_classes(&self) -> usize {
98 self.vocab.len()
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use ndarray::array;
106
107 fn s(val: &str) -> String {
108 val.to_string()
109 }
110
111 #[test]
112 fn test_fit_transform() {
113 let labels = vec![s("cat"), s("dog"), s("cat"), s("bird")];
114 let encoder = LabelEncoder::new();
115 let fitted = encoder.fit(&labels).unwrap();
116 let encoded = fitted.transform(&labels).unwrap();
117
118 assert_eq!(fitted.vocab(), &[s("bird"), s("cat"), s("dog")]);
120 assert_eq!(encoded, array![1, 2, 1, 0]);
121 }
122
123 #[test]
124 fn test_inverse_transform_roundtrip() {
125 let labels = vec![
126 s("apple"),
127 s("banana"),
128 s("cherry"),
129 s("banana"),
130 s("apple"),
131 ];
132 let encoder = LabelEncoder::new();
133 let fitted = encoder.fit(&labels).unwrap();
134 let encoded = fitted.transform(&labels).unwrap();
135 let recovered = fitted.inverse_transform(&encoded).unwrap();
136
137 assert_eq!(recovered, labels);
138 }
139
140 #[test]
141 fn test_unknown_label() {
142 let labels = vec![s("cat"), s("dog")];
143 let encoder = LabelEncoder::new();
144 let fitted = encoder.fit(&labels).unwrap();
145
146 let unknown = vec![s("fish")];
147 assert!(fitted.transform(&unknown).is_err());
148 }
149
150 #[test]
151 fn test_out_of_range_index() {
152 let labels = vec![s("a"), s("b")];
153 let encoder = LabelEncoder::new();
154 let fitted = encoder.fit(&labels).unwrap();
155
156 let bad_encoded = array![0, 5];
157 assert!(fitted.inverse_transform(&bad_encoded).is_err());
158 }
159
160 #[test]
161 fn test_empty_labels() {
162 let labels: Vec<String> = vec![];
163 let encoder = LabelEncoder::new();
164 assert!(encoder.fit(&labels).is_err());
165 }
166
167 #[test]
168 fn test_single_label() {
169 let labels = vec![s("only")];
170 let encoder = LabelEncoder::new();
171 let fitted = encoder.fit(&labels).unwrap();
172 let encoded = fitted.transform(&labels).unwrap();
173
174 assert_eq!(encoded, array![0]);
175 assert_eq!(fitted.num_classes(), 1);
176 }
177
178 #[test]
179 fn test_duplicate_labels() {
180 let labels = vec![s("x"), s("x"), s("x"), s("y"), s("y")];
181 let encoder = LabelEncoder::new();
182 let fitted = encoder.fit(&labels).unwrap();
183
184 assert_eq!(fitted.num_classes(), 2);
185 assert_eq!(fitted.vocab(), &[s("x"), s("y")]);
186
187 let encoded = fitted.transform(&labels).unwrap();
188 assert_eq!(encoded, array![0, 0, 0, 1, 1]);
189 }
190
191 #[test]
192 fn test_sorted_vocabulary() {
193 let labels = vec![s("zebra"), s("apple"), s("mango"), s("banana")];
194 let encoder = LabelEncoder::new();
195 let fitted = encoder.fit(&labels).unwrap();
196
197 assert_eq!(
198 fitted.vocab(),
199 &[s("apple"), s("banana"), s("mango"), s("zebra")]
200 );
201 }
202
203 #[test]
204 fn test_default() {
205 let encoder = LabelEncoder::default();
206 let labels = vec![s("a"), s("b")];
207 let fitted = encoder.fit(&labels).unwrap();
208 assert_eq!(fitted.num_classes(), 2);
209 }
210
211 #[test]
212 fn test_many_classes() {
213 let labels: Vec<String> = (0..100).map(|i| format!("class_{:03}", i)).collect();
214 let encoder = LabelEncoder::new();
215 let fitted = encoder.fit(&labels).unwrap();
216 let encoded = fitted.transform(&labels).unwrap();
217 let recovered = fitted.inverse_transform(&encoded).unwrap();
218
219 assert_eq!(fitted.num_classes(), 100);
220 assert_eq!(recovered, labels);
221 }
222}