datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10    RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::prelude::Expr;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use oracle::sql_type::OracleType as ColumnType;
21use oracle::{Connector, Row};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct OracleConnectionOptions {
26    pub(crate) host: String,
27    pub(crate) port: u16,
28    pub(crate) username: String,
29    pub(crate) password: String,
30    pub(crate) service_name: String,
31    pub(crate) pool_max_size: usize,
32    pub(crate) stream_chunk_size: usize,
33}
34
35impl OracleConnectionOptions {
36    pub fn new(
37        host: impl Into<String>,
38        port: u16,
39        username: impl Into<String>,
40        password: impl Into<String>,
41        service_name: impl Into<String>,
42    ) -> Self {
43        Self {
44            host: host.into(),
45            port,
46            username: username.into(),
47            password: password.into(),
48            service_name: service_name.into(),
49            pool_max_size: 10,
50            stream_chunk_size: 2048,
51        }
52    }
53}
54
55#[derive(Debug)]
56pub struct OraclePool {
57    pool: bb8::Pool<OracleConnectionManager>,
58}
59
60pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
61    let connect_string = format!(
62        "//{}:{}/{}",
63        options.host, options.port, options.service_name
64    );
65    let connector = Connector::new(
66        options.username.clone(),
67        options.password.clone(),
68        connect_string,
69    );
70    let _ = connector
71        .connect()
72        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
73    let manager = OracleConnectionManager::from_connector(connector);
74    let pool = bb8::Pool::builder()
75        .max_size(options.pool_max_size as u32)
76        .build(manager)
77        .await
78        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
79    Ok(OraclePool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for OraclePool {
84    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85        let conn = self.pool.get_owned().await.map_err(|e| {
86            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
87        })?;
88        Ok(Arc::new(OracleConnection { conn }))
89    }
90}
91
92#[derive(Debug)]
93pub struct OracleConnection {
94    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
95}
96
97#[async_trait::async_trait]
98impl Connection for OracleConnection {
99    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
100        let sql = RemoteDbType::Oracle.query_limit_1(sql)?;
101        let row = self.conn.query_row(&sql, &[]).map_err(|e| {
102            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
103        })?;
104        let remote_schema = Arc::new(build_remote_schema(&row)?);
105        Ok(remote_schema)
106    }
107
108    async fn query(
109        &self,
110        conn_options: &ConnectionOptions,
111        sql: &str,
112        table_schema: SchemaRef,
113        projection: Option<&Vec<usize>>,
114        filters: &[Expr],
115        limit: Option<usize>,
116    ) -> DFResult<SendableRecordBatchStream> {
117        let projected_schema = project_schema(&table_schema, projection)?;
118        let sql = RemoteDbType::Oracle.try_rewrite_query(sql, filters, limit)?;
119        let projection = projection.cloned();
120        let chunk_size = conn_options.stream_chunk_size();
121        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
122            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
123        })?;
124        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
125
126        let stream = stream.map(move |rows| {
127            let rows: Vec<Row> = rows
128                .into_iter()
129                .collect::<Result<Vec<_>, _>>()
130                .map_err(|e| {
131                    DataFusionError::Execution(format!(
132                        "Failed to collect rows from oracle due to {e}",
133                    ))
134                })?;
135            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
136        });
137
138        Ok(Box::pin(RecordBatchStreamAdapter::new(
139            projected_schema,
140            stream,
141        )))
142    }
143}
144
145fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
146    match oracle_type {
147        ColumnType::Number(precision, scale) => {
148            // TODO need more investigation on the precision and scale
149            let precision = if *precision == 0 { 38 } else { *precision };
150            let scale = if *scale == -127 { 0 } else { *scale };
151            Ok(OracleType::Number(precision, scale))
152        }
153        ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
154        ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
155        ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
156        ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
157        ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
158        ColumnType::Char(size) => Ok(OracleType::Char(*size)),
159        ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
160        ColumnType::Long => Ok(OracleType::Long),
161        ColumnType::CLOB => Ok(OracleType::Clob),
162        ColumnType::NCLOB => Ok(OracleType::NClob),
163        ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
164        ColumnType::LongRaw => Ok(OracleType::LongRaw),
165        ColumnType::BLOB => Ok(OracleType::Blob),
166        ColumnType::Date => Ok(OracleType::Date),
167        ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
168        ColumnType::Boolean => Ok(OracleType::Boolean),
169        _ => Err(DataFusionError::NotImplemented(format!(
170            "Unsupported oracle type: {oracle_type:?}",
171        ))),
172    }
173}
174
175fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
176    let mut remote_fields = vec![];
177    for col in row.column_info() {
178        let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
179        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
180    }
181    Ok(RemoteSchema::new(remote_fields))
182}
183
184macro_rules! handle_primitive_type {
185    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
186        let builder = $builder
187            .as_any_mut()
188            .downcast_mut::<$builder_ty>()
189            .unwrap_or_else(|| {
190                panic!(
191                    "Failed to downcast builder to {} for {:?} and {:?}",
192                    stringify!($builder_ty),
193                    $field,
194                    $col
195                )
196            });
197        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
198            DataFusionError::Execution(format!(
199                "Failed to get {} value for {:?} and {:?}: {e:?}",
200                stringify!($value_ty),
201                $field,
202                $col
203            ))
204        })?;
205
206        match v {
207            Some(v) => builder.append_value($convert(v)?),
208            None => builder.append_null(),
209        }
210    }};
211}
212
213fn rows_to_batch(
214    rows: &[Row],
215    table_schema: &SchemaRef,
216    projection: Option<&Vec<usize>>,
217) -> DFResult<RecordBatch> {
218    let projected_schema = project_schema(table_schema, projection)?;
219    let mut array_builders = vec![];
220    for field in table_schema.fields() {
221        let builder = make_builder(field.data_type(), rows.len());
222        array_builders.push(builder);
223    }
224
225    for row in rows {
226        for (idx, field) in table_schema.fields.iter().enumerate() {
227            if !projections_contains(projection, idx) {
228                continue;
229            }
230            let builder = &mut array_builders[idx];
231            let col = row.column_info().get(idx);
232            match field.data_type() {
233                DataType::Int16 => {
234                    handle_primitive_type!(
235                        builder,
236                        field,
237                        col,
238                        Int16Builder,
239                        i16,
240                        row,
241                        idx,
242                        just_return
243                    );
244                }
245                DataType::Int32 => {
246                    handle_primitive_type!(
247                        builder,
248                        field,
249                        col,
250                        Int32Builder,
251                        i32,
252                        row,
253                        idx,
254                        just_return
255                    );
256                }
257                DataType::Float32 => {
258                    handle_primitive_type!(
259                        builder,
260                        field,
261                        col,
262                        Float32Builder,
263                        f32,
264                        row,
265                        idx,
266                        just_return
267                    );
268                }
269                DataType::Float64 => {
270                    handle_primitive_type!(
271                        builder,
272                        field,
273                        col,
274                        Float64Builder,
275                        f64,
276                        row,
277                        idx,
278                        just_return
279                    );
280                }
281                DataType::Utf8 => {
282                    handle_primitive_type!(
283                        builder,
284                        field,
285                        col,
286                        StringBuilder,
287                        String,
288                        row,
289                        idx,
290                        just_return
291                    );
292                }
293                DataType::LargeUtf8 => {
294                    handle_primitive_type!(
295                        builder,
296                        field,
297                        col,
298                        LargeStringBuilder,
299                        String,
300                        row,
301                        idx,
302                        just_return
303                    );
304                }
305                DataType::Decimal128(_precision, scale) => {
306                    handle_primitive_type!(
307                        builder,
308                        field,
309                        col,
310                        Decimal128Builder,
311                        String,
312                        row,
313                        idx,
314                        |v: String| {
315                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
316                                DataFusionError::Execution(format!(
317                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
318                                ))
319                            })?;
320                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
321                                DataFusionError::Execution(format!(
322                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
323                                ))
324                            })
325                        }
326                    );
327                }
328                DataType::Timestamp(TimeUnit::Second, None) => {
329                    handle_primitive_type!(
330                        builder,
331                        field,
332                        col,
333                        TimestampSecondBuilder,
334                        chrono::NaiveDateTime,
335                        row,
336                        idx,
337                        |v: chrono::NaiveDateTime| {
338                            let t = v.and_utc().timestamp();
339                            Ok::<_, DataFusionError>(t)
340                        }
341                    );
342                }
343                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
344                    handle_primitive_type!(
345                        builder,
346                        field,
347                        col,
348                        TimestampNanosecondBuilder,
349                        chrono::NaiveDateTime,
350                        row,
351                        idx,
352                        |v: chrono::NaiveDateTime| {
353                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
354                                DataFusionError::Execution(format!(
355                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
356                                ))
357                            })
358                        }
359                    );
360                }
361                DataType::Date64 => {
362                    handle_primitive_type!(
363                        builder,
364                        field,
365                        col,
366                        Date64Builder,
367                        chrono::NaiveDateTime,
368                        row,
369                        idx,
370                        |v: chrono::NaiveDateTime| {
371                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
372                        }
373                    );
374                }
375                DataType::Boolean => {
376                    handle_primitive_type!(
377                        builder,
378                        field,
379                        col,
380                        BooleanBuilder,
381                        bool,
382                        row,
383                        idx,
384                        just_return
385                    );
386                }
387                DataType::Binary => {
388                    handle_primitive_type!(
389                        builder,
390                        field,
391                        col,
392                        BinaryBuilder,
393                        Vec<u8>,
394                        row,
395                        idx,
396                        just_return
397                    );
398                }
399                DataType::LargeBinary => {
400                    handle_primitive_type!(
401                        builder,
402                        field,
403                        col,
404                        LargeBinaryBuilder,
405                        Vec<u8>,
406                        row,
407                        idx,
408                        just_return
409                    );
410                }
411                _ => {
412                    return Err(DataFusionError::NotImplemented(format!(
413                        "Unsupported data type {:?} for col: {:?}",
414                        field.data_type(),
415                        col
416                    )));
417                }
418            }
419        }
420    }
421
422    let projected_columns = array_builders
423        .into_iter()
424        .enumerate()
425        .filter(|(idx, _)| projections_contains(projection, *idx))
426        .map(|(_, mut builder)| builder.finish())
427        .collect::<Vec<ArrayRef>>();
428    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
429}