use crate::analysis::profiler::{ColumnDataType, ColumnProfile, DataFrameProfile};
use crate::analysis::DatasetAnalysis;
use crate::defaults::features as feature_defaults;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct FeaturePlan {
pub polynomial_features: Vec<String>,
pub ratio_pairs: Vec<(String, String)>,
pub interaction_pairs: Vec<(String, String)>,
pub time_features: Vec<(String, TimeFeatureType)>,
pub reasoning: Vec<String>,
}
impl FeaturePlan {
pub fn new() -> Self {
Self {
polynomial_features: Vec::new(),
ratio_pairs: Vec::new(),
interaction_pairs: Vec::new(),
time_features: Vec::new(),
reasoning: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.polynomial_features.is_empty()
&& self.ratio_pairs.is_empty()
&& self.interaction_pairs.is_empty()
&& self.time_features.is_empty()
}
pub fn estimated_feature_count(&self) -> usize {
let poly_count = self.polynomial_features.len() * 2; let ratio_count = self.ratio_pairs.len();
let interaction_count = self.interaction_pairs.len();
let time_count = self.time_features.len() * 4;
poly_count + ratio_count + interaction_count + time_count
}
}
impl Default for FeaturePlan {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TimeFeatureType {
Hour,
DayOfWeek,
DayOfMonth,
Month,
Year,
IsWeekend,
CyclicalHour,
CyclicalDayOfWeek,
CyclicalMonth,
}
#[derive(Debug, Clone)]
pub struct LttFeaturePlan {
pub linear_features: FeaturePlan,
pub tree_features: FeaturePlan,
pub shared_features: Vec<String>,
pub reasoning: Vec<String>,
}
impl LttFeaturePlan {
pub fn new() -> Self {
Self {
linear_features: FeaturePlan::new(),
tree_features: FeaturePlan::new(),
shared_features: Vec::new(),
reasoning: Vec::new(),
}
}
}
impl Default for LttFeaturePlan {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SmartFeatureConfig {
pub enable_polynomial: bool,
pub enable_ratios: bool,
pub enable_interactions: bool,
pub enable_time_features: bool,
pub max_new_features: usize,
pub low_linear_r2_threshold: f32,
pub ratio_correlation_threshold: f32,
pub top_n_polynomial: usize,
pub top_n_interactions: usize,
pub skip_features: HashSet<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SmartFeaturePreset {
Standard,
Minimal,
Aggressive,
}
impl Default for SmartFeatureConfig {
fn default() -> Self {
Self {
enable_polynomial: true,
enable_ratios: true,
enable_interactions: true,
enable_time_features: true,
max_new_features: feature_defaults::DEFAULT_MAX_NEW_FEATURES,
low_linear_r2_threshold: feature_defaults::LOW_LINEAR_R2_THRESHOLD,
ratio_correlation_threshold: feature_defaults::RATIO_CORRELATION_THRESHOLD,
top_n_polynomial: feature_defaults::TOP_N_POLYNOMIAL,
top_n_interactions: feature_defaults::TOP_N_INTERACTIONS,
skip_features: HashSet::new(),
}
}
}
impl SmartFeatureConfig {
pub fn with_preset(mut self, preset: SmartFeaturePreset) -> Self {
match preset {
SmartFeaturePreset::Standard => {}
SmartFeaturePreset::Minimal => {
self.enable_polynomial = true;
self.enable_ratios = true;
self.enable_interactions = false;
self.enable_time_features = false;
self.max_new_features = feature_defaults::MINIMAL_MAX_NEW_FEATURES;
}
SmartFeaturePreset::Aggressive => {
self.enable_polynomial = true;
self.enable_ratios = true;
self.enable_interactions = true;
self.enable_time_features = true;
self.max_new_features = feature_defaults::AGGRESSIVE_MAX_NEW_FEATURES;
self.low_linear_r2_threshold = feature_defaults::AGGRESSIVE_LOW_LINEAR_R2_THRESHOLD;
self.ratio_correlation_threshold =
feature_defaults::AGGRESSIVE_RATIO_CORRELATION_THRESHOLD;
self.top_n_polynomial = feature_defaults::AGGRESSIVE_TOP_N_POLYNOMIAL;
self.top_n_interactions = feature_defaults::AGGRESSIVE_TOP_N_INTERACTIONS;
}
}
self
}
}
#[derive(Debug, Clone)]
pub struct SmartFeatureEngine {
pub config: SmartFeatureConfig,
}
impl SmartFeatureEngine {
pub fn new() -> Self {
Self {
config: SmartFeatureConfig::default(),
}
}
pub fn with_config(config: SmartFeatureConfig) -> Self {
Self { config }
}
pub fn infer(profile: &DataFrameProfile, analysis: Option<&DatasetAnalysis>) -> FeaturePlan {
let config = SmartFeatureConfig::default();
Self::infer_with_config(profile, analysis, &config)
}
pub fn infer_with_config(
profile: &DataFrameProfile,
analysis: Option<&DatasetAnalysis>,
config: &SmartFeatureConfig,
) -> FeaturePlan {
let mut plan = FeaturePlan::new();
let mut numeric_cols: Vec<&ColumnProfile> = profile
.columns
.iter()
.filter(|c| c.dtype == ColumnDataType::Numeric)
.filter(|c| !config.skip_features.contains(&c.name))
.collect();
numeric_cols.sort_by(|a, b| {
let corr_a = a.target_correlation.map(|c| c.abs()).unwrap_or(0.0);
let corr_b = b.target_correlation.map(|c| c.abs()).unwrap_or(0.0);
corr_b
.partial_cmp(&corr_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
if config.enable_polynomial {
Self::add_polynomial_features(&mut plan, &numeric_cols, config);
}
if config.enable_ratios {
Self::add_ratio_features(&mut plan, &numeric_cols, config);
}
if config.enable_interactions {
let linear_r2 = analysis.map(|a| a.linear_r2).unwrap_or(0.0);
if linear_r2 < config.low_linear_r2_threshold {
Self::add_interaction_features(&mut plan, &numeric_cols, config);
plan.reasoning.push(format!(
"Adding interactions: Linear R²={:.3} < {:.2} threshold",
linear_r2, config.low_linear_r2_threshold
));
} else {
plan.reasoning.push(format!(
"Skipping interactions: Linear R²={:.3} >= {:.2} threshold",
linear_r2, config.low_linear_r2_threshold
));
}
}
if config.enable_time_features {
Self::add_time_features(&mut plan, profile, config);
}
if plan.estimated_feature_count() > config.max_new_features {
plan.reasoning.push(format!(
"Warning: Estimated {} features exceeds max {} - consider reducing",
plan.estimated_feature_count(),
config.max_new_features
));
}
plan
}
fn add_polynomial_features(
plan: &mut FeaturePlan,
numeric_cols: &[&ColumnProfile],
config: &SmartFeatureConfig,
) {
let top_n = config.top_n_polynomial.min(numeric_cols.len());
for col in numeric_cols.iter().take(top_n) {
if col.has_negative {
plan.reasoning.push(format!(
"{}: Skip polynomial (has negative values)",
col.name
));
continue;
}
plan.polynomial_features.push(col.name.clone());
let corr = col.target_correlation.unwrap_or(0.0);
plan.reasoning.push(format!(
"{}: Add polynomial (x², sqrt, log) - correlation={:.3}",
col.name, corr
));
}
}
fn add_ratio_features(
plan: &mut FeaturePlan,
numeric_cols: &[&ColumnProfile],
config: &SmartFeatureConfig,
) {
let high_corr_cols: Vec<&ColumnProfile> = numeric_cols
.iter()
.filter(|c| {
c.target_correlation
.map(|r| r.abs() > config.ratio_correlation_threshold)
.unwrap_or(false)
})
.copied()
.collect();
for (i, col_a) in high_corr_cols.iter().enumerate() {
for col_b in high_corr_cols.iter().skip(i + 1) {
if col_b.min.map(|v| v.abs() > 0.01).unwrap_or(false) {
plan.ratio_pairs
.push((col_a.name.clone(), col_b.name.clone()));
plan.reasoning.push(format!(
"Ratio: {} / {} (both highly correlated with target)",
col_a.name, col_b.name
));
if plan.ratio_pairs.len() >= config.top_n_interactions {
break;
}
}
}
if plan.ratio_pairs.len() >= config.top_n_interactions {
break;
}
}
}
fn add_interaction_features(
plan: &mut FeaturePlan,
numeric_cols: &[&ColumnProfile],
config: &SmartFeatureConfig,
) {
let max_pairs = if numeric_cols.len() >= 2 {
numeric_cols.len() * (numeric_cols.len() - 1) / 2
} else {
0
};
let top_n = config.top_n_interactions.min(max_pairs);
let mut pair_count = 0;
for (i, col_a) in numeric_cols.iter().enumerate() {
for col_b in numeric_cols.iter().skip(i + 1) {
plan.interaction_pairs
.push((col_a.name.clone(), col_b.name.clone()));
plan.reasoning.push(format!(
"Interaction: {} × {} (top correlated features)",
col_a.name, col_b.name
));
pair_count += 1;
if pair_count >= top_n {
break;
}
}
if pair_count >= top_n {
break;
}
}
}
fn add_time_features(
plan: &mut FeaturePlan,
profile: &DataFrameProfile,
_config: &SmartFeatureConfig,
) {
for col in &profile.columns {
if col.dtype == ColumnDataType::DateTime {
plan.time_features
.push((col.name.clone(), TimeFeatureType::Hour));
plan.time_features
.push((col.name.clone(), TimeFeatureType::DayOfWeek));
plan.time_features
.push((col.name.clone(), TimeFeatureType::Month));
plan.time_features
.push((col.name.clone(), TimeFeatureType::IsWeekend));
plan.time_features
.push((col.name.clone(), TimeFeatureType::CyclicalHour));
plan.time_features
.push((col.name.clone(), TimeFeatureType::CyclicalDayOfWeek));
plan.reasoning.push(format!(
"{}: Add time features (hour, day_of_week, month, is_weekend, cyclical)",
col.name
));
}
}
}
pub fn infer_ltt(
profile: &DataFrameProfile,
analysis: Option<&DatasetAnalysis>,
) -> LttFeaturePlan {
let config = SmartFeatureConfig::default();
Self::infer_ltt_with_config(profile, analysis, &config)
}
pub fn infer_ltt_with_config(
profile: &DataFrameProfile,
analysis: Option<&DatasetAnalysis>,
config: &SmartFeatureConfig,
) -> LttFeaturePlan {
let mut ltt_plan = LttFeaturePlan::new();
ltt_plan
.reasoning
.push("=== LTT Dual-Phase Feature Engineering ===".to_string());
ltt_plan
.reasoning
.push("Phase 1 (Linear): Polynomial features extend linear model's reach".to_string());
ltt_plan.reasoning.push(
"Phase 2 (Tree): Interaction features capture what trees struggle with".to_string(),
);
let mut numeric_cols: Vec<&ColumnProfile> = profile
.columns
.iter()
.filter(|c| c.dtype == ColumnDataType::Numeric)
.filter(|c| !config.skip_features.contains(&c.name))
.collect();
numeric_cols.sort_by(|a, b| {
let corr_a = a.target_correlation.map(|c| c.abs()).unwrap_or(0.0);
let corr_b = b.target_correlation.map(|c| c.abs()).unwrap_or(0.0);
corr_b
.partial_cmp(&corr_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
for col in numeric_cols.iter().take(config.top_n_polynomial) {
if !col.has_negative {
ltt_plan
.linear_features
.polynomial_features
.push(col.name.clone());
ltt_plan.linear_features.reasoning.push(format!(
"{}: Polynomial for linear phase (correlation={:.3})",
col.name,
col.target_correlation.unwrap_or(0.0)
));
}
}
let linear_r2 = analysis.map(|a| a.linear_r2).unwrap_or(0.0);
if linear_r2 < 0.5 && config.enable_interactions {
for (i, col_a) in numeric_cols.iter().enumerate().take(3) {
for col_b in numeric_cols.iter().skip(i + 1).take(3) {
ltt_plan
.linear_features
.interaction_pairs
.push((col_a.name.clone(), col_b.name.clone()));
}
}
ltt_plan.linear_features.reasoning.push(format!(
"Adding limited interactions: Linear R²={:.3} < 0.5",
linear_r2
));
}
let top_n = config
.top_n_interactions
.min(numeric_cols.len() * (numeric_cols.len() - 1) / 2);
let mut pair_count = 0;
for (i, col_a) in numeric_cols.iter().enumerate() {
for col_b in numeric_cols.iter().skip(i + 1) {
ltt_plan
.tree_features
.interaction_pairs
.push((col_a.name.clone(), col_b.name.clone()));
pair_count += 1;
if pair_count >= top_n {
break;
}
}
if pair_count >= top_n {
break;
}
}
ltt_plan.tree_features.reasoning.push(format!(
"Added {} interaction pairs for tree phase",
ltt_plan.tree_features.interaction_pairs.len()
));
let high_corr_cols: Vec<&ColumnProfile> = numeric_cols
.iter()
.filter(|c| {
c.target_correlation
.map(|r| r.abs() > config.ratio_correlation_threshold)
.unwrap_or(false)
})
.copied()
.collect();
for (i, col_a) in high_corr_cols.iter().enumerate().take(5) {
for col_b in high_corr_cols.iter().skip(i + 1).take(5) {
if col_b.min.map(|v| v.abs() > 0.01).unwrap_or(false) {
ltt_plan
.tree_features
.ratio_pairs
.push((col_a.name.clone(), col_b.name.clone()));
}
}
}
if !ltt_plan.tree_features.ratio_pairs.is_empty() {
ltt_plan.tree_features.reasoning.push(format!(
"Added {} ratio pairs for tree phase (scale-free)",
ltt_plan.tree_features.ratio_pairs.len()
));
}
for col in &profile.columns {
if col.dtype == ColumnDataType::DateTime {
ltt_plan.shared_features.push(col.name.clone());
ltt_plan.reasoning.push(format!(
"{}: DateTime features shared between phases",
col.name
));
}
}
ltt_plan
}
pub fn summarize(plan: &FeaturePlan) -> String {
let mut summary = String::new();
summary.push_str("Feature Generation Plan:\n");
summary.push_str(&format!(
" Polynomial features: {}\n",
plan.polynomial_features.len()
));
summary.push_str(&format!(" Ratio pairs: {}\n", plan.ratio_pairs.len()));
summary.push_str(&format!(
" Interaction pairs: {}\n",
plan.interaction_pairs.len()
));
summary.push_str(&format!(" Time features: {}\n", plan.time_features.len()));
summary.push_str(&format!(
" Estimated total: {} new features\n",
plan.estimated_feature_count()
));
if !plan.reasoning.is_empty() {
summary.push_str("\nDecisions:\n");
for reason in &plan.reasoning {
summary.push_str(&format!(" - {}\n", reason));
}
}
summary
}
pub fn summarize_ltt(plan: &LttFeaturePlan) -> String {
let mut summary = String::new();
summary.push_str("=== LTT Feature Engineering Plan ===\n\n");
summary.push_str("Linear Phase Features:\n");
summary.push_str(&format!(
" Polynomial: {}\n",
plan.linear_features.polynomial_features.len()
));
summary.push_str(&format!(
" Interactions: {}\n",
plan.linear_features.interaction_pairs.len()
));
summary.push_str("\nTree Phase Features:\n");
summary.push_str(&format!(
" Interactions: {}\n",
plan.tree_features.interaction_pairs.len()
));
summary.push_str(&format!(
" Ratios: {}\n",
plan.tree_features.ratio_pairs.len()
));
summary.push_str(&format!(
"\nShared Features: {}\n",
plan.shared_features.len()
));
if !plan.reasoning.is_empty() {
summary.push_str("\nDecisions:\n");
for reason in &plan.reasoning {
summary.push_str(&format!(" - {}\n", reason));
}
}
summary
}
}
impl Default for SmartFeatureEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use polars::prelude::*;
fn create_test_profile() -> DataFrameProfile {
let df = DataFrame::new(vec![
Series::new("feature1".into(), vec![1.0f64, 2.0, 3.0, 4.0, 5.0]).into(),
Series::new("feature2".into(), vec![2.0f64, 4.0, 6.0, 8.0, 10.0]).into(),
Series::new("feature3".into(), vec![5.0f64, 4.0, 3.0, 2.0, 1.0]).into(),
Series::new("target".into(), vec![1.0f64, 2.0, 3.0, 4.0, 5.0]).into(),
])
.unwrap();
DataFrameProfile::analyze(&df, "target").unwrap()
}
#[test]
fn test_infer_feature_plan() {
let profile = create_test_profile();
let plan = SmartFeatureEngine::infer(&profile, None);
assert!(!plan.polynomial_features.is_empty());
}
#[test]
fn test_infer_ltt_plan() {
let profile = create_test_profile();
let ltt_plan = SmartFeatureEngine::infer_ltt(&profile, None);
assert!(!ltt_plan.linear_features.polynomial_features.is_empty());
assert!(!ltt_plan.tree_features.interaction_pairs.is_empty());
assert!(ltt_plan.tree_features.polynomial_features.is_empty());
}
#[test]
fn test_feature_plan_estimation() {
let mut plan = FeaturePlan::new();
plan.polynomial_features.push("f1".to_string());
plan.polynomial_features.push("f2".to_string());
plan.interaction_pairs
.push(("f1".to_string(), "f2".to_string()));
assert!(plan.estimated_feature_count() >= 4);
}
#[test]
fn test_skip_negative_polynomial() {
let df = DataFrame::new(vec![
Series::new(
"negative_feature".into(),
vec![-1.0f64, -2.0, 3.0, 4.0, 5.0],
)
.into(),
Series::new("positive_feature".into(), vec![1.0f64, 2.0, 3.0, 4.0, 5.0]).into(),
Series::new("target".into(), vec![1.0f64, 2.0, 3.0, 4.0, 5.0]).into(),
])
.unwrap();
let profile = DataFrameProfile::analyze(&df, "target").unwrap();
let plan = SmartFeatureEngine::infer(&profile, None);
assert!(!plan
.polynomial_features
.contains(&"negative_feature".to_string()));
}
}