use std::borrow::Cow;
use std::sync::Arc;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
use futures::{AsyncRead, AsyncWrite};
use itertools::{zip_eq, Itertools};
use tiberius::numeric::Numeric;
use tiberius::{BulkLoadRequest, Client, ColumnData, TokenRow};
use tokio::runtime::Runtime;
use crate::api::Append;
use crate::types::{FixedSizeBinaryType, NullType};
use crate::util::escape::escaped_ident;
use crate::util::transport::{Consume, ConsumeTy};
use crate::util::ArrayCellRef;
use crate::{impl_consume_unsupported, ConnectorError};
pub struct TiberiusAppender<'c, S: AsyncRead + AsyncWrite + Unpin + Send> {
rt: Arc<Runtime>,
bulk_load: BulkLoadRequest<'c, S>,
}
impl<'conn, S: AsyncRead + AsyncWrite + Unpin + Send> TiberiusAppender<'conn, S> {
pub fn new(
rt: Arc<Runtime>,
client: &'conn mut Client<S>,
table_name: &str,
) -> Result<Self, ConnectorError> {
let table_name = escaped_ident(table_name).to_string();
let table_name: &'conn str = unsafe { std::mem::transmute::<_, _>(table_name.as_str()) };
let bulk_load = client.bulk_insert(table_name);
let bulk_load = rt.block_on(bulk_load)?;
Ok(Self { rt, bulk_load })
}
}
impl<'conn, S: AsyncRead + AsyncWrite + Unpin + Send> Append<'conn> for TiberiusAppender<'conn, S> {
fn append(&mut self, batch: RecordBatch) -> Result<(), ConnectorError> {
let schema = batch.schema();
let mut row_ref = zip_eq(batch.columns(), schema.fields())
.map(|(array, field)| ArrayCellRef {
array,
field,
row_number: 0,
})
.collect_vec();
for row_number in 0..batch.num_rows() {
let mut tb_row = TokenRow::with_capacity(row_ref.len());
let mut buffer = Vec::with_capacity(1);
for cell_ref in &mut row_ref {
cell_ref.row_number = row_number;
crate::util::transport::transport(cell_ref.field, &*cell_ref, &mut buffer)?;
tb_row.push(buffer.pop().unwrap());
}
let f = self.bulk_load.send(tb_row);
self.rt.block_on(f)?;
}
Ok(())
}
fn finish(self) -> Result<(), ConnectorError> {
let res = self.bulk_load.finalize();
self.rt.block_on(res)?;
Ok(())
}
}
impl Consume for Vec<ColumnData<'static>> {}
macro_rules! impl_consume_ty {
($ArrTy: ty, $variant: ident) => {
impl_consume_ty!($ArrTy, $variant, std::convert::identity);
};
($ArrTy: ty, $variant: ident, $conversion: expr) => {
impl ConsumeTy<$ArrTy> for Vec<ColumnData<'static>> {
fn consume(
&mut self,
_ty: &DataType,
value: <$ArrTy as crate::types::ArrowType>::Native,
) {
self.push(ColumnData::$variant(Some(($conversion)(value))))
}
fn consume_null(&mut self, _ty: &DataType) {
self.push(ColumnData::$variant(None))
}
}
};
}
impl ConsumeTy<NullType> for Vec<ColumnData<'static>> {
fn consume(&mut self, _ty: &DataType, _: ()) {
self.push(ColumnData::U8(None))
}
fn consume_null(&mut self, _ty: &DataType) {
self.push(ColumnData::U8(None))
}
}
impl_consume_ty!(BooleanType, Bit);
impl_consume_ty!(Int8Type, I16, i16::from);
impl_consume_ty!(Int16Type, I16);
impl_consume_ty!(Int32Type, I32);
impl_consume_ty!(Int64Type, I64);
impl_consume_ty!(UInt8Type, U8);
impl_consume_ty!(UInt16Type, I32, i32::from);
impl_consume_ty!(UInt32Type, I64, i64::from);
impl_consume_ty!(UInt64Type, Numeric, u64_to_numeric);
impl_consume_ty!(Float16Type, F32, f32::from);
impl_consume_ty!(Float32Type, F32);
impl_consume_ty!(Float64Type, F64);
impl_consume_ty!(Utf8Type, String, Cow::from);
impl_consume_ty!(LargeUtf8Type, String, Cow::from);
impl_consume_ty!(TimestampSecondType, I64);
impl_consume_ty!(TimestampMillisecondType, I64);
impl_consume_ty!(TimestampMicrosecondType, I64);
impl_consume_ty!(TimestampNanosecondType, I64);
impl_consume_unsupported!(
Vec<ColumnData<'static>>,
(
Date32Type,
Date64Type,
Time32SecondType,
Time32MillisecondType,
Time64MicrosecondType,
Time64NanosecondType,
DurationSecondType,
DurationMillisecondType,
DurationMicrosecondType,
DurationNanosecondType,
IntervalYearMonthType,
IntervalDayTimeType,
IntervalMonthDayNanoType,
BinaryType,
LargeBinaryType,
FixedSizeBinaryType,
)
);
fn u64_to_numeric(val: u64) -> Numeric {
Numeric::new_with_scale(i128::from(val), 0)
}
impl ConsumeTy<Decimal128Type> for Vec<ColumnData<'static>> {
fn consume(&mut self, ty: &DataType, value: i128) {
let DataType::Decimal128(p, s) = ty else {
panic!()
};
if super::can_decimal_fit_in_numeric(*p, *s) {
self.push(ColumnData::Numeric(Some(Numeric::new_with_scale(
value, *s as u8,
))));
} else {
let string = Decimal128Type::format_decimal(value, *p, *s);
self.push(ColumnData::String(Some(string.into())));
}
}
fn consume_null(&mut self, ty: &DataType) {
let DataType::Decimal128(p, s) = ty else {
panic!()
};
if super::can_decimal_fit_in_numeric(*p, *s) {
self.push(ColumnData::Numeric(None));
} else {
self.push(ColumnData::String(None));
}
}
}
impl ConsumeTy<Decimal256Type> for Vec<ColumnData<'static>> {
fn consume(&mut self, ty: &DataType, value: i256) {
let DataType::Decimal256(p, s) = ty else {
panic!()
};
if super::can_decimal_fit_in_numeric(*p, *s) {
self.push(ColumnData::Numeric(Some(Numeric::new_with_scale(
value.as_i128(),
*s as u8,
))));
} else {
let string = Decimal256Type::format_decimal(value, *p, *s);
self.push(ColumnData::String(Some(string.into())));
}
}
fn consume_null(&mut self, ty: &DataType) {
let DataType::Decimal256(p, s) = ty else {
panic!()
};
if super::can_decimal_fit_in_numeric(*p, *s) {
self.push(ColumnData::Numeric(None));
} else {
self.push(ColumnData::String(None));
}
}
}