1use ferrolearn_core::error::FerroError;
20use ferrolearn_core::traits::{Fit, Transform};
21use ndarray::{Array1, Array2};
22use num_traits::Float;
23use std::collections::HashMap;
24
25#[must_use]
55#[derive(Debug, Clone)]
56pub struct TargetEncoder<F> {
57 smooth: F,
59}
60
61impl<F: Float + Send + Sync + 'static> TargetEncoder<F> {
62 pub fn new(smooth: F) -> Self {
64 Self { smooth }
65 }
66
67 #[must_use]
69 pub fn smooth(&self) -> F {
70 self.smooth
71 }
72}
73
74impl<F: Float + Send + Sync + 'static> Default for TargetEncoder<F> {
75 fn default() -> Self {
76 Self::new(F::one())
77 }
78}
79
80#[derive(Debug, Clone)]
88pub struct FittedTargetEncoder<F> {
89 category_maps: Vec<HashMap<usize, F>>,
91 global_mean: F,
93}
94
95impl<F: Float + Send + Sync + 'static> FittedTargetEncoder<F> {
96 #[must_use]
98 pub fn category_maps(&self) -> &[HashMap<usize, F>] {
99 &self.category_maps
100 }
101
102 #[must_use]
104 pub fn global_mean(&self) -> F {
105 self.global_mean
106 }
107}
108
109impl<F: Float + Send + Sync + 'static> Fit<Array2<usize>, Array1<F>> for TargetEncoder<F> {
114 type Fitted = FittedTargetEncoder<F>;
115 type Error = FerroError;
116
117 fn fit(&self, x: &Array2<usize>, y: &Array1<F>) -> Result<FittedTargetEncoder<F>, FerroError> {
125 let n_samples = x.nrows();
126 if n_samples == 0 {
127 return Err(FerroError::InsufficientSamples {
128 required: 1,
129 actual: 0,
130 context: "TargetEncoder::fit".into(),
131 });
132 }
133 if y.len() != n_samples {
134 return Err(FerroError::ShapeMismatch {
135 expected: vec![n_samples],
136 actual: vec![y.len()],
137 context: "TargetEncoder::fit — y must have same length as x rows".into(),
138 });
139 }
140 if self.smooth < F::zero() {
141 return Err(FerroError::InvalidParameter {
142 name: "smooth".into(),
143 reason: "smoothing factor must be non-negative".into(),
144 });
145 }
146
147 let n_features = x.ncols();
148 let global_mean = y.iter().copied().fold(F::zero(), |a, v| a + v)
149 / F::from(n_samples).unwrap_or(F::one());
150
151 let mut category_maps = Vec::with_capacity(n_features);
152
153 for j in 0..n_features {
154 let mut cat_stats: HashMap<usize, (F, usize)> = HashMap::new();
156 for i in 0..n_samples {
157 let cat = x[[i, j]];
158 let entry = cat_stats.entry(cat).or_insert((F::zero(), 0));
159 entry.0 = entry.0 + y[i];
160 entry.1 += 1;
161 }
162
163 let mut cat_map: HashMap<usize, F> = HashMap::new();
165 for (&cat, &(sum, count)) in &cat_stats {
166 let count_f = F::from(count).unwrap_or(F::one());
167 let cat_mean = sum / count_f;
168 let encoded =
169 (count_f * cat_mean + self.smooth * global_mean) / (count_f + self.smooth);
170 cat_map.insert(cat, encoded);
171 }
172
173 category_maps.push(cat_map);
174 }
175
176 Ok(FittedTargetEncoder {
177 category_maps,
178 global_mean,
179 })
180 }
181}
182
183impl<F: Float + Send + Sync + 'static> Transform<Array2<usize>> for FittedTargetEncoder<F> {
184 type Output = Array2<F>;
185 type Error = FerroError;
186
187 fn transform(&self, x: &Array2<usize>) -> Result<Array2<F>, FerroError> {
196 let n_features = self.category_maps.len();
197 if x.ncols() != n_features {
198 return Err(FerroError::ShapeMismatch {
199 expected: vec![x.nrows(), n_features],
200 actual: vec![x.nrows(), x.ncols()],
201 context: "FittedTargetEncoder::transform".into(),
202 });
203 }
204
205 let n_samples = x.nrows();
206 let mut out = Array2::zeros((n_samples, n_features));
207
208 for j in 0..n_features {
209 let cat_map = &self.category_maps[j];
210 for i in 0..n_samples {
211 let cat = x[[i, j]];
212 out[[i, j]] = *cat_map.get(&cat).unwrap_or(&self.global_mean);
213 }
214 }
215
216 Ok(out)
217 }
218}
219
220#[cfg(test)]
225mod tests {
226 use super::*;
227 use approx::assert_abs_diff_eq;
228 use ndarray::array;
229
230 #[test]
231 fn test_target_encoder_basic() {
232 let enc = TargetEncoder::<f64>::new(0.0); let x = array![[0usize], [0], [1], [1]];
236 let y = array![1.0, 2.0, 3.0, 4.0];
237 let fitted = enc.fit(&x, &y).unwrap();
238 let out = fitted.transform(&x).unwrap();
239 assert_abs_diff_eq!(out[[0, 0]], 1.5, epsilon = 1e-10);
240 assert_abs_diff_eq!(out[[1, 0]], 1.5, epsilon = 1e-10);
241 assert_abs_diff_eq!(out[[2, 0]], 3.5, epsilon = 1e-10);
242 assert_abs_diff_eq!(out[[3, 0]], 3.5, epsilon = 1e-10);
243 }
244
245 #[test]
246 fn test_target_encoder_smoothing() {
247 let enc = TargetEncoder::<f64>::new(2.0);
248 let x = array![[0usize], [1], [1]];
252 let y = array![1.0, 3.0, 5.0];
253 let fitted = enc.fit(&x, &y).unwrap();
254 let out = fitted.transform(&x).unwrap();
255 let expected_0 = (1.0 * 1.0 + 2.0 * 3.0) / (1.0 + 2.0);
257 assert_abs_diff_eq!(out[[0, 0]], expected_0, epsilon = 1e-10);
258 let expected_1 = (2.0 * 4.0 + 2.0 * 3.0) / (2.0 + 2.0);
260 assert_abs_diff_eq!(out[[1, 0]], expected_1, epsilon = 1e-10);
261 }
262
263 #[test]
264 fn test_target_encoder_unseen_category() {
265 let enc = TargetEncoder::<f64>::new(1.0);
266 let x = array![[0usize], [0], [1], [1]];
267 let y = array![1.0, 2.0, 3.0, 4.0];
268 let fitted = enc.fit(&x, &y).unwrap();
269 let x_new = array![[2usize]];
271 let out = fitted.transform(&x_new).unwrap();
272 assert_abs_diff_eq!(out[[0, 0]], 2.5, epsilon = 1e-10);
274 }
275
276 #[test]
277 fn test_target_encoder_multi_feature() {
278 let enc = TargetEncoder::<f64>::new(0.0);
279 let x = array![[0usize, 1], [0, 0], [1, 1], [1, 0]];
280 let y = array![1.0, 2.0, 3.0, 4.0];
281 let fitted = enc.fit(&x, &y).unwrap();
282 let out = fitted.transform(&x).unwrap();
283 assert_eq!(out.shape(), &[4, 2]);
284 }
285
286 #[test]
287 fn test_target_encoder_zero_rows_error() {
288 let enc = TargetEncoder::<f64>::new(1.0);
289 let x: Array2<usize> = Array2::zeros((0, 2));
290 let y: Array1<f64> = Array1::zeros(0);
291 assert!(enc.fit(&x, &y).is_err());
292 }
293
294 #[test]
295 fn test_target_encoder_shape_mismatch_fit() {
296 let enc = TargetEncoder::<f64>::new(1.0);
297 let x = array![[0usize], [1]];
298 let y = array![1.0]; assert!(enc.fit(&x, &y).is_err());
300 }
301
302 #[test]
303 fn test_target_encoder_shape_mismatch_transform() {
304 let enc = TargetEncoder::<f64>::new(1.0);
305 let x = array![[0usize, 1], [1, 0]];
306 let y = array![1.0, 2.0];
307 let fitted = enc.fit(&x, &y).unwrap();
308 let x_bad = array![[0usize]]; assert!(fitted.transform(&x_bad).is_err());
310 }
311
312 #[test]
313 fn test_target_encoder_negative_smooth_error() {
314 let enc = TargetEncoder::<f64>::new(-1.0);
315 let x = array![[0usize]];
316 let y = array![1.0];
317 assert!(enc.fit(&x, &y).is_err());
318 }
319
320 #[test]
321 fn test_target_encoder_default() {
322 let enc = TargetEncoder::<f64>::default();
323 assert_abs_diff_eq!(enc.smooth(), 1.0, epsilon = 1e-10);
324 }
325
326 #[test]
327 fn test_target_encoder_global_mean_accessor() {
328 let enc = TargetEncoder::<f64>::new(0.0);
329 let x = array![[0usize], [1]];
330 let y = array![2.0, 4.0];
331 let fitted = enc.fit(&x, &y).unwrap();
332 assert_abs_diff_eq!(fitted.global_mean(), 3.0, epsilon = 1e-10);
333 }
334
335 #[test]
336 fn test_target_encoder_f32() {
337 let enc = TargetEncoder::<f32>::new(1.0f32);
338 let x = array![[0usize], [0], [1]];
339 let y: Array1<f32> = array![1.0f32, 2.0, 3.0];
340 let fitted = enc.fit(&x, &y).unwrap();
341 let out = fitted.transform(&x).unwrap();
342 assert!(!out[[0, 0]].is_nan());
343 }
344}