datafusion_table_providers/sql/db_connection_pool/dbconnection/
odbcconn.rs

1/*
2Copyright 2024 The Spice.ai OSS Authors
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8     https://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17use std::any::Any;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use crate::sql::db_connection_pool::{
22    dbconnection::{self, AsyncDbConnection, DbConnection, GenericError},
23    runtime::run_async_with_tokio,
24    DbConnectionPool,
25};
26use arrow_odbc::arrow_schema_from;
27use arrow_odbc::OdbcReader;
28use arrow_odbc::OdbcReaderBuilder;
29use async_stream::stream;
30use async_trait::async_trait;
31use datafusion::arrow::datatypes::Schema;
32use datafusion::arrow::datatypes::SchemaRef;
33use datafusion::arrow::record_batch::RecordBatch;
34use datafusion::error::DataFusionError;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
37use datafusion::sql::TableReference;
38use dyn_clone::DynClone;
39use futures::lock::Mutex;
40use odbc_api::handles::SqlResult;
41use odbc_api::handles::Statement;
42use odbc_api::handles::StatementImpl;
43use odbc_api::parameter::InputParameter;
44use odbc_api::Cursor;
45use odbc_api::CursorImpl;
46use secrecy::{ExposeSecret, SecretBox, SecretString};
47use snafu::prelude::*;
48use snafu::Snafu;
49use tokio::runtime::Handle;
50
51use odbc_api::Connection;
52use tokio::sync::mpsc::Sender;
53
54type Result<T, E = GenericError> = std::result::Result<T, E>;
55
56pub trait ODBCSyncParameter: InputParameter + Sync + Send + DynClone {
57    fn as_input_parameter(&self) -> &dyn InputParameter;
58}
59
60impl<T: InputParameter + Sync + Send + DynClone> ODBCSyncParameter for T {
61    fn as_input_parameter(&self) -> &dyn InputParameter {
62        self
63    }
64}
65
66dyn_clone::clone_trait_object!(ODBCSyncParameter);
67
68pub type ODBCParameter = Box<dyn ODBCSyncParameter>;
69pub type ODBCDbConnection<'a> = (dyn DbConnection<Connection<'a>, ODBCParameter>);
70pub type ODBCDbConnectionPool<'a> =
71    dyn DbConnectionPool<Connection<'a>, ODBCParameter> + Sync + Send;
72
73#[derive(Debug, Snafu)]
74pub enum Error {
75    #[snafu(display("Failed to convert query result to Arrow: {source}"))]
76    ArrowError {
77        source: datafusion::arrow::error::ArrowError,
78    },
79    #[snafu(display("arrow_odbc error: {source}"))]
80    ArrowODBCError { source: arrow_odbc::Error },
81    #[snafu(display("odbc_api Error: {source}"))]
82    ODBCAPIError { source: odbc_api::Error },
83    #[snafu(display("odbc_api Error: {message}"))]
84    ODBCAPIErrorNoSource { message: String },
85    #[snafu(display("Failed to convert query result to Arrow: {source}"))]
86    TryFromError { source: std::num::TryFromIntError },
87    #[snafu(display("Unable to bind integer parameter: {source}"))]
88    UnableToBindIntParameter { source: std::num::TryFromIntError },
89    #[snafu(display("Internal communication channel error: {message}"))]
90    ChannelError { message: String },
91}
92
93pub struct ODBCConnection<'a> {
94    pub conn: Arc<Mutex<Connection<'a>>>,
95    pub params: Arc<HashMap<String, SecretString>>,
96}
97
98impl<'a> DbConnection<Connection<'a>, ODBCParameter> for ODBCConnection<'a>
99where
100    'a: 'static,
101{
102    fn as_any(&self) -> &dyn Any {
103        self
104    }
105
106    fn as_any_mut(&mut self) -> &mut dyn Any {
107        self
108    }
109
110    fn as_async(&self) -> Option<&dyn AsyncDbConnection<Connection<'a>, ODBCParameter>> {
111        Some(self)
112    }
113}
114
115fn blocking_channel_send<T>(channel: &Sender<T>, item: T) -> Result<()> {
116    match channel.blocking_send(item) {
117        Ok(()) => Ok(()),
118        Err(e) => Err(Error::ChannelError {
119            message: format!("{e}"),
120        }
121        .into()),
122    }
123}
124
125#[async_trait]
126impl<'a> AsyncDbConnection<Connection<'a>, ODBCParameter> for ODBCConnection<'a>
127where
128    'a: 'static,
129{
130    fn new(conn: Connection<'a>) -> Self {
131        ODBCConnection {
132            conn: Arc::new(conn.into()),
133            params: Arc::new(HashMap::new()),
134        }
135    }
136
137    async fn tables(&self, _schema: &str) -> Result<Vec<String>, super::Error> {
138        unimplemented!()
139    }
140
141    async fn schemas(&self) -> Result<Vec<String>, super::Error> {
142        unimplemented!()
143    }
144
145    #[must_use]
146    async fn get_schema(
147        &self,
148        table_reference: &TableReference,
149    ) -> Result<SchemaRef, dbconnection::Error> {
150        let cxn = self.conn.lock().await;
151
152        let mut prepared = cxn
153            .prepare(&format!(
154                "SELECT * FROM {} LIMIT 1",
155                table_reference.to_quoted_string()
156            ))
157            .boxed()
158            .map_err(|e| dbconnection::Error::UnableToGetSchema { source: e })?;
159
160        let schema = Arc::new(
161            arrow_schema_from(&mut prepared, None, false)
162                .boxed()
163                .map_err(|e| dbconnection::Error::UnableToGetSchema { source: e })?,
164        );
165
166        Ok(schema)
167    }
168
169    async fn query_arrow(
170        &self,
171        sql: &str,
172        params: &[ODBCParameter],
173        _projected_schema: Option<SchemaRef>,
174    ) -> Result<SendableRecordBatchStream> {
175        // prepare some tokio channels to communicate query results back from the thread
176        let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
177        let (schema_tx, mut schema_rx) = tokio::sync::mpsc::channel::<Arc<Schema>>(1);
178
179        // clone internals and parameters to let the thread own them
180        let conn = Arc::clone(&self.conn); // clones the mutex not the connection, so we can .lock a connection inside the thread
181        let sql = sql.to_string();
182
183        // ODBCParameter is a dynamic trait object, so we can't use std::clone::Clone because it's not object safe
184        // DynClone provides an object-safe clone trait, which we use to clone the boxed parameters
185        let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
186        let secrets = Arc::clone(&self.params);
187
188        let create_stream = async || -> Result<SendableRecordBatchStream> {
189            let join_handle = tokio::task::spawn_blocking(move || {
190                let handle = Handle::current();
191                let cxn = handle.block_on(async { conn.lock().await });
192
193                let mut prepared = cxn.prepare(&sql)?;
194                let schema = Arc::new(arrow_schema_from(&mut prepared, None, false)?);
195                blocking_channel_send(&schema_tx, Arc::clone(&schema))?;
196
197                let mut statement = prepared.into_handle();
198
199                bind_parameters(&mut statement, &params)?;
200
201                // StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
202                let cursor = unsafe {
203                    if let SqlResult::Error { function } = statement.execute() {
204                        return Err(Error::ODBCAPIErrorNoSource {
205                            message: function.to_string(),
206                        }
207                        .into());
208                    }
209
210                    Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
211                }?;
212
213                let reader = build_odbc_reader(cursor, &schema, &secrets)?;
214                for batch in reader {
215                    blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
216                }
217
218                Ok::<_, GenericError>(())
219            });
220
221            // we need to wait for the schema first before we can build our RecordBatchStreamAdapter
222            let Some(schema) = schema_rx.recv().await else {
223                // if the channel drops, the task errored
224                if !join_handle.is_finished() {
225                    unreachable!("Schema channel should not have dropped before the task finished");
226                }
227
228                let result = join_handle.await?;
229                let Err(err) = result else {
230                    unreachable!("Task should have errored");
231                };
232
233                return Err(err);
234            };
235
236            let output_stream = stream! {
237                while let Some(batch) = batch_rx.recv().await {
238                    yield Ok(batch);
239                }
240
241                if let Err(e) = join_handle.await {
242                    yield Err(DataFusionError::Execution(format!(
243                        "Failed to execute ODBC query: {e}"
244                    )))
245                }
246            };
247
248            let result: SendableRecordBatchStream =
249                Box::pin(RecordBatchStreamAdapter::new(schema, output_stream));
250            Ok(result)
251        };
252        run_async_with_tokio(create_stream).await
253    }
254
255    async fn execute(&self, query: &str, params: &[ODBCParameter]) -> Result<u64> {
256        let cxn = self.conn.lock().await;
257        let prepared = cxn.prepare(query)?;
258        let mut statement = prepared.into_handle();
259
260        bind_parameters(&mut statement, params)?;
261
262        let row_count = unsafe {
263            statement.execute().unwrap();
264            statement.row_count()
265        };
266
267        Ok(row_count.unwrap().try_into().context(TryFromSnafu)?)
268    }
269}
270
271fn build_odbc_reader<C: Cursor>(
272    cursor: C,
273    schema: &Arc<Schema>,
274    params: &HashMap<String, SecretString>,
275) -> Result<OdbcReader<C>, Error> {
276    let mut builder = OdbcReaderBuilder::new();
277    builder.with_schema(Arc::clone(schema));
278
279    let bind_as_usize = |k: &str, default: Option<usize>, f: &mut dyn FnMut(usize)| {
280        params
281            .get(k)
282            .map(SecretBox::expose_secret)
283            .and_then(|s| s.parse::<usize>().ok())
284            .or(default)
285            .into_iter()
286            .for_each(f);
287    };
288
289    bind_as_usize("max_binary_size", None, &mut |s| {
290        builder.with_max_binary_size(s);
291    });
292    bind_as_usize("max_text_size", None, &mut |s| {
293        builder.with_max_text_size(s);
294    });
295    bind_as_usize("max_bytes_per_batch", Some(512_000_000), &mut |s| {
296        builder.with_max_bytes_per_batch(s);
297    });
298
299    // larger default max_num_rows_per_batch reduces IO overhead but increases memory usage
300    // lower numbers reduce memory usage but increase IO overhead
301    bind_as_usize("max_num_rows_per_batch", Some(4000), &mut |s| {
302        builder.with_max_num_rows_per_batch(s);
303    });
304
305    builder.build(cursor).context(ArrowODBCSnafu)
306}
307
308/// Binds parameter to an ODBC statement.
309///
310/// `StatementImpl<'_>::bind_input_parameter` is unsafe.
311fn bind_parameters(statement: &mut StatementImpl, params: &[ODBCParameter]) -> Result<(), Error> {
312    for (i, param) in params.iter().enumerate() {
313        unsafe {
314            statement
315                .bind_input_parameter(
316                    (i + 1).try_into().context(UnableToBindIntParameterSnafu)?,
317                    param.as_input_parameter(),
318                )
319                .unwrap();
320        }
321    }
322
323    Ok(())
324}