entrenar/config/infer/
inference.rs1use std::collections::HashSet;
4use std::path::Path;
5
6use super::config::InferenceConfig;
7use super::schema::InferredSchema;
8use super::stats::ColumnStats;
9use super::types::FeatureType;
10
11fn is_target_column(name_lower: &str, target_columns: &[String]) -> bool {
13 target_columns.iter().any(|t| {
14 let t_lower = t.to_lowercase();
15 name_lower == t_lower
16 || name_lower.ends_with(&format!("_{t_lower}"))
17 || name_lower.starts_with(&format!("{t_lower}_"))
18 })
19}
20
21fn infer_target_type(stats: &ColumnStats) -> FeatureType {
23 if stats.all_integers && stats.unique_count == 2 {
24 FeatureType::BinaryTarget
25 } else if stats.all_integers && stats.unique_count <= 100 {
26 FeatureType::MultiClassTarget
27 } else {
28 FeatureType::RegressionTarget
29 }
30}
31
32fn infer_numeric_type(stats: &ColumnStats, config: &InferenceConfig) -> FeatureType {
34 if stats.all_integers && stats.cardinality_ratio() < config.categorical_threshold {
35 FeatureType::Categorical
36 } else {
37 FeatureType::Numeric
38 }
39}
40
41fn has_token_sequences(sample_values: &[String]) -> bool {
43 sample_values.iter().any(|s| s.split_whitespace().count() > 5)
44}
45
46fn infer_string_type(stats: &ColumnStats, config: &InferenceConfig, avg_len: f32) -> FeatureType {
48 if avg_len >= config.text_min_avg_len {
49 return FeatureType::Text;
50 }
51 if stats.cardinality_ratio() < config.categorical_threshold {
52 return FeatureType::Categorical;
53 }
54 if has_token_sequences(&stats.sample_values) {
55 return FeatureType::TokenSequence;
56 }
57 FeatureType::Text
58}
59
60pub fn infer_type(stats: &ColumnStats, config: &InferenceConfig) -> FeatureType {
62 let name_lower = stats.name.to_lowercase();
63 let is_target = is_target_column(&name_lower, &config.target_columns);
64
65 if config.exclude_columns.contains(&stats.name) {
66 return FeatureType::Unknown;
67 }
68
69 if stats.is_array && stats.array_len.is_some() {
70 return FeatureType::Embedding;
71 }
72
73 if stats.looks_like_datetime {
74 return FeatureType::DateTime;
75 }
76
77 if stats.all_numeric {
78 return if is_target {
79 infer_target_type(stats)
80 } else {
81 infer_numeric_type(stats, config)
82 };
83 }
84
85 if let Some(avg_len) = stats.avg_str_len {
86 return infer_string_type(stats, config, avg_len);
87 }
88
89 FeatureType::Unknown
90}
91
92pub fn infer_schema(stats: Vec<ColumnStats>, config: &InferenceConfig) -> InferredSchema {
94 let mut schema = InferredSchema::default();
95
96 for col_stats in stats {
97 let feature_type = infer_type(&col_stats, config);
98 schema.features.insert(col_stats.name.clone(), feature_type);
99 schema.stats.insert(col_stats.name.clone(), col_stats);
100 }
101
102 schema
103}
104
105fn is_numeric_string(s: &str) -> (bool, bool) {
107 let is_float = s.parse::<f64>().is_ok();
108 let is_int = s.parse::<i64>().is_ok();
109 (is_float, is_int)
110}
111
112fn looks_like_datetime(s: &str) -> bool {
114 s.contains('-')
115 && s.len() >= 10
116 && s.len() <= 30
117 && s.chars().filter(char::is_ascii_digit).count() >= 8
118}
119
120struct StatsAccumulator<'a> {
122 unique: HashSet<&'a str>,
123 total_len: usize,
124 min_len: usize,
125 max_len: usize,
126 all_numeric: bool,
127 all_integers: bool,
128 datetime_count: usize,
129}
130
131impl<'a> StatsAccumulator<'a> {
132 fn new() -> Self {
133 Self {
134 unique: HashSet::new(),
135 total_len: 0,
136 min_len: usize::MAX,
137 max_len: 0,
138 all_numeric: true,
139 all_integers: true,
140 datetime_count: 0,
141 }
142 }
143
144 fn process(&mut self, s: &'a str) {
145 self.unique.insert(s);
146 let len = s.len();
147 self.total_len += len;
148 self.min_len = self.min_len.min(len);
149 self.max_len = self.max_len.max(len);
150
151 let (is_float, is_int) = is_numeric_string(s);
152 if !is_float {
153 self.all_numeric = false;
154 self.all_integers = false;
155 } else if !is_int {
156 self.all_integers = false;
157 }
158
159 if looks_like_datetime(s) {
160 self.datetime_count += 1;
161 }
162 }
163}
164
165fn finalize_stats(stats: &mut ColumnStats, acc: &StatsAccumulator<'_>) {
167 stats.unique_count = acc.unique.len();
168 stats.all_numeric = acc.all_numeric && stats.null_count < stats.count;
169 stats.all_integers = acc.all_integers && stats.null_count < stats.count;
170
171 let non_null = stats.count - stats.null_count;
172 if non_null > 0 {
173 stats.min_str_len = Some(acc.min_len);
174 stats.max_str_len = Some(acc.max_len);
175 stats.avg_str_len = Some(acc.total_len as f32 / non_null as f32);
176 stats.looks_like_datetime = acc.datetime_count as f32 / non_null as f32 > 0.5;
177 }
178}
179
180pub fn collect_stats_from_samples(name: &str, values: &[Option<&str>]) -> ColumnStats {
182 let mut stats = ColumnStats::new(name);
183 stats.count = values.len();
184
185 let mut acc = StatsAccumulator::new();
186
187 for val in values {
188 match val {
189 Some(s) => {
190 acc.process(s);
191 if stats.sample_values.len() < 10 {
192 stats.sample_values.push((*s).to_string());
193 }
194 }
195 None => {
196 stats.null_count += 1;
197 }
198 }
199 }
200
201 finalize_stats(&mut stats, &acc);
202 stats
203}
204
205pub fn infer_schema_from_path(
208 _path: &Path,
209 _config: &InferenceConfig,
210) -> Result<InferredSchema, std::io::Error> {
211 Ok(InferredSchema::default())
219}