datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10    RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use derive_getters::Getters;
17use derive_with::With;
18use futures::StreamExt;
19use oracle::sql_type::OracleType as ColumnType;
20use oracle::{Connector, Row};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, With, Getters)]
24pub struct OracleConnectionOptions {
25    pub(crate) host: String,
26    pub(crate) port: u16,
27    pub(crate) username: String,
28    pub(crate) password: String,
29    pub(crate) service_name: String,
30    pub(crate) pool_max_size: usize,
31    pub(crate) stream_chunk_size: usize,
32}
33
34impl OracleConnectionOptions {
35    pub fn new(
36        host: impl Into<String>,
37        port: u16,
38        username: impl Into<String>,
39        password: impl Into<String>,
40        service_name: impl Into<String>,
41    ) -> Self {
42        Self {
43            host: host.into(),
44            port,
45            username: username.into(),
46            password: password.into(),
47            service_name: service_name.into(),
48            pool_max_size: 10,
49            stream_chunk_size: 2048,
50        }
51    }
52}
53
54#[derive(Debug)]
55pub struct OraclePool {
56    pool: bb8::Pool<OracleConnectionManager>,
57}
58
59pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
60    let connect_string = format!(
61        "//{}:{}/{}",
62        options.host, options.port, options.service_name
63    );
64    let connector = Connector::new(
65        options.username.clone(),
66        options.password.clone(),
67        connect_string,
68    );
69    let _ = connector
70        .connect()
71        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
72    let manager = OracleConnectionManager::from_connector(connector);
73    let pool = bb8::Pool::builder()
74        .max_size(options.pool_max_size as u32)
75        .build(manager)
76        .await
77        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
78    Ok(OraclePool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for OraclePool {
83    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84        let conn = self.pool.get_owned().await.map_err(|e| {
85            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
86        })?;
87        Ok(Arc::new(OracleConnection { conn }))
88    }
89}
90
91#[derive(Debug)]
92pub struct OracleConnection {
93    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
94}
95
96#[async_trait::async_trait]
97impl Connection for OracleConnection {
98    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
99        let sql = RemoteDbType::Oracle.query_limit_1(sql)?;
100        let row = self.conn.query_row(&sql, &[]).map_err(|e| {
101            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
102        })?;
103        let remote_schema = Arc::new(build_remote_schema(&row)?);
104        Ok(remote_schema)
105    }
106
107    async fn query(
108        &self,
109        conn_options: &ConnectionOptions,
110        sql: &str,
111        table_schema: SchemaRef,
112        projection: Option<&Vec<usize>>,
113        unparsed_filters: &[String],
114        limit: Option<usize>,
115    ) -> DFResult<SendableRecordBatchStream> {
116        let projected_schema = project_schema(&table_schema, projection)?;
117        let sql = RemoteDbType::Oracle.try_rewrite_query(sql, unparsed_filters, limit)?;
118        let projection = projection.cloned();
119        let chunk_size = conn_options.stream_chunk_size();
120        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
121            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
122        })?;
123        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
124
125        let stream = stream.map(move |rows| {
126            let rows: Vec<Row> = rows
127                .into_iter()
128                .collect::<Result<Vec<_>, _>>()
129                .map_err(|e| {
130                    DataFusionError::Execution(format!(
131                        "Failed to collect rows from oracle due to {e}",
132                    ))
133                })?;
134            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
135        });
136
137        Ok(Box::pin(RecordBatchStreamAdapter::new(
138            projected_schema,
139            stream,
140        )))
141    }
142}
143
144fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
145    match oracle_type {
146        ColumnType::Number(precision, scale) => {
147            // TODO need more investigation on the precision and scale
148            let precision = if *precision == 0 { 38 } else { *precision };
149            let scale = if *scale == -127 { 0 } else { *scale };
150            Ok(OracleType::Number(precision, scale))
151        }
152        ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
153        ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
154        ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
155        ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
156        ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
157        ColumnType::Char(size) => Ok(OracleType::Char(*size)),
158        ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
159        ColumnType::Long => Ok(OracleType::Long),
160        ColumnType::CLOB => Ok(OracleType::Clob),
161        ColumnType::NCLOB => Ok(OracleType::NClob),
162        ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
163        ColumnType::LongRaw => Ok(OracleType::LongRaw),
164        ColumnType::BLOB => Ok(OracleType::Blob),
165        ColumnType::Date => Ok(OracleType::Date),
166        ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
167        ColumnType::Boolean => Ok(OracleType::Boolean),
168        _ => Err(DataFusionError::NotImplemented(format!(
169            "Unsupported oracle type: {oracle_type:?}",
170        ))),
171    }
172}
173
174fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
175    let mut remote_fields = vec![];
176    for col in row.column_info() {
177        let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
178        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
179    }
180    Ok(RemoteSchema::new(remote_fields))
181}
182
183macro_rules! handle_primitive_type {
184    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
185        let builder = $builder
186            .as_any_mut()
187            .downcast_mut::<$builder_ty>()
188            .unwrap_or_else(|| {
189                panic!(
190                    "Failed to downcast builder to {} for {:?} and {:?}",
191                    stringify!($builder_ty),
192                    $field,
193                    $col
194                )
195            });
196        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
197            DataFusionError::Execution(format!(
198                "Failed to get {} value for {:?} and {:?}: {e:?}",
199                stringify!($value_ty),
200                $field,
201                $col
202            ))
203        })?;
204
205        match v {
206            Some(v) => builder.append_value($convert(v)?),
207            None => builder.append_null(),
208        }
209    }};
210}
211
212fn rows_to_batch(
213    rows: &[Row],
214    table_schema: &SchemaRef,
215    projection: Option<&Vec<usize>>,
216) -> DFResult<RecordBatch> {
217    let projected_schema = project_schema(table_schema, projection)?;
218    let mut array_builders = vec![];
219    for field in table_schema.fields() {
220        let builder = make_builder(field.data_type(), rows.len());
221        array_builders.push(builder);
222    }
223
224    for row in rows {
225        for (idx, field) in table_schema.fields.iter().enumerate() {
226            if !projections_contains(projection, idx) {
227                continue;
228            }
229            let builder = &mut array_builders[idx];
230            let col = row.column_info().get(idx);
231            match field.data_type() {
232                DataType::Int16 => {
233                    handle_primitive_type!(
234                        builder,
235                        field,
236                        col,
237                        Int16Builder,
238                        i16,
239                        row,
240                        idx,
241                        just_return
242                    );
243                }
244                DataType::Int32 => {
245                    handle_primitive_type!(
246                        builder,
247                        field,
248                        col,
249                        Int32Builder,
250                        i32,
251                        row,
252                        idx,
253                        just_return
254                    );
255                }
256                DataType::Float32 => {
257                    handle_primitive_type!(
258                        builder,
259                        field,
260                        col,
261                        Float32Builder,
262                        f32,
263                        row,
264                        idx,
265                        just_return
266                    );
267                }
268                DataType::Float64 => {
269                    handle_primitive_type!(
270                        builder,
271                        field,
272                        col,
273                        Float64Builder,
274                        f64,
275                        row,
276                        idx,
277                        just_return
278                    );
279                }
280                DataType::Utf8 => {
281                    handle_primitive_type!(
282                        builder,
283                        field,
284                        col,
285                        StringBuilder,
286                        String,
287                        row,
288                        idx,
289                        just_return
290                    );
291                }
292                DataType::LargeUtf8 => {
293                    handle_primitive_type!(
294                        builder,
295                        field,
296                        col,
297                        LargeStringBuilder,
298                        String,
299                        row,
300                        idx,
301                        just_return
302                    );
303                }
304                DataType::Decimal128(_precision, scale) => {
305                    handle_primitive_type!(
306                        builder,
307                        field,
308                        col,
309                        Decimal128Builder,
310                        String,
311                        row,
312                        idx,
313                        |v: String| {
314                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
315                                DataFusionError::Execution(format!(
316                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
317                                ))
318                            })?;
319                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
320                                DataFusionError::Execution(format!(
321                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
322                                ))
323                            })
324                        }
325                    );
326                }
327                DataType::Timestamp(TimeUnit::Second, None) => {
328                    handle_primitive_type!(
329                        builder,
330                        field,
331                        col,
332                        TimestampSecondBuilder,
333                        chrono::NaiveDateTime,
334                        row,
335                        idx,
336                        |v: chrono::NaiveDateTime| {
337                            let t = v.and_utc().timestamp();
338                            Ok::<_, DataFusionError>(t)
339                        }
340                    );
341                }
342                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
343                    handle_primitive_type!(
344                        builder,
345                        field,
346                        col,
347                        TimestampNanosecondBuilder,
348                        chrono::NaiveDateTime,
349                        row,
350                        idx,
351                        |v: chrono::NaiveDateTime| {
352                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
353                                DataFusionError::Execution(format!(
354                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
355                                ))
356                            })
357                        }
358                    );
359                }
360                DataType::Date64 => {
361                    handle_primitive_type!(
362                        builder,
363                        field,
364                        col,
365                        Date64Builder,
366                        chrono::NaiveDateTime,
367                        row,
368                        idx,
369                        |v: chrono::NaiveDateTime| {
370                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
371                        }
372                    );
373                }
374                DataType::Boolean => {
375                    handle_primitive_type!(
376                        builder,
377                        field,
378                        col,
379                        BooleanBuilder,
380                        bool,
381                        row,
382                        idx,
383                        just_return
384                    );
385                }
386                DataType::Binary => {
387                    handle_primitive_type!(
388                        builder,
389                        field,
390                        col,
391                        BinaryBuilder,
392                        Vec<u8>,
393                        row,
394                        idx,
395                        just_return
396                    );
397                }
398                DataType::LargeBinary => {
399                    handle_primitive_type!(
400                        builder,
401                        field,
402                        col,
403                        LargeBinaryBuilder,
404                        Vec<u8>,
405                        row,
406                        idx,
407                        just_return
408                    );
409                }
410                _ => {
411                    return Err(DataFusionError::NotImplemented(format!(
412                        "Unsupported data type {:?} for col: {:?}",
413                        field.data_type(),
414                        col
415                    )));
416                }
417            }
418        }
419    }
420
421    let projected_columns = array_builders
422        .into_iter()
423        .enumerate()
424        .filter(|(idx, _)| projections_contains(projection, *idx))
425        .map(|(_, mut builder)| builder.finish())
426        .collect::<Vec<ArrayRef>>();
427    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
428}