Skip to main content

anofox_ml_preprocessing/
label_encoder.rs

1use anofox_ml_core::{Result, RustMlError};
2use ndarray::Array1;
3use std::collections::HashMap;
4
5/// Encodes string labels as integer indices.
6///
7/// Maps each unique label to a unique integer in sorted order. This is useful
8/// for converting categorical target labels to numeric form for model training.
9///
10/// Unlike the numeric transformers, `LabelEncoder` works on string slices
11/// rather than float arrays, so it does not implement `FitUnsupervised`.
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct LabelEncoder;
14
15impl LabelEncoder {
16    /// Create a new `LabelEncoder`.
17    pub fn new() -> Self {
18        Self
19    }
20
21    /// Fit the encoder on the given labels, learning the vocabulary.
22    pub fn fit(&self, labels: &[String]) -> Result<FittedLabelEncoder> {
23        if labels.is_empty() {
24            return Err(RustMlError::EmptyInput("labels slice is empty".into()));
25        }
26
27        let mut vocab: Vec<String> = labels.iter().cloned().collect();
28        vocab.sort();
29        vocab.dedup();
30
31        let label_to_index: HashMap<String, usize> = vocab
32            .iter()
33            .enumerate()
34            .map(|(i, s)| (s.clone(), i))
35            .collect();
36
37        Ok(FittedLabelEncoder {
38            vocab,
39            label_to_index,
40        })
41    }
42}
43
44impl Default for LabelEncoder {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50/// Fitted LabelEncoder — holds the learned vocabulary and mapping.
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct FittedLabelEncoder {
53    vocab: Vec<String>,
54    label_to_index: HashMap<String, usize>,
55}
56
57impl FittedLabelEncoder {
58    /// Transform string labels into integer-encoded values.
59    pub fn transform(&self, labels: &[String]) -> Result<Array1<usize>> {
60        let mut encoded = Vec::with_capacity(labels.len());
61        for label in labels {
62            match self.label_to_index.get(label) {
63                Some(&idx) => encoded.push(idx),
64                None => {
65                    return Err(RustMlError::InvalidParameter(format!(
66                        "unknown label: '{}'",
67                        label
68                    )));
69                }
70            }
71        }
72        Ok(Array1::from_vec(encoded))
73    }
74
75    /// Inverse-transform integer-encoded values back to string labels.
76    pub fn inverse_transform(&self, encoded: &Array1<usize>) -> Result<Vec<String>> {
77        let mut labels = Vec::with_capacity(encoded.len());
78        for &idx in encoded.iter() {
79            if idx >= self.vocab.len() {
80                return Err(RustMlError::InvalidParameter(format!(
81                    "encoded index {} is out of range (vocabulary size {})",
82                    idx,
83                    self.vocab.len()
84                )));
85            }
86            labels.push(self.vocab[idx].clone());
87        }
88        Ok(labels)
89    }
90
91    /// Return the learned vocabulary (sorted).
92    pub fn vocab(&self) -> &[String] {
93        &self.vocab
94    }
95
96    /// Return the number of classes.
97    pub fn num_classes(&self) -> usize {
98        self.vocab.len()
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use ndarray::array;
106
107    fn s(val: &str) -> String {
108        val.to_string()
109    }
110
111    #[test]
112    fn test_fit_transform() {
113        let labels = vec![s("cat"), s("dog"), s("cat"), s("bird")];
114        let encoder = LabelEncoder::new();
115        let fitted = encoder.fit(&labels).unwrap();
116        let encoded = fitted.transform(&labels).unwrap();
117
118        // Sorted vocab: ["bird", "cat", "dog"]
119        assert_eq!(fitted.vocab(), &[s("bird"), s("cat"), s("dog")]);
120        assert_eq!(encoded, array![1, 2, 1, 0]);
121    }
122
123    #[test]
124    fn test_inverse_transform_roundtrip() {
125        let labels = vec![
126            s("apple"),
127            s("banana"),
128            s("cherry"),
129            s("banana"),
130            s("apple"),
131        ];
132        let encoder = LabelEncoder::new();
133        let fitted = encoder.fit(&labels).unwrap();
134        let encoded = fitted.transform(&labels).unwrap();
135        let recovered = fitted.inverse_transform(&encoded).unwrap();
136
137        assert_eq!(recovered, labels);
138    }
139
140    #[test]
141    fn test_unknown_label() {
142        let labels = vec![s("cat"), s("dog")];
143        let encoder = LabelEncoder::new();
144        let fitted = encoder.fit(&labels).unwrap();
145
146        let unknown = vec![s("fish")];
147        assert!(fitted.transform(&unknown).is_err());
148    }
149
150    #[test]
151    fn test_out_of_range_index() {
152        let labels = vec![s("a"), s("b")];
153        let encoder = LabelEncoder::new();
154        let fitted = encoder.fit(&labels).unwrap();
155
156        let bad_encoded = array![0, 5];
157        assert!(fitted.inverse_transform(&bad_encoded).is_err());
158    }
159
160    #[test]
161    fn test_empty_labels() {
162        let labels: Vec<String> = vec![];
163        let encoder = LabelEncoder::new();
164        assert!(encoder.fit(&labels).is_err());
165    }
166
167    #[test]
168    fn test_single_label() {
169        let labels = vec![s("only")];
170        let encoder = LabelEncoder::new();
171        let fitted = encoder.fit(&labels).unwrap();
172        let encoded = fitted.transform(&labels).unwrap();
173
174        assert_eq!(encoded, array![0]);
175        assert_eq!(fitted.num_classes(), 1);
176    }
177
178    #[test]
179    fn test_duplicate_labels() {
180        let labels = vec![s("x"), s("x"), s("x"), s("y"), s("y")];
181        let encoder = LabelEncoder::new();
182        let fitted = encoder.fit(&labels).unwrap();
183
184        assert_eq!(fitted.num_classes(), 2);
185        assert_eq!(fitted.vocab(), &[s("x"), s("y")]);
186
187        let encoded = fitted.transform(&labels).unwrap();
188        assert_eq!(encoded, array![0, 0, 0, 1, 1]);
189    }
190
191    #[test]
192    fn test_sorted_vocabulary() {
193        let labels = vec![s("zebra"), s("apple"), s("mango"), s("banana")];
194        let encoder = LabelEncoder::new();
195        let fitted = encoder.fit(&labels).unwrap();
196
197        assert_eq!(
198            fitted.vocab(),
199            &[s("apple"), s("banana"), s("mango"), s("zebra")]
200        );
201    }
202
203    #[test]
204    fn test_default() {
205        let encoder = LabelEncoder::default();
206        let labels = vec![s("a"), s("b")];
207        let fitted = encoder.fit(&labels).unwrap();
208        assert_eq!(fitted.num_classes(), 2);
209    }
210
211    #[test]
212    fn test_many_classes() {
213        let labels: Vec<String> = (0..100).map(|i| format!("class_{:03}", i)).collect();
214        let encoder = LabelEncoder::new();
215        let fitted = encoder.fit(&labels).unwrap();
216        let encoded = fitted.transform(&labels).unwrap();
217        let recovered = fitted.inverse_transform(&encoded).unwrap();
218
219        assert_eq!(fitted.num_classes(), 100);
220        assert_eq!(recovered, labels);
221    }
222}