use std::collections::HashMap;
use crate::preprocessing::incremental::IncrementalEncoder;
use crate::{Result, TreeBoostError};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FrequencyEncoder {
counts: HashMap<String, usize>,
total_count: usize,
unknown_value: Option<f32>,
normalize: bool,
fitted: bool,
}
impl FrequencyEncoder {
pub fn new() -> Self {
Self {
counts: HashMap::new(),
total_count: 0,
unknown_value: Some(0.0), normalize: false,
fitted: false,
}
}
pub fn with_unknown_value(mut self, value: Option<f32>) -> Self {
self.unknown_value = value;
self
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn fit(&mut self, categories: &[impl AsRef<str>]) {
self.counts.clear();
self.total_count = categories.len();
for cat in categories {
*self.counts.entry(cat.as_ref().to_string()).or_insert(0) += 1;
}
self.fitted = true;
}
pub fn transform_single(&self, category: &str) -> Option<f32> {
if !self.fitted {
return None;
}
match self.counts.get(category) {
Some(&count) => {
let value = if self.normalize {
count as f32 / self.total_count as f32
} else {
count as f32
};
Some(value)
}
None => self.unknown_value,
}
}
pub fn transform(&self, categories: &[impl AsRef<str>]) -> Result<Vec<f32>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"FrequencyEncoder not fitted. Call fit() first.".into(),
));
}
let mut result = Vec::with_capacity(categories.len());
for (i, cat) in categories.iter().enumerate() {
match self.transform_single(cat.as_ref()) {
Some(value) => result.push(value),
None => {
return Err(TreeBoostError::Data(format!(
"Unknown category '{}' at index {} and no unknown_value set",
cat.as_ref(),
i
)));
}
}
}
Ok(result)
}
pub fn fit_transform(&mut self, categories: &[impl AsRef<str>]) -> Result<Vec<f32>> {
self.fit(categories);
self.transform(categories)
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn num_categories(&self) -> usize {
self.counts.len()
}
pub fn counts(&self) -> &HashMap<String, usize> {
&self.counts
}
}
impl Default for FrequencyEncoder {
fn default() -> Self {
Self::new()
}
}
impl IncrementalEncoder for FrequencyEncoder {
fn partial_fit(&mut self, categories: &[&str]) -> Result<()> {
for cat in categories {
*self.counts.entry(cat.to_string()).or_insert(0) += 1;
self.total_count += 1;
}
self.fitted = true;
Ok(())
}
fn n_samples(&self) -> u64 {
self.total_count as u64
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LabelEncoder {
mapping: HashMap<String, u32>,
inverse_mapping: Vec<String>,
unknown_label: Option<u32>,
fitted: bool,
}
impl LabelEncoder {
pub fn new() -> Self {
Self {
mapping: HashMap::new(),
inverse_mapping: Vec::new(),
unknown_label: None,
fitted: false,
}
}
pub fn with_unknown_label(mut self, label: Option<u32>) -> Self {
self.unknown_label = label;
self
}
pub fn fit(&mut self, categories: &[impl AsRef<str>]) {
let mut unique: Vec<String> = categories
.iter()
.map(|c| c.as_ref().to_string())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
unique.sort();
self.mapping.clear();
self.inverse_mapping = unique.clone();
for (label, category) in unique.into_iter().enumerate() {
self.mapping.insert(category, label as u32);
}
self.fitted = true;
}
pub fn transform_single(&self, category: &str) -> Option<u32> {
if !self.fitted {
return None;
}
match self.mapping.get(category) {
Some(&label) => Some(label),
None => self.unknown_label,
}
}
pub fn transform(&self, categories: &[impl AsRef<str>]) -> Result<Vec<u32>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"LabelEncoder not fitted. Call fit() first.".into(),
));
}
let mut result = Vec::with_capacity(categories.len());
for (i, cat) in categories.iter().enumerate() {
match self.transform_single(cat.as_ref()) {
Some(label) => result.push(label),
None => {
return Err(TreeBoostError::Data(format!(
"Unknown category '{}' at index {} and no unknown_label set",
cat.as_ref(),
i
)));
}
}
}
Ok(result)
}
pub fn transform_f32(&self, categories: &[impl AsRef<str>]) -> Result<Vec<f32>> {
self.transform(categories)
.map(|labels| labels.into_iter().map(|l| l as f32).collect())
}
pub fn fit_transform(&mut self, categories: &[impl AsRef<str>]) -> Result<Vec<u32>> {
self.fit(categories);
self.transform(categories)
}
pub fn inverse_transform(&self, labels: &[u32]) -> Result<Vec<String>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"LabelEncoder not fitted. Call fit() first.".into(),
));
}
let mut result = Vec::with_capacity(labels.len());
for (i, &label) in labels.iter().enumerate() {
if (label as usize) < self.inverse_mapping.len() {
result.push(self.inverse_mapping[label as usize].clone());
} else {
return Err(TreeBoostError::Data(format!(
"Unknown label {} at index {}",
label, i
)));
}
}
Ok(result)
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn num_classes(&self) -> usize {
self.mapping.len()
}
pub fn get_category(&self, label: u32) -> Option<&str> {
self.inverse_mapping.get(label as usize).map(|s| s.as_str())
}
pub fn get_label(&self, category: &str) -> Option<u32> {
self.mapping.get(category).copied()
}
pub fn classes(&self) -> &[String] {
&self.inverse_mapping
}
}
impl Default for LabelEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub enum UnknownStrategy {
AllZeros,
Error,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OneHotEncoder {
categories: Vec<String>,
category_to_idx: HashMap<String, usize>,
handle_unknown: UnknownStrategy,
drop_first: bool,
max_categories: usize,
fitted: bool,
}
impl OneHotEncoder {
pub fn new() -> Self {
Self {
categories: Vec::new(),
category_to_idx: HashMap::new(),
handle_unknown: UnknownStrategy::AllZeros,
drop_first: false,
max_categories: 100, fitted: false,
}
}
pub fn with_unknown_strategy(mut self, strategy: UnknownStrategy) -> Self {
self.handle_unknown = strategy;
self
}
pub fn with_drop_first(mut self, drop: bool) -> Self {
self.drop_first = drop;
self
}
pub fn with_max_categories(mut self, max: usize) -> Self {
self.max_categories = max;
self
}
pub fn max_categories(&self) -> usize {
self.max_categories
}
pub fn fit(&mut self, categories: &[impl AsRef<str>]) -> Result<()> {
let mut unique: Vec<String> = categories
.iter()
.map(|c| c.as_ref().to_string())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
if self.max_categories > 0 && unique.len() > self.max_categories {
return Err(TreeBoostError::Config(format!(
"OneHotEncoder: {} unique categories exceeds max_categories limit of {}. \
High-cardinality one-hot encoding can cause memory explosion. \
Consider using TargetEncoder or increasing max_categories with with_max_categories().",
unique.len(),
self.max_categories
)));
}
unique.sort();
self.category_to_idx.clear();
for (idx, cat) in unique.iter().enumerate() {
self.category_to_idx.insert(cat.clone(), idx);
}
self.categories = unique;
self.fitted = true;
Ok(())
}
pub fn num_columns(&self) -> usize {
if self.drop_first && !self.categories.is_empty() {
self.categories.len() - 1
} else {
self.categories.len()
}
}
pub fn get_feature_names(&self, prefix: &str) -> Vec<String> {
let start_idx = if self.drop_first { 1 } else { 0 };
self.categories[start_idx..]
.iter()
.map(|cat| format!("{}_{}", prefix, cat))
.collect()
}
pub fn transform_single(&self, category: &str) -> Result<Vec<f32>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"OneHotEncoder not fitted. Call fit() first.".into(),
));
}
let num_cols = self.num_columns();
let mut result = vec![0.0; num_cols];
match self.category_to_idx.get(category) {
Some(&idx) => {
let adjusted_idx = if self.drop_first {
idx.saturating_sub(1)
} else {
idx
};
if !(self.drop_first && idx == 0) && adjusted_idx < num_cols {
result[adjusted_idx] = 1.0;
}
}
None => {
match self.handle_unknown {
UnknownStrategy::AllZeros => {
}
UnknownStrategy::Error => {
return Err(TreeBoostError::Data(format!(
"Unknown category '{}'",
category
)));
}
}
}
}
Ok(result)
}
pub fn transform(&self, categories: &[impl AsRef<str>]) -> Result<Vec<f32>> {
if !self.fitted {
return Err(TreeBoostError::Data(
"OneHotEncoder not fitted. Call fit() first.".into(),
));
}
let num_cols = self.num_columns();
let mut result = Vec::with_capacity(categories.len() * num_cols);
for cat in categories {
let row = self.transform_single(cat.as_ref())?;
result.extend(row);
}
Ok(result)
}
pub fn fit_transform(&mut self, categories: &[impl AsRef<str>]) -> Result<Vec<f32>> {
self.fit(categories)?;
self.transform(categories)
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
pub fn categories(&self) -> &[String] {
&self.categories
}
}
impl Default for OneHotEncoder {
fn default() -> Self {
Self::new()
}
}
impl OneHotEncoder {
pub fn partial_fit(&mut self, _categories: &[&str]) -> Result<()> {
Err(crate::preprocessing::incremental::not_supported_error(
"OneHotEncoder",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frequency_encoder_basic() {
let categories = vec!["A", "B", "A", "C", "A", "B"];
let mut encoder = FrequencyEncoder::new();
encoder.fit(&categories);
assert!(encoder.is_fitted());
assert_eq!(encoder.num_categories(), 3);
assert_eq!(encoder.transform_single("A"), Some(3.0));
assert_eq!(encoder.transform_single("B"), Some(2.0));
assert_eq!(encoder.transform_single("C"), Some(1.0));
}
#[test]
fn test_frequency_encoder_unknown() {
let categories = vec!["A", "B", "A"];
let mut encoder = FrequencyEncoder::new();
encoder.fit(&categories);
assert_eq!(encoder.transform_single("D"), Some(0.0));
let mut encoder2 = FrequencyEncoder::new().with_unknown_value(None);
encoder2.fit(&categories);
assert_eq!(encoder2.transform_single("D"), None);
}
#[test]
fn test_frequency_encoder_normalize() {
let categories = vec!["A", "B", "A", "A", "B"];
let mut encoder = FrequencyEncoder::new().with_normalize(true);
encoder.fit(&categories);
assert!((encoder.transform_single("A").unwrap() - 0.6).abs() < 1e-6);
assert!((encoder.transform_single("B").unwrap() - 0.4).abs() < 1e-6);
}
#[test]
fn test_frequency_encoder_transform_batch() {
let categories = vec!["A", "B", "A", "C", "A"];
let mut encoder = FrequencyEncoder::new();
encoder.fit(&categories);
let result = encoder.transform(&["A", "B", "C"]).unwrap();
assert_eq!(result, vec![3.0, 1.0, 1.0]);
}
#[test]
fn test_label_encoder_basic() {
let categories = vec!["red", "blue", "red", "green"];
let mut encoder = LabelEncoder::new();
encoder.fit(&categories);
assert!(encoder.is_fitted());
assert_eq!(encoder.num_classes(), 3);
assert_eq!(encoder.transform_single("blue"), Some(0));
assert_eq!(encoder.transform_single("green"), Some(1));
assert_eq!(encoder.transform_single("red"), Some(2));
}
#[test]
fn test_label_encoder_unknown() {
let categories = vec!["A", "B"];
let mut encoder = LabelEncoder::new();
encoder.fit(&categories);
assert_eq!(encoder.transform_single("C"), None);
let mut encoder2 = LabelEncoder::new().with_unknown_label(Some(999));
encoder2.fit(&categories);
assert_eq!(encoder2.transform_single("C"), Some(999));
}
#[test]
fn test_label_encoder_inverse_transform() {
let categories = vec!["red", "blue", "green"];
let mut encoder = LabelEncoder::new();
encoder.fit(&categories);
let labels = encoder.transform(&["red", "blue", "green"]).unwrap();
let reversed = encoder.inverse_transform(&labels).unwrap();
assert_eq!(reversed, vec!["red", "blue", "green"]);
}
#[test]
fn test_label_encoder_classes() {
let categories = vec!["C", "A", "B", "A"];
let mut encoder = LabelEncoder::new();
encoder.fit(&categories);
assert_eq!(encoder.classes(), &["A", "B", "C"]);
}
#[test]
fn test_onehot_encoder_basic() {
let categories = vec!["red", "blue", "green"];
let mut encoder = OneHotEncoder::new();
encoder.fit(&categories).unwrap();
assert!(encoder.is_fitted());
assert_eq!(encoder.num_columns(), 3);
let blue = encoder.transform_single("blue").unwrap();
let green = encoder.transform_single("green").unwrap();
let red = encoder.transform_single("red").unwrap();
assert_eq!(blue, vec![1.0, 0.0, 0.0]);
assert_eq!(green, vec![0.0, 1.0, 0.0]);
assert_eq!(red, vec![0.0, 0.0, 1.0]);
}
#[test]
fn test_onehot_encoder_drop_first() {
let categories = vec!["red", "blue", "green"];
let mut encoder = OneHotEncoder::new().with_drop_first(true);
encoder.fit(&categories).unwrap();
assert_eq!(encoder.num_columns(), 2);
let blue = encoder.transform_single("blue").unwrap();
let green = encoder.transform_single("green").unwrap();
let red = encoder.transform_single("red").unwrap();
assert_eq!(blue, vec![0.0, 0.0]);
assert_eq!(green, vec![1.0, 0.0]);
assert_eq!(red, vec![0.0, 1.0]);
}
#[test]
fn test_onehot_encoder_unknown_allzeros() {
let categories = vec!["A", "B"];
let mut encoder = OneHotEncoder::new().with_unknown_strategy(UnknownStrategy::AllZeros);
encoder.fit(&categories).unwrap();
let unknown = encoder.transform_single("C").unwrap();
assert_eq!(unknown, vec![0.0, 0.0]);
}
#[test]
fn test_onehot_encoder_unknown_error() {
let categories = vec!["A", "B"];
let mut encoder = OneHotEncoder::new().with_unknown_strategy(UnknownStrategy::Error);
encoder.fit(&categories).unwrap();
let result = encoder.transform_single("C");
assert!(result.is_err());
}
#[test]
fn test_onehot_encoder_feature_names() {
let categories = vec!["red", "blue", "green"];
let mut encoder = OneHotEncoder::new();
encoder.fit(&categories).unwrap();
let names = encoder.get_feature_names("color");
assert_eq!(names, vec!["color_blue", "color_green", "color_red"]);
}
#[test]
fn test_onehot_encoder_batch_transform() {
let categories = vec!["A", "B", "C"];
let mut encoder = OneHotEncoder::new();
encoder.fit(&categories).unwrap();
let result = encoder.transform(&["A", "B", "C"]).unwrap();
assert_eq!(result.len(), 9);
assert_eq!(result, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
}
#[test]
fn test_onehot_encoder_max_categories_limit() {
let categories: Vec<String> = (0..150).map(|i| format!("cat_{}", i)).collect();
let mut encoder = OneHotEncoder::new();
let result = encoder.fit(&categories);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("150 unique categories"));
assert!(err_msg.contains("max_categories limit of 100"));
let mut encoder2 = OneHotEncoder::new().with_max_categories(200);
assert!(encoder2.fit(&categories).is_ok());
assert_eq!(encoder2.num_columns(), 150);
let mut encoder3 = OneHotEncoder::new().with_max_categories(0);
assert!(encoder3.fit(&categories).is_ok());
let small_categories = vec!["A", "B", "C"];
let mut encoder4 = OneHotEncoder::new();
assert!(encoder4.fit(&small_categories).is_ok());
}
#[test]
fn test_frequency_encoder_partial_fit() {
let mut encoder = FrequencyEncoder::new();
encoder.partial_fit(&["A", "A", "B"]).unwrap();
assert_eq!(encoder.n_samples(), 3);
assert_eq!(encoder.transform_single("A"), Some(2.0)); assert_eq!(encoder.transform_single("B"), Some(1.0));
encoder.partial_fit(&["A", "C"]).unwrap();
assert_eq!(encoder.n_samples(), 5);
assert_eq!(encoder.transform_single("A"), Some(3.0)); assert_eq!(encoder.transform_single("B"), Some(1.0)); assert_eq!(encoder.transform_single("C"), Some(1.0));
assert_eq!(encoder.transform_single("D"), Some(0.0));
}
#[test]
fn test_frequency_encoder_partial_fit_normalized() {
let mut encoder = FrequencyEncoder::new().with_normalize(true);
encoder.partial_fit(&["A", "A", "B"]).unwrap();
assert!((encoder.transform_single("A").unwrap() - 2.0 / 3.0).abs() < 1e-6);
encoder.partial_fit(&["A", "C"]).unwrap();
assert!((encoder.transform_single("A").unwrap() - 3.0 / 5.0).abs() < 1e-6);
}
#[test]
fn test_onehot_encoder_partial_fit_not_supported() {
let mut encoder = OneHotEncoder::new();
let result = encoder.partial_fit(&["A", "B"]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("does not support incremental"));
assert!(err_msg.contains("FrequencyEncoder")); }
}