datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8    ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9    Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10    RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::prelude::Expr;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use oracle::sql_type::OracleType as ColumnType;
21use oracle::{Connector, Row};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct OracleConnectionOptions {
26    pub(crate) host: String,
27    pub(crate) port: u16,
28    pub(crate) username: String,
29    pub(crate) password: String,
30    pub(crate) service_name: String,
31    pub(crate) pool_max_size: usize,
32    pub(crate) stream_chunk_size: usize,
33}
34
35impl OracleConnectionOptions {
36    pub fn new(
37        host: impl Into<String>,
38        port: u16,
39        username: impl Into<String>,
40        password: impl Into<String>,
41        service_name: impl Into<String>,
42    ) -> Self {
43        Self {
44            host: host.into(),
45            port,
46            username: username.into(),
47            password: password.into(),
48            service_name: service_name.into(),
49            pool_max_size: 10,
50            stream_chunk_size: 2048,
51        }
52    }
53}
54
55#[derive(Debug)]
56pub struct OraclePool {
57    pool: bb8::Pool<OracleConnectionManager>,
58}
59
60pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
61    let connect_string = format!(
62        "//{}:{}/{}",
63        options.host, options.port, options.service_name
64    );
65    let connector = Connector::new(
66        options.username.clone(),
67        options.password.clone(),
68        connect_string,
69    );
70    let _ = connector
71        .connect()
72        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
73    let manager = OracleConnectionManager::from_connector(connector);
74    let pool = bb8::Pool::builder()
75        .max_size(options.pool_max_size as u32)
76        .build(manager)
77        .await
78        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
79    Ok(OraclePool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for OraclePool {
84    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85        let conn = self.pool.get_owned().await.map_err(|e| {
86            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
87        })?;
88        Ok(Arc::new(OracleConnection { conn }))
89    }
90}
91
92#[derive(Debug)]
93pub struct OracleConnection {
94    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
95}
96
97#[async_trait::async_trait]
98impl Connection for OracleConnection {
99    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
100        let sql = RemoteDbType::Oracle
101            .try_rewrite_query(sql, &[], Some(1))
102            .unwrap_or_else(|| sql.to_string());
103        let row = self.conn.query_row(&sql, &[]).map_err(|e| {
104            DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
105        })?;
106        let remote_schema = Arc::new(build_remote_schema(&row)?);
107        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
108        Ok((remote_schema, arrow_schema))
109    }
110
111    async fn query(
112        &self,
113        conn_options: &ConnectionOptions,
114        sql: &str,
115        table_schema: SchemaRef,
116        projection: Option<&Vec<usize>>,
117        filters: &[Expr],
118        limit: Option<usize>,
119    ) -> DFResult<SendableRecordBatchStream> {
120        let projected_schema = project_schema(&table_schema, projection)?;
121        let sql = RemoteDbType::Oracle
122            .try_rewrite_query(sql, filters, limit)
123            .unwrap_or_else(|| sql.to_string());
124        let projection = projection.cloned();
125        let chunk_size = conn_options.stream_chunk_size();
126        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
127            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
128        })?;
129        let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
130
131        let stream = stream.map(move |rows| {
132            let rows: Vec<Row> = rows
133                .into_iter()
134                .collect::<Result<Vec<_>, _>>()
135                .map_err(|e| {
136                    DataFusionError::Execution(format!(
137                        "Failed to collect rows from oracle due to {e}",
138                    ))
139                })?;
140            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
141        });
142
143        Ok(Box::pin(RecordBatchStreamAdapter::new(
144            projected_schema,
145            stream,
146        )))
147    }
148}
149
150fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
151    match oracle_type {
152        ColumnType::Number(precision, scale) => {
153            // TODO need more investigation on the precision and scale
154            let precision = if *precision == 0 { 38 } else { *precision };
155            let scale = if *scale == -127 { 0 } else { *scale };
156            Ok(RemoteType::Oracle(OracleType::Number(precision, scale)))
157        }
158        ColumnType::BinaryFloat => Ok(RemoteType::Oracle(OracleType::BinaryFloat)),
159        ColumnType::BinaryDouble => Ok(RemoteType::Oracle(OracleType::BinaryDouble)),
160        ColumnType::Float(precision) => Ok(RemoteType::Oracle(OracleType::Float(*precision))),
161        ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
162        ColumnType::NVarchar2(size) => Ok(RemoteType::Oracle(OracleType::NVarchar2(*size))),
163        ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
164        ColumnType::NChar(size) => Ok(RemoteType::Oracle(OracleType::NChar(*size))),
165        ColumnType::Long => Ok(RemoteType::Oracle(OracleType::Long)),
166        ColumnType::CLOB => Ok(RemoteType::Oracle(OracleType::Clob)),
167        ColumnType::NCLOB => Ok(RemoteType::Oracle(OracleType::NClob)),
168        ColumnType::Raw(size) => Ok(RemoteType::Oracle(OracleType::Raw(*size))),
169        ColumnType::LongRaw => Ok(RemoteType::Oracle(OracleType::LongRaw)),
170        ColumnType::BLOB => Ok(RemoteType::Oracle(OracleType::Blob)),
171        ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
172        ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
173        ColumnType::Boolean => Ok(RemoteType::Oracle(OracleType::Boolean)),
174        _ => Err(DataFusionError::NotImplemented(format!(
175            "Unsupported oracle type: {oracle_type:?}",
176        ))),
177    }
178}
179
180fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
181    let mut remote_fields = vec![];
182    for col in row.column_info() {
183        let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
184        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
185    }
186    Ok(RemoteSchema::new(remote_fields))
187}
188
189macro_rules! handle_primitive_type {
190    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
191        let builder = $builder
192            .as_any_mut()
193            .downcast_mut::<$builder_ty>()
194            .unwrap_or_else(|| {
195                panic!(
196                    "Failed to downcast builder to {} for {:?} and {:?}",
197                    stringify!($builder_ty),
198                    $field,
199                    $col
200                )
201            });
202        let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
203            DataFusionError::Execution(format!(
204                "Failed to get {} value for {:?} and {:?}: {e:?}",
205                stringify!($value_ty),
206                $field,
207                $col
208            ))
209        })?;
210
211        match v {
212            Some(v) => builder.append_value($convert(v)?),
213            None => builder.append_null(),
214        }
215    }};
216}
217
218fn rows_to_batch(
219    rows: &[Row],
220    table_schema: &SchemaRef,
221    projection: Option<&Vec<usize>>,
222) -> DFResult<RecordBatch> {
223    let projected_schema = project_schema(table_schema, projection)?;
224    let mut array_builders = vec![];
225    for field in table_schema.fields() {
226        let builder = make_builder(field.data_type(), rows.len());
227        array_builders.push(builder);
228    }
229
230    for row in rows {
231        for (idx, field) in table_schema.fields.iter().enumerate() {
232            if !projections_contains(projection, idx) {
233                continue;
234            }
235            let builder = &mut array_builders[idx];
236            let col = row.column_info().get(idx);
237            match field.data_type() {
238                DataType::Int16 => {
239                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
240                        Ok::<_, DataFusionError>(v)
241                    });
242                }
243                DataType::Int32 => {
244                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
245                        Ok::<_, DataFusionError>(v)
246                    });
247                }
248                DataType::Float32 => {
249                    handle_primitive_type!(
250                        builder,
251                        field,
252                        col,
253                        Float32Builder,
254                        f32,
255                        row,
256                        idx,
257                        |v| { Ok::<_, DataFusionError>(v) }
258                    );
259                }
260                DataType::Float64 => {
261                    handle_primitive_type!(
262                        builder,
263                        field,
264                        col,
265                        Float64Builder,
266                        f64,
267                        row,
268                        idx,
269                        |v| { Ok::<_, DataFusionError>(v) }
270                    );
271                }
272                DataType::Utf8 => {
273                    handle_primitive_type!(
274                        builder,
275                        field,
276                        col,
277                        StringBuilder,
278                        String,
279                        row,
280                        idx,
281                        |v| { Ok::<_, DataFusionError>(v) }
282                    );
283                }
284                DataType::LargeUtf8 => {
285                    handle_primitive_type!(
286                        builder,
287                        field,
288                        col,
289                        LargeStringBuilder,
290                        String,
291                        row,
292                        idx,
293                        |v| { Ok::<_, DataFusionError>(v) }
294                    );
295                }
296                DataType::Decimal128(_precision, scale) => {
297                    handle_primitive_type!(
298                        builder,
299                        field,
300                        col,
301                        Decimal128Builder,
302                        String,
303                        row,
304                        idx,
305                        |v: String| {
306                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
307                                DataFusionError::Execution(format!(
308                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
309                                ))
310                            })?;
311                            big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
312                                DataFusionError::Execution(format!(
313                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
314                                ))
315                            })
316                        }
317                    );
318                }
319                DataType::Timestamp(TimeUnit::Second, None) => {
320                    handle_primitive_type!(
321                        builder,
322                        field,
323                        col,
324                        TimestampSecondBuilder,
325                        chrono::NaiveDateTime,
326                        row,
327                        idx,
328                        |v: chrono::NaiveDateTime| {
329                            let t = v.and_utc().timestamp();
330                            Ok::<_, DataFusionError>(t)
331                        }
332                    );
333                }
334                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
335                    handle_primitive_type!(
336                        builder,
337                        field,
338                        col,
339                        TimestampNanosecondBuilder,
340                        chrono::NaiveDateTime,
341                        row,
342                        idx,
343                        |v: chrono::NaiveDateTime| {
344                            v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
345                                DataFusionError::Execution(format!(
346                                    "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
347                                ))
348                            })
349                        }
350                    );
351                }
352                DataType::Date64 => {
353                    handle_primitive_type!(
354                        builder,
355                        field,
356                        col,
357                        Date64Builder,
358                        chrono::NaiveDateTime,
359                        row,
360                        idx,
361                        |v: chrono::NaiveDateTime| {
362                            Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
363                        }
364                    );
365                }
366                DataType::Boolean => {
367                    handle_primitive_type!(
368                        builder,
369                        field,
370                        col,
371                        BooleanBuilder,
372                        bool,
373                        row,
374                        idx,
375                        |v| { Ok::<_, DataFusionError>(v) }
376                    );
377                }
378                DataType::Binary => {
379                    handle_primitive_type!(
380                        builder,
381                        field,
382                        col,
383                        BinaryBuilder,
384                        Vec<u8>,
385                        row,
386                        idx,
387                        |v| { Ok::<_, DataFusionError>(v) }
388                    );
389                }
390                DataType::LargeBinary => {
391                    handle_primitive_type!(
392                        builder,
393                        field,
394                        col,
395                        LargeBinaryBuilder,
396                        Vec<u8>,
397                        row,
398                        idx,
399                        |v| { Ok::<_, DataFusionError>(v) }
400                    );
401                }
402                _ => {
403                    return Err(DataFusionError::NotImplemented(format!(
404                        "Unsupported data type {:?} for col: {:?}",
405                        field.data_type(),
406                        col
407                    )));
408                }
409            }
410        }
411    }
412
413    let projected_columns = array_builders
414        .into_iter()
415        .enumerate()
416        .filter(|(idx, _)| projections_contains(projection, *idx))
417        .map(|(_, mut builder)| builder.finish())
418        .collect::<Vec<ArrayRef>>();
419    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
420}