Skip to main content

ferrolearn_preprocess/
label_encoder.rs

1//! Label encoder: maps string labels to integer indices.
2//!
3//! Learns an ordered mapping from unique string labels to consecutive integers
4//! `0, 1, ..., n_classes - 1`. Supports forward (`label → int`) and reverse
5//! (`int → label`) transformation.
6
7use ferrolearn_core::error::FerroError;
8use ferrolearn_core::traits::{Fit, FitTransform, Transform};
9use ndarray::Array1;
10use std::collections::HashMap;
11
12// ---------------------------------------------------------------------------
13// LabelEncoder (unfitted)
14// ---------------------------------------------------------------------------
15
16/// An unfitted label encoder.
17///
18/// Calling [`Fit::fit`] on an `Array1<String>` learns an alphabetically
19/// ordered mapping from unique string labels to integer indices
20/// `0, 1, ..., n_classes - 1` and returns a [`FittedLabelEncoder`].
21///
22/// # Examples
23///
24/// ```
25/// use ferrolearn_preprocess::LabelEncoder;
26/// use ferrolearn_core::traits::{Fit, Transform};
27/// use ndarray::array;
28///
29/// let enc = LabelEncoder::new();
30/// let labels = array!["cat".to_string(), "dog".to_string(), "cat".to_string()];
31/// let fitted = enc.fit(&labels, &()).unwrap();
32/// let encoded = fitted.transform(&labels).unwrap();
33/// assert_eq!(encoded[0], 0); // "cat" → 0
34/// assert_eq!(encoded[1], 1); // "dog" → 1
35/// ```
36#[derive(Debug, Clone, Default)]
37pub struct LabelEncoder;
38
39impl LabelEncoder {
40    /// Create a new `LabelEncoder`.
41    #[must_use]
42    pub fn new() -> Self {
43        Self
44    }
45}
46
47// ---------------------------------------------------------------------------
48// FittedLabelEncoder
49// ---------------------------------------------------------------------------
50
51/// A fitted label encoder holding the bidirectional label-to-index mapping.
52///
53/// Created by calling [`Fit::fit`] on a [`LabelEncoder`].
54#[derive(Debug, Clone)]
55pub struct FittedLabelEncoder {
56    /// Ordered list of unique class labels (index = class integer).
57    pub(crate) classes: Vec<String>,
58    /// Map from label string to integer index.
59    pub(crate) label_to_index: HashMap<String, usize>,
60}
61
62impl FittedLabelEncoder {
63    /// Return the ordered list of class labels.
64    ///
65    /// `classes[i]` is the label corresponding to integer `i`.
66    #[must_use]
67    pub fn classes(&self) -> &[String] {
68        &self.classes
69    }
70
71    /// Return the number of unique classes.
72    #[must_use]
73    pub fn n_classes(&self) -> usize {
74        self.classes.len()
75    }
76
77    /// Map integer indices back to the original string labels.
78    ///
79    /// # Errors
80    ///
81    /// Returns [`FerroError::InvalidParameter`] if any index is out of range.
82    pub fn inverse_transform(&self, y: &Array1<usize>) -> Result<Array1<String>, FerroError> {
83        let n_classes = self.classes.len();
84        let mut out = Vec::with_capacity(y.len());
85        for (i, &idx) in y.iter().enumerate() {
86            if idx >= n_classes {
87                return Err(FerroError::InvalidParameter {
88                    name: format!("y[{i}]"),
89                    reason: format!("index {idx} is out of range (n_classes = {n_classes})"),
90                });
91            }
92            out.push(self.classes[idx].clone());
93        }
94        Ok(Array1::from_vec(out))
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Trait implementations
100// ---------------------------------------------------------------------------
101
102impl Fit<Array1<String>, ()> for LabelEncoder {
103    type Fitted = FittedLabelEncoder;
104    type Error = FerroError;
105
106    /// Fit the encoder by learning the sorted set of unique labels.
107    ///
108    /// Labels are sorted alphabetically; the first label maps to `0`.
109    ///
110    /// # Errors
111    ///
112    /// Returns [`FerroError::InsufficientSamples`] if the input is empty.
113    fn fit(&self, x: &Array1<String>, _y: &()) -> Result<FittedLabelEncoder, FerroError> {
114        if x.is_empty() {
115            return Err(FerroError::InsufficientSamples {
116                required: 1,
117                actual: 0,
118                context: "LabelEncoder::fit".into(),
119            });
120        }
121
122        let mut unique: Vec<String> = x
123            .iter()
124            .cloned()
125            .collect::<std::collections::HashSet<_>>()
126            .into_iter()
127            .collect();
128        unique.sort();
129
130        let label_to_index: HashMap<String, usize> = unique
131            .iter()
132            .enumerate()
133            .map(|(i, label)| (label.clone(), i))
134            .collect();
135
136        Ok(FittedLabelEncoder {
137            classes: unique,
138            label_to_index,
139        })
140    }
141}
142
143impl Transform<Array1<String>> for FittedLabelEncoder {
144    type Output = Array1<usize>;
145    type Error = FerroError;
146
147    /// Transform string labels to integer indices.
148    ///
149    /// # Errors
150    ///
151    /// Returns [`FerroError::InvalidParameter`] if any label was not seen during fitting.
152    fn transform(&self, x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
153        let mut out = Vec::with_capacity(x.len());
154        for (i, label) in x.iter().enumerate() {
155            match self.label_to_index.get(label) {
156                Some(&idx) => out.push(idx),
157                None => {
158                    return Err(FerroError::InvalidParameter {
159                        name: format!("x[{i}]"),
160                        reason: format!("unknown label \"{label}\""),
161                    });
162                }
163            }
164        }
165        Ok(Array1::from_vec(out))
166    }
167}
168
169/// Implement `Transform` on the unfitted encoder to satisfy the `FitTransform: Transform`
170/// supertrait bound. Calling `transform` on an unfitted encoder always returns an error.
171impl Transform<Array1<String>> for LabelEncoder {
172    type Output = Array1<usize>;
173    type Error = FerroError;
174
175    /// Always returns an error — the encoder must be fitted first.
176    ///
177    /// Use [`Fit::fit`] to produce a [`FittedLabelEncoder`], then call
178    /// [`Transform::transform`] on that.
179    fn transform(&self, _x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
180        Err(FerroError::InvalidParameter {
181            name: "LabelEncoder".into(),
182            reason: "encoder must be fitted before calling transform; use fit() first".into(),
183        })
184    }
185}
186
187impl FitTransform<Array1<String>> for LabelEncoder {
188    type FitError = FerroError;
189
190    /// Fit the encoder on `x` and return the encoded output in one step.
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if fitting or transformation fails.
195    fn fit_transform(&self, x: &Array1<String>) -> Result<Array1<usize>, FerroError> {
196        let fitted = self.fit(x, &())?;
197        fitted.transform(x)
198    }
199}
200
201// ---------------------------------------------------------------------------
202// Tests
203// ---------------------------------------------------------------------------
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use ndarray::array;
209
210    fn str_arr(v: &[&str]) -> Array1<String> {
211        Array1::from_vec(v.iter().map(|s| s.to_string()).collect())
212    }
213
214    #[test]
215    fn test_label_encoder_basic() {
216        let enc = LabelEncoder::new();
217        let labels = str_arr(&["cat", "dog", "cat", "bird"]);
218        let fitted = enc.fit(&labels, &()).unwrap();
219
220        // Classes should be sorted alphabetically
221        assert_eq!(fitted.classes(), &["bird", "cat", "dog"]);
222        assert_eq!(fitted.n_classes(), 3);
223
224        let encoded = fitted.transform(&labels).unwrap();
225        assert_eq!(encoded[0], 1); // "cat" → 1
226        assert_eq!(encoded[1], 2); // "dog" → 2
227        assert_eq!(encoded[2], 1); // "cat" → 1
228        assert_eq!(encoded[3], 0); // "bird" → 0
229    }
230
231    #[test]
232    fn test_inverse_transform_roundtrip() {
233        let enc = LabelEncoder::new();
234        let labels = str_arr(&["a", "b", "c", "a", "b"]);
235        let fitted = enc.fit(&labels, &()).unwrap();
236        let encoded = fitted.transform(&labels).unwrap();
237        let recovered = fitted.inverse_transform(&encoded).unwrap();
238        for (orig, rec) in labels.iter().zip(recovered.iter()) {
239            assert_eq!(orig, rec);
240        }
241    }
242
243    #[test]
244    fn test_unknown_label_error() {
245        let enc = LabelEncoder::new();
246        let labels = str_arr(&["a", "b"]);
247        let fitted = enc.fit(&labels, &()).unwrap();
248        let unknown = str_arr(&["c"]);
249        assert!(fitted.transform(&unknown).is_err());
250    }
251
252    #[test]
253    fn test_inverse_transform_out_of_range() {
254        let enc = LabelEncoder::new();
255        let labels = str_arr(&["x", "y"]);
256        let fitted = enc.fit(&labels, &()).unwrap();
257        let bad_indices = array![5usize];
258        assert!(fitted.inverse_transform(&bad_indices).is_err());
259    }
260
261    #[test]
262    fn test_fit_transform_equivalence() {
263        let enc = LabelEncoder::new();
264        let labels = str_arr(&["foo", "bar", "foo", "baz"]);
265        let via_fit_transform = enc.fit_transform(&labels).unwrap();
266        let fitted = enc.fit(&labels, &()).unwrap();
267        let via_separate = fitted.transform(&labels).unwrap();
268        assert_eq!(via_fit_transform, via_separate);
269    }
270
271    #[test]
272    fn test_empty_input_error() {
273        let enc = LabelEncoder::new();
274        let empty: Array1<String> = Array1::from_vec(vec![]);
275        assert!(enc.fit(&empty, &()).is_err());
276    }
277
278    #[test]
279    fn test_single_class() {
280        let enc = LabelEncoder::new();
281        let labels = str_arr(&["only", "only", "only"]);
282        let fitted = enc.fit(&labels, &()).unwrap();
283        assert_eq!(fitted.n_classes(), 1);
284        let encoded = fitted.transform(&labels).unwrap();
285        assert!(encoded.iter().all(|&v| v == 0));
286    }
287}