datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
10    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, TimestampNanosecondBuilder,
11    TimestampSecondBuilder, make_builder,
12};
13use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
14use datafusion::common::{DataFusionError, project_schema};
15use datafusion::execution::SendableRecordBatchStream;
16use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use log::debug;
21use oracle::sql_type::OracleType as ColumnType;
22use oracle::{Connector, Row};
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct OracleConnectionOptions {
27    pub(crate) host: String,
28    pub(crate) port: u16,
29    pub(crate) username: String,
30    pub(crate) password: String,
31    pub(crate) service_name: String,
32    pub(crate) pool_max_size: usize,
33    pub(crate) stream_chunk_size: usize,
34}
35
36impl OracleConnectionOptions {
37    pub fn new(
38        host: impl Into<String>,
39        port: u16,
40        username: impl Into<String>,
41        password: impl Into<String>,
42        service_name: impl Into<String>,
43    ) -> Self {
44        Self {
45            host: host.into(),
46            port,
47            username: username.into(),
48            password: password.into(),
49            service_name: service_name.into(),
50            pool_max_size: 10,
51            stream_chunk_size: 2048,
52        }
53    }
54}
55
56#[derive(Debug)]
57pub struct OraclePool {
58    pool: bb8::Pool<OracleConnectionManager>,
59}
60
61pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
62    let connect_string = format!(
63        "//{}:{}/{}",
64        options.host, options.port, options.service_name
65    );
66    let connector = Connector::new(
67        options.username.clone(),
68        options.password.clone(),
69        connect_string,
70    );
71    let _ = connector
72        .connect()
73        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
74    let manager = OracleConnectionManager::from_connector(connector);
75    let pool = bb8::Pool::builder()
76        .max_size(options.pool_max_size as u32)
77        .build(manager)
78        .await
79        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
80    Ok(OraclePool { pool })
81}
82
83#[async_trait::async_trait]
84impl Pool for OraclePool {
85    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
86        let conn = self.pool.get_owned().await.map_err(|e| {
87            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
88        })?;
89        Ok(Arc::new(OracleConnection { conn }))
90    }
91}
92
93#[derive(Debug)]
94pub struct OracleConnection {
95    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
96}
97
98#[async_trait::async_trait]
99impl Connection for OracleConnection {
100    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
101        let sql = RemoteDbType::Oracle.query_limit_1(sql);
102        let row = self.conn.query_row(&sql, &[]).map_err(|e| {
103            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
104        })?;
105        let remote_schema = Arc::new(build_remote_schema(&row)?);
106        Ok(remote_schema)
107    }
108
109    async fn query(
110        &self,
111        conn_options: &ConnectionOptions,
112        sql: &str,
113        table_schema: SchemaRef,
114        projection: Option<&Vec<usize>>,
115        unparsed_filters: &[String],
116        limit: Option<usize>,
117    ) -> DFResult<SendableRecordBatchStream> {
118        let projected_schema = project_schema(&table_schema, projection)?;
119
120        let sql = RemoteDbType::Oracle.rewrite_query(sql, unparsed_filters, limit);
121        debug!("[remote-table] executing oracle query: {sql}");
122
123        let projection = projection.cloned();
124        let chunk_size = conn_options.stream_chunk_size();
125        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
126            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
127        })?;
128        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
129
130        let stream = stream.map(move |rows| {
131            let rows: Vec<Row> = rows
132                .into_iter()
133                .collect::<Result<Vec<_>, _>>()
134                .map_err(|e| {
135                    DataFusionError::Execution(format!(
136                        "Failed to collect rows from oracle due to {e}",
137                    ))
138                })?;
139            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
140        });
141
142        Ok(Box::pin(RecordBatchStreamAdapter::new(
143            projected_schema,
144            stream,
145        )))
146    }
147}
148
149fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
150    match oracle_type {
151        ColumnType::Number(precision, scale) => {
152            // TODO need more investigation on the precision and scale
153            let precision = if *precision == 0 { 38 } else { *precision };
154            let scale = if *scale == -127 { 0 } else { *scale };
155            Ok(OracleType::Number(precision, scale))
156        }
157        ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
158        ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
159        ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
160        ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
161        ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
162        ColumnType::Char(size) => Ok(OracleType::Char(*size)),
163        ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
164        ColumnType::Long => Ok(OracleType::Long),
165        ColumnType::CLOB => Ok(OracleType::Clob),
166        ColumnType::NCLOB => Ok(OracleType::NClob),
167        ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
168        ColumnType::LongRaw => Ok(OracleType::LongRaw),
169        ColumnType::BLOB => Ok(OracleType::Blob),
170        ColumnType::Date => Ok(OracleType::Date),
171        ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
172        ColumnType::Boolean => Ok(OracleType::Boolean),
173        _ => Err(DataFusionError::NotImplemented(format!(
174            "Unsupported oracle type: {oracle_type:?}",
175        ))),
176    }
177}
178
179fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
180    let mut remote_fields = vec![];
181    for col in row.column_info() {
182        let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
183        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
184    }
185    Ok(RemoteSchema::new(remote_fields))
186}
187
188macro_rules! handle_primitive_type {
189    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
190        let builder = $builder
191            .as_any_mut()
192            .downcast_mut::<$builder_ty>()
193            .unwrap_or_else(|| {
194                panic!(
195                    "Failed to downcast builder to {} for {:?} and {:?}",
196                    stringify!($builder_ty),
197                    $field,
198                    $col
199                )
200            });
201        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
202            DataFusionError::Execution(format!(
203                "Failed to get {} value for {:?} and {:?}: {e:?}",
204                stringify!($value_ty),
205                $field,
206                $col
207            ))
208        })?;
209
210        match v {
211            Some(v) => builder.append_value($convert(v)?),
212            None => builder.append_null(),
213        }
214    }};
215}
216
217fn rows_to_batch(
218    rows: &[Row],
219    table_schema: &SchemaRef,
220    projection: Option<&Vec<usize>>,
221) -> DFResult<RecordBatch> {
222    let projected_schema = project_schema(table_schema, projection)?;
223    let mut array_builders = vec![];
224    for field in table_schema.fields() {
225        let builder = make_builder(field.data_type(), rows.len());
226        array_builders.push(builder);
227    }
228
229    for row in rows {
230        for (idx, field) in table_schema.fields.iter().enumerate() {
231            if !projections_contains(projection, idx) {
232                continue;
233            }
234            let builder = &mut array_builders[idx];
235            let col = row.column_info().get(idx);
236            match field.data_type() {
237                DataType::Int16 => {
238                    handle_primitive_type!(
239                        builder,
240                        field,
241                        col,
242                        Int16Builder,
243                        i16,
244                        row,
245                        idx,
246                        just_return
247                    );
248                }
249                DataType::Int32 => {
250                    handle_primitive_type!(
251                        builder,
252                        field,
253                        col,
254                        Int32Builder,
255                        i32,
256                        row,
257                        idx,
258                        just_return
259                    );
260                }
261                DataType::Int64 => {
262                    handle_primitive_type!(
263                        builder,
264                        field,
265                        col,
266                        Int64Builder,
267                        i64,
268                        row,
269                        idx,
270                        just_return
271                    );
272                }
273                DataType::Float32 => {
274                    handle_primitive_type!(
275                        builder,
276                        field,
277                        col,
278                        Float32Builder,
279                        f32,
280                        row,
281                        idx,
282                        just_return
283                    );
284                }
285                DataType::Float64 => {
286                    handle_primitive_type!(
287                        builder,
288                        field,
289                        col,
290                        Float64Builder,
291                        f64,
292                        row,
293                        idx,
294                        just_return
295                    );
296                }
297                DataType::Utf8 => {
298                    handle_primitive_type!(
299                        builder,
300                        field,
301                        col,
302                        StringBuilder,
303                        String,
304                        row,
305                        idx,
306                        just_return
307                    );
308                }
309                DataType::LargeUtf8 => {
310                    handle_primitive_type!(
311                        builder,
312                        field,
313                        col,
314                        LargeStringBuilder,
315                        String,
316                        row,
317                        idx,
318                        just_return
319                    );
320                }
321                DataType::Decimal128(_precision, scale) => {
322                    handle_primitive_type!(
323                        builder,
324                        field,
325                        col,
326                        Decimal128Builder,
327                        String,
328                        row,
329                        idx,
330                        |v: String| {
331                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
332                                DataFusionError::Execution(format!(
333                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
334                                ))
335                            })?;
336                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
337                                DataFusionError::Execution(format!(
338                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
339                                ))
340                            })
341                        }
342                    );
343                }
344                DataType::Timestamp(TimeUnit::Second, None) => {
345                    handle_primitive_type!(
346                        builder,
347                        field,
348                        col,
349                        TimestampSecondBuilder,
350                        chrono::NaiveDateTime,
351                        row,
352                        idx,
353                        |v: chrono::NaiveDateTime| {
354                            let t = v.and_utc().timestamp();
355                            Ok::<_, DataFusionError>(t)
356                        }
357                    );
358                }
359                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
360                    handle_primitive_type!(
361                        builder,
362                        field,
363                        col,
364                        TimestampNanosecondBuilder,
365                        chrono::NaiveDateTime,
366                        row,
367                        idx,
368                        |v: chrono::NaiveDateTime| {
369                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
370                                DataFusionError::Execution(format!(
371                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
372                                ))
373                            })
374                        }
375                    );
376                }
377                DataType::Date64 => {
378                    handle_primitive_type!(
379                        builder,
380                        field,
381                        col,
382                        Date64Builder,
383                        chrono::NaiveDateTime,
384                        row,
385                        idx,
386                        |v: chrono::NaiveDateTime| {
387                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
388                        }
389                    );
390                }
391                DataType::Boolean => {
392                    handle_primitive_type!(
393                        builder,
394                        field,
395                        col,
396                        BooleanBuilder,
397                        bool,
398                        row,
399                        idx,
400                        just_return
401                    );
402                }
403                DataType::Binary => {
404                    handle_primitive_type!(
405                        builder,
406                        field,
407                        col,
408                        BinaryBuilder,
409                        Vec<u8>,
410                        row,
411                        idx,
412                        just_return
413                    );
414                }
415                DataType::LargeBinary => {
416                    handle_primitive_type!(
417                        builder,
418                        field,
419                        col,
420                        LargeBinaryBuilder,
421                        Vec<u8>,
422                        row,
423                        idx,
424                        just_return
425                    );
426                }
427                _ => {
428                    return Err(DataFusionError::NotImplemented(format!(
429                        "Unsupported data type {:?} for col: {:?}",
430                        field.data_type(),
431                        col
432                    )));
433                }
434            }
435        }
436    }
437
438    let projected_columns = array_builders
439        .into_iter()
440        .enumerate()
441        .filter(|(idx, _)| projections_contains(projection, *idx))
442        .map(|(_, mut builder)| builder.finish())
443        .collect::<Vec<ArrayRef>>();
444    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
445    Ok(RecordBatch::try_new_with_options(
446        projected_schema,
447        projected_columns,
448        &options,
449    )?)
450}