datafusion_table_providers/sql/db_connection_pool/dbconnection/
odbcconn.rs1use std::any::Any;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use crate::sql::db_connection_pool::{
22 dbconnection::{self, AsyncDbConnection, DbConnection, GenericError},
23 runtime::run_async_with_tokio,
24 DbConnectionPool,
25};
26use arrow_odbc::arrow_schema_from;
27use arrow_odbc::OdbcReader;
28use arrow_odbc::OdbcReaderBuilder;
29use async_stream::stream;
30use async_trait::async_trait;
31use datafusion::arrow::datatypes::Schema;
32use datafusion::arrow::datatypes::SchemaRef;
33use datafusion::arrow::record_batch::RecordBatch;
34use datafusion::error::DataFusionError;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
37use datafusion::sql::TableReference;
38use dyn_clone::DynClone;
39use futures::lock::Mutex;
40use odbc_api::handles::SqlResult;
41use odbc_api::handles::Statement;
42use odbc_api::handles::StatementImpl;
43use odbc_api::parameter::InputParameter;
44use odbc_api::Cursor;
45use odbc_api::CursorImpl;
46use secrecy::{ExposeSecret, SecretBox, SecretString};
47use snafu::prelude::*;
48use snafu::Snafu;
49use tokio::runtime::Handle;
50
51use odbc_api::Connection;
52use tokio::sync::mpsc::Sender;
53
54type Result<T, E = GenericError> = std::result::Result<T, E>;
55
56pub trait ODBCSyncParameter: InputParameter + Sync + Send + DynClone {
57 fn as_input_parameter(&self) -> &dyn InputParameter;
58}
59
60impl<T: InputParameter + Sync + Send + DynClone> ODBCSyncParameter for T {
61 fn as_input_parameter(&self) -> &dyn InputParameter {
62 self
63 }
64}
65
66dyn_clone::clone_trait_object!(ODBCSyncParameter);
67
68pub type ODBCParameter = Box<dyn ODBCSyncParameter>;
69pub type ODBCDbConnection<'a> = (dyn DbConnection<Connection<'a>, ODBCParameter>);
70pub type ODBCDbConnectionPool<'a> =
71 dyn DbConnectionPool<Connection<'a>, ODBCParameter> + Sync + Send;
72
73#[derive(Debug, Snafu)]
74pub enum Error {
75 #[snafu(display("Failed to convert query result to Arrow: {source}"))]
76 ArrowError {
77 source: datafusion::arrow::error::ArrowError,
78 },
79 #[snafu(display("arrow_odbc error: {source}"))]
80 ArrowODBCError { source: arrow_odbc::Error },
81 #[snafu(display("odbc_api Error: {source}"))]
82 ODBCAPIError { source: odbc_api::Error },
83 #[snafu(display("odbc_api Error: {message}"))]
84 ODBCAPIErrorNoSource { message: String },
85 #[snafu(display("Failed to convert query result to Arrow: {source}"))]
86 TryFromError { source: std::num::TryFromIntError },
87 #[snafu(display("Unable to bind integer parameter: {source}"))]
88 UnableToBindIntParameter { source: std::num::TryFromIntError },
89 #[snafu(display("Internal communication channel error: {message}"))]
90 ChannelError { message: String },
91}
92
93pub struct ODBCConnection<'a> {
94 pub conn: Arc<Mutex<Connection<'a>>>,
95 pub params: Arc<HashMap<String, SecretString>>,
96}
97
98impl<'a> DbConnection<Connection<'a>, ODBCParameter> for ODBCConnection<'a>
99where
100 'a: 'static,
101{
102 fn as_any(&self) -> &dyn Any {
103 self
104 }
105
106 fn as_any_mut(&mut self) -> &mut dyn Any {
107 self
108 }
109
110 fn as_async(&self) -> Option<&dyn AsyncDbConnection<Connection<'a>, ODBCParameter>> {
111 Some(self)
112 }
113}
114
115fn blocking_channel_send<T>(channel: &Sender<T>, item: T) -> Result<()> {
116 match channel.blocking_send(item) {
117 Ok(()) => Ok(()),
118 Err(e) => Err(Error::ChannelError {
119 message: format!("{e}"),
120 }
121 .into()),
122 }
123}
124
125#[async_trait]
126impl<'a> AsyncDbConnection<Connection<'a>, ODBCParameter> for ODBCConnection<'a>
127where
128 'a: 'static,
129{
130 fn new(conn: Connection<'a>) -> Self {
131 ODBCConnection {
132 conn: Arc::new(conn.into()),
133 params: Arc::new(HashMap::new()),
134 }
135 }
136
137 async fn tables(&self, _schema: &str) -> Result<Vec<String>, super::Error> {
138 unimplemented!()
139 }
140
141 async fn schemas(&self) -> Result<Vec<String>, super::Error> {
142 unimplemented!()
143 }
144
145 #[must_use]
146 async fn get_schema(
147 &self,
148 table_reference: &TableReference,
149 ) -> Result<SchemaRef, dbconnection::Error> {
150 let cxn = self.conn.lock().await;
151
152 let mut prepared = cxn
153 .prepare(&format!(
154 "SELECT * FROM {} LIMIT 1",
155 table_reference.to_quoted_string()
156 ))
157 .boxed()
158 .map_err(|e| dbconnection::Error::UnableToGetSchema { source: e })?;
159
160 let schema = Arc::new(
161 arrow_schema_from(&mut prepared, false)
162 .boxed()
163 .map_err(|e| dbconnection::Error::UnableToGetSchema { source: e })?,
164 );
165
166 Ok(schema)
167 }
168
169 async fn query_arrow(
170 &self,
171 sql: &str,
172 params: &[ODBCParameter],
173 _projected_schema: Option<SchemaRef>,
174 ) -> Result<SendableRecordBatchStream> {
175 let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
177 let (schema_tx, mut schema_rx) = tokio::sync::mpsc::channel::<Arc<Schema>>(1);
178
179 let conn = Arc::clone(&self.conn); let sql = sql.to_string();
182
183 let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
186 let secrets = Arc::clone(&self.params);
187
188 let create_stream = async || -> Result<SendableRecordBatchStream> {
189 let join_handle = tokio::task::spawn_blocking(move || {
190 let handle = Handle::current();
191 let cxn = handle.block_on(async { conn.lock().await });
192
193 let mut prepared = cxn.prepare(&sql)?;
194 let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
195 blocking_channel_send(&schema_tx, Arc::clone(&schema))?;
196
197 let mut statement = prepared.into_statement();
198
199 bind_parameters(&mut statement, ¶ms)?;
200
201 let cursor = unsafe {
203 if let SqlResult::Error { function } = statement.execute() {
204 return Err(Error::ODBCAPIErrorNoSource {
205 message: function.to_string(),
206 }
207 .into());
208 }
209
210 Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
211 }?;
212
213 let reader = build_odbc_reader(cursor, &schema, &secrets)?;
214 for batch in reader {
215 blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
216 }
217
218 Ok::<_, GenericError>(())
219 });
220
221 let Some(schema) = schema_rx.recv().await else {
223 if !join_handle.is_finished() {
225 unreachable!("Schema channel should not have dropped before the task finished");
226 }
227
228 let result = join_handle.await?;
229 let Err(err) = result else {
230 unreachable!("Task should have errored");
231 };
232
233 return Err(err);
234 };
235
236 let output_stream = stream! {
237 while let Some(batch) = batch_rx.recv().await {
238 yield Ok(batch);
239 }
240
241 if let Err(e) = join_handle.await {
242 yield Err(DataFusionError::Execution(format!(
243 "Failed to execute ODBC query: {e}"
244 )))
245 }
246 };
247
248 let result: SendableRecordBatchStream =
249 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream));
250 Ok(result)
251 };
252 run_async_with_tokio(create_stream).await
253 }
254
255 async fn execute(&self, query: &str, params: &[ODBCParameter]) -> Result<u64> {
256 let cxn = self.conn.lock().await;
257 let prepared = cxn.prepare(query)?;
258 let mut statement = prepared.into_statement();
259
260 bind_parameters(&mut statement, params)?;
261
262 let row_count = unsafe {
263 statement.execute().unwrap();
264 statement.row_count()
265 };
266
267 Ok(row_count.unwrap().try_into().context(TryFromSnafu)?)
268 }
269}
270
271fn build_odbc_reader<C: Cursor>(
272 cursor: C,
273 schema: &Arc<Schema>,
274 params: &HashMap<String, SecretString>,
275) -> Result<OdbcReader<C>, Error> {
276 let mut builder = OdbcReaderBuilder::new();
277 builder.with_schema(Arc::clone(schema));
278
279 let bind_as_usize = |k: &str, default: Option<usize>, f: &mut dyn FnMut(usize)| {
280 params
281 .get(k)
282 .map(SecretBox::expose_secret)
283 .and_then(|s| s.parse::<usize>().ok())
284 .or(default)
285 .into_iter()
286 .for_each(f);
287 };
288
289 bind_as_usize("max_binary_size", None, &mut |s| {
290 builder.with_max_binary_size(s);
291 });
292 bind_as_usize("max_text_size", None, &mut |s| {
293 builder.with_max_text_size(s);
294 });
295 bind_as_usize("max_bytes_per_batch", Some(512_000_000), &mut |s| {
296 builder.with_max_bytes_per_batch(s);
297 });
298
299 bind_as_usize("max_num_rows_per_batch", Some(4000), &mut |s| {
302 builder.with_max_num_rows_per_batch(s);
303 });
304
305 builder.build(cursor).context(ArrowODBCSnafu)
306}
307
308fn bind_parameters(statement: &mut StatementImpl, params: &[ODBCParameter]) -> Result<(), Error> {
312 for (i, param) in params.iter().enumerate() {
313 unsafe {
314 statement
315 .bind_input_parameter(
316 (i + 1).try_into().context(UnableToBindIntParameterSnafu)?,
317 param.as_input_parameter(),
318 )
319 .unwrap();
320 }
321 }
322
323 Ok(())
324}