use crate::connection::ODBC_ENV;
use crate::connection::dm::buffer::{buffer_to_batch, build_buffer_desc};
use crate::connection::dm::row::row_to_batch;
use crate::{
Connection, ConnectionOptions, DFResult, DmConnectionOptions, DmType, Literalize, Pool,
PoolState, RemoteDbType, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
};
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use async_stream::stream;
use datafusion_common::DataFusionError;
use datafusion_common::project_schema;
use datafusion_execution::SendableRecordBatchStream;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::lock::Mutex;
use log::debug;
use odbc_api::buffers::ColumnarAnyBuffer;
use odbc_api::handles::StatementImpl;
use odbc_api::{Cursor, CursorImpl, Environment, ResultSetMetadata};
use std::any::Any;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::runtime::Handle;
mod buffer;
mod row;
#[derive(Debug)]
pub struct DmPool {
options: DmConnectionOptions,
connections: Arc<AtomicUsize>,
}
pub(crate) fn connect_dm(options: &DmConnectionOptions) -> DFResult<DmPool> {
Ok(DmPool {
options: options.clone(),
connections: Arc::new(AtomicUsize::new(0)),
})
}
#[async_trait::async_trait]
impl Pool for DmPool {
async fn get(&self) -> DFResult<Arc<dyn Connection>> {
let env = ODBC_ENV.get_or_init(|| Environment::new().expect("failed to create ODBC env"));
let mut connection_str = format!(
"Driver={{{}}};Server={};TCP_Port={};UID={};PWD={}",
self.options.driver,
self.options.host,
self.options.port,
self.options.username,
self.options.password,
);
if let Some(schema) = &self.options.schema {
connection_str.push_str(&format!(";SCHEMA={schema}"));
}
let connection = env
.connect_with_connection_string(&connection_str, odbc_api::ConnectionOptions::default())
.map_err(|e| {
DataFusionError::Execution(format!("Failed to create odbc connection: {e:?}"))
})?;
self.connections.fetch_add(1, Ordering::SeqCst);
Ok(Arc::new(DmConnection {
conn: Arc::new(Mutex::new(connection)),
pool_connections: self.connections.clone(),
}))
}
async fn state(&self) -> DFResult<PoolState> {
Ok(PoolState {
connections: self.connections.load(Ordering::SeqCst),
idle_connections: 0,
})
}
}
#[derive(Debug)]
pub struct DmConnection {
conn: Arc<Mutex<odbc_api::Connection<'static>>>,
pool_connections: Arc<AtomicUsize>,
}
impl Drop for DmConnection {
fn drop(&mut self) {
self.pool_connections.fetch_sub(1, Ordering::SeqCst);
}
}
#[async_trait::async_trait]
impl Connection for DmConnection {
fn as_any(&self) -> &dyn Any {
self
}
async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
let sql = RemoteDbType::Dm.limit_1_query_if_possible(source);
let conn = self.conn.lock().await;
let cursor_opt = conn.execute(&sql, (), None).map_err(|e| {
DataFusionError::Plan(format!("Failed to execute query {sql} on dm: {e:?}"))
})?;
match cursor_opt {
None => Err(DataFusionError::Plan(
"No rows returned to infer schema".to_string(),
)),
Some(cursor) => {
let remote_schema = Arc::new(build_remote_schema(cursor)?);
Ok(remote_schema)
}
}
}
async fn query(
&self,
conn_options: &ConnectionOptions,
source: &RemoteSource,
table_schema: SchemaRef,
projection: Option<&Vec<usize>>,
unparsed_filters: &[String],
limit: Option<usize>,
) -> DFResult<SendableRecordBatchStream> {
let projected_schema = project_schema(&table_schema, projection)?;
let sql = RemoteDbType::Dm.rewrite_query(source, unparsed_filters, limit);
debug!("[remote-table] executing dm query: {sql}");
let chunk_size = conn_options.stream_chunk_size();
let conn = Arc::clone(&self.conn);
let projection = projection.cloned();
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
let join_handle = tokio::task::spawn_blocking(move || {
let handle = Handle::current();
let conn = handle.block_on(async { conn.lock().await });
let cursor_opt = conn.execute(&sql, (), None).map_err(|e| {
DataFusionError::Execution(format!("Failed to execute query: {e:?}"))
})?;
match cursor_opt {
None => {}
Some(mut cursor) => {
if contains_large_column(&mut cursor)? {
while let Some(row) = cursor.next_row().map_err(|e| {
DataFusionError::Execution(format!("Failed to fetch row: {e:?}"))
})? {
let batch = row_to_batch(row, &table_schema, projection.as_ref())?;
batch_tx.blocking_send(batch).map_err(|e| {
DataFusionError::Execution(format!("Failed to send batch: {e:?}"))
})?;
}
} else {
let buffer_descs = table_schema
.fields()
.iter()
.enumerate()
.map(|(idx, field)| build_buffer_desc(field, &mut cursor, idx))
.collect::<DFResult<Vec<_>>>()?;
let row_set_buffer = ColumnarAnyBuffer::try_from_descs(
chunk_size,
buffer_descs,
)
.map_err(|e| {
DataFusionError::Execution(format!("Failed to create buffer: {e:?}"))
})?;
let mut block_cursor = cursor.bind_buffer(row_set_buffer).map_err(|e| {
DataFusionError::Execution(format!("Failed to bind buffer: {e:?}"))
})?;
loop {
match block_cursor.fetch_with_truncation_check(true) {
Ok(Some(buffer)) => {
let batch = buffer_to_batch(
buffer,
&table_schema,
projection.as_ref(),
chunk_size,
)?;
batch_tx.blocking_send(batch).map_err(|e| {
DataFusionError::Execution(format!(
"Failed to send batch: {e:?}"
))
})?;
}
Ok(None) => break,
Err(odbc_error) => {
return Err(DataFusionError::External(Box::new(odbc_error)));
}
}
}
}
}
}
Ok::<_, DataFusionError>(())
});
let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
}
match join_handle.await {
Ok(Ok(())) => {},
Ok(Err(e)) => yield Err(e),
Err(e) => yield Err(DataFusionError::Execution(format!(
"Failed to execute ODBC query: {e}"
))),
}
};
Ok(Box::pin(RecordBatchStreamAdapter::new(
projected_schema,
output_stream,
)))
}
async fn insert(
&self,
_conn_options: &ConnectionOptions,
_literalizer: Arc<dyn Literalize>,
_table: &[String],
_remote_schema: RemoteSchemaRef,
_batch: RecordBatch,
) -> DFResult<usize> {
Err(DataFusionError::Execution(
"Insert operation is not supported for dm".to_string(),
))
}
}
fn build_remote_schema(mut cursor: CursorImpl<StatementImpl>) -> DFResult<RemoteSchema> {
let col_count = cursor
.num_result_cols()
.map_err(|e| DataFusionError::External(Box::new(e)))? as u16;
let mut remote_fields = vec![];
for i in 1..=col_count {
let col_name = cursor
.col_name(i)
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let col_type = cursor
.col_data_type(i)
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let remote_type = RemoteType::Dm(dm_type_to_remote_type(col_type)?);
let col_nullable = cursor
.col_nullability(i)
.map_err(|e| DataFusionError::External(Box::new(e)))?
.could_be_nullable();
remote_fields.push(RemoteField::new(col_name, remote_type, col_nullable));
}
Ok(RemoteSchema::new(remote_fields))
}
fn contains_large_column(cursor: &mut CursorImpl<StatementImpl>) -> DFResult<bool> {
let col_count = cursor
.num_result_cols()
.map_err(|e| DataFusionError::External(Box::new(e)))? as u16;
for i in 1..=col_count {
let col_type = cursor
.col_data_type(i)
.map_err(|e| DataFusionError::External(Box::new(e)))?;
if matches!(
col_type,
odbc_api::DataType::LongVarchar { length: _ }
| odbc_api::DataType::LongVarbinary { length: _ }
) {
return Ok(true);
}
}
Ok(false)
}
fn dm_type_to_remote_type(data_type: odbc_api::DataType) -> DFResult<DmType> {
match data_type {
odbc_api::DataType::TinyInt => Ok(DmType::TinyInt),
odbc_api::DataType::SmallInt => Ok(DmType::SmallInt),
odbc_api::DataType::Integer => Ok(DmType::Integer),
odbc_api::DataType::BigInt => Ok(DmType::BigInt),
odbc_api::DataType::Real => Ok(DmType::Real),
odbc_api::DataType::Double => Ok(DmType::Double),
odbc_api::DataType::Numeric { precision, scale } => {
assert!(precision >= 1);
assert!(precision <= 38);
assert!(scale <= 38);
Ok(DmType::Numeric(precision as u8, scale as i8))
}
odbc_api::DataType::Decimal { precision, scale } => {
assert!(precision >= 1);
assert!(precision <= 38);
assert!(scale <= 38);
Ok(DmType::Decimal(precision as u8, scale as i8))
}
odbc_api::DataType::Char { length } => Ok(DmType::Char(length.map(|l| l.get() as u16))),
odbc_api::DataType::Varchar { length } => {
Ok(DmType::Varchar(length.map(|l| l.get() as u16)))
}
odbc_api::DataType::Binary { length } => Ok(DmType::Binary(
length.expect("length should not be none").get() as u16,
)),
odbc_api::DataType::Varbinary { length } => {
Ok(DmType::Varbinary(length.map(|l| l.get() as u16)))
}
odbc_api::DataType::LongVarchar { .. } => Ok(DmType::Text),
odbc_api::DataType::LongVarbinary { .. } => Ok(DmType::Image),
odbc_api::DataType::Bit => Ok(DmType::Bit),
odbc_api::DataType::Timestamp { precision } => {
assert!(precision >= 0);
assert!(precision <= 9);
Ok(DmType::Timestamp(precision as u8))
}
odbc_api::DataType::Time { precision } => {
assert!(precision >= 0);
assert!(precision <= 6);
Ok(DmType::Time(precision as u8))
}
odbc_api::DataType::Date => Ok(DmType::Date),
_ => Err(DataFusionError::Execution(format!(
"Unsupported DM type: {data_type:?}"
))),
}
}
pub(crate) fn seconds_since_epoch(value: &odbc_api::sys::Timestamp) -> DFResult<i64> {
let ndt =
chrono::NaiveDate::from_ymd_opt(value.year as i32, value.month as u32, value.day as u32)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?
.and_hms_opt(value.hour as u32, value.minute as u32, value.second as u32)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?;
Ok::<_, DataFusionError>(ndt.and_utc().timestamp())
}
pub(crate) fn ms_since_epoch(value: &odbc_api::sys::Timestamp) -> DFResult<i64> {
let ndt =
chrono::NaiveDate::from_ymd_opt(value.year as i32, value.month as u32, value.day as u32)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?
.and_hms_nano_opt(
value.hour as u32,
value.minute as u32,
value.second as u32,
value.fraction,
)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?;
Ok::<_, DataFusionError>(ndt.and_utc().timestamp_millis())
}
pub(crate) fn us_since_epoch(value: &odbc_api::sys::Timestamp) -> DFResult<i64> {
let ndt =
chrono::NaiveDate::from_ymd_opt(value.year as i32, value.month as u32, value.day as u32)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?
.and_hms_nano_opt(
value.hour as u32,
value.minute as u32,
value.second as u32,
value.fraction,
)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?;
Ok::<_, DataFusionError>(ndt.and_utc().timestamp_micros())
}
pub(crate) fn ns_since_epoch(value: &odbc_api::sys::Timestamp) -> DFResult<i64> {
let ndt =
chrono::NaiveDate::from_ymd_opt(value.year as i32, value.month as u32, value.day as u32)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?
.and_hms_nano_opt(
value.hour as u32,
value.minute as u32,
value.second as u32,
value.fraction,
)
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))?;
ndt.and_utc()
.timestamp_nanos_opt()
.ok_or_else(|| DataFusionError::Execution(format!("Invalid timestamp: {value:?}")))
}