datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5    RemoteSchemaRef, RemoteSource, RemoteType, Unparse,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
10    Float64Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
11    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, TimestampNanosecondBuilder,
12    TimestampSecondBuilder, make_builder,
13};
14use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
15use datafusion::common::{DataFusionError, project_schema};
16use datafusion::execution::SendableRecordBatchStream;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use derive_getters::Getters;
19use derive_with::With;
20use futures::StreamExt;
21use log::debug;
22use oracle::sql_type::OracleType as ColumnType;
23use oracle::{Connector, Row};
24use std::any::Any;
25use std::sync::Arc;
26
27#[derive(Debug, Clone, With, Getters)]
28pub struct OracleConnectionOptions {
29    pub(crate) host: String,
30    pub(crate) port: u16,
31    pub(crate) username: String,
32    pub(crate) password: String,
33    pub(crate) service_name: String,
34    pub(crate) pool_max_size: usize,
35    pub(crate) stream_chunk_size: usize,
36}
37
38impl OracleConnectionOptions {
39    pub fn new(
40        host: impl Into<String>,
41        port: u16,
42        username: impl Into<String>,
43        password: impl Into<String>,
44        service_name: impl Into<String>,
45    ) -> Self {
46        Self {
47            host: host.into(),
48            port,
49            username: username.into(),
50            password: password.into(),
51            service_name: service_name.into(),
52            pool_max_size: 10,
53            stream_chunk_size: 2048,
54        }
55    }
56}
57
58impl From<OracleConnectionOptions> for ConnectionOptions {
59    fn from(options: OracleConnectionOptions) -> Self {
60        ConnectionOptions::Oracle(options)
61    }
62}
63
64#[derive(Debug)]
65pub struct OraclePool {
66    pool: bb8::Pool<OracleConnectionManager>,
67}
68
69pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
70    let connect_string = format!(
71        "//{}:{}/{}",
72        options.host, options.port, options.service_name
73    );
74    let connector = Connector::new(
75        options.username.clone(),
76        options.password.clone(),
77        connect_string,
78    );
79    let _ = connector
80        .connect()
81        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
82    let manager = OracleConnectionManager::from_connector(connector);
83    let pool = bb8::Pool::builder()
84        .max_size(options.pool_max_size as u32)
85        .build(manager)
86        .await
87        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {e:?}")))?;
88    Ok(OraclePool { pool })
89}
90
91#[async_trait::async_trait]
92impl Pool for OraclePool {
93    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
94        let conn = self.pool.get_owned().await.map_err(|e| {
95            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
96        })?;
97        Ok(Arc::new(OracleConnection { conn }))
98    }
99}
100
101#[derive(Debug)]
102pub struct OracleConnection {
103    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
104}
105
106#[async_trait::async_trait]
107impl Connection for OracleConnection {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111
112    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
113        let sql = RemoteDbType::Oracle.limit_1_query_if_possible(source);
114        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
115            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
116        })?;
117        let remote_schema = Arc::new(build_remote_schema(&result_set)?);
118        Ok(remote_schema)
119    }
120
121    async fn query(
122        &self,
123        conn_options: &ConnectionOptions,
124        source: &RemoteSource,
125        table_schema: SchemaRef,
126        projection: Option<&Vec<usize>>,
127        unparsed_filters: &[String],
128        limit: Option<usize>,
129    ) -> DFResult<SendableRecordBatchStream> {
130        let projected_schema = project_schema(&table_schema, projection)?;
131
132        let sql = RemoteDbType::Oracle.rewrite_query(source, unparsed_filters, limit);
133        debug!("[remote-table] executing oracle query: {sql}");
134
135        let projection = projection.cloned();
136        let chunk_size = conn_options.stream_chunk_size();
137        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
138            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
139        })?;
140        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
141
142        let stream = stream.map(move |rows| {
143            let rows: Vec<Row> = rows
144                .into_iter()
145                .collect::<Result<Vec<_>, _>>()
146                .map_err(|e| {
147                    DataFusionError::Execution(format!(
148                        "Failed to collect rows from oracle due to {e}",
149                    ))
150                })?;
151            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
152        });
153
154        Ok(Box::pin(RecordBatchStreamAdapter::new(
155            projected_schema,
156            stream,
157        )))
158    }
159
160    async fn insert(
161        &self,
162        _conn_options: &ConnectionOptions,
163        _unparser: Arc<dyn Unparse>,
164        _table: &[String],
165        _remote_schema: RemoteSchemaRef,
166        _input: SendableRecordBatchStream,
167    ) -> DFResult<usize> {
168        Err(DataFusionError::Execution(
169            "Insert operation is not supported for oracle".to_string(),
170        ))
171    }
172}
173
174fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
175    match oracle_type {
176        ColumnType::Number(precision, scale) => {
177            // TODO need more investigation on the precision and scale
178            let precision = if *precision == 0 { 38 } else { *precision };
179            let scale = if *scale == -127 { 0 } else { *scale };
180            Ok(OracleType::Number(precision, scale))
181        }
182        ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
183        ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
184        ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
185        ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
186        ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
187        ColumnType::Char(size) => Ok(OracleType::Char(*size)),
188        ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
189        ColumnType::Long => Ok(OracleType::Long),
190        ColumnType::CLOB => Ok(OracleType::Clob),
191        ColumnType::NCLOB => Ok(OracleType::NClob),
192        ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
193        ColumnType::LongRaw => Ok(OracleType::LongRaw),
194        ColumnType::BLOB => Ok(OracleType::Blob),
195        ColumnType::Date => Ok(OracleType::Date),
196        ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
197        ColumnType::Boolean => Ok(OracleType::Boolean),
198        _ => Err(DataFusionError::NotImplemented(format!(
199            "Unsupported oracle type: {oracle_type:?}",
200        ))),
201    }
202}
203
204fn build_remote_schema(result_set: &oracle::ResultSet<Row>) -> DFResult<RemoteSchema> {
205    let mut remote_fields = vec![];
206    for col in result_set.column_info() {
207        let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
208        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
209    }
210    Ok(RemoteSchema::new(remote_fields))
211}
212
213macro_rules! handle_primitive_type {
214    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
215        let builder = $builder
216            .as_any_mut()
217            .downcast_mut::<$builder_ty>()
218            .unwrap_or_else(|| {
219                panic!(
220                    "Failed to downcast builder to {} for {:?} and {:?}",
221                    stringify!($builder_ty),
222                    $field,
223                    $col
224                )
225            });
226        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
227            DataFusionError::Execution(format!(
228                "Failed to get {} value for {:?} and {:?}: {e:?}",
229                stringify!($value_ty),
230                $field,
231                $col
232            ))
233        })?;
234
235        match v {
236            Some(v) => builder.append_value($convert(v)?),
237            None => builder.append_null(),
238        }
239    }};
240}
241
242fn rows_to_batch(
243    rows: &[Row],
244    table_schema: &SchemaRef,
245    projection: Option<&Vec<usize>>,
246) -> DFResult<RecordBatch> {
247    let projected_schema = project_schema(table_schema, projection)?;
248    let mut array_builders = vec![];
249    for field in table_schema.fields() {
250        let builder = make_builder(field.data_type(), rows.len());
251        array_builders.push(builder);
252    }
253
254    for row in rows {
255        for (idx, field) in table_schema.fields.iter().enumerate() {
256            if !projections_contains(projection, idx) {
257                continue;
258            }
259            let builder = &mut array_builders[idx];
260            let col = row.column_info().get(idx);
261            match field.data_type() {
262                DataType::Int16 => {
263                    handle_primitive_type!(
264                        builder,
265                        field,
266                        col,
267                        Int16Builder,
268                        i16,
269                        row,
270                        idx,
271                        just_return
272                    );
273                }
274                DataType::Int32 => {
275                    handle_primitive_type!(
276                        builder,
277                        field,
278                        col,
279                        Int32Builder,
280                        i32,
281                        row,
282                        idx,
283                        just_return
284                    );
285                }
286                DataType::Int64 => {
287                    handle_primitive_type!(
288                        builder,
289                        field,
290                        col,
291                        Int64Builder,
292                        i64,
293                        row,
294                        idx,
295                        just_return
296                    );
297                }
298                DataType::Float32 => {
299                    handle_primitive_type!(
300                        builder,
301                        field,
302                        col,
303                        Float32Builder,
304                        f32,
305                        row,
306                        idx,
307                        just_return
308                    );
309                }
310                DataType::Float64 => {
311                    handle_primitive_type!(
312                        builder,
313                        field,
314                        col,
315                        Float64Builder,
316                        f64,
317                        row,
318                        idx,
319                        just_return
320                    );
321                }
322                DataType::Utf8 => {
323                    handle_primitive_type!(
324                        builder,
325                        field,
326                        col,
327                        StringBuilder,
328                        String,
329                        row,
330                        idx,
331                        just_return
332                    );
333                }
334                DataType::LargeUtf8 => {
335                    handle_primitive_type!(
336                        builder,
337                        field,
338                        col,
339                        LargeStringBuilder,
340                        String,
341                        row,
342                        idx,
343                        just_return
344                    );
345                }
346                DataType::Decimal128(_precision, scale) => {
347                    handle_primitive_type!(
348                        builder,
349                        field,
350                        col,
351                        Decimal128Builder,
352                        String,
353                        row,
354                        idx,
355                        |v: String| {
356                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
357                                DataFusionError::Execution(format!(
358                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
359                                ))
360                            })?;
361                            big_decimal_to_i128(&decimal, Some(*scale as i32))
362                        }
363                    );
364                }
365                DataType::Timestamp(TimeUnit::Second, None) => {
366                    handle_primitive_type!(
367                        builder,
368                        field,
369                        col,
370                        TimestampSecondBuilder,
371                        chrono::NaiveDateTime,
372                        row,
373                        idx,
374                        |v: chrono::NaiveDateTime| {
375                            let t = v.and_utc().timestamp();
376                            Ok::<_, DataFusionError>(t)
377                        }
378                    );
379                }
380                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
381                    handle_primitive_type!(
382                        builder,
383                        field,
384                        col,
385                        TimestampNanosecondBuilder,
386                        chrono::NaiveDateTime,
387                        row,
388                        idx,
389                        |v: chrono::NaiveDateTime| {
390                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
391                                DataFusionError::Execution(format!(
392                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
393                                ))
394                            })
395                        }
396                    );
397                }
398                DataType::Date64 => {
399                    handle_primitive_type!(
400                        builder,
401                        field,
402                        col,
403                        Date64Builder,
404                        chrono::NaiveDateTime,
405                        row,
406                        idx,
407                        |v: chrono::NaiveDateTime| {
408                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
409                        }
410                    );
411                }
412                DataType::Boolean => {
413                    handle_primitive_type!(
414                        builder,
415                        field,
416                        col,
417                        BooleanBuilder,
418                        bool,
419                        row,
420                        idx,
421                        just_return
422                    );
423                }
424                DataType::Binary => {
425                    handle_primitive_type!(
426                        builder,
427                        field,
428                        col,
429                        BinaryBuilder,
430                        Vec<u8>,
431                        row,
432                        idx,
433                        just_return
434                    );
435                }
436                DataType::LargeBinary => {
437                    handle_primitive_type!(
438                        builder,
439                        field,
440                        col,
441                        LargeBinaryBuilder,
442                        Vec<u8>,
443                        row,
444                        idx,
445                        just_return
446                    );
447                }
448                _ => {
449                    return Err(DataFusionError::NotImplemented(format!(
450                        "Unsupported data type {:?} for col: {:?}",
451                        field.data_type(),
452                        col
453                    )));
454                }
455            }
456        }
457    }
458
459    let projected_columns = array_builders
460        .into_iter()
461        .enumerate()
462        .filter(|(idx, _)| projections_contains(projection, *idx))
463        .map(|(_, mut builder)| builder.finish())
464        .collect::<Vec<ArrayRef>>();
465    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
466    Ok(RecordBatch::try_new_with_options(
467        projected_schema,
468        projected_columns,
469        &options,
470    )?)
471}