datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10    RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::limit::LimitStream;
16use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use derive_getters::Getters;
19use derive_with::With;
20use futures::StreamExt;
21use oracle::sql_type::OracleType as ColumnType;
22use oracle::{Connector, Row};
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct OracleConnectionOptions {
27    pub(crate) host: String,
28    pub(crate) port: u16,
29    pub(crate) username: String,
30    pub(crate) password: String,
31    pub(crate) service_name: String,
32    pub(crate) pool_max_size: usize,
33    pub(crate) stream_chunk_size: usize,
34}
35
36impl OracleConnectionOptions {
37    pub fn new(
38        host: impl Into<String>,
39        port: u16,
40        username: impl Into<String>,
41        password: impl Into<String>,
42        service_name: impl Into<String>,
43    ) -> Self {
44        Self {
45            host: host.into(),
46            port,
47            username: username.into(),
48            password: password.into(),
49            service_name: service_name.into(),
50            pool_max_size: 10,
51            stream_chunk_size: 2048,
52        }
53    }
54}
55
56#[derive(Debug)]
57pub struct OraclePool {
58    pool: bb8::Pool<OracleConnectionManager>,
59}
60
61pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
62    let connect_string = format!(
63        "//{}:{}/{}",
64        options.host, options.port, options.service_name
65    );
66    let connector = Connector::new(
67        options.username.clone(),
68        options.password.clone(),
69        connect_string,
70    );
71    let _ = connector
72        .connect()
73        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
74    let manager = OracleConnectionManager::from_connector(connector);
75    let pool = bb8::Pool::builder()
76        .max_size(options.pool_max_size as u32)
77        .build(manager)
78        .await
79        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
80    Ok(OraclePool { pool })
81}
82
83#[async_trait::async_trait]
84impl Pool for OraclePool {
85    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
86        let conn = self.pool.get_owned().await.map_err(|e| {
87            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
88        })?;
89        Ok(Arc::new(OracleConnection { conn }))
90    }
91}
92
93#[derive(Debug)]
94pub struct OracleConnection {
95    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
96}
97
98#[async_trait::async_trait]
99impl Connection for OracleConnection {
100    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
101        let sql = try_limit_query(sql, 1).unwrap_or_else(|| sql.to_string());
102        let row = self.conn.query_row(&sql, &[]).map_err(|e| {
103            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
104        })?;
105        let remote_schema = Arc::new(build_remote_schema(&row)?);
106        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
107        Ok((remote_schema, arrow_schema))
108    }
109
110    async fn query(
111        &self,
112        conn_options: &ConnectionOptions,
113        sql: &str,
114        table_schema: SchemaRef,
115        projection: Option<&Vec<usize>>,
116        limit: Option<usize>,
117    ) -> DFResult<SendableRecordBatchStream> {
118        let projected_schema = project_schema(&table_schema, projection)?;
119        let (sql, limit_stream) = match limit {
120            Some(limit) => {
121                if let Some(limited_sql) = try_limit_query(sql, limit) {
122                    (limited_sql, false)
123                } else {
124                    (sql.to_string(), true)
125                }
126            }
127            None => (sql.to_string(), false),
128        };
129        let projection = projection.cloned();
130        let chunk_size = conn_options.stream_chunk_size();
131        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
132            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
133        })?;
134        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
135
136        let stream = stream.map(move |rows| {
137            let rows: Vec<Row> = rows
138                .into_iter()
139                .collect::<Result<Vec<_>, _>>()
140                .map_err(|e| {
141                    DataFusionError::Execution(format!(
142                        "Failed to collect rows from oracle due to {e}",
143                    ))
144                })?;
145            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
146        });
147
148        let sendable_stream = Box::pin(RecordBatchStreamAdapter::new(projected_schema, stream));
149
150        if limit_stream {
151            let metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
152            Ok(Box::pin(LimitStream::new(
153                sendable_stream,
154                0,
155                limit,
156                metrics,
157            )))
158        } else {
159            Ok(sendable_stream)
160        }
161    }
162}
163
164fn try_limit_query(sql: &str, limit: usize) -> Option<String> {
165    if sql.trim()[0..6].eq_ignore_ascii_case("select") {
166        Some(format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {limit}"))
167    } else {
168        None
169    }
170}
171
172fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
173    match oracle_type {
174        ColumnType::Number(precision, scale) => {
175            // TODO need more investigation on the precision and scale
176            let precision = if *precision == 0 { 38 } else { *precision };
177            let scale = if *scale == -127 { 0 } else { *scale };
178            Ok(RemoteType::Oracle(OracleType::Number(precision, scale)))
179        }
180        ColumnType::BinaryFloat => Ok(RemoteType::Oracle(OracleType::BinaryFloat)),
181        ColumnType::BinaryDouble => Ok(RemoteType::Oracle(OracleType::BinaryDouble)),
182        ColumnType::Float(precision) => Ok(RemoteType::Oracle(OracleType::Float(*precision))),
183        ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
184        ColumnType::NVarchar2(size) => Ok(RemoteType::Oracle(OracleType::NVarchar2(*size))),
185        ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
186        ColumnType::NChar(size) => Ok(RemoteType::Oracle(OracleType::NChar(*size))),
187        ColumnType::Long => Ok(RemoteType::Oracle(OracleType::Long)),
188        ColumnType::CLOB => Ok(RemoteType::Oracle(OracleType::Clob)),
189        ColumnType::NCLOB => Ok(RemoteType::Oracle(OracleType::NClob)),
190        ColumnType::Raw(size) => Ok(RemoteType::Oracle(OracleType::Raw(*size))),
191        ColumnType::LongRaw => Ok(RemoteType::Oracle(OracleType::LongRaw)),
192        ColumnType::BLOB => Ok(RemoteType::Oracle(OracleType::Blob)),
193        ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
194        ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
195        ColumnType::Boolean => Ok(RemoteType::Oracle(OracleType::Boolean)),
196        _ => Err(DataFusionError::NotImplemented(format!(
197            "Unsupported oracle type: {oracle_type:?}",
198        ))),
199    }
200}
201
202fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
203    let mut remote_fields = vec![];
204    for col in row.column_info() {
205        let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
206        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
207    }
208    Ok(RemoteSchema::new(remote_fields))
209}
210
211macro_rules! handle_primitive_type {
212    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
213        let builder = $builder
214            .as_any_mut()
215            .downcast_mut::<$builder_ty>()
216            .unwrap_or_else(|| {
217                panic!(
218                    "Failed to downcast builder to {} for {:?} and {:?}",
219                    stringify!($builder_ty),
220                    $field,
221                    $col
222                )
223            });
224        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
225            DataFusionError::Execution(format!(
226                "Failed to get {} value for {:?} and {:?}: {e:?}",
227                stringify!($value_ty),
228                $field,
229                $col
230            ))
231        })?;
232
233        match v {
234            Some(v) => builder.append_value($convert(v)?),
235            None => builder.append_null(),
236        }
237    }};
238}
239
240fn rows_to_batch(
241    rows: &[Row],
242    table_schema: &SchemaRef,
243    projection: Option<&Vec<usize>>,
244) -> DFResult<RecordBatch> {
245    let projected_schema = project_schema(table_schema, projection)?;
246    let mut array_builders = vec![];
247    for field in table_schema.fields() {
248        let builder = make_builder(field.data_type(), rows.len());
249        array_builders.push(builder);
250    }
251
252    for row in rows {
253        for (idx, field) in table_schema.fields.iter().enumerate() {
254            if !projections_contains(projection, idx) {
255                continue;
256            }
257            let builder = &mut array_builders[idx];
258            let col = row.column_info().get(idx);
259            match field.data_type() {
260                DataType::Int16 => {
261                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
262                        Ok::<_, DataFusionError>(v)
263                    });
264                }
265                DataType::Int32 => {
266                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
267                        Ok::<_, DataFusionError>(v)
268                    });
269                }
270                DataType::Float32 => {
271                    handle_primitive_type!(
272                        builder,
273                        field,
274                        col,
275                        Float32Builder,
276                        f32,
277                        row,
278                        idx,
279                        |v| { Ok::<_, DataFusionError>(v) }
280                    );
281                }
282                DataType::Float64 => {
283                    handle_primitive_type!(
284                        builder,
285                        field,
286                        col,
287                        Float64Builder,
288                        f64,
289                        row,
290                        idx,
291                        |v| { Ok::<_, DataFusionError>(v) }
292                    );
293                }
294                DataType::Utf8 => {
295                    handle_primitive_type!(
296                        builder,
297                        field,
298                        col,
299                        StringBuilder,
300                        String,
301                        row,
302                        idx,
303                        |v| { Ok::<_, DataFusionError>(v) }
304                    );
305                }
306                DataType::LargeUtf8 => {
307                    handle_primitive_type!(
308                        builder,
309                        field,
310                        col,
311                        LargeStringBuilder,
312                        String,
313                        row,
314                        idx,
315                        |v| { Ok::<_, DataFusionError>(v) }
316                    );
317                }
318                DataType::Decimal128(_precision, scale) => {
319                    handle_primitive_type!(
320                        builder,
321                        field,
322                        col,
323                        Decimal128Builder,
324                        String,
325                        row,
326                        idx,
327                        |v: String| {
328                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
329                                DataFusionError::Execution(format!(
330                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
331                                ))
332                            })?;
333                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
334                                DataFusionError::Execution(format!(
335                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
336                                ))
337                            })
338                        }
339                    );
340                }
341                DataType::Timestamp(TimeUnit::Second, None) => {
342                    handle_primitive_type!(
343                        builder,
344                        field,
345                        col,
346                        TimestampSecondBuilder,
347                        chrono::NaiveDateTime,
348                        row,
349                        idx,
350                        |v: chrono::NaiveDateTime| {
351                            let t = v.and_utc().timestamp();
352                            Ok::<_, DataFusionError>(t)
353                        }
354                    );
355                }
356                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
357                    handle_primitive_type!(
358                        builder,
359                        field,
360                        col,
361                        TimestampNanosecondBuilder,
362                        chrono::NaiveDateTime,
363                        row,
364                        idx,
365                        |v: chrono::NaiveDateTime| {
366                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
367                                DataFusionError::Execution(format!(
368                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
369                                ))
370                            })
371                        }
372                    );
373                }
374                DataType::Date64 => {
375                    handle_primitive_type!(
376                        builder,
377                        field,
378                        col,
379                        Date64Builder,
380                        chrono::NaiveDateTime,
381                        row,
382                        idx,
383                        |v: chrono::NaiveDateTime| {
384                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
385                        }
386                    );
387                }
388                DataType::Boolean => {
389                    handle_primitive_type!(
390                        builder,
391                        field,
392                        col,
393                        BooleanBuilder,
394                        bool,
395                        row,
396                        idx,
397                        |v| { Ok::<_, DataFusionError>(v) }
398                    );
399                }
400                DataType::Binary => {
401                    handle_primitive_type!(
402                        builder,
403                        field,
404                        col,
405                        BinaryBuilder,
406                        Vec<u8>,
407                        row,
408                        idx,
409                        |v| { Ok::<_, DataFusionError>(v) }
410                    );
411                }
412                DataType::LargeBinary => {
413                    handle_primitive_type!(
414                        builder,
415                        field,
416                        col,
417                        LargeBinaryBuilder,
418                        Vec<u8>,
419                        row,
420                        idx,
421                        |v| { Ok::<_, DataFusionError>(v) }
422                    );
423                }
424                _ => {
425                    return Err(DataFusionError::NotImplemented(format!(
426                        "Unsupported data type {:?} for col: {:?}",
427                        field.data_type(),
428                        col
429                    )));
430                }
431            }
432        }
433    }
434
435    let projected_columns = array_builders
436        .into_iter()
437        .enumerate()
438        .filter(|(idx, _)| projections_contains(projection, *idx))
439        .map(|(_, mut builder)| builder.finish())
440        .collect::<Vec<ArrayRef>>();
441    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
442}