datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
10    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, TimestampNanosecondBuilder,
11    TimestampSecondBuilder, make_builder,
12};
13use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
14use datafusion::common::{DataFusionError, project_schema};
15use datafusion::execution::SendableRecordBatchStream;
16use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use log::debug;
21use oracle::sql_type::OracleType as ColumnType;
22use oracle::{Connector, Row};
23use std::any::Any;
24use std::sync::Arc;
25
26#[derive(Debug, Clone, With, Getters)]
27pub struct OracleConnectionOptions {
28    pub(crate) host: String,
29    pub(crate) port: u16,
30    pub(crate) username: String,
31    pub(crate) password: String,
32    pub(crate) service_name: String,
33    pub(crate) pool_max_size: usize,
34    pub(crate) stream_chunk_size: usize,
35}
36
37impl OracleConnectionOptions {
38    pub fn new(
39        host: impl Into<String>,
40        port: u16,
41        username: impl Into<String>,
42        password: impl Into<String>,
43        service_name: impl Into<String>,
44    ) -> Self {
45        Self {
46            host: host.into(),
47            port,
48            username: username.into(),
49            password: password.into(),
50            service_name: service_name.into(),
51            pool_max_size: 10,
52            stream_chunk_size: 2048,
53        }
54    }
55}
56
57impl From<OracleConnectionOptions> for ConnectionOptions {
58    fn from(options: OracleConnectionOptions) -> Self {
59        ConnectionOptions::Oracle(options)
60    }
61}
62
63#[derive(Debug)]
64pub struct OraclePool {
65    pool: bb8::Pool<OracleConnectionManager>,
66}
67
68pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
69    let connect_string = format!(
70        "//{}:{}/{}",
71        options.host, options.port, options.service_name
72    );
73    let connector = Connector::new(
74        options.username.clone(),
75        options.password.clone(),
76        connect_string,
77    );
78    let _ = connector
79        .connect()
80        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
81    let manager = OracleConnectionManager::from_connector(connector);
82    let pool = bb8::Pool::builder()
83        .max_size(options.pool_max_size as u32)
84        .build(manager)
85        .await
86        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {e:?}")))?;
87    Ok(OraclePool { pool })
88}
89
90#[async_trait::async_trait]
91impl Pool for OraclePool {
92    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
93        let conn = self.pool.get_owned().await.map_err(|e| {
94            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
95        })?;
96        Ok(Arc::new(OracleConnection { conn }))
97    }
98}
99
100#[derive(Debug)]
101pub struct OracleConnection {
102    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
103}
104
105#[async_trait::async_trait]
106impl Connection for OracleConnection {
107    fn as_any(&self) -> &dyn Any {
108        self
109    }
110
111    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
112        let sql = RemoteDbType::Oracle.query_limit_1(sql);
113        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
114            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
115        })?;
116        let remote_schema = Arc::new(build_remote_schema(&result_set)?);
117        Ok(remote_schema)
118    }
119
120    async fn query(
121        &self,
122        conn_options: &ConnectionOptions,
123        sql: &str,
124        table_schema: SchemaRef,
125        projection: Option<&Vec<usize>>,
126        unparsed_filters: &[String],
127        limit: Option<usize>,
128    ) -> DFResult<SendableRecordBatchStream> {
129        let projected_schema = project_schema(&table_schema, projection)?;
130
131        let sql = RemoteDbType::Oracle.rewrite_query(sql, unparsed_filters, limit);
132        debug!("[remote-table] executing oracle query: {sql}");
133
134        let projection = projection.cloned();
135        let chunk_size = conn_options.stream_chunk_size();
136        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
137            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
138        })?;
139        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
140
141        let stream = stream.map(move |rows| {
142            let rows: Vec<Row> = rows
143                .into_iter()
144                .collect::<Result<Vec<_>, _>>()
145                .map_err(|e| {
146                    DataFusionError::Execution(format!(
147                        "Failed to collect rows from oracle due to {e}",
148                    ))
149                })?;
150            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
151        });
152
153        Ok(Box::pin(RecordBatchStreamAdapter::new(
154            projected_schema,
155            stream,
156        )))
157    }
158}
159
160fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
161    match oracle_type {
162        ColumnType::Number(precision, scale) => {
163            // TODO need more investigation on the precision and scale
164            let precision = if *precision == 0 { 38 } else { *precision };
165            let scale = if *scale == -127 { 0 } else { *scale };
166            Ok(OracleType::Number(precision, scale))
167        }
168        ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
169        ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
170        ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
171        ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
172        ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
173        ColumnType::Char(size) => Ok(OracleType::Char(*size)),
174        ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
175        ColumnType::Long => Ok(OracleType::Long),
176        ColumnType::CLOB => Ok(OracleType::Clob),
177        ColumnType::NCLOB => Ok(OracleType::NClob),
178        ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
179        ColumnType::LongRaw => Ok(OracleType::LongRaw),
180        ColumnType::BLOB => Ok(OracleType::Blob),
181        ColumnType::Date => Ok(OracleType::Date),
182        ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
183        ColumnType::Boolean => Ok(OracleType::Boolean),
184        _ => Err(DataFusionError::NotImplemented(format!(
185            "Unsupported oracle type: {oracle_type:?}",
186        ))),
187    }
188}
189
190fn build_remote_schema(result_set: &oracle::ResultSet<Row>) -> DFResult<RemoteSchema> {
191    let mut remote_fields = vec![];
192    for col in result_set.column_info() {
193        let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
194        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
195    }
196    Ok(RemoteSchema::new(remote_fields))
197}
198
199macro_rules! handle_primitive_type {
200    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
201        let builder = $builder
202            .as_any_mut()
203            .downcast_mut::<$builder_ty>()
204            .unwrap_or_else(|| {
205                panic!(
206                    "Failed to downcast builder to {} for {:?} and {:?}",
207                    stringify!($builder_ty),
208                    $field,
209                    $col
210                )
211            });
212        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
213            DataFusionError::Execution(format!(
214                "Failed to get {} value for {:?} and {:?}: {e:?}",
215                stringify!($value_ty),
216                $field,
217                $col
218            ))
219        })?;
220
221        match v {
222            Some(v) => builder.append_value($convert(v)?),
223            None => builder.append_null(),
224        }
225    }};
226}
227
228fn rows_to_batch(
229    rows: &[Row],
230    table_schema: &SchemaRef,
231    projection: Option<&Vec<usize>>,
232) -> DFResult<RecordBatch> {
233    let projected_schema = project_schema(table_schema, projection)?;
234    let mut array_builders = vec![];
235    for field in table_schema.fields() {
236        let builder = make_builder(field.data_type(), rows.len());
237        array_builders.push(builder);
238    }
239
240    for row in rows {
241        for (idx, field) in table_schema.fields.iter().enumerate() {
242            if !projections_contains(projection, idx) {
243                continue;
244            }
245            let builder = &mut array_builders[idx];
246            let col = row.column_info().get(idx);
247            match field.data_type() {
248                DataType::Int16 => {
249                    handle_primitive_type!(
250                        builder,
251                        field,
252                        col,
253                        Int16Builder,
254                        i16,
255                        row,
256                        idx,
257                        just_return
258                    );
259                }
260                DataType::Int32 => {
261                    handle_primitive_type!(
262                        builder,
263                        field,
264                        col,
265                        Int32Builder,
266                        i32,
267                        row,
268                        idx,
269                        just_return
270                    );
271                }
272                DataType::Int64 => {
273                    handle_primitive_type!(
274                        builder,
275                        field,
276                        col,
277                        Int64Builder,
278                        i64,
279                        row,
280                        idx,
281                        just_return
282                    );
283                }
284                DataType::Float32 => {
285                    handle_primitive_type!(
286                        builder,
287                        field,
288                        col,
289                        Float32Builder,
290                        f32,
291                        row,
292                        idx,
293                        just_return
294                    );
295                }
296                DataType::Float64 => {
297                    handle_primitive_type!(
298                        builder,
299                        field,
300                        col,
301                        Float64Builder,
302                        f64,
303                        row,
304                        idx,
305                        just_return
306                    );
307                }
308                DataType::Utf8 => {
309                    handle_primitive_type!(
310                        builder,
311                        field,
312                        col,
313                        StringBuilder,
314                        String,
315                        row,
316                        idx,
317                        just_return
318                    );
319                }
320                DataType::LargeUtf8 => {
321                    handle_primitive_type!(
322                        builder,
323                        field,
324                        col,
325                        LargeStringBuilder,
326                        String,
327                        row,
328                        idx,
329                        just_return
330                    );
331                }
332                DataType::Decimal128(_precision, scale) => {
333                    handle_primitive_type!(
334                        builder,
335                        field,
336                        col,
337                        Decimal128Builder,
338                        String,
339                        row,
340                        idx,
341                        |v: String| {
342                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
343                                DataFusionError::Execution(format!(
344                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
345                                ))
346                            })?;
347                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
348                                DataFusionError::Execution(format!(
349                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
350                                ))
351                            })
352                        }
353                    );
354                }
355                DataType::Timestamp(TimeUnit::Second, None) => {
356                    handle_primitive_type!(
357                        builder,
358                        field,
359                        col,
360                        TimestampSecondBuilder,
361                        chrono::NaiveDateTime,
362                        row,
363                        idx,
364                        |v: chrono::NaiveDateTime| {
365                            let t = v.and_utc().timestamp();
366                            Ok::<_, DataFusionError>(t)
367                        }
368                    );
369                }
370                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
371                    handle_primitive_type!(
372                        builder,
373                        field,
374                        col,
375                        TimestampNanosecondBuilder,
376                        chrono::NaiveDateTime,
377                        row,
378                        idx,
379                        |v: chrono::NaiveDateTime| {
380                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
381                                DataFusionError::Execution(format!(
382                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
383                                ))
384                            })
385                        }
386                    );
387                }
388                DataType::Date64 => {
389                    handle_primitive_type!(
390                        builder,
391                        field,
392                        col,
393                        Date64Builder,
394                        chrono::NaiveDateTime,
395                        row,
396                        idx,
397                        |v: chrono::NaiveDateTime| {
398                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
399                        }
400                    );
401                }
402                DataType::Boolean => {
403                    handle_primitive_type!(
404                        builder,
405                        field,
406                        col,
407                        BooleanBuilder,
408                        bool,
409                        row,
410                        idx,
411                        just_return
412                    );
413                }
414                DataType::Binary => {
415                    handle_primitive_type!(
416                        builder,
417                        field,
418                        col,
419                        BinaryBuilder,
420                        Vec<u8>,
421                        row,
422                        idx,
423                        just_return
424                    );
425                }
426                DataType::LargeBinary => {
427                    handle_primitive_type!(
428                        builder,
429                        field,
430                        col,
431                        LargeBinaryBuilder,
432                        Vec<u8>,
433                        row,
434                        idx,
435                        just_return
436                    );
437                }
438                _ => {
439                    return Err(DataFusionError::NotImplemented(format!(
440                        "Unsupported data type {:?} for col: {:?}",
441                        field.data_type(),
442                        col
443                    )));
444                }
445            }
446        }
447    }
448
449    let projected_columns = array_builders
450        .into_iter()
451        .enumerate()
452        .filter(|(idx, _)| projections_contains(projection, *idx))
453        .map(|(_, mut builder)| builder.finish())
454        .collect::<Vec<ArrayRef>>();
455    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
456    Ok(RecordBatch::try_new_with_options(
457        projected_schema,
458        projected_columns,
459        &options,
460    )?)
461}