Skip to main content

ferrolearn_preprocess/
target_encoder.rs

1//! Target encoder: encode categorical features using target statistics.
2//!
3//! [`TargetEncoder`] replaces each category with the mean of the target variable
4//! for that category, regularised toward the global mean using smoothing.
5//!
6//! This is especially useful for high-cardinality categorical features where
7//! one-hot encoding would produce too many columns.
8//!
9//! # Smoothing
10//!
11//! The encoded value for category `c` is:
12//!
13//! ```text
14//! encoded(c) = (count(c) * mean_c + smooth * global_mean) / (count(c) + smooth)
15//! ```
16//!
17//! where `smooth` controls the degree of regularisation.
18
19use ferrolearn_core::error::FerroError;
20use ferrolearn_core::traits::{Fit, Transform};
21use ndarray::{Array1, Array2};
22use num_traits::Float;
23use std::collections::HashMap;
24
25// ---------------------------------------------------------------------------
26// TargetEncoder (unfitted)
27// ---------------------------------------------------------------------------
28
29/// An unfitted target encoder.
30///
31/// Takes a matrix of categorical integer features and a continuous (or binary)
32/// target vector at fit time. Each category is encoded as the smoothed mean of
33/// the target for that category.
34///
35/// # Parameters
36///
37/// - `smooth` — smoothing factor (default 1.0). Higher values regularise more
38///   toward the global mean. Set to 0 for no smoothing.
39///
40/// # Examples
41///
42/// ```
43/// use ferrolearn_preprocess::target_encoder::TargetEncoder;
44/// use ferrolearn_core::traits::{Fit, Transform};
45/// use ndarray::array;
46///
47/// let enc = TargetEncoder::<f64>::new(1.0);
48/// let x = array![[0usize, 1], [0, 0], [1, 1], [1, 0]];
49/// let y = array![1.0, 2.0, 3.0, 4.0];
50/// let fitted = enc.fit(&x, &y).unwrap();
51/// let out = fitted.transform(&x).unwrap();
52/// assert_eq!(out.shape(), &[4, 2]);
53/// ```
54#[must_use]
55#[derive(Debug, Clone)]
56pub struct TargetEncoder<F> {
57    /// Smoothing factor.
58    smooth: F,
59}
60
61impl<F: Float + Send + Sync + 'static> TargetEncoder<F> {
62    /// Create a new `TargetEncoder` with the given smoothing factor.
63    pub fn new(smooth: F) -> Self {
64        Self { smooth }
65    }
66
67    /// Return the smoothing factor.
68    #[must_use]
69    pub fn smooth(&self) -> F {
70        self.smooth
71    }
72}
73
74impl<F: Float + Send + Sync + 'static> Default for TargetEncoder<F> {
75    fn default() -> Self {
76        Self::new(F::one())
77    }
78}
79
80// ---------------------------------------------------------------------------
81// FittedTargetEncoder
82// ---------------------------------------------------------------------------
83
84/// A fitted target encoder holding per-feature, per-category encoding values.
85///
86/// Created by calling [`Fit::fit`] on a [`TargetEncoder`].
87#[derive(Debug, Clone)]
88pub struct FittedTargetEncoder<F> {
89    /// Per-feature mapping from category → encoded value.
90    category_maps: Vec<HashMap<usize, F>>,
91    /// Global target mean (used for unseen categories).
92    global_mean: F,
93}
94
95impl<F: Float + Send + Sync + 'static> FittedTargetEncoder<F> {
96    /// Return the encoding maps per feature.
97    #[must_use]
98    pub fn category_maps(&self) -> &[HashMap<usize, F>] {
99        &self.category_maps
100    }
101
102    /// Return the global target mean.
103    #[must_use]
104    pub fn global_mean(&self) -> F {
105        self.global_mean
106    }
107}
108
109// ---------------------------------------------------------------------------
110// Trait implementations
111// ---------------------------------------------------------------------------
112
113impl<F: Float + Send + Sync + 'static> Fit<Array2<usize>, Array1<F>> for TargetEncoder<F> {
114    type Fitted = FittedTargetEncoder<F>;
115    type Error = FerroError;
116
117    /// Fit the encoder by computing smoothed target means per category.
118    ///
119    /// # Errors
120    ///
121    /// - [`FerroError::InsufficientSamples`] if the input has zero rows.
122    /// - [`FerroError::ShapeMismatch`] if `x` rows and `y` length differ.
123    /// - [`FerroError::InvalidParameter`] if `smooth` is negative.
124    fn fit(&self, x: &Array2<usize>, y: &Array1<F>) -> Result<FittedTargetEncoder<F>, FerroError> {
125        let n_samples = x.nrows();
126        if n_samples == 0 {
127            return Err(FerroError::InsufficientSamples {
128                required: 1,
129                actual: 0,
130                context: "TargetEncoder::fit".into(),
131            });
132        }
133        if y.len() != n_samples {
134            return Err(FerroError::ShapeMismatch {
135                expected: vec![n_samples],
136                actual: vec![y.len()],
137                context: "TargetEncoder::fit — y must have same length as x rows".into(),
138            });
139        }
140        if self.smooth < F::zero() {
141            return Err(FerroError::InvalidParameter {
142                name: "smooth".into(),
143                reason: "smoothing factor must be non-negative".into(),
144            });
145        }
146
147        let n_features = x.ncols();
148        let global_mean = y.iter().copied().fold(F::zero(), |a, v| a + v)
149            / F::from(n_samples).unwrap_or(F::one());
150
151        let mut category_maps = Vec::with_capacity(n_features);
152
153        for j in 0..n_features {
154            // Collect (sum, count) per category
155            let mut cat_stats: HashMap<usize, (F, usize)> = HashMap::new();
156            for i in 0..n_samples {
157                let cat = x[[i, j]];
158                let entry = cat_stats.entry(cat).or_insert((F::zero(), 0));
159                entry.0 = entry.0 + y[i];
160                entry.1 += 1;
161            }
162
163            // Compute smoothed mean per category
164            let mut cat_map: HashMap<usize, F> = HashMap::new();
165            for (&cat, &(sum, count)) in &cat_stats {
166                let count_f = F::from(count).unwrap_or(F::one());
167                let cat_mean = sum / count_f;
168                let encoded =
169                    (count_f * cat_mean + self.smooth * global_mean) / (count_f + self.smooth);
170                cat_map.insert(cat, encoded);
171            }
172
173            category_maps.push(cat_map);
174        }
175
176        Ok(FittedTargetEncoder {
177            category_maps,
178            global_mean,
179        })
180    }
181}
182
183impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for FittedTargetEncoder<F> {
184    type Output = Array2<F>;
185    type Error = FerroError;
186
187    /// Encode categorical features using the learned target statistics.
188    ///
189    /// Unseen categories are encoded as the global target mean.
190    ///
191    /// # Errors
192    ///
193    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
194    /// from the number of features seen during fitting.
195    fn transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
196        let n_features = self.category_maps.len();
197        if x.ncols() != n_features {
198            return Err(FerroError::ShapeMismatch {
199                expected: vec![x.nrows(), n_features],
200                actual: vec![x.nrows(), x.ncols()],
201                context: "FittedTargetEncoder::transform".into(),
202            });
203        }
204
205        let n_samples = x.nrows();
206        let mut out = Array2::zeros((n_samples, n_features));
207
208        for j in 0..n_features {
209            let cat_map = &self.category_maps[j];
210            for i in 0..n_samples {
211                let cat = x[[i, j]];
212                out[[i, j]] = *cat_map.get(&cat).unwrap_or(&self.global_mean);
213            }
214        }
215
216        Ok(out)
217    }
218}
219
220// ---------------------------------------------------------------------------
221// Tests
222// ---------------------------------------------------------------------------
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use approx::assert_abs_diff_eq;
228    use ndarray::array;
229
230    #[test]
231    fn test_target_encoder_basic() {
232        let enc = TargetEncoder::<f64>::new(0.0); // no smoothing
233        // Category 0: targets [1.0, 2.0], mean = 1.5
234        // Category 1: targets [3.0, 4.0], mean = 3.5
235        let x = array![[0usize], [0], [1], [1]];
236        let y = array![1.0, 2.0, 3.0, 4.0];
237        let fitted = enc.fit(&x, &y).unwrap();
238        let out = fitted.transform(&x).unwrap();
239        assert_abs_diff_eq!(out[[0, 0]], 1.5, epsilon = 1e-10);
240        assert_abs_diff_eq!(out[[1, 0]], 1.5, epsilon = 1e-10);
241        assert_abs_diff_eq!(out[[2, 0]], 3.5, epsilon = 1e-10);
242        assert_abs_diff_eq!(out[[3, 0]], 3.5, epsilon = 1e-10);
243    }
244
245    #[test]
246    fn test_target_encoder_smoothing() {
247        let enc = TargetEncoder::<f64>::new(2.0);
248        // Category 0: targets [1.0], mean = 1.0, count = 1
249        // Category 1: targets [3.0, 5.0], mean = 4.0, count = 2
250        // Global mean = (1 + 3 + 5) / 3 = 3.0
251        let x = array![[0usize], [1], [1]];
252        let y = array![1.0, 3.0, 5.0];
253        let fitted = enc.fit(&x, &y).unwrap();
254        let out = fitted.transform(&x).unwrap();
255        // Cat 0: (1 * 1.0 + 2 * 3.0) / (1 + 2) = 7/3 ≈ 2.333
256        let expected_0 = (1.0 * 1.0 + 2.0 * 3.0) / (1.0 + 2.0);
257        assert_abs_diff_eq!(out[[0, 0]], expected_0, epsilon = 1e-10);
258        // Cat 1: (2 * 4.0 + 2 * 3.0) / (2 + 2) = 14/4 = 3.5
259        let expected_1 = (2.0 * 4.0 + 2.0 * 3.0) / (2.0 + 2.0);
260        assert_abs_diff_eq!(out[[1, 0]], expected_1, epsilon = 1e-10);
261    }
262
263    #[test]
264    fn test_target_encoder_unseen_category() {
265        let enc = TargetEncoder::<f64>::new(1.0);
266        let x = array![[0usize], [0], [1], [1]];
267        let y = array![1.0, 2.0, 3.0, 4.0];
268        let fitted = enc.fit(&x, &y).unwrap();
269        // Transform with unseen category 2
270        let x_new = array![[2usize]];
271        let out = fitted.transform(&x_new).unwrap();
272        // Unseen category → global mean = 2.5
273        assert_abs_diff_eq!(out[[0, 0]], 2.5, epsilon = 1e-10);
274    }
275
276    #[test]
277    fn test_target_encoder_multi_feature() {
278        let enc = TargetEncoder::<f64>::new(0.0);
279        let x = array![[0usize, 1], [0, 0], [1, 1], [1, 0]];
280        let y = array![1.0, 2.0, 3.0, 4.0];
281        let fitted = enc.fit(&x, &y).unwrap();
282        let out = fitted.transform(&x).unwrap();
283        assert_eq!(out.shape(), &[4, 2]);
284    }
285
286    #[test]
287    fn test_target_encoder_zero_rows_error() {
288        let enc = TargetEncoder::<f64>::new(1.0);
289        let x: Array2<usize> = Array2::zeros((0, 2));
290        let y: Array1<f64> = Array1::zeros(0);
291        assert!(enc.fit(&x, &y).is_err());
292    }
293
294    #[test]
295    fn test_target_encoder_shape_mismatch_fit() {
296        let enc = TargetEncoder::<f64>::new(1.0);
297        let x = array![[0usize], [1]];
298        let y = array![1.0]; // wrong length
299        assert!(enc.fit(&x, &y).is_err());
300    }
301
302    #[test]
303    fn test_target_encoder_shape_mismatch_transform() {
304        let enc = TargetEncoder::<f64>::new(1.0);
305        let x = array![[0usize, 1], [1, 0]];
306        let y = array![1.0, 2.0];
307        let fitted = enc.fit(&x, &y).unwrap();
308        let x_bad = array![[0usize]]; // wrong number of columns
309        assert!(fitted.transform(&x_bad).is_err());
310    }
311
312    #[test]
313    fn test_target_encoder_negative_smooth_error() {
314        let enc = TargetEncoder::<f64>::new(-1.0);
315        let x = array![[0usize]];
316        let y = array![1.0];
317        assert!(enc.fit(&x, &y).is_err());
318    }
319
320    #[test]
321    fn test_target_encoder_default() {
322        let enc = TargetEncoder::<f64>::default();
323        assert_abs_diff_eq!(enc.smooth(), 1.0, epsilon = 1e-10);
324    }
325
326    #[test]
327    fn test_target_encoder_global_mean_accessor() {
328        let enc = TargetEncoder::<f64>::new(0.0);
329        let x = array![[0usize], [1]];
330        let y = array![2.0, 4.0];
331        let fitted = enc.fit(&x, &y).unwrap();
332        assert_abs_diff_eq!(fitted.global_mean(), 3.0, epsilon = 1e-10);
333    }
334
335    #[test]
336    fn test_target_encoder_f32() {
337        let enc = TargetEncoder::<f32>::new(1.0f32);
338        let x = array![[0usize], [0], [1]];
339        let y: Array1<f32> = array![1.0f32, 2.0, 3.0];
340        let fitted = enc.fit(&x, &y).unwrap();
341        let out = fitted.transform(&x).unwrap();
342        assert!(!out[[0, 0]].is_nan());
343    }
344}