use std::{
collections::HashSet,
io::{Read, Seek},
};
use super::{ByteRecord, Reader};
use crate::datatypes::{DataType, TimeUnit};
use crate::datatypes::{Field, Schema};
use crate::error::Result;
pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";
fn is_boolean(bytes: &[u8]) -> bool {
bytes.eq_ignore_ascii_case(b"true") | bytes.eq_ignore_ascii_case(b"false")
}
fn is_float(bytes: &[u8]) -> bool {
lexical_core::parse::<f64>(bytes).is_ok()
}
fn is_integer(bytes: &[u8]) -> bool {
lexical_core::parse::<i64>(bytes).is_ok()
}
fn is_date(string: &str) -> bool {
string.parse::<chrono::NaiveDate>().is_ok()
}
fn is_time(string: &str) -> bool {
string.parse::<chrono::NaiveTime>().is_ok()
}
fn is_naive_datetime(string: &str) -> bool {
string.parse::<chrono::NaiveDateTime>().is_ok()
}
fn is_datetime(string: &str) -> Option<String> {
let mut parsed = chrono::format::Parsed::new();
let fmt = chrono::format::StrftimeItems::new(RFC3339);
if chrono::format::parse(&mut parsed, string, fmt).is_ok() {
parsed.offset.map(|x| {
let hours = x / 60 / 60;
let minutes = x / 60 - hours * 60;
format!("{:03}:{:02}", hours, minutes)
})
} else {
None
}
}
pub fn infer(bytes: &[u8]) -> DataType {
if is_boolean(bytes) {
DataType::Boolean
} else if is_integer(bytes) {
DataType::Int64
} else if is_float(bytes) {
DataType::Float64
} else if let Ok(string) = simdutf8::basic::from_utf8(bytes) {
if is_date(string) {
DataType::Date32
} else if is_time(string) {
DataType::Time32(TimeUnit::Millisecond)
} else if is_naive_datetime(string) {
DataType::Timestamp(TimeUnit::Millisecond, None)
} else if let Some(offset) = is_datetime(string) {
DataType::Timestamp(TimeUnit::Millisecond, Some(offset))
} else {
DataType::Utf8
}
} else {
DataType::Binary
}
}
pub fn infer_schema<R: Read + Seek, F: Fn(&[u8]) -> DataType>(
reader: &mut Reader<R>,
max_rows: Option<usize>,
has_header: bool,
infer: &F,
) -> Result<Schema> {
let headers: Vec<String> = if has_header {
reader.headers()?.iter().map(|s| s.to_string()).collect()
} else {
let first_record_count = &reader.headers()?.len();
(0..*first_record_count)
.map(|i| format!("column_{}", i + 1))
.collect()
};
let position = reader.position().clone();
let header_length = headers.len();
let mut column_types: Vec<HashSet<DataType>> = vec![HashSet::new(); header_length];
let mut records_count = 0;
let mut record = ByteRecord::new();
let max_records = max_rows.unwrap_or(usize::MAX);
while records_count < max_records {
if !reader.read_byte_record(&mut record)? {
break;
}
records_count += 1;
for (i, column) in column_types.iter_mut().enumerate() {
if let Some(string) = record.get(i) {
column.insert(infer(string));
}
}
}
let fields = headers
.iter()
.zip(column_types.into_iter())
.map(|(field_name, mut possibilities)| {
let data_type = match possibilities.len() {
1 => possibilities.drain().next().unwrap(),
2 => {
if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64)
{
DataType::Float64
} else {
DataType::Utf8
}
}
_ => DataType::Utf8,
};
Field::new(field_name, data_type, true)
})
.collect();
reader.seek(position)?;
Ok(Schema::new(fields))
}