mod append;
mod query;
mod schema;
mod types;
pub use tiberius;
use arrow::datatypes::*;
use futures::{AsyncRead, AsyncWrite};
use itertools::Itertools;
use std::sync::Arc;
use tokio::runtime::Runtime;
use crate::api::Connector;
use crate::ConnectorError;
pub struct TiberiusConnection<S: AsyncRead + AsyncWrite + Unpin + Send> {
rt: Arc<Runtime>,
client: tiberius::Client<S>,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> TiberiusConnection<S> {
pub fn new(rt: Arc<Runtime>, client: tiberius::Client<S>) -> Self {
TiberiusConnection { rt, client }
}
pub fn unwrap(self) -> (Arc<Runtime>, tiberius::Client<S>) {
(self.rt, self.client)
}
pub fn inner_mut(&mut self) -> (&mut Arc<Runtime>, &mut tiberius::Client<S>) {
(&mut self.rt, &mut self.client)
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connector for TiberiusConnection<S> {
type Stmt<'conn>
= query::TiberiusStatement<'conn, S>
where
Self: 'conn;
type Append<'conn>
= append::TiberiusAppender<'conn, S>
where
Self: 'conn;
fn query<'a>(&'a mut self, query: &str) -> Result<Self::Stmt<'a>, ConnectorError> {
Ok(query::TiberiusStatement {
conn: self,
query: query.to_string(),
})
}
fn append<'a>(&'a mut self, table_name: &str) -> Result<Self::Append<'a>, ConnectorError> {
append::TiberiusAppender::new(self.rt.clone(), &mut self.client, table_name)
}
#[allow(clippy::get_first)]
fn type_db_into_arrow(ty: &str) -> Option<DataType> {
let ty = ty.to_lowercase();
let (name, args) = if let Some((name, args)) = ty.split_once('(') {
(
name,
args.trim_end_matches(')')
.split(',')
.filter_map(|a| a.trim().parse::<i16>().ok())
.collect_vec(),
)
} else {
(ty.as_str(), vec![])
};
Some(match name {
"null" | "intn" => DataType::Null,
"bit" => DataType::Boolean,
"tinyint" => DataType::UInt8,
"smallint" => DataType::Int16,
"int" => DataType::Int32,
"bigint" => DataType::Int64,
"char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext" => DataType::Utf8,
"real" | "float" => {
let is_f32 = args
.get(0)
.map(|p| *p <= 24)
.unwrap_or_else(|| name == "real");
if is_f32 {
DataType::Float32
} else {
DataType::Float64
}
}
"decimal" | "numeric" => DataType::Utf8,
_ => return None,
})
}
fn type_arrow_into_db(ty: &DataType) -> Option<String> {
Some(
match ty {
DataType::Null => "tinyint",
DataType::Boolean => "bit",
DataType::Int8 => "smallint",
DataType::Int16 => "smallint",
DataType::Int32 => "int",
DataType::Int64 => "bigint",
DataType::UInt8 => "tinyint",
DataType::UInt16 => "int",
DataType::UInt32 => "bigint",
DataType::UInt64 => "decimal(20, 0)",
DataType::Float32 => "float(24)", DataType::Float64 => "float(53)", DataType::Float16 => "float(24)",
DataType::Timestamp(_, _) => "bigint",
DataType::Utf8 => "nvarchar(max)",
DataType::LargeUtf8 => "nvarchar(max)",
DataType::Decimal128(p, s) | DataType::Decimal256(p, s)
if can_decimal_fit_in_numeric(*p, *s) =>
{
return Some(format!("numeric({p}, {s})"));
}
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => "nvarchar(max)",
_ => return None,
}
.to_string(),
)
}
}
fn can_decimal_fit_in_numeric(precision: u8, scale: i8) -> bool {
precision < 38 && scale >= 0 && precision >= scale as u8
}