datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10    RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::prelude::Expr;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use oracle::sql_type::OracleType as ColumnType;
21use oracle::{Connector, Row};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct OracleConnectionOptions {
26    pub(crate) host: String,
27    pub(crate) port: u16,
28    pub(crate) username: String,
29    pub(crate) password: String,
30    pub(crate) service_name: String,
31    pub(crate) pool_max_size: usize,
32    pub(crate) stream_chunk_size: usize,
33}
34
35impl OracleConnectionOptions {
36    pub fn new(
37        host: impl Into<String>,
38        port: u16,
39        username: impl Into<String>,
40        password: impl Into<String>,
41        service_name: impl Into<String>,
42    ) -> Self {
43        Self {
44            host: host.into(),
45            port,
46            username: username.into(),
47            password: password.into(),
48            service_name: service_name.into(),
49            pool_max_size: 10,
50            stream_chunk_size: 2048,
51        }
52    }
53}
54
55#[derive(Debug)]
56pub struct OraclePool {
57    pool: bb8::Pool<OracleConnectionManager>,
58}
59
60pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
61    let connect_string = format!(
62        "//{}:{}/{}",
63        options.host, options.port, options.service_name
64    );
65    let connector = Connector::new(
66        options.username.clone(),
67        options.password.clone(),
68        connect_string,
69    );
70    let _ = connector
71        .connect()
72        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
73    let manager = OracleConnectionManager::from_connector(connector);
74    let pool = bb8::Pool::builder()
75        .max_size(options.pool_max_size as u32)
76        .build(manager)
77        .await
78        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
79    Ok(OraclePool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for OraclePool {
84    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85        let conn = self.pool.get_owned().await.map_err(|e| {
86            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
87        })?;
88        Ok(Arc::new(OracleConnection { conn }))
89    }
90}
91
92#[derive(Debug)]
93pub struct OracleConnection {
94    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
95}
96
97#[async_trait::async_trait]
98impl Connection for OracleConnection {
99    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
100        let sql = RemoteDbType::Oracle.query_limit_1(sql)?;
101        let row = self.conn.query_row(&sql, &[]).map_err(|e| {
102            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
103        })?;
104        let remote_schema = Arc::new(build_remote_schema(&row)?);
105        Ok(remote_schema)
106    }
107
108    async fn query(
109        &self,
110        conn_options: &ConnectionOptions,
111        sql: &str,
112        table_schema: SchemaRef,
113        projection: Option<&Vec<usize>>,
114        filters: &[Expr],
115        limit: Option<usize>,
116    ) -> DFResult<SendableRecordBatchStream> {
117        let projected_schema = project_schema(&table_schema, projection)?;
118        let sql = RemoteDbType::Oracle.try_rewrite_query(sql, filters, limit)?;
119        let projection = projection.cloned();
120        let chunk_size = conn_options.stream_chunk_size();
121        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
122            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
123        })?;
124        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
125
126        let stream = stream.map(move |rows| {
127            let rows: Vec<Row> = rows
128                .into_iter()
129                .collect::<Result<Vec<_>, _>>()
130                .map_err(|e| {
131                    DataFusionError::Execution(format!(
132                        "Failed to collect rows from oracle due to {e}",
133                    ))
134                })?;
135            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
136        });
137
138        Ok(Box::pin(RecordBatchStreamAdapter::new(
139            projected_schema,
140            stream,
141        )))
142    }
143}
144
145fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
146    match oracle_type {
147        ColumnType::Number(precision, scale) => {
148            // TODO need more investigation on the precision and scale
149            let precision = if *precision == 0 { 38 } else { *precision };
150            let scale = if *scale == -127 { 0 } else { *scale };
151            Ok(OracleType::Number(precision, scale))
152        }
153        ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
154        ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
155        ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
156        ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
157        ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
158        ColumnType::Char(size) => Ok(OracleType::Char(*size)),
159        ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
160        ColumnType::Long => Ok(OracleType::Long),
161        ColumnType::CLOB => Ok(OracleType::Clob),
162        ColumnType::NCLOB => Ok(OracleType::NClob),
163        ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
164        ColumnType::LongRaw => Ok(OracleType::LongRaw),
165        ColumnType::BLOB => Ok(OracleType::Blob),
166        ColumnType::Date => Ok(OracleType::Date),
167        ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
168        ColumnType::Boolean => Ok(OracleType::Boolean),
169        _ => Err(DataFusionError::NotImplemented(format!(
170            "Unsupported oracle type: {oracle_type:?}",
171        ))),
172    }
173}
174
175fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
176    let mut remote_fields = vec![];
177    for col in row.column_info() {
178        let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
179        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
180    }
181    Ok(RemoteSchema::new(remote_fields))
182}
183
184macro_rules! handle_primitive_type {
185    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
186        let builder = $builder
187            .as_any_mut()
188            .downcast_mut::<$builder_ty>()
189            .unwrap_or_else(|| {
190                panic!(
191                    "Failed to downcast builder to {} for {:?} and {:?}",
192                    stringify!($builder_ty),
193                    $field,
194                    $col
195                )
196            });
197        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
198            DataFusionError::Execution(format!(
199                "Failed to get {} value for {:?} and {:?}: {e:?}",
200                stringify!($value_ty),
201                $field,
202                $col
203            ))
204        })?;
205
206        match v {
207            Some(v) => builder.append_value($convert(v)?),
208            None => builder.append_null(),
209        }
210    }};
211}
212
213fn rows_to_batch(
214    rows: &[Row],
215    table_schema: &SchemaRef,
216    projection: Option<&Vec<usize>>,
217) -> DFResult<RecordBatch> {
218    let projected_schema = project_schema(table_schema, projection)?;
219    let mut array_builders = vec![];
220    for field in table_schema.fields() {
221        let builder = make_builder(field.data_type(), rows.len());
222        array_builders.push(builder);
223    }
224
225    for row in rows {
226        for (idx, field) in table_schema.fields.iter().enumerate() {
227            if !projections_contains(projection, idx) {
228                continue;
229            }
230            let builder = &mut array_builders[idx];
231            let col = row.column_info().get(idx);
232            match field.data_type() {
233                DataType::Int16 => {
234                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
235                        Ok::<_, DataFusionError>(v)
236                    });
237                }
238                DataType::Int32 => {
239                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
240                        Ok::<_, DataFusionError>(v)
241                    });
242                }
243                DataType::Float32 => {
244                    handle_primitive_type!(
245                        builder,
246                        field,
247                        col,
248                        Float32Builder,
249                        f32,
250                        row,
251                        idx,
252                        |v| { Ok::<_, DataFusionError>(v) }
253                    );
254                }
255                DataType::Float64 => {
256                    handle_primitive_type!(
257                        builder,
258                        field,
259                        col,
260                        Float64Builder,
261                        f64,
262                        row,
263                        idx,
264                        |v| { Ok::<_, DataFusionError>(v) }
265                    );
266                }
267                DataType::Utf8 => {
268                    handle_primitive_type!(
269                        builder,
270                        field,
271                        col,
272                        StringBuilder,
273                        String,
274                        row,
275                        idx,
276                        |v| { Ok::<_, DataFusionError>(v) }
277                    );
278                }
279                DataType::LargeUtf8 => {
280                    handle_primitive_type!(
281                        builder,
282                        field,
283                        col,
284                        LargeStringBuilder,
285                        String,
286                        row,
287                        idx,
288                        |v| { Ok::<_, DataFusionError>(v) }
289                    );
290                }
291                DataType::Decimal128(_precision, scale) => {
292                    handle_primitive_type!(
293                        builder,
294                        field,
295                        col,
296                        Decimal128Builder,
297                        String,
298                        row,
299                        idx,
300                        |v: String| {
301                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
302                                DataFusionError::Execution(format!(
303                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
304                                ))
305                            })?;
306                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
307                                DataFusionError::Execution(format!(
308                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
309                                ))
310                            })
311                        }
312                    );
313                }
314                DataType::Timestamp(TimeUnit::Second, None) => {
315                    handle_primitive_type!(
316                        builder,
317                        field,
318                        col,
319                        TimestampSecondBuilder,
320                        chrono::NaiveDateTime,
321                        row,
322                        idx,
323                        |v: chrono::NaiveDateTime| {
324                            let t = v.and_utc().timestamp();
325                            Ok::<_, DataFusionError>(t)
326                        }
327                    );
328                }
329                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
330                    handle_primitive_type!(
331                        builder,
332                        field,
333                        col,
334                        TimestampNanosecondBuilder,
335                        chrono::NaiveDateTime,
336                        row,
337                        idx,
338                        |v: chrono::NaiveDateTime| {
339                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
340                                DataFusionError::Execution(format!(
341                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
342                                ))
343                            })
344                        }
345                    );
346                }
347                DataType::Date64 => {
348                    handle_primitive_type!(
349                        builder,
350                        field,
351                        col,
352                        Date64Builder,
353                        chrono::NaiveDateTime,
354                        row,
355                        idx,
356                        |v: chrono::NaiveDateTime| {
357                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
358                        }
359                    );
360                }
361                DataType::Boolean => {
362                    handle_primitive_type!(
363                        builder,
364                        field,
365                        col,
366                        BooleanBuilder,
367                        bool,
368                        row,
369                        idx,
370                        |v| { Ok::<_, DataFusionError>(v) }
371                    );
372                }
373                DataType::Binary => {
374                    handle_primitive_type!(
375                        builder,
376                        field,
377                        col,
378                        BinaryBuilder,
379                        Vec<u8>,
380                        row,
381                        idx,
382                        |v| { Ok::<_, DataFusionError>(v) }
383                    );
384                }
385                DataType::LargeBinary => {
386                    handle_primitive_type!(
387                        builder,
388                        field,
389                        col,
390                        LargeBinaryBuilder,
391                        Vec<u8>,
392                        row,
393                        idx,
394                        |v| { Ok::<_, DataFusionError>(v) }
395                    );
396                }
397                _ => {
398                    return Err(DataFusionError::NotImplemented(format!(
399                        "Unsupported data type {:?} for col: {:?}",
400                        field.data_type(),
401                        col
402                    )));
403                }
404            }
405        }
406    }
407
408    let projected_columns = array_builders
409        .into_iter()
410        .enumerate()
411        .filter(|(idx, _)| projections_contains(projection, *idx))
412        .map(|(_, mut builder)| builder.finish())
413        .collect::<Vec<ArrayRef>>();
414    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
415}