datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4    project_remote_schema, Connection, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5    RemoteType, Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9    make_builder, ArrayRef, Decimal128Builder, RecordBatch, StringBuilder,
10    TimestampNanosecondBuilder, TimestampSecondBuilder,
11};
12use datafusion::arrow::datatypes::SchemaRef;
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::StreamExt;
17use oracle::sql_type::OracleType as ColumnType;
18use oracle::{Connector, Row};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
22pub struct OracleConnectionOptions {
23    pub host: String,
24    pub port: u16,
25    pub username: String,
26    pub password: String,
27    pub service_name: String,
28}
29
30impl OracleConnectionOptions {
31    pub fn new(
32        host: impl Into<String>,
33        port: u16,
34        username: impl Into<String>,
35        password: impl Into<String>,
36        service_name: impl Into<String>,
37    ) -> Self {
38        Self {
39            host: host.into(),
40            port,
41            username: username.into(),
42            password: password.into(),
43            service_name: service_name.into(),
44        }
45    }
46}
47
48#[derive(Debug)]
49pub struct OraclePool {
50    pool: bb8::Pool<OracleConnectionManager>,
51}
52
53pub async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
54    let connect_string = format!(
55        "//{}:{}/{}",
56        options.host, options.port, options.service_name
57    );
58    let connector = Connector::new(
59        options.username.clone(),
60        options.password.clone(),
61        connect_string,
62    );
63    let _ = connector
64        .connect()
65        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
66    let manager = OracleConnectionManager::from_connector(connector);
67    let pool = bb8::Pool::builder()
68        .build(manager)
69        .await
70        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
71    Ok(OraclePool { pool })
72}
73
74#[async_trait::async_trait]
75impl Pool for OraclePool {
76    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
77        let conn = self.pool.get_owned().await.map_err(|e| {
78            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
79        })?;
80        Ok(Arc::new(OracleConnection { conn }))
81    }
82}
83
84#[derive(Debug)]
85pub struct OracleConnection {
86    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
87}
88
89#[async_trait::async_trait]
90impl Connection for OracleConnection {
91    async fn infer_schema(
92        &self,
93        sql: &str,
94        transform: Option<Arc<dyn Transform>>,
95    ) -> DFResult<(RemoteSchema, SchemaRef)> {
96        let row = self.conn.query_row(sql, &[]).map_err(|e| {
97            DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
98        })?;
99        let remote_schema = build_remote_schema(&row)?;
100        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
101        if let Some(transform) = transform {
102            let batch = rows_to_batch(&[row], arrow_schema, None)?;
103            let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
104            Ok((remote_schema, transformed_batch.schema()))
105        } else {
106            Ok((remote_schema, arrow_schema))
107        }
108    }
109
110    async fn query(
111        &self,
112        sql: String,
113        projection: Option<Vec<usize>>,
114    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
115        let result_set = self.conn.query(&sql, &[]).unwrap();
116        let mut stream = futures::stream::iter(result_set).chunks(2000).boxed();
117
118        let Some(first_chunk) = stream.next().await else {
119            return Err(DataFusionError::Execution(
120                "No data returned from oracle".to_string(),
121            ));
122        };
123        let first_chunk: Vec<Row> = first_chunk
124            .into_iter()
125            .collect::<Result<Vec<_>, _>>()
126            .map_err(|e| {
127                DataFusionError::Execution(
128                    format!("Failed to collect rows from oracle due to {e}",),
129                )
130            })?;
131        let Some(first_row) = first_chunk.first() else {
132            return Err(DataFusionError::Execution(
133                "No data returned from oracle".to_string(),
134            ));
135        };
136
137        let remote_schema = build_remote_schema(first_row)?;
138        let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
139        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
140        let first_chunk = rows_to_batch(
141            first_chunk.as_slice(),
142            arrow_schema.clone(),
143            projection.as_ref(),
144        )?;
145        let schema = first_chunk.schema();
146
147        let mut stream = stream.map(move |rows| {
148            let rows: Vec<Row> = rows
149                .into_iter()
150                .collect::<Result<Vec<_>, _>>()
151                .map_err(|e| {
152                    DataFusionError::Execution(format!(
153                        "Failed to collect rows from oracle due to {e}",
154                    ))
155                })?;
156            let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
157            Ok::<RecordBatch, DataFusionError>(batch)
158        });
159
160        let output_stream = async_stream::stream! {
161           yield Ok(first_chunk);
162           while let Some(batch) = stream.next().await {
163                yield batch
164           }
165        };
166
167        Ok((
168            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
169            projected_remote_schema,
170        ))
171    }
172}
173
174fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
175    match oracle_type {
176        ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
177        ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
178        ColumnType::Number(precision, scale) => {
179            Ok(RemoteType::Oracle(OracleType::Number(*precision, *scale)))
180        }
181        ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
182        ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
183        _ => Err(DataFusionError::NotImplemented(format!(
184            "Unsupported oracle type: {oracle_type:?}",
185        ))),
186    }
187}
188
189fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
190    let mut remote_fields = vec![];
191    for col in row.column_info() {
192        let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
193        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
194    }
195    Ok(RemoteSchema::new(remote_fields))
196}
197
198macro_rules! handle_primitive_type {
199    ($builder:expr, $field:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
200        let builder = $builder
201            .as_any_mut()
202            .downcast_mut::<$builder_ty>()
203            .unwrap_or_else(|| {
204                panic!(
205                    concat!(
206                        "Failed to downcast builder to ",
207                        stringify!($builder_ty),
208                        " for {:?}"
209                    ),
210                    $field
211                )
212            });
213        let v = $row
214            .get::<usize, Option<$value_ty>>($index)
215            .unwrap_or_else(|e| {
216                panic!(
217                    concat!(
218                        "Failed to get ",
219                        stringify!($value_ty),
220                        " value for {:?}: {:?}"
221                    ),
222                    $field, e
223                )
224            });
225
226        match v {
227            Some(v) => builder.append_value(v),
228            None => builder.append_null(),
229        }
230    }};
231}
232
233fn rows_to_batch(
234    rows: &[Row],
235    arrow_schema: SchemaRef,
236    projection: Option<&Vec<usize>>,
237) -> DFResult<RecordBatch> {
238    let projected_schema = project_schema(&arrow_schema, projection)?;
239    let mut array_builders = vec![];
240    for field in arrow_schema.fields() {
241        let builder = make_builder(field.data_type(), rows.len());
242        array_builders.push(builder);
243    }
244
245    for row in rows {
246        for (i, col) in row.column_info().iter().enumerate() {
247            let builder = &mut array_builders[i];
248            match col.oracle_type() {
249                ColumnType::Varchar2(_size) | ColumnType::Char(_size) => {
250                    handle_primitive_type!(builder, col, StringBuilder, String, row, i);
251                }
252                ColumnType::Number(_precision, scale) => {
253                    let builder = builder
254                        .as_any_mut()
255                        .downcast_mut::<Decimal128Builder>()
256                        .unwrap_or_else(|| {
257                            panic!("Failed to downcast builder to Decimal128Builder for {col:?}")
258                        });
259
260                    let v = row.get::<usize, Option<String>>(i).unwrap_or_else(|e| {
261                        panic!("Failed to get String value for {col:?}: {e:?}")
262                    });
263
264                    match v {
265                        Some(v) => {
266                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
267                                DataFusionError::Execution(format!(
268                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
269                                ))
270                            })?;
271                            let Some(v) = big_decimal_to_i128(&decimal, Some(*scale as u32)) else {
272                                return Err(DataFusionError::Execution(format!(
273                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
274                                )));
275                            };
276                            builder.append_value(v);
277                        }
278                        None => builder.append_null(),
279                    }
280                }
281                ColumnType::Date => {
282                    let builder = builder
283                        .as_any_mut()
284                        .downcast_mut::<TimestampSecondBuilder>()
285                        .unwrap_or_else(|| {
286                            panic!(
287                                "Failed to downcast builder to TimestampSecondBuilder for {col:?}"
288                            )
289                        });
290                    let v = row
291                        .get::<usize, Option<chrono::NaiveDateTime>>(i)
292                        .unwrap_or_else(|e| {
293                            panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
294                        });
295
296                    match v {
297                        Some(v) => {
298                            let t = v.and_utc().timestamp();
299                            builder.append_value(t);
300                        }
301                        None => builder.append_null(),
302                    }
303                }
304                ColumnType::Timestamp(_) => {
305                    let builder = builder
306                        .as_any_mut()
307                        .downcast_mut::<TimestampNanosecondBuilder>()
308                        .unwrap_or_else(|| {
309                            panic!("Failed to downcast builder to TimestampNanosecondBuilder for {col:?}")
310                        });
311                    let v = row
312                        .get::<usize, Option<chrono::NaiveDateTime>>(i)
313                        .unwrap_or_else(|e| {
314                            panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
315                        });
316
317                    match v {
318                        Some(v) => {
319                            let t = v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
320                                DataFusionError::Execution(format!(
321                                "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
322                            ))
323                            })?;
324                            builder.append_value(t);
325                        }
326                        None => builder.append_null(),
327                    }
328                }
329                _ => {
330                    return Err(DataFusionError::NotImplemented(format!(
331                        "Unsupported oracle type: {col:?}",
332                    )))
333                }
334            }
335        }
336    }
337
338    let projected_columns = array_builders
339        .into_iter()
340        .enumerate()
341        .filter(|(idx, _)| projections_contains(projection, *idx))
342        .map(|(_, mut builder)| builder.finish())
343        .collect::<Vec<ArrayRef>>();
344    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
345}