use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct TargetEncoder {
categorical_indices: Vec<usize>,
category_stats: Vec<HashMap<u64, (f64, u64)>>,
global_sum: f64,
global_count: u64,
smoothing: f64,
}
impl TargetEncoder {
pub fn new(categorical_indices: Vec<usize>) -> Self {
let n_cats = categorical_indices.len();
Self {
categorical_indices,
category_stats: vec![HashMap::new(); n_cats],
global_sum: 0.0,
global_count: 0,
smoothing: 10.0,
}
}
pub fn with_smoothing(categorical_indices: Vec<usize>, smoothing: f64) -> Self {
assert!(
smoothing >= 0.0,
"TargetEncoder: smoothing must be >= 0.0, got {}",
smoothing
);
let n_cats = categorical_indices.len();
Self {
categorical_indices,
category_stats: vec![HashMap::new(); n_cats],
global_sum: 0.0,
global_count: 0,
smoothing,
}
}
pub fn categorical_indices(&self) -> &[usize] {
&self.categorical_indices
}
pub fn smoothing(&self) -> f64 {
self.smoothing
}
pub fn global_mean(&self) -> f64 {
if self.global_count == 0 {
0.0
} else {
self.global_sum / self.global_count as f64
}
}
pub fn n_categories(&self, cat_idx: usize) -> usize {
self.category_stats[cat_idx].len()
}
pub fn update(&mut self, features: &[f64], target: f64) {
self.global_sum += target;
self.global_count += 1;
for (i, &feat_idx) in self.categorical_indices.iter().enumerate() {
let key = features[feat_idx].to_bits();
let entry = self.category_stats[i].entry(key).or_insert((0.0, 0));
entry.0 += target;
entry.1 += 1;
}
}
pub fn transform(&self, features: &[f64]) -> Vec<f64> {
let global_mean = self.global_mean();
let mut out = features.to_vec();
for (i, &feat_idx) in self.categorical_indices.iter().enumerate() {
let key = features[feat_idx].to_bits();
out[feat_idx] = match self.category_stats[i].get(&key) {
Some(&(cat_sum, cat_count)) if cat_count > 0 => {
let cat_mean = cat_sum / cat_count as f64;
(cat_count as f64 * cat_mean + self.smoothing * global_mean)
/ (cat_count as f64 + self.smoothing)
}
_ => global_mean,
};
}
out
}
pub fn update_and_transform(&mut self, features: &[f64], target: f64) -> Vec<f64> {
self.update(features, target);
self.transform(features)
}
pub fn reset(&mut self) {
for map in &mut self.category_stats {
map.clear();
}
self.global_sum = 0.0;
self.global_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-9;
#[test]
fn passthrough_non_categorical() {
let mut enc = TargetEncoder::new(vec![]);
enc.update(&[1.0, 2.5, 3.0], 10.0);
enc.update(&[4.0, 5.5, 6.0], 20.0);
let out = enc.transform(&[7.0, 8.5, 9.0]);
assert!((out[0] - 7.0).abs() < EPS);
assert!((out[1] - 8.5).abs() < EPS);
assert!((out[2] - 9.0).abs() < EPS);
}
#[test]
fn encodes_with_target_mean() {
let mut enc = TargetEncoder::new(vec![0]);
for _ in 0..1000 {
enc.update(&[1.0, 5.0], 10.0);
}
for _ in 0..1000 {
enc.update(&[2.0, 5.0], 50.0);
}
let out1 = enc.transform(&[1.0, 5.0]);
let out2 = enc.transform(&[2.0, 5.0]);
assert!(
(out1[0] - 10.0).abs() < 0.5,
"expected ~10.0 for cat 1, got {}",
out1[0]
);
assert!(
(out2[0] - 50.0).abs() < 0.5,
"expected ~50.0 for cat 2, got {}",
out2[0]
);
assert!((out1[1] - 5.0).abs() < EPS);
}
#[test]
fn smoothing_pulls_toward_global_mean() {
let mut enc = TargetEncoder::with_smoothing(vec![0], 10.0);
for _ in 0..100 {
enc.update(&[0.0], 0.0);
}
enc.update(&[1.0], 100.0);
enc.update(&[1.0], 100.0);
let global_mean = enc.global_mean();
let cat_mean = 100.0;
let out = enc.transform(&[1.0]);
assert!(
out[0] > global_mean && out[0] < cat_mean,
"encoded {} should be between global_mean {} and cat_mean {}",
out[0],
global_mean,
cat_mean
);
}
#[test]
fn unknown_category_uses_global_mean() {
let mut enc = TargetEncoder::new(vec![0]);
enc.update(&[0.0, 1.0], 10.0);
enc.update(&[1.0, 2.0], 20.0);
let global_mean = enc.global_mean();
let out = enc.transform(&[99.0, 3.0]);
assert!(
(out[0] - global_mean).abs() < EPS,
"unknown category should get global mean {}, got {}",
global_mean,
out[0]
);
assert!((out[1] - 3.0).abs() < EPS);
}
#[test]
fn update_and_transform_consistency() {
let features = [2.0, 7.5];
let target = 42.0;
let mut enc_a = TargetEncoder::with_smoothing(vec![0], 5.0);
enc_a.update(&[0.0, 1.0], 10.0);
enc_a.update(&[1.0, 2.0], 20.0);
let out_a = enc_a.update_and_transform(&features, target);
let mut enc_b = TargetEncoder::with_smoothing(vec![0], 5.0);
enc_b.update(&[0.0, 1.0], 10.0);
enc_b.update(&[1.0, 2.0], 20.0);
enc_b.update(&features, target);
let out_b = enc_b.transform(&features);
assert_eq!(out_a.len(), out_b.len());
for (a, b) in out_a.iter().zip(out_b.iter()) {
assert!(
(a - b).abs() < EPS,
"mismatch: update_and_transform={}, update+transform={}",
a,
b
);
}
}
#[test]
fn reset_clears_all_stats() {
let mut enc = TargetEncoder::new(vec![0]);
enc.update(&[1.0, 2.0], 10.0);
enc.update(&[2.0, 3.0], 20.0);
assert!(enc.global_mean() != 0.0);
assert!(enc.n_categories(0) > 0);
enc.reset();
assert_eq!(enc.global_mean(), 0.0);
assert_eq!(enc.n_categories(0), 0);
assert!(enc.category_stats[0].is_empty());
}
#[test]
fn multiple_categorical_features() {
let mut enc = TargetEncoder::with_smoothing(vec![0, 2], 1.0);
for _ in 0..50 {
enc.update(&[0.0, 5.0, 10.0], 100.0);
}
for _ in 0..50 {
enc.update(&[1.0, 5.0, 20.0], 200.0);
}
let out = enc.transform(&[0.0, 99.0, 20.0]);
assert!(
(out[1] - 99.0).abs() < EPS,
"numeric feature should be 99.0, got {}",
out[1]
);
assert!(
(out[0] - 100.0).abs() < 5.0,
"cat_a=0 expected ~100.0, got {}",
out[0]
);
assert!(
(out[2] - 200.0).abs() < 5.0,
"cat_b=20 expected ~200.0, got {}",
out[2]
);
assert_eq!(enc.n_categories(0), 2); assert_eq!(enc.n_categories(1), 2); }
#[test]
fn zero_smoothing_uses_pure_category_mean() {
let mut enc = TargetEncoder::with_smoothing(vec![0], 0.0);
enc.update(&[1.0], 10.0);
enc.update(&[1.0], 30.0);
enc.update(&[2.0], 100.0);
let out = enc.transform(&[1.0]);
assert!(
(out[0] - 20.0).abs() < EPS,
"with zero smoothing expected exact category mean 20.0, got {}",
out[0]
);
let out2 = enc.transform(&[2.0]);
assert!(
(out2[0] - 100.0).abs() < EPS,
"with zero smoothing expected exact category mean 100.0, got {}",
out2[0]
);
}
#[test]
#[should_panic(expected = "smoothing must be >= 0.0")]
fn negative_smoothing_panics() {
TargetEncoder::with_smoothing(vec![0], -1.0);
}
}