use std::collections::HashSet;
use super::data::DataType;
use super::field_value::{DataRow, DataTable, FieldValue};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct InferredColumnType {
pub storage_type: FieldValueType,
pub semantic_type: DataType,
pub confidence: f64,
pub cardinality: usize,
pub null_count: usize,
pub sample_size: usize,
pub metadata: TypeMetadata,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum FieldValueType {
Numeric,
Text,
Timestamp,
Boolean,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct TypeMetadata {
pub is_integer: bool,
pub has_currency_symbols: bool,
pub date_format: Option<DateFormat>,
pub is_low_cardinality: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum DateFormat {
ISO8601,
DMY,
MDY,
YMD,
Ambiguous,
}
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub sample_size: usize,
pub confidence_threshold: f64,
pub categorical_threshold: usize,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
sample_size: 500,
confidence_threshold: 0.9,
categorical_threshold: 100,
}
}
}
pub fn infer_column_types(
table: &DataTable,
columns: &[String],
config: InferenceConfig,
) -> Vec<InferredColumnType> {
columns
.iter()
.map(|col_name| infer_column_type(table, col_name, &config))
.collect()
}
#[derive(Debug, Default)]
struct PatternStats {
numeric_count: usize,
text_count: usize,
timestamp_count: usize,
bool_count: usize,
null_count: usize,
integer_count: usize,
datetime_pattern_count: usize,
date_pattern_count: usize,
distinct_values: HashSet<String>,
has_currency: bool,
}
impl PatternStats {
fn total(&self) -> usize {
self.numeric_count + self.text_count + self.timestamp_count + self.bool_count
}
fn cardinality(&self) -> usize {
self.distinct_values.len()
}
fn all_integers(&self) -> bool {
self.numeric_count > 0 && self.integer_count == self.numeric_count
}
fn is_mixed(&self) -> bool {
let types_present = [
self.numeric_count > 0,
self.text_count > 0,
self.timestamp_count > 0,
self.bool_count > 0,
]
.iter()
.filter(|&&x| x)
.count();
types_present > 1
}
}
fn infer_column_type(
table: &DataTable,
col_name: &str,
config: &InferenceConfig,
) -> InferredColumnType {
let sample_rows = sample_rows(table, config.sample_size);
let mut stats = PatternStats::default();
for row in &sample_rows {
if let Some(value) = row.get(col_name) {
analyze_value(value, &mut stats);
}
}
let storage_type = detect_storage_type(&stats);
let semantic_type = detect_semantic_type(&stats, &storage_type, config);
let confidence = calculate_confidence(&stats, &storage_type);
let metadata = TypeMetadata {
is_integer: stats.all_integers(),
has_currency_symbols: stats.has_currency,
date_format: detect_date_format(&stats, &sample_rows, col_name),
is_low_cardinality: stats.cardinality() <= config.categorical_threshold,
};
InferredColumnType {
storage_type,
semantic_type,
confidence,
cardinality: stats.cardinality(),
null_count: stats.null_count,
sample_size: sample_rows.len(),
metadata,
}
}
fn sample_rows(table: &DataTable, sample_size: usize) -> Vec<&DataRow> {
table.rows().iter().take(sample_size).collect()
}
fn analyze_value(value: &FieldValue, stats: &mut PatternStats) {
match value {
FieldValue::Null => {
stats.null_count += 1;
}
FieldValue::Numeric(n) => {
stats.numeric_count += 1;
if n.fract() == 0.0 {
stats.integer_count += 1;
}
stats.distinct_values.insert(value_to_key(value));
}
FieldValue::Timestamp(ts) => {
stats.timestamp_count += 1;
stats.distinct_values.insert(ts.to_string());
}
FieldValue::Bool(b) => {
stats.bool_count += 1;
stats.distinct_values.insert(b.to_string());
}
FieldValue::Text(s) => {
stats.text_count += 1;
if is_datetime_pattern(s) {
stats.datetime_pattern_count += 1;
} else if is_date_pattern(s) {
stats.date_pattern_count += 1;
}
if has_currency_symbol(s) {
stats.has_currency = true;
}
stats.distinct_values.insert(s.clone());
}
}
}
fn value_to_key(value: &FieldValue) -> String {
match value {
FieldValue::Numeric(n) => n.to_string(),
FieldValue::Text(s) => s.clone(),
FieldValue::Timestamp(ts) => ts.to_string(),
FieldValue::Bool(b) => b.to_string(),
FieldValue::Null => "null".to_string(),
}
}
#[derive(Debug, Clone, Copy)]
struct DateParts {
part1: u32,
part2: u32,
_part3: u32,
separator: char,
}
fn parse_date_parts(date_str: &str) -> Option<DateParts> {
let s = date_str.trim();
let separator = if s.contains('-') {
'-'
} else if s.contains('/') {
'/'
} else {
return None;
};
let parts: Vec<&str> = s.split(separator).collect();
if parts.len() != 3 {
return None;
}
let part1 = parts[0].parse::<u32>().ok()?;
let part2 = parts[1].parse::<u32>().ok()?;
let part3 = parts[2].parse::<u32>().ok()?;
Some(DateParts {
part1,
part2,
_part3: part3,
separator,
})
}
fn extract_date_samples(sample_rows: &[&DataRow], col_name: &str) -> Vec<String> {
let mut date_strings = Vec::new();
for row in sample_rows {
if let Some(FieldValue::Text(s)) = row.get(col_name) {
if is_date_pattern(s) {
date_strings.push(s.clone());
} else if is_datetime_pattern(s) {
let date_part = if s.contains('T') {
s.split('T').next().unwrap_or(s)
} else {
s.split(' ').next().unwrap_or(s)
};
date_strings.push(date_part.to_string());
}
if date_strings.len() >= 50 {
break;
}
}
}
date_strings
}
fn detect_date_format(
stats: &PatternStats,
sample_rows: &[&DataRow],
col_name: &str,
) -> Option<DateFormat> {
let date_count = stats.date_pattern_count + stats.datetime_pattern_count;
if date_count == 0 {
return None;
}
let date_samples = extract_date_samples(sample_rows, col_name);
if date_samples.is_empty() {
return None;
}
let parsed: Vec<DateParts> = date_samples
.iter()
.filter_map(|s| parse_date_parts(s))
.collect();
if parsed.is_empty() {
return Some(DateFormat::Ambiguous);
}
let total = parsed.len();
let iso8601_count = parsed
.iter()
.filter(|p| p.part1 >= 1000 && p.separator == '-')
.count();
if (iso8601_count as f64 / total as f64) > 0.8 {
return Some(DateFormat::ISO8601);
}
let ymd_count = parsed
.iter()
.filter(|p| p.part1 >= 1000 && p.separator == '/')
.count();
if (ymd_count as f64 / total as f64) > 0.8 {
return Some(DateFormat::YMD);
}
let mut dmy_votes = 0;
let mut mdy_votes = 0;
for parts in &parsed {
if parts.part1 >= 1000 {
continue;
}
if parts.part1 > 12 {
dmy_votes += 1;
}
else if parts.part2 > 12 {
mdy_votes += 1;
}
}
if dmy_votes >= 3 * mdy_votes && dmy_votes > 0 {
return Some(DateFormat::DMY);
}
if mdy_votes >= 3 * dmy_votes && mdy_votes > 0 {
return Some(DateFormat::MDY);
}
if dmy_votes == 0 && mdy_votes == 0 {
return Some(DateFormat::Ambiguous);
}
let slash_count = parsed.iter().filter(|p| p.separator == '/').count();
if (slash_count as f64 / total as f64) > 0.7 {
return Some(DateFormat::DMY);
}
Some(DateFormat::Ambiguous)
}
fn detect_storage_type(stats: &PatternStats) -> FieldValueType {
let total = stats.total();
if total == 0 {
return FieldValueType::Text; }
let threshold = 0.6;
if (stats.timestamp_count as f64 / total as f64) > threshold {
return FieldValueType::Timestamp;
}
let date_like_count = stats.datetime_pattern_count + stats.date_pattern_count;
if stats.text_count > 0 && (date_like_count as f64 / stats.text_count as f64) > threshold {
return FieldValueType::Timestamp;
}
if (stats.numeric_count as f64 / total as f64) > threshold {
FieldValueType::Numeric
} else if (stats.bool_count as f64 / total as f64) > threshold {
FieldValueType::Boolean
} else {
FieldValueType::Text
}
}
fn detect_semantic_type(
stats: &PatternStats,
storage: &FieldValueType,
config: &InferenceConfig,
) -> DataType {
match storage {
FieldValueType::Timestamp => DataType::Temporal,
FieldValueType::Numeric => {
if stats.cardinality() <= config.categorical_threshold && stats.all_integers() {
DataType::Ordinal
} else {
DataType::Quantitative
}
}
FieldValueType::Text => {
if stats.cardinality() <= config.categorical_threshold {
DataType::Nominal
} else {
DataType::Nominal
}
}
FieldValueType::Boolean => DataType::Nominal, }
}
fn calculate_confidence(stats: &PatternStats, inferred_storage: &FieldValueType) -> f64 {
let total = stats.total();
if total == 0 {
return 0.0;
}
let matching_count = match inferred_storage {
FieldValueType::Numeric => stats.numeric_count,
FieldValueType::Text => stats.text_count,
FieldValueType::Timestamp => stats.timestamp_count,
FieldValueType::Boolean => stats.bool_count,
};
let match_rate = matching_count as f64 / total as f64;
if stats.is_mixed() {
match_rate * 0.8 } else {
match_rate
}
}
fn is_datetime_pattern(s: &str) -> bool {
let s = s.trim();
if s.contains("://") || s.contains("www.") {
return false;
}
if (s.contains('x') || s.contains('(')) && !s.contains('T') {
return false;
}
let has_date_sep = s.contains('-') || s.contains('/');
let has_time_sep = s.contains(':');
if !has_date_sep || !has_time_sep {
return false;
}
let parts: Vec<&str> = if s.contains('T') {
s.split('T').collect()
} else {
s.split(' ').collect()
};
if parts.len() != 2 {
return false;
}
let date_part = parts[0];
if !is_date_like_structure(date_part) {
return false;
}
let time_part = parts[1];
is_time_like_structure(time_part)
}
fn is_date_pattern(s: &str) -> bool {
let s = s.trim();
if s.contains("://") || s.contains("www.") || s.contains(".com") || s.contains(".org") {
return false;
}
if s.contains('x') || s.contains('(') || s.contains(')') {
return false;
}
if s.contains(':') {
return false;
}
let has_sep = s.contains('-') || s.contains('/');
if !has_sep {
return false;
}
is_date_like_structure(s)
}
fn is_date_like_structure(s: &str) -> bool {
let parts: Vec<&str> = if s.contains('-') {
s.split('-').collect()
} else if s.contains('/') {
s.split('/').collect()
} else {
return false;
};
if parts.len() != 3 {
return false;
}
if !parts.iter().all(|p| p.chars().all(|c| c.is_ascii_digit())) {
return false;
}
let lens: Vec<usize> = parts.iter().map(|p| p.len()).collect();
matches!(
lens.as_slice(),
[4, 2, 2]
| [2, 2, 4]
| [4, 1, 1]
| [1, 1, 4]
| [2, 1, 4]
| [1, 2, 4]
| [4, 2, 1]
| [4, 1, 2]
)
}
fn is_time_like_structure(s: &str) -> bool {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() < 2 || parts.len() > 3 {
return false;
}
if !parts.iter().all(|p| p.chars().all(|c| c.is_ascii_digit())) {
return false;
}
parts.iter().all(|p| p.len() == 2 || p.len() == 1)
}
fn has_currency_symbol(s: &str) -> bool {
s.contains('$') || s.contains('€') || s.contains('£') || s.contains('¥')
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn create_test_table(rows: Vec<HashMap<String, FieldValue>>) -> DataTable {
DataTable::new(rows)
}
#[test]
fn test_infer_quantitative() {
let rows = vec![
HashMap::from([("revenue".to_string(), FieldValue::Numeric(1250.50))]),
HashMap::from([("revenue".to_string(), FieldValue::Numeric(980.00))]),
HashMap::from([("revenue".to_string(), FieldValue::Numeric(1450.75))]),
];
let table = create_test_table(rows);
let columns = vec!["revenue".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].storage_type, FieldValueType::Numeric);
assert_eq!(inferred[0].semantic_type, DataType::Quantitative);
assert!(inferred[0].confidence > 0.9);
}
#[test]
fn test_infer_nominal_low_cardinality() {
let rows = vec![
HashMap::from([("region".to_string(), FieldValue::Text("North".to_string()))]),
HashMap::from([("region".to_string(), FieldValue::Text("South".to_string()))]),
HashMap::from([("region".to_string(), FieldValue::Text("North".to_string()))]),
];
let table = create_test_table(rows);
let columns = vec!["region".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].storage_type, FieldValueType::Text);
assert_eq!(inferred[0].semantic_type, DataType::Nominal);
assert_eq!(inferred[0].cardinality, 2);
assert!(inferred[0].metadata.is_low_cardinality);
}
#[test]
fn test_infer_ordinal_zip_codes() {
let rows = vec![
HashMap::from([("zip".to_string(), FieldValue::Numeric(10001.0))]),
HashMap::from([("zip".to_string(), FieldValue::Numeric(10002.0))]),
HashMap::from([("zip".to_string(), FieldValue::Numeric(90210.0))]),
];
let table = create_test_table(rows);
let columns = vec!["zip".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].storage_type, FieldValueType::Numeric);
assert_eq!(inferred[0].semantic_type, DataType::Ordinal); assert!(inferred[0].metadata.is_integer);
}
#[test]
fn test_infer_mixed_types_low_confidence() {
let rows = vec![
HashMap::from([("value".to_string(), FieldValue::Numeric(100.0))]),
HashMap::from([("value".to_string(), FieldValue::Text("abc".to_string()))]),
HashMap::from([("value".to_string(), FieldValue::Numeric(200.0))]),
HashMap::from([("value".to_string(), FieldValue::Text("xyz".to_string()))]),
];
let table = create_test_table(rows);
let columns = vec!["value".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].storage_type, FieldValueType::Text);
assert!(inferred[0].confidence < 0.7);
}
#[test]
fn test_pattern_detection_datetime() {
assert!(is_datetime_pattern("2024-01-15 10:30:00"));
assert!(is_datetime_pattern("2024/01/15 10:30:00"));
assert!(is_datetime_pattern("2024-01-15T10:30:00"));
assert!(!is_datetime_pattern("2024-01-15"));
assert!(!is_datetime_pattern("10:30:00"));
assert!(!is_datetime_pattern("http://www.example.com/"));
assert!(!is_datetime_pattern("https://example.com:8080/path"));
assert!(!is_datetime_pattern("846-790-4623x4715"));
}
#[test]
fn test_pattern_detection_date() {
assert!(is_date_pattern("2024-01-15"));
assert!(is_date_pattern("01/15/2024"));
assert!(is_date_pattern("15-01-2024"));
assert!(is_date_pattern("2024/1/5"));
assert!(!is_date_pattern("2024-01-15 10:30:00"));
assert!(!is_date_pattern("revenue"));
assert!(!is_date_pattern("20240115"));
assert!(!is_date_pattern("http://www.shea.biz/"));
assert!(!is_date_pattern("www.example.com/path"));
assert!(!is_date_pattern("example.com"));
assert!(!is_date_pattern("846-790-4623x4715"));
assert!(!is_date_pattern("(335)987-3085x3780"));
assert!(!is_date_pattern("124-597-8652"));
assert!(!is_date_pattern("2024-01"));
assert!(!is_date_pattern("2024-01-15-10"));
}
#[test]
fn test_date_vs_phone_distinction() {
assert!(!is_date_pattern("846-790-4623x4715")); assert!(!is_date_pattern("(335)987-3085x3780")); assert!(!is_date_pattern("555-1234"));
assert!(is_date_pattern("2024-01-15")); assert!(is_date_pattern("01-15-2024")); }
#[test]
fn test_date_vs_url_distinction() {
assert!(!is_date_pattern("http://www.shea.biz/"));
assert!(!is_date_pattern("https://example.com/path"));
assert!(!is_date_pattern("www.example.com"));
assert!(!is_date_pattern("example.com"));
assert!(is_date_pattern("2024/01/15"));
assert!(is_date_pattern("01/15/2024"));
}
#[test]
fn test_currency_detection() {
assert!(has_currency_symbol("$100.00"));
assert!(has_currency_symbol("€50"));
assert!(has_currency_symbol("£25.99"));
assert!(!has_currency_symbol("100"));
}
#[test]
fn test_date_format_iso8601() {
let rows = vec![
HashMap::from([(
"date".to_string(),
FieldValue::Text("2024-01-15".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("2024-03-20".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("2024-12-31".to_string()),
)]),
];
let table = create_test_table(rows);
let columns = vec!["date".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].metadata.date_format, Some(DateFormat::ISO8601));
}
#[test]
fn test_date_format_ymd() {
let rows = vec![
HashMap::from([(
"date".to_string(),
FieldValue::Text("2024/01/15".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("2024/03/20".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("2024/12/31".to_string()),
)]),
];
let table = create_test_table(rows);
let columns = vec!["date".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].metadata.date_format, Some(DateFormat::YMD));
}
#[test]
fn test_date_format_dmy_unambiguous() {
let rows = vec![
HashMap::from([(
"date".to_string(),
FieldValue::Text("15/01/2024".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("25/03/2024".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("31/12/2024".to_string()),
)]),
];
let table = create_test_table(rows);
let columns = vec!["date".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].metadata.date_format, Some(DateFormat::DMY));
}
#[test]
fn test_date_format_mdy_unambiguous() {
let rows = vec![
HashMap::from([(
"date".to_string(),
FieldValue::Text("01/15/2024".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("03/25/2024".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("12/31/2024".to_string()),
)]),
];
let table = create_test_table(rows);
let columns = vec!["date".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].metadata.date_format, Some(DateFormat::MDY));
}
#[test]
fn test_date_format_ambiguous() {
let rows = vec![
HashMap::from([(
"date".to_string(),
FieldValue::Text("01/02/2024".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("03/04/2024".to_string()),
)]),
HashMap::from([(
"date".to_string(),
FieldValue::Text("05/06/2024".to_string()),
)]),
];
let table = create_test_table(rows);
let columns = vec!["date".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(
inferred[0].metadata.date_format,
Some(DateFormat::Ambiguous)
);
}
#[test]
fn test_date_format_datetime_strips_time() {
let rows = vec![
HashMap::from([(
"datetime".to_string(),
FieldValue::Text("2024-01-15 10:30:00".to_string()),
)]),
HashMap::from([(
"datetime".to_string(),
FieldValue::Text("2024-03-20T14:45:00".to_string()),
)]),
];
let table = create_test_table(rows);
let columns = vec!["datetime".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].metadata.date_format, Some(DateFormat::ISO8601));
}
#[test]
fn test_date_format_non_temporal_column() {
let rows = vec![
HashMap::from([("revenue".to_string(), FieldValue::Numeric(1250.50))]),
HashMap::from([("revenue".to_string(), FieldValue::Numeric(980.00))]),
];
let table = create_test_table(rows);
let columns = vec!["revenue".to_string()];
let inferred = infer_column_types(&table, &columns, InferenceConfig::default());
assert_eq!(inferred.len(), 1);
assert_eq!(inferred[0].metadata.date_format, None); }
}