Skip to main content

entrenar/config/infer/
inference.rs

1//! Type inference functions
2
3use std::collections::HashSet;
4use std::path::Path;
5
6use super::config::InferenceConfig;
7use super::schema::InferredSchema;
8use super::stats::ColumnStats;
9use super::types::FeatureType;
10
11/// Check if a column name matches any target column pattern
12fn is_target_column(name_lower: &str, target_columns: &[String]) -> bool {
13    target_columns.iter().any(|t| {
14        let t_lower = t.to_lowercase();
15        name_lower == t_lower
16            || name_lower.ends_with(&format!("_{t_lower}"))
17            || name_lower.starts_with(&format!("{t_lower}_"))
18    })
19}
20
21/// Infer type for a numeric target column
22fn infer_target_type(stats: &ColumnStats) -> FeatureType {
23    if stats.all_integers && stats.unique_count == 2 {
24        FeatureType::BinaryTarget
25    } else if stats.all_integers && stats.unique_count <= 100 {
26        FeatureType::MultiClassTarget
27    } else {
28        FeatureType::RegressionTarget
29    }
30}
31
32/// Infer type for a numeric non-target column
33fn infer_numeric_type(stats: &ColumnStats, config: &InferenceConfig) -> FeatureType {
34    if stats.all_integers && stats.cardinality_ratio() < config.categorical_threshold {
35        FeatureType::Categorical
36    } else {
37        FeatureType::Numeric
38    }
39}
40
41/// Check if sample values contain token sequences
42fn has_token_sequences(sample_values: &[String]) -> bool {
43    sample_values.iter().any(|s| s.split_whitespace().count() > 5)
44}
45
46/// Infer type for a string column
47fn infer_string_type(stats: &ColumnStats, config: &InferenceConfig, avg_len: f32) -> FeatureType {
48    if avg_len >= config.text_min_avg_len {
49        return FeatureType::Text;
50    }
51    if stats.cardinality_ratio() < config.categorical_threshold {
52        return FeatureType::Categorical;
53    }
54    if has_token_sequences(&stats.sample_values) {
55        return FeatureType::TokenSequence;
56    }
57    FeatureType::Text
58}
59
60/// Infer feature type from column statistics
61pub fn infer_type(stats: &ColumnStats, config: &InferenceConfig) -> FeatureType {
62    let name_lower = stats.name.to_lowercase();
63    let is_target = is_target_column(&name_lower, &config.target_columns);
64
65    if config.exclude_columns.contains(&stats.name) {
66        return FeatureType::Unknown;
67    }
68
69    if stats.is_array && stats.array_len.is_some() {
70        return FeatureType::Embedding;
71    }
72
73    if stats.looks_like_datetime {
74        return FeatureType::DateTime;
75    }
76
77    if stats.all_numeric {
78        return if is_target {
79            infer_target_type(stats)
80        } else {
81            infer_numeric_type(stats, config)
82        };
83    }
84
85    if let Some(avg_len) = stats.avg_str_len {
86        return infer_string_type(stats, config, avg_len);
87    }
88
89    FeatureType::Unknown
90}
91
92/// Infer schema from column statistics
93pub fn infer_schema(stats: Vec<ColumnStats>, config: &InferenceConfig) -> InferredSchema {
94    let mut schema = InferredSchema::default();
95
96    for col_stats in stats {
97        let feature_type = infer_type(&col_stats, config);
98        schema.features.insert(col_stats.name.clone(), feature_type);
99        schema.stats.insert(col_stats.name.clone(), col_stats);
100    }
101
102    schema
103}
104
105/// Check if a string looks like a numeric value
106fn is_numeric_string(s: &str) -> (bool, bool) {
107    let is_float = s.parse::<f64>().is_ok();
108    let is_int = s.parse::<i64>().is_ok();
109    (is_float, is_int)
110}
111
112/// Check if a string looks like a datetime
113fn looks_like_datetime(s: &str) -> bool {
114    s.contains('-')
115        && s.len() >= 10
116        && s.len() <= 30
117        && s.chars().filter(char::is_ascii_digit).count() >= 8
118}
119
120/// Accumulator for collecting string statistics
121struct StatsAccumulator<'a> {
122    unique: HashSet<&'a str>,
123    total_len: usize,
124    min_len: usize,
125    max_len: usize,
126    all_numeric: bool,
127    all_integers: bool,
128    datetime_count: usize,
129}
130
131impl<'a> StatsAccumulator<'a> {
132    fn new() -> Self {
133        Self {
134            unique: HashSet::new(),
135            total_len: 0,
136            min_len: usize::MAX,
137            max_len: 0,
138            all_numeric: true,
139            all_integers: true,
140            datetime_count: 0,
141        }
142    }
143
144    fn process(&mut self, s: &'a str) {
145        self.unique.insert(s);
146        let len = s.len();
147        self.total_len += len;
148        self.min_len = self.min_len.min(len);
149        self.max_len = self.max_len.max(len);
150
151        let (is_float, is_int) = is_numeric_string(s);
152        if !is_float {
153            self.all_numeric = false;
154            self.all_integers = false;
155        } else if !is_int {
156            self.all_integers = false;
157        }
158
159        if looks_like_datetime(s) {
160            self.datetime_count += 1;
161        }
162    }
163}
164
165/// Finalize stats from accumulator
166fn finalize_stats(stats: &mut ColumnStats, acc: &StatsAccumulator<'_>) {
167    stats.unique_count = acc.unique.len();
168    stats.all_numeric = acc.all_numeric && stats.null_count < stats.count;
169    stats.all_integers = acc.all_integers && stats.null_count < stats.count;
170
171    let non_null = stats.count - stats.null_count;
172    if non_null > 0 {
173        stats.min_str_len = Some(acc.min_len);
174        stats.max_str_len = Some(acc.max_len);
175        stats.avg_str_len = Some(acc.total_len as f32 / non_null as f32);
176        stats.looks_like_datetime = acc.datetime_count as f32 / non_null as f32 > 0.5;
177    }
178}
179
180/// Collect statistics from sample values (simplified in-memory analysis)
181pub fn collect_stats_from_samples(name: &str, values: &[Option<&str>]) -> ColumnStats {
182    let mut stats = ColumnStats::new(name);
183    stats.count = values.len();
184
185    let mut acc = StatsAccumulator::new();
186
187    for val in values {
188        match val {
189            Some(s) => {
190                acc.process(s);
191                if stats.sample_values.len() < 10 {
192                    stats.sample_values.push((*s).to_string());
193                }
194            }
195            None => {
196                stats.null_count += 1;
197            }
198        }
199    }
200
201    finalize_stats(&mut stats, &acc);
202    stats
203}
204
205/// Placeholder: Load stats from Parquet file
206/// Real implementation would use arrow-rs/parquet crate
207pub fn infer_schema_from_path(
208    _path: &Path,
209    _config: &InferenceConfig,
210) -> Result<InferredSchema, std::io::Error> {
211    // In a real implementation, this would:
212    // 1. Open the Parquet file
213    // 2. Read schema metadata
214    // 3. Sample rows for statistics
215    // 4. Call infer_schema()
216
217    // For now, return empty schema
218    Ok(InferredSchema::default())
219}