use std::{cmp::min, convert::TryInto, num::NonZeroUsize};
use anyhow::Error;
use log::{debug, info};
use odbc_api::{
buffers::{AnySlice, BufferDesc},
sys::SqlDataType,
DataType, Nullability, ResultSetMetadata,
};
use parquet::{
basic::{LogicalType, Repetition},
column::writer::ColumnWriter,
data_type::{
ByteArrayType, DoubleType, FixedLenByteArrayType, FloatType, Int32Type, Int64Type,
},
schema::types::Type,
};
use crate::{
parquet_buffer::ParquetBuffer,
query::{
binary::Binary,
boolean::Boolean,
date::Date,
decimal::decimal_fetch_strategy,
identical::{fetch_identical, fetch_identical_with_logical_type},
text::text_strategy,
time::time_from_text,
timestamp::timestamp_without_tz,
timestamp_tz::timestamp_tz,
},
};
pub trait ColumnStrategy {
fn parquet_type(&self, name: &str) -> Type;
fn buffer_desc(&self) -> BufferDesc;
fn copy_odbc_to_parquet(
&self,
parquet_buffer: &mut ParquetBuffer,
column_writer: &mut ColumnWriter,
column_view: AnySlice,
) -> Result<(), Error>;
}
#[derive(Clone, Copy)]
pub struct MappingOptions<'a> {
pub db_name: &'a str,
pub use_utf16: bool,
pub prefer_varbinary: bool,
pub avoid_decimal: bool,
pub driver_does_support_i64: bool,
pub column_length_limit: usize,
}
pub fn strategy_from_column_description(
name: &str,
data_type: DataType,
nullability: Nullability,
mapping_options: MappingOptions,
cursor: &mut impl ResultSetMetadata,
index: i16,
) -> Result<Box<dyn ColumnStrategy>, Error> {
let MappingOptions {
db_name,
use_utf16,
prefer_varbinary,
avoid_decimal,
driver_does_support_i64,
column_length_limit,
} = mapping_options;
let is_optional = nullability.could_be_nullable();
let repetition = if is_optional {
Repetition::OPTIONAL
} else {
Repetition::REQUIRED
};
let apply_length_limit = |reported_length: Option<NonZeroUsize>| {
min(
reported_length
.map(NonZeroUsize::get)
.unwrap_or(column_length_limit),
column_length_limit,
)
};
let strategy: Box<dyn ColumnStrategy> = match data_type {
DataType::Float { precision: 0..=24 } | DataType::Real => {
fetch_identical::<FloatType>(is_optional)
}
DataType::Float { precision: _ } => fetch_identical::<DoubleType>(is_optional),
DataType::Double => fetch_identical::<DoubleType>(is_optional),
DataType::SmallInt => fetch_identical_with_logical_type::<Int32Type>(
is_optional,
LogicalType::Integer {
bit_width: 16,
is_signed: true,
},
),
DataType::Integer => fetch_identical_with_logical_type::<Int32Type>(
is_optional,
LogicalType::Integer {
bit_width: 32,
is_signed: true,
},
),
DataType::Date => Box::new(Date::new(repetition)),
DataType::Numeric { scale, precision } | DataType::Decimal { scale, precision } => {
decimal_fetch_strategy(
is_optional,
scale as i32,
precision.try_into().unwrap(),
avoid_decimal,
driver_does_support_i64,
)
}
DataType::Timestamp { precision } => {
timestamp_without_tz(repetition, precision.try_into().unwrap())
}
DataType::BigInt => fetch_identical::<Int64Type>(is_optional),
DataType::Bit => Box::new(Boolean::new(repetition)),
DataType::TinyInt => {
let is_signed = !cursor.column_is_unsigned(index.try_into().unwrap())?;
fetch_identical_with_logical_type::<Int32Type>(
is_optional,
LogicalType::Integer {
bit_width: 8,
is_signed,
},
)
}
DataType::Binary { length } => {
let length = apply_length_limit(length);
if prefer_varbinary {
Box::new(Binary::<ByteArrayType>::new(repetition, length))
} else {
Box::new(Binary::<FixedLenByteArrayType>::new(repetition, length))
}
}
DataType::Varbinary { length } | DataType::LongVarbinary { length } => {
let length = apply_length_limit(length);
Box::new(Binary::<ByteArrayType>::new(repetition, length))
}
dt @ (DataType::Char { length: _ }
| DataType::Varchar { length: _ }
| DataType::WVarchar { length: _ }
| DataType::WLongVarchar { length: _ }
| DataType::LongVarchar { length: _ }
| DataType::WChar { length: _ }) => {
let len_in_chars = if use_utf16 {
dt.utf16_len()
} else {
dt.utf8_len()
};
let length = apply_length_limit(len_in_chars);
text_strategy(use_utf16, repetition, length)
}
DataType::Other {
data_type: SqlDataType(-154),
column_size: _,
decimal_digits: precision,
} => {
if db_name == "Microsoft SQL Server" {
time_from_text(repetition, precision.try_into().unwrap())
} else {
unknown_non_char_type(&data_type, cursor, index, repetition, apply_length_limit)?
}
}
DataType::Other {
data_type: SqlDataType(-155),
column_size: _,
decimal_digits: precision,
} => {
if db_name == "Microsoft SQL Server" {
info!(
"Detected Timestamp type with time zone. Applying instant semantics for \
column {name}."
);
timestamp_tz(precision.try_into().unwrap(), repetition)?
} else {
unknown_non_char_type(&data_type, cursor, index, repetition, apply_length_limit)?
}
}
DataType::Unknown | DataType::Time { .. } | DataType::Other { .. } => {
unknown_non_char_type(&data_type, cursor, index, repetition, apply_length_limit)?
}
};
let desc = strategy.buffer_desc();
debug!("ODBC buffer description for column at index {index}: {desc:?}",);
Ok(strategy)
}
fn unknown_non_char_type(
data_type: &DataType,
cursor: &mut impl ResultSetMetadata,
index: i16,
repetition: Repetition,
apply_length_limit: impl FnOnce(Option<NonZeroUsize>) -> usize,
) -> Result<Box<dyn ColumnStrategy>, Error> {
let length = if let Some(len) = data_type.utf8_len() {
Some(len)
} else {
cursor.col_display_size(index.try_into().unwrap())?
};
let length = apply_length_limit(length);
let use_utf16 = false;
Ok(text_strategy(use_utf16, repetition, length))
}