use rkyv::{Archive, Deserialize, Serialize};
use rustc_hash::FxHashMap;
#[derive(Debug, Clone, Default)]
struct CategoryStats {
sum: f64,
count: u64,
}
impl CategoryStats {
#[inline]
fn mean(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.sum / self.count as f64
}
}
}
pub struct OrderedTargetEncoder {
smoothing: f64,
global_stats: CategoryStats,
category_stats: FxHashMap<String, CategoryStats>,
}
impl OrderedTargetEncoder {
pub fn new(smoothing: f64) -> Self {
assert!(smoothing >= 0.0, "smoothing must be non-negative");
Self {
smoothing,
global_stats: CategoryStats::default(),
category_stats: FxHashMap::default(),
}
}
pub fn reset(&mut self) {
self.global_stats = CategoryStats::default();
self.category_stats.clear();
}
pub fn encode_and_update(&mut self, category: &str, target: f64) -> f64 {
let global_mean = self.global_stats.mean();
let cat_stats = self.category_stats.get(category);
let encoded = match cat_stats {
Some(stats) if stats.count > 0 => {
let n = stats.count as f64;
let m = self.smoothing;
(n * stats.mean() + m * global_mean) / (n + m)
}
_ => global_mean, };
self.global_stats.sum += target;
self.global_stats.count += 1;
let cat_stats = self.category_stats.entry(category.to_string()).or_default();
cat_stats.sum += target;
cat_stats.count += 1;
encoded
}
pub fn encode_column(&mut self, categories: &[String], targets: &[f64]) -> Vec<f64> {
assert_eq!(categories.len(), targets.len());
self.reset();
let mut encoded = Vec::with_capacity(categories.len());
for (cat, &target) in categories.iter().zip(targets.iter()) {
encoded.push(self.encode_and_update(cat, target));
}
encoded
}
pub fn encode_inference(&self, category: &str) -> f64 {
let global_mean = self.global_stats.mean();
match self.category_stats.get(category) {
Some(stats) if stats.count > 0 => {
let n = stats.count as f64;
let m = self.smoothing;
(n * stats.mean() + m * global_mean) / (n + m)
}
_ => global_mean,
}
}
pub fn get_encoding_map(&self) -> EncodingMap {
let global_mean = self.global_stats.mean();
let mut encodings: Vec<(String, f64)> = self
.category_stats
.iter()
.map(|(cat, stats)| {
let n = stats.count as f64;
let m = self.smoothing;
let encoded = (n * stats.mean() + m * global_mean) / (n + m);
(cat.clone(), encoded)
})
.collect();
encodings.sort_by(|a, b| a.0.cmp(&b.0));
EncodingMap {
encodings,
default_value: global_mean,
smoothing: self.smoothing,
}
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct EncodingMap {
pub encodings: Vec<(String, f64)>,
pub default_value: f64,
pub smoothing: f64,
}
impl EncodingMap {
pub fn encode(&self, category: &str) -> f64 {
match self
.encodings
.binary_search_by(|(cat, _)| cat.as_str().cmp(category))
{
Ok(pos) => self.encodings[pos].1,
Err(_) => self.default_value,
}
}
pub fn encode_batch(&self, categories: &[String]) -> Vec<f64> {
categories.iter().map(|c| self.encode(c)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ordered_target_encoding() {
let mut encoder = OrderedTargetEncoder::new(1.0);
let enc1 = encoder.encode_and_update("cat_a", 100.0);
assert_eq!(enc1, 0.0);
let enc2 = encoder.encode_and_update("cat_a", 200.0);
assert!((enc2 - 100.0).abs() < 1e-6);
let enc3 = encoder.encode_and_update("cat_b", 50.0);
assert!((enc3 - 150.0).abs() < 1e-6);
}
#[test]
fn test_smoothing_effect() {
let mut encoder_low = OrderedTargetEncoder::new(0.1);
let mut encoder_high = OrderedTargetEncoder::new(10.0);
encoder_low.encode_and_update("global", 0.0);
encoder_low.encode_and_update("global", 100.0);
encoder_low.encode_and_update("cat_a", 100.0);
encoder_high.encode_and_update("global", 0.0);
encoder_high.encode_and_update("global", 100.0);
encoder_high.encode_and_update("cat_a", 100.0);
let enc_low = encoder_low.encode_and_update("cat_a", 100.0);
let enc_high = encoder_high.encode_and_update("cat_a", 100.0);
assert!(enc_low > enc_high);
}
#[test]
fn test_encode_column() {
let mut encoder = OrderedTargetEncoder::new(1.0);
let categories = vec![
"a".to_string(),
"a".to_string(),
"b".to_string(),
"a".to_string(),
];
let targets = vec![10.0, 20.0, 30.0, 40.0];
let encoded = encoder.encode_column(&categories, &targets);
assert_eq!(encoded.len(), 4);
assert_eq!(encoded[0], 0.0);
}
#[test]
fn test_encoding_map() {
let mut encoder = OrderedTargetEncoder::new(1.0);
let categories = vec![
"a".to_string(),
"b".to_string(),
"a".to_string(),
"b".to_string(),
];
let targets = vec![10.0, 100.0, 20.0, 200.0];
encoder.encode_column(&categories, &targets);
let map = encoder.get_encoding_map();
let enc_a = map.encode("a");
let enc_b = map.encode("b");
let enc_unknown = map.encode("unknown");
assert!(enc_a < enc_b);
assert!((enc_unknown - map.default_value).abs() < 1e-6);
}
}