use std::{borrow::Cow, cmp::min, sync::Arc};
use thiserror::Error;
use arrow::{
array::Array,
datatypes::{
DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
Time64MicrosecondType, Time64NanosecondType, TimeUnit, UInt8Type,
},
error::ArrowError,
record_batch::{RecordBatch, RecordBatchReader},
};
use odbc_api::{
BindParamDesc, ColumnarBulkInserter, Connection, ConnectionTransitions, Prepared,
buffers::{AnyBuffer, AnySliceMut},
handles::{AsStatementRef, StatementConnection, StatementImpl, StatementParent},
parameter::WithDataType,
};
use crate::{
date_time::{NullableTimeAsText, epoch_to_date, sec_since_midnight_to_time},
decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
odbc_writer::timestamp::insert_timestamp_strategy,
};
use self::{
binary::VariadicBinary,
boolean::boolean_to_bit,
map_arrow_to_odbc::MapArrowToOdbc,
text::{LargeUtf8ToNativeText, Utf8ToNativeText},
};
mod binary;
mod boolean;
mod map_arrow_to_odbc;
mod text;
mod timestamp;
pub fn insert_into_table(
connection: &Connection,
batches: &mut impl RecordBatchReader,
table_name: &str,
batch_size: usize,
) -> Result<(), WriterError> {
let schema = batches.schema();
let mut inserter =
OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
inserter.write_all(batches)
}
fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
let column_names = column_names
.iter()
.map(|cn| quote_column_name(cn))
.collect::<Vec<_>>();
let columns = column_names.join(", ");
let values = column_names
.iter()
.map(|_| "?")
.collect::<Vec<_>>()
.join(", ");
format!("INSERT INTO {table} ({columns}) VALUES ({values})")
}
fn quote_column_name(column_name: &str) -> Cow<'_, str> {
let is_already_quoted = || {
(column_name.starts_with('"') && column_name.ends_with('"'))
|| column_name.starts_with('[') && column_name.ends_with(']')
|| column_name.starts_with('`') && column_name.ends_with('`')
};
let contains_invalid_characters = || column_name.contains(|c| !valid_in_column_name(c));
let needs_quotes = contains_invalid_characters() && !is_already_quoted();
if needs_quotes {
Cow::Owned(format!("\"{column_name}\""))
} else {
Cow::Borrowed(column_name)
}
}
fn valid_in_column_name(c: char) -> bool {
c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
}
pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
let fields = schema.fields();
let num_columns = fields.len();
let column_names: Vec<_> = (0..num_columns)
.map(|i| fields[i].name().as_str())
.collect();
insert_statement_text(table_name, &column_names)
}
#[derive(Debug, Error)]
pub enum WriterError {
#[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
BindParameterBuffers(#[source] odbc_api::Error),
#[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
ExecuteStatment(#[source] odbc_api::Error),
#[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
RebindBuffer(#[source] odbc_api::Error),
#[error("The arrow data type {0} is not supported for insertion.")]
UnsupportedArrowDataType(DataType),
#[error("An error occured extracting a record batch from an error reader.\n{0}")]
ReadingRecordBatch(#[source] ArrowError),
#[error("Unable to parse '{time_zone}' into a valid IANA time zone.")]
InvalidTimeZone { time_zone: Arc<str> },
#[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
PreparingInsertStatement {
#[source]
source: odbc_api::Error,
sql: String,
},
}
pub struct OdbcWriter<S> {
inserter: ColumnarBulkInserter<S, WithDataType<AnyBuffer>>,
strategies: Vec<Box<dyn WriteStrategy>>,
}
impl<S> OdbcWriter<S>
where
S: AsStatementRef,
{
pub fn new(
row_capacity: usize,
schema: &Schema,
statement: Prepared<S>,
) -> Result<Self, WriterError> {
let strategies: Vec<_> = schema
.fields()
.iter()
.map(|field| field_to_write_strategy(field.as_ref()))
.collect::<Result<_, _>>()?;
let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
let inserter = statement
.into_column_inserter(row_capacity, descriptions)
.map_err(WriterError::BindParameterBuffers)?;
Ok(Self {
inserter,
strategies,
})
}
pub fn write_all(
&mut self,
reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
) -> Result<(), WriterError> {
for result in reader {
let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
self.write_batch(&record_batch)?;
}
self.flush()?;
Ok(())
}
pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
let capacity = self.inserter.capacity();
let mut remanining_rows = record_batch.num_rows();
while remanining_rows != 0 {
let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
let param_offset = self.inserter.num_rows();
self.inserter.set_num_rows(param_offset + chunk_size);
let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
for (index, (array, strategy)) in chunk
.columns()
.iter()
.zip(self.strategies.iter())
.enumerate()
{
strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
}
if self.inserter.num_rows() == capacity {
self.flush()?;
}
remanining_rows -= chunk_size;
}
Ok(())
}
pub fn flush(&mut self) -> Result<(), WriterError> {
self.inserter
.execute()
.map_err(WriterError::ExecuteStatment)?;
self.inserter.clear();
Ok(())
}
}
impl<C> OdbcWriter<StatementConnection<C>>
where
C: StatementParent,
{
pub fn from_connection<C2>(
connection: C2,
schema: &Schema,
table_name: &str,
row_capacity: usize,
) -> Result<Self, WriterError>
where
C2: ConnectionTransitions<StatementParent = C>,
{
let sql = insert_statement_from_schema(schema, table_name);
let statement = connection
.into_prepared(&sql)
.map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
Self::new(row_capacity, schema, statement)
}
}
impl<'o> OdbcWriter<StatementImpl<'o>> {
pub fn with_connection(
connection: &'o Connection<'o>,
schema: &Schema,
table_name: &str,
row_capacity: usize,
) -> Result<Self, WriterError> {
let sql = insert_statement_from_schema(schema, table_name);
let statement = connection
.prepare(&sql)
.map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
Self::new(row_capacity, schema, statement)
}
}
pub trait WriteStrategy {
fn buffer_desc(&self) -> BindParamDesc;
fn write_rows(
&self,
param_offset: usize,
column_buf: AnySliceMut<'_>,
array: &dyn Array,
) -> Result<(), WriterError>;
}
fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
let is_nullable = field.is_nullable();
let strategy = match field.data_type() {
DataType::Utf8 => Box::new(Utf8ToNativeText {}),
DataType::Boolean => boolean_to_bit(is_nullable),
DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
DataType::Int8 => Int8Type::identical(is_nullable),
DataType::Int16 => Int16Type::identical(is_nullable),
DataType::Int32 => Int32Type::identical(is_nullable),
DataType::Int64 => Int64Type::identical(is_nullable),
DataType::UInt8 => UInt8Type::identical(is_nullable),
DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
DataType::Float32 => Float32Type::identical(is_nullable),
DataType::Float64 => Float64Type::identical(is_nullable),
DataType::Timestamp(time_unit, time_zone) => {
insert_timestamp_strategy(is_nullable, &time_unit, time_zone.clone())?
}
DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
epoch_to_date(days_since_epoch.try_into().unwrap())
}),
DataType::Time32(TimeUnit::Second) => {
Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
}
DataType::Time32(TimeUnit::Millisecond) => {
Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
}
DataType::Time64(TimeUnit::Microsecond) => {
Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
}
DataType::Time64(TimeUnit::Nanosecond) => {
Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
}
DataType::Binary => Box::new(VariadicBinary::new(1)),
DataType::FixedSizeBinary(length) => {
Box::new(VariadicBinary::new((*length).try_into().unwrap()))
}
DataType::Decimal128(precision, scale) => {
Box::new(NullableDecimal128AsText::new(*precision, *scale))
}
DataType::Decimal256(precision, scale) => {
Box::new(NullableDecimal256AsText::new(*precision, *scale))
}
unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
};
Ok(strategy)
}