use treeboost::encoding::{CategoryFilter, CategoryMapping, OrderedTargetEncoder};
#[test]
fn test_category_filter() {
let mut filter = CategoryFilter::new(0.01, 0.99, 5);
for _ in 0..100 {
filter.count("frequent_a");
filter.count("frequent_b");
}
for _ in 0..10 {
filter.count("medium");
}
for _ in 0..2 {
filter.count("rare");
}
filter.count("very_rare");
filter.finalize(vec![
"frequent_a".to_string(),
"frequent_b".to_string(),
"medium".to_string(),
"rare".to_string(),
"very_rare".to_string(),
]);
assert!(filter.is_frequent("frequent_a"));
assert!(filter.is_frequent("frequent_b"));
assert!(filter.is_frequent("medium"));
assert!(!filter.is_frequent("rare")); assert!(!filter.is_frequent("very_rare")); assert!(!filter.is_frequent("unseen"));
assert_eq!(filter.filter("frequent_a"), "frequent_a");
assert_eq!(filter.filter("rare"), "unknown");
assert_eq!(filter.filter("unseen"), "unknown");
}
#[test]
fn test_category_mapping() {
let mut filter = CategoryFilter::new(0.01, 0.99, 3);
for _ in 0..10 {
filter.count("cat_a");
filter.count("cat_b");
filter.count("cat_c");
}
filter.count("rare");
filter.finalize(vec![
"cat_a".to_string(),
"cat_b".to_string(),
"cat_c".to_string(),
"rare".to_string(),
]);
let mapping = CategoryMapping::from_filter(&filter);
assert_eq!(mapping.num_categories(), 4);
let idx_a = mapping.get_index("cat_a");
let idx_b = mapping.get_index("cat_b");
let idx_c = mapping.get_index("cat_c");
let idx_rare = mapping.get_index("rare");
assert!(idx_a < 3);
assert!(idx_b < 3);
assert!(idx_c < 3);
assert_ne!(idx_a, idx_b);
assert_ne!(idx_b, idx_c);
assert_ne!(idx_a, idx_c);
assert_eq!(idx_rare, mapping.unknown_idx);
assert_eq!(mapping.get_index("unseen"), mapping.unknown_idx);
}
#[test]
fn test_ordered_target_encoder() {
let categories = vec![
"A".to_string(),
"B".to_string(),
"A".to_string(),
"B".to_string(),
"A".to_string(),
"C".to_string(),
];
let targets = vec![10.0, 20.0, 12.0, 22.0, 14.0, 50.0];
let mut encoder = OrderedTargetEncoder::new(5.0);
let encoded = encoder.encode_column(&categories, &targets);
assert_eq!(encoded.len(), 6);
for &val in &encoded {
assert!(val.is_finite(), "Encoded value should be finite");
}
assert_eq!(encoded[0], 0.0, "First element should be 0 (no prior data)");
assert!((encoded[1] - 10.0).abs() < 0.01, "Second should be ~10.0");
assert!(encoded[5] > 0.0, "Later values should be positive");
}