use std::collections::HashSet;
use std::path::Path;
use super::config::InferenceConfig;
use super::schema::InferredSchema;
use super::stats::ColumnStats;
use super::types::FeatureType;
fn is_target_column(name_lower: &str, target_columns: &[String]) -> bool {
target_columns.iter().any(|t| {
let t_lower = t.to_lowercase();
name_lower == t_lower
|| name_lower.ends_with(&format!("_{t_lower}"))
|| name_lower.starts_with(&format!("{t_lower}_"))
})
}
fn infer_target_type(stats: &ColumnStats) -> FeatureType {
if stats.all_integers && stats.unique_count == 2 {
FeatureType::BinaryTarget
} else if stats.all_integers && stats.unique_count <= 100 {
FeatureType::MultiClassTarget
} else {
FeatureType::RegressionTarget
}
}
fn infer_numeric_type(stats: &ColumnStats, config: &InferenceConfig) -> FeatureType {
if stats.all_integers && stats.cardinality_ratio() < config.categorical_threshold {
FeatureType::Categorical
} else {
FeatureType::Numeric
}
}
fn has_token_sequences(sample_values: &[String]) -> bool {
sample_values.iter().any(|s| s.split_whitespace().count() > 5)
}
fn infer_string_type(stats: &ColumnStats, config: &InferenceConfig, avg_len: f32) -> FeatureType {
if avg_len >= config.text_min_avg_len {
return FeatureType::Text;
}
if stats.cardinality_ratio() < config.categorical_threshold {
return FeatureType::Categorical;
}
if has_token_sequences(&stats.sample_values) {
return FeatureType::TokenSequence;
}
FeatureType::Text
}
pub fn infer_type(stats: &ColumnStats, config: &InferenceConfig) -> FeatureType {
let name_lower = stats.name.to_lowercase();
let is_target = is_target_column(&name_lower, &config.target_columns);
if config.exclude_columns.contains(&stats.name) {
return FeatureType::Unknown;
}
if stats.is_array && stats.array_len.is_some() {
return FeatureType::Embedding;
}
if stats.looks_like_datetime {
return FeatureType::DateTime;
}
if stats.all_numeric {
return if is_target {
infer_target_type(stats)
} else {
infer_numeric_type(stats, config)
};
}
if let Some(avg_len) = stats.avg_str_len {
return infer_string_type(stats, config, avg_len);
}
FeatureType::Unknown
}
pub fn infer_schema(stats: Vec<ColumnStats>, config: &InferenceConfig) -> InferredSchema {
let mut schema = InferredSchema::default();
for col_stats in stats {
let feature_type = infer_type(&col_stats, config);
schema.features.insert(col_stats.name.clone(), feature_type);
schema.stats.insert(col_stats.name.clone(), col_stats);
}
schema
}
fn is_numeric_string(s: &str) -> (bool, bool) {
let is_float = s.parse::<f64>().is_ok();
let is_int = s.parse::<i64>().is_ok();
(is_float, is_int)
}
fn looks_like_datetime(s: &str) -> bool {
s.contains('-')
&& s.len() >= 10
&& s.len() <= 30
&& s.chars().filter(char::is_ascii_digit).count() >= 8
}
struct StatsAccumulator<'a> {
unique: HashSet<&'a str>,
total_len: usize,
min_len: usize,
max_len: usize,
all_numeric: bool,
all_integers: bool,
datetime_count: usize,
}
impl<'a> StatsAccumulator<'a> {
fn new() -> Self {
Self {
unique: HashSet::new(),
total_len: 0,
min_len: usize::MAX,
max_len: 0,
all_numeric: true,
all_integers: true,
datetime_count: 0,
}
}
fn process(&mut self, s: &'a str) {
self.unique.insert(s);
let len = s.len();
self.total_len += len;
self.min_len = self.min_len.min(len);
self.max_len = self.max_len.max(len);
let (is_float, is_int) = is_numeric_string(s);
if !is_float {
self.all_numeric = false;
self.all_integers = false;
} else if !is_int {
self.all_integers = false;
}
if looks_like_datetime(s) {
self.datetime_count += 1;
}
}
}
fn finalize_stats(stats: &mut ColumnStats, acc: &StatsAccumulator<'_>) {
stats.unique_count = acc.unique.len();
stats.all_numeric = acc.all_numeric && stats.null_count < stats.count;
stats.all_integers = acc.all_integers && stats.null_count < stats.count;
let non_null = stats.count - stats.null_count;
if non_null > 0 {
stats.min_str_len = Some(acc.min_len);
stats.max_str_len = Some(acc.max_len);
stats.avg_str_len = Some(acc.total_len as f32 / non_null as f32);
stats.looks_like_datetime = acc.datetime_count as f32 / non_null as f32 > 0.5;
}
}
pub fn collect_stats_from_samples(name: &str, values: &[Option<&str>]) -> ColumnStats {
let mut stats = ColumnStats::new(name);
stats.count = values.len();
let mut acc = StatsAccumulator::new();
for val in values {
match val {
Some(s) => {
acc.process(s);
if stats.sample_values.len() < 10 {
stats.sample_values.push((*s).to_string());
}
}
None => {
stats.null_count += 1;
}
}
}
finalize_stats(&mut stats, &acc);
stats
}
pub fn infer_schema_from_path(
_path: &Path,
_config: &InferenceConfig,
) -> Result<InferredSchema, std::io::Error> {
Ok(InferredSchema::default())
}