use crate::naming;
use crate::plot::scale::coerce_dtypes;
use crate::plot::{CastTargetType, Plot};
use crate::reader::SqlDialect;
use arrow::datatypes::{DataType, TimeUnit};
use std::collections::HashMap;
use super::schema::TypeInfo;
#[derive(Debug, Clone)]
pub struct TypeRequirement {
pub column: String,
pub target_type: CastTargetType,
pub sql_type_name: String,
}
pub fn determine_type_requirements(
spec: &Plot,
layer_type_info: &[Vec<TypeInfo>],
dialect: &dyn SqlDialect,
) -> Vec<Vec<TypeRequirement>> {
use crate::plot::scale::TransformKind;
let mut layer_requirements: Vec<Vec<TypeRequirement>> = Vec::new();
for (layer_idx, layer) in spec.layers.iter().enumerate() {
let mut requirements: Vec<TypeRequirement> = Vec::new();
let type_info = &layer_type_info[layer_idx];
let column_dtypes: HashMap<&str, &DataType> = type_info
.iter()
.map(|(name, dtype, _)| (name.as_str(), dtype))
.collect();
for (aesthetic, value) in &layer.mappings.aesthetics {
let col_name = match value.column_name() {
Some(name) => name,
None => continue, };
if naming::is_synthetic_column(col_name) {
continue;
}
let col_dtype = match column_dtypes.get(col_name) {
Some(dtype) => *dtype,
None => continue, };
let scale = match spec.scales.iter().find(|s| s.aesthetic == *aesthetic) {
Some(s) => s,
None => continue, };
let scale_type = match &scale.scale_type {
Some(st) => st,
None => continue, };
let all_dtypes: Vec<DataType> = layer_type_info
.iter()
.zip(spec.layers.iter())
.filter_map(|(info, l)| {
l.mappings
.get(aesthetic)
.and_then(|v| v.column_name())
.and_then(|name| info.iter().find(|(n, _, _)| n == name))
.map(|(_, dtype, _)| dtype.clone())
})
.collect();
let target_dtype = match coerce_dtypes(&all_dtypes) {
Ok(dt) => dt,
Err(_) => continue, };
if let Some(cast_target) = scale_type.required_cast_type(col_dtype, &target_dtype) {
if let Some(sql_type) = dialect.type_name_for(cast_target) {
if !requirements.iter().any(|r| r.column == col_name) {
requirements.push(TypeRequirement {
column: col_name.to_string(),
target_type: cast_target,
sql_type_name: sql_type.to_string(),
});
}
}
}
if let Some(ref transform) = scale.transform {
if transform.transform_kind() == TransformKind::Integer {
let needs_int_cast = match col_dtype {
DataType::Float32 | DataType::Float64 => true,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => false,
_ => false,
};
if needs_int_cast {
if let Some(sql_type) = dialect.type_name_for(CastTargetType::Integer) {
if !requirements.iter().any(|r| r.column == col_name) {
requirements.push(TypeRequirement {
column: col_name.to_string(),
target_type: CastTargetType::Integer,
sql_type_name: sql_type.to_string(),
});
}
}
}
}
}
}
layer_requirements.push(requirements);
}
layer_requirements
}
pub fn update_type_info_for_casting(type_info: &mut [TypeInfo], requirements: &[TypeRequirement]) {
for req in requirements {
if let Some(entry) = type_info
.iter_mut()
.find(|(name, _, _)| name == &req.column)
{
entry.1 = match req.target_type {
CastTargetType::Number => DataType::Float64,
CastTargetType::Integer => DataType::Int64,
CastTargetType::Date => DataType::Date32,
CastTargetType::DateTime => DataType::Timestamp(TimeUnit::Microsecond, None),
CastTargetType::Time => DataType::Time64(TimeUnit::Nanosecond),
CastTargetType::String => DataType::Utf8,
CastTargetType::Boolean => DataType::Boolean,
};
entry.2 = matches!(entry.1, DataType::Utf8 | DataType::Boolean);
}
}
}