datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4    Connection, DFResult, OracleType, Pool, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteType,
5    Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9    make_builder, ArrayRef, Decimal128Builder, RecordBatch, StringBuilder,
10    TimestampNanosecondBuilder, TimestampSecondBuilder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::StreamExt;
17use oracle::sql_type::OracleType as ColumnType;
18use oracle::{Connector, Row};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
22pub struct OracleConnectionOptions {
23    pub host: String,
24    pub port: u16,
25    pub username: String,
26    pub password: String,
27    pub service_name: String,
28}
29
30impl OracleConnectionOptions {
31    pub fn new(
32        host: impl Into<String>,
33        port: u16,
34        username: impl Into<String>,
35        password: impl Into<String>,
36        service_name: impl Into<String>,
37    ) -> Self {
38        Self {
39            host: host.into(),
40            port,
41            username: username.into(),
42            password: password.into(),
43            service_name: service_name.into(),
44        }
45    }
46}
47
48#[derive(Debug)]
49pub struct OraclePool {
50    pool: bb8::Pool<OracleConnectionManager>,
51}
52
53pub async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
54    let connect_string = format!(
55        "//{}:{}/{}",
56        options.host, options.port, options.service_name
57    );
58    let connector = Connector::new(
59        options.username.clone(),
60        options.password.clone(),
61        connect_string,
62    );
63    let _ = connector
64        .connect()
65        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
66    let manager = OracleConnectionManager::from_connector(connector);
67    let pool = bb8::Pool::builder()
68        .build(manager)
69        .await
70        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
71    Ok(OraclePool { pool })
72}
73
74#[async_trait::async_trait]
75impl Pool for OraclePool {
76    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
77        let conn = self.pool.get_owned().await.map_err(|e| {
78            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
79        })?;
80        Ok(Arc::new(OracleConnection { conn }))
81    }
82}
83
84#[derive(Debug)]
85pub struct OracleConnection {
86    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
87}
88
89#[async_trait::async_trait]
90impl Connection for OracleConnection {
91    async fn infer_schema(
92        &self,
93        sql: &str,
94        transform: Option<Arc<dyn Transform>>,
95    ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
96        let row = self.conn.query_row(sql, &[]).map_err(|e| {
97            DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
98        })?;
99        let remote_schema = Arc::new(build_remote_schema(&row)?);
100        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
101        if let Some(transform) = transform {
102            let batch = rows_to_batch(&[row], &arrow_schema, None)?;
103            let transformed_batch = transform_batch(
104                batch,
105                transform.as_ref(),
106                &arrow_schema,
107                None,
108                Some(&remote_schema),
109            )?;
110            Ok((remote_schema, transformed_batch.schema()))
111        } else {
112            Ok((remote_schema, arrow_schema))
113        }
114    }
115
116    async fn query(
117        &self,
118        sql: String,
119        table_schema: SchemaRef,
120        projection: Option<Vec<usize>>,
121    ) -> DFResult<SendableRecordBatchStream> {
122        let projected_schema = project_schema(&table_schema, projection.as_ref())?;
123        let result_set = self.conn.query(&sql, &[]).map_err(|e| {
124            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
125        })?;
126        let stream = futures::stream::iter(result_set).chunks(2000).boxed();
127
128        let stream = stream.map(move |rows| {
129            let rows: Vec<Row> = rows
130                .into_iter()
131                .collect::<Result<Vec<_>, _>>()
132                .map_err(|e| {
133                    DataFusionError::Execution(format!(
134                        "Failed to collect rows from oracle due to {e}",
135                    ))
136                })?;
137            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
138        });
139
140        Ok(Box::pin(RecordBatchStreamAdapter::new(
141            projected_schema,
142            stream,
143        )))
144    }
145}
146
147fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
148    match oracle_type {
149        ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
150        ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
151        ColumnType::Number(precision, scale) => {
152            Ok(RemoteType::Oracle(OracleType::Number(*precision, *scale)))
153        }
154        ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
155        ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
156        _ => Err(DataFusionError::NotImplemented(format!(
157            "Unsupported oracle type: {oracle_type:?}",
158        ))),
159    }
160}
161
162fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
163    let mut remote_fields = vec![];
164    for col in row.column_info() {
165        let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
166        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
167    }
168    Ok(RemoteSchema::new(remote_fields))
169}
170
171macro_rules! handle_primitive_type {
172    ($builder:expr, $field:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
173        let builder = $builder
174            .as_any_mut()
175            .downcast_mut::<$builder_ty>()
176            .unwrap_or_else(|| {
177                panic!(
178                    concat!(
179                        "Failed to downcast builder to ",
180                        stringify!($builder_ty),
181                        " for {:?}"
182                    ),
183                    $field
184                )
185            });
186        let v = $row
187            .get::<usize, Option<$value_ty>>($index)
188            .unwrap_or_else(|e| {
189                panic!(
190                    concat!(
191                        "Failed to get ",
192                        stringify!($value_ty),
193                        " value for {:?}: {:?}"
194                    ),
195                    $field, e
196                )
197            });
198
199        match v {
200            Some(v) => builder.append_value(v),
201            None => builder.append_null(),
202        }
203    }};
204}
205
206fn rows_to_batch(
207    rows: &[Row],
208    table_schema: &SchemaRef,
209    projection: Option<&Vec<usize>>,
210) -> DFResult<RecordBatch> {
211    let projected_schema = project_schema(table_schema, projection)?;
212    let mut array_builders = vec![];
213    for field in table_schema.fields() {
214        let builder = make_builder(field.data_type(), rows.len());
215        array_builders.push(builder);
216    }
217
218    for row in rows {
219        for (idx, field) in table_schema.fields.iter().enumerate() {
220            if !projections_contains(projection, idx) {
221                continue;
222            }
223            let builder = &mut array_builders[idx];
224            let col = row.column_info().get(idx);
225            match field.data_type() {
226                DataType::Utf8 => {
227                    handle_primitive_type!(builder, col, StringBuilder, String, row, idx);
228                }
229                DataType::Decimal128(_precision, scale) => {
230                    let builder = builder
231                        .as_any_mut()
232                        .downcast_mut::<Decimal128Builder>()
233                        .unwrap_or_else(|| {
234                            panic!("Failed to downcast builder to Decimal128Builder for {col:?}")
235                        });
236
237                    let v = row.get::<usize, Option<String>>(idx).unwrap_or_else(|e| {
238                        panic!("Failed to get String value for {col:?}: {e:?}")
239                    });
240
241                    match v {
242                        Some(v) => {
243                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
244                                DataFusionError::Execution(format!(
245                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
246                                ))
247                            })?;
248                            let Some(v) = big_decimal_to_i128(&decimal, Some(*scale as u32)) else {
249                                return Err(DataFusionError::Execution(format!(
250                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
251                                )));
252                            };
253                            builder.append_value(v);
254                        }
255                        None => builder.append_null(),
256                    }
257                }
258                DataType::Timestamp(TimeUnit::Second, None) => {
259                    let builder = builder
260                        .as_any_mut()
261                        .downcast_mut::<TimestampSecondBuilder>()
262                        .unwrap_or_else(|| {
263                            panic!(
264                                "Failed to downcast builder to TimestampSecondBuilder for {col:?}"
265                            )
266                        });
267                    let v = row
268                        .get::<usize, Option<chrono::NaiveDateTime>>(idx)
269                        .unwrap_or_else(|e| {
270                            panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
271                        });
272
273                    match v {
274                        Some(v) => {
275                            let t = v.and_utc().timestamp();
276                            builder.append_value(t);
277                        }
278                        None => builder.append_null(),
279                    }
280                }
281                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
282                    let builder = builder
283                                .as_any_mut()
284                                .downcast_mut::<TimestampNanosecondBuilder>()
285                                .unwrap_or_else(|| {
286                                    panic!("Failed to downcast builder to TimestampNanosecondBuilder for {col:?}")
287                                });
288                    let v = row
289                        .get::<usize, Option<chrono::NaiveDateTime>>(idx)
290                        .unwrap_or_else(|e| {
291                            panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
292                        });
293
294                    match v {
295                        Some(v) => {
296                            let t = v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
297                                        DataFusionError::Execution(format!(
298                                        "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
299                                    ))
300                                    })?;
301                            builder.append_value(t);
302                        }
303                        None => builder.append_null(),
304                    }
305                }
306                _ => {
307                    return Err(DataFusionError::NotImplemented(format!(
308                        "Unsupported data type {:?} for col: {:?}",
309                        field.data_type(),
310                        col
311                    )));
312                }
313            }
314        }
315    }
316
317    let projected_columns = array_builders
318        .into_iter()
319        .enumerate()
320        .filter(|(idx, _)| projections_contains(projection, *idx))
321        .map(|(_, mut builder)| builder.finish())
322        .collect::<Vec<ArrayRef>>();
323    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
324}