use polars::prelude::*;
use rkyv::Archive;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::defaults::linear_feature as linear_feature_defaults;
use crate::Result;
use crate::TreeBoostError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Archive, Serialize, Deserialize)]
pub enum ColumnType {
Numeric,
Categorical,
Text,
Boolean,
DateTime,
IdLike,
Constant,
}
impl ColumnType {
pub fn detect(
dtype: &DataType,
cardinality_ratio: f32,
is_constant: bool,
is_monotonic: bool,
) -> Self {
match dtype {
DataType::Boolean => ColumnType::Boolean,
DataType::String | DataType::Categorical(_, _) => {
if cardinality_ratio > 0.9 && is_monotonic {
ColumnType::IdLike
} else if cardinality_ratio > 0.5 {
ColumnType::Text
} else {
ColumnType::Categorical
}
}
DataType::Date | DataType::Datetime(_, _) | DataType::Time | DataType::Duration(_) => {
ColumnType::DateTime
}
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => {
if is_constant {
ColumnType::Constant
} else if cardinality_ratio > 0.9 && is_monotonic {
ColumnType::IdLike
} else {
ColumnType::Numeric
}
}
_ => ColumnType::Text, }
}
pub fn from_series(series: &Series) -> Self {
let dtype = series.dtype();
let len = series.len();
if len == 0 {
return ColumnType::Numeric;
}
let unique_count = series.n_unique().unwrap_or(1);
let cardinality_ratio = unique_count as f32 / len as f32;
let is_constant = unique_count <= 1;
let is_monotonic = Self::is_monotonic(series);
Self::detect(dtype, cardinality_ratio, is_constant, is_monotonic)
}
fn is_monotonic(series: &Series) -> bool {
match series.dtype() {
DataType::Int32 => {
if let Ok(ca) = series.i32() {
let vals: Vec<Option<i32>> = ca.into_iter().collect();
for i in 1..vals.len() {
if let (Some(a), Some(b)) = (vals[i - 1], vals[i]) {
if b <= a {
return false;
}
}
}
true
} else {
false
}
}
DataType::Int64 => {
if let Ok(ca) = series.i64() {
let vals: Vec<Option<i64>> = ca.into_iter().collect();
for i in 1..vals.len() {
if let (Some(a), Some(b)) = (vals[i - 1], vals[i]) {
if b <= a {
return false;
}
}
}
true
} else {
false
}
}
DataType::UInt32 => {
if let Ok(ca) = series.u32() {
let vals: Vec<Option<u32>> = ca.into_iter().collect();
for i in 1..vals.len() {
if let (Some(a), Some(b)) = (vals[i - 1], vals[i]) {
if b <= a {
return false;
}
}
}
true
} else {
false
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct LinearFeatureConfig {
pub exclude_columns: HashSet<String>,
pub exclude_categorical: bool,
pub exclude_id: bool,
pub exclude_constant: bool,
pub exclude_boolean: bool,
pub exclude_datetime: bool,
pub exclude_text: bool,
}
impl Default for LinearFeatureConfig {
fn default() -> Self {
Self {
exclude_columns: HashSet::new(),
exclude_categorical: linear_feature_defaults::EXCLUDE_CATEGORICAL, exclude_id: linear_feature_defaults::EXCLUDE_ID, exclude_constant: linear_feature_defaults::EXCLUDE_CONSTANT, exclude_boolean: linear_feature_defaults::EXCLUDE_BOOLEAN, exclude_datetime: linear_feature_defaults::EXCLUDE_DATETIME, exclude_text: linear_feature_defaults::EXCLUDE_TEXT, }
}
}
impl LinearFeatureConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_exclude_columns(mut self, columns: &[&str]) -> Self {
self.exclude_columns = columns.iter().map(|s| s.to_string()).collect();
self
}
pub fn exclude_column(mut self, column: &str) -> Self {
self.exclude_columns.insert(column.to_string());
self
}
pub fn with_exclude_categorical(mut self, enable: bool) -> Self {
self.exclude_categorical = enable;
self
}
pub fn with_exclude_id(mut self, enable: bool) -> Self {
self.exclude_id = enable;
self
}
pub fn with_exclude_constant(mut self, enable: bool) -> Self {
self.exclude_constant = enable;
self
}
pub fn with_exclude_boolean(mut self, enable: bool) -> Self {
self.exclude_boolean = enable;
self
}
pub fn with_exclude_datetime(mut self, enable: bool) -> Self {
self.exclude_datetime = enable;
self
}
pub fn with_exclude_text(mut self, enable: bool) -> Self {
self.exclude_text = enable;
self
}
}
#[derive(Debug, Clone)]
pub struct FeatureExtractionResult {
pub features: Vec<f32>,
pub num_features: usize,
pub feature_names: Vec<String>,
pub report: FeatureExtractionReport,
}
#[derive(Debug, Clone)]
pub struct FeatureExtractionReport {
pub total_columns: usize,
pub excluded_by_type: HashMap<ColumnType, Vec<String>>,
pub excluded_by_user: Vec<String>,
pub target_column: String,
pub final_features: Vec<String>,
}
impl FeatureExtractionReport {
pub fn format(&self) -> String {
let mut output = String::new();
output.push_str("=== Feature Extraction Report ===\n");
output.push_str(&format!("Total columns: {}\n", self.total_columns));
output.push_str(&format!("Target column: {}\n", self.target_column));
output.push_str(&format!(
"Final features: {}\n\n",
self.final_features.len()
));
if !self.excluded_by_user.is_empty() {
output.push_str("Excluded by user:\n");
for col in &self.excluded_by_user {
output.push_str(&format!(" - {}\n", col));
}
output.push('\n');
}
for (col_type, cols) in &self.excluded_by_type {
if !cols.is_empty() {
output.push_str(&format!("Excluded as {:?}:\n", col_type));
for col in cols {
output.push_str(&format!(" - {}\n", col));
}
output.push('\n');
}
}
output.push_str("Final features for linear model:\n");
for (idx, col) in self.final_features.iter().enumerate() {
output.push_str(&format!(" [{}] {}\n", idx, col));
}
output
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize)]
pub struct FeatureExtractor {
config: LinearFeatureConfig,
}
impl Default for FeatureExtractor {
fn default() -> Self {
Self::new()
}
}
impl FeatureExtractor {
pub fn new() -> Self {
Self {
config: LinearFeatureConfig::default(),
}
}
pub fn with_config(config: LinearFeatureConfig) -> Self {
Self { config }
}
pub fn extract_numeric_features(
&self,
df: &DataFrame,
target_col: &str,
) -> Result<FeatureExtractionResult> {
let num_rows = df.height();
let total_columns = df.width();
if num_rows == 0 {
return Ok(FeatureExtractionResult {
features: vec![],
num_features: 0,
feature_names: vec![],
report: FeatureExtractionReport {
total_columns,
excluded_by_type: HashMap::new(),
excluded_by_user: vec![],
target_column: target_col.to_string(),
final_features: vec![],
},
});
}
let mut excluded_by_type: HashMap<ColumnType, Vec<String>> = HashMap::new();
let mut excluded_by_user: Vec<String> = vec![];
let mut feature_names: Vec<String> = vec![];
for col_name in df.get_column_names() {
let col_name_str = col_name.as_str();
if col_name_str == target_col {
continue;
}
if self.config.exclude_columns.contains(col_name_str) {
excluded_by_user.push(col_name_str.to_string());
continue;
}
let col = df.column(col_name).map_err(|e| {
TreeBoostError::Data(format!("Column '{}' not found: {}", col_name, e))
})?;
let series = col.as_materialized_series();
let col_type = ColumnType::from_series(series);
let should_exclude = match col_type {
ColumnType::Numeric => false, ColumnType::Boolean => self.config.exclude_boolean,
ColumnType::Categorical => self.config.exclude_categorical,
ColumnType::IdLike => self.config.exclude_id,
ColumnType::Constant => self.config.exclude_constant,
ColumnType::DateTime => self.config.exclude_datetime,
ColumnType::Text => self.config.exclude_text,
};
if should_exclude {
excluded_by_type
.entry(col_type)
.or_insert_with(Vec::new)
.push(col_name_str.to_string());
continue;
}
feature_names.push(col_name_str.to_string());
}
let num_features = feature_names.len();
let mut features = Vec::with_capacity(num_rows * num_features);
for row_idx in 0..num_rows {
for col_name in &feature_names {
let col = df.column(col_name).map_err(|e| {
TreeBoostError::Data(format!(
"Column '{}' not found during extraction: {}",
col_name, e
))
})?;
let series = col.as_materialized_series();
let val = series.get(row_idx)?;
let f_val = self.anyvalue_to_f32(val, row_idx, col_name);
features.push(f_val);
}
}
let report = FeatureExtractionReport {
total_columns,
excluded_by_type: excluded_by_type.clone(),
excluded_by_user: excluded_by_user.clone(),
target_column: target_col.to_string(),
final_features: feature_names.clone(),
};
Ok(FeatureExtractionResult {
features,
num_features,
feature_names,
report,
})
}
pub fn extract(&self, df: &DataFrame, target_col: &str) -> Result<(Vec<f32>, usize)> {
let result = self.extract_numeric_features(df, target_col)?;
Ok((result.features, result.num_features))
}
fn anyvalue_to_f32(&self, val: AnyValue, _row_idx: usize, _col_name: &str) -> f32 {
match val {
AnyValue::Null => {
0.0
}
AnyValue::Int8(v) => v as f32,
AnyValue::Int16(v) => v as f32,
AnyValue::Int32(v) => v as f32,
AnyValue::Int64(v) => v as f32,
AnyValue::UInt8(v) => v as f32,
AnyValue::UInt16(v) => v as f32,
AnyValue::UInt32(v) => v as f32,
AnyValue::UInt64(v) => {
v.min(u32::MAX as u64) as f32
}
AnyValue::Float32(v) => {
if v.is_finite() {
v
} else {
0.0
}
}
AnyValue::Float64(v) => {
if v.is_finite() {
v as f32
} else {
0.0
}
}
AnyValue::Boolean(v) => {
if v {
1.0
} else {
0.0
}
}
_ => 0.0,
}
}
pub fn config(&self) -> &LinearFeatureConfig {
&self.config
}
pub fn should_exclude_column(&self, df: &DataFrame, col_name: &str, target_col: &str) -> bool {
if col_name == target_col {
return true;
}
if self.config.exclude_columns.contains(col_name) {
return true;
}
if let Ok(col) = df.column(col_name) {
let series = col.as_materialized_series();
let col_type = ColumnType::from_series(series);
match col_type {
ColumnType::Numeric => false,
ColumnType::Boolean => self.config.exclude_boolean,
ColumnType::Categorical => self.config.exclude_categorical,
ColumnType::IdLike => self.config.exclude_id,
ColumnType::Constant => self.config.exclude_constant,
ColumnType::DateTime => self.config.exclude_datetime,
ColumnType::Text => self.config.exclude_text,
}
} else {
true
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_dataframe() -> DataFrame {
let df = df!(
"id" => &[1i32, 2, 3, 4, 5],
"year" => &[2020i32, 2021, 2022, 2023, 2024],
"rank" => &[1i32, 2, 3, 4, 5],
"numeric1" => &[1.0f32, 2.0, 3.0, 4.0, 5.0],
"numeric2" => &[10.0f32, 20.0, 30.0, 40.0, 50.0],
"constant" => &[5i32, 5, 5, 5, 5],
"category" => &["A", "B", "A", "B", "A"],
"target" => &[1.0f32, 2.0, 3.0, 4.0, 5.0],
)
.unwrap();
df
}
#[test]
fn test_exclude_target_column() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(!result.feature_names.contains(&"target".to_string()));
assert!(result.num_features > 0);
}
#[test]
fn test_auto_exclude_categorical() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(!result.feature_names.contains(&"category".to_string()));
assert!(result
.report
.excluded_by_type
.contains_key(&ColumnType::Categorical));
}
#[test]
fn test_auto_exclude_constant() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(!result.feature_names.contains(&"constant".to_string()));
assert!(result
.report
.excluded_by_type
.contains_key(&ColumnType::Constant));
}
#[test]
fn test_include_numeric() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(result.feature_names.contains(&"numeric1".to_string()));
assert!(result.feature_names.contains(&"numeric2".to_string()));
}
#[test]
fn test_user_exclude_override() {
let df = create_test_dataframe();
let config = LinearFeatureConfig::new().with_exclude_columns(&["numeric1"]);
let extractor = FeatureExtractor::with_config(config);
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(!result.feature_names.contains(&"numeric1".to_string()));
assert!(result.feature_names.contains(&"numeric2".to_string()));
assert!(result
.report
.excluded_by_user
.contains(&"numeric1".to_string()));
}
#[test]
fn test_row_major_layout() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
let num_rows = df.height();
let num_features = result.num_features;
assert_eq!(result.features.len(), num_rows * num_features);
if num_features >= 2 {
let row_0_feat_0 = result.features[0 * num_features + 0];
let row_0_feat_1 = result.features[0 * num_features + 1];
let row_1_feat_0 = result.features[1 * num_features + 0];
assert_eq!(row_0_feat_0, 1.0); assert_eq!(row_0_feat_1, 10.0); assert_eq!(row_1_feat_0, 2.0); }
}
#[test]
fn test_auto_exclude_id_like() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(!result.feature_names.contains(&"id".to_string()));
}
#[test]
fn test_report_formatting() {
let df = create_test_dataframe();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
let report_str = result.report.format();
assert!(report_str.contains("Feature Extraction Report"));
assert!(report_str.contains("Total columns"));
assert!(report_str.contains("Final features"));
}
#[test]
fn test_boolean_handling() {
let df = df!(
"feature" => &[1.0f32, 2.0, 3.0],
"flag" => &[true, false, true],
"target" => &[1.0f32, 2.0, 3.0],
)
.unwrap();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(result.feature_names.contains(&"flag".to_string()));
let config = LinearFeatureConfig::new().with_exclude_boolean(true);
let extractor = FeatureExtractor::with_config(config);
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert!(!result.feature_names.contains(&"flag".to_string()));
}
#[test]
fn test_nan_handling() {
let df = df!(
"feature" => &[Some(1.0f32), None, Some(3.0)],
"target" => &[Some(1.0f32), Some(2.0), Some(3.0)],
)
.unwrap();
let extractor = FeatureExtractor::new();
let result = extractor.extract_numeric_features(&df, "target").unwrap();
assert_eq!(result.features[1 * result.num_features], 0.0);
}
#[test]
fn test_column_type_detection() {
let s_id = Series::new("test".into(), &[1i32, 2, 3, 4, 5]);
assert_eq!(ColumnType::from_series(&s_id), ColumnType::IdLike);
let s_numeric = Series::new("test".into(), &[1i32, 2, 1, 3, 2]);
assert_eq!(ColumnType::from_series(&s_numeric), ColumnType::Numeric);
let s_constant = Series::new("test".into(), &[1i32, 1, 1, 1, 1]);
assert_eq!(ColumnType::from_series(&s_constant), ColumnType::Constant);
let s_cat = Series::new("test".into(), &["A", "B", "A", "B", "A"]);
assert_eq!(ColumnType::from_series(&s_cat), ColumnType::Categorical);
let s_bool = Series::new("test".into(), &[true, false, true, false, true]);
assert_eq!(ColumnType::from_series(&s_bool), ColumnType::Boolean);
}
#[test]
#[ignore] fn test_serialization_roundtrip() {
let config = LinearFeatureConfig::new()
.with_exclude_columns(&["col1", "col2"])
.with_exclude_categorical(false);
let extractor = FeatureExtractor::with_config(config.clone());
assert_eq!(
extractor.config.exclude_categorical,
config.exclude_categorical
);
assert_eq!(extractor.config.exclude_columns, config.exclude_columns);
}
}