datafusion_remote_table/connection/
oracle.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4    Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5    RemoteSchemaRef, RemoteType, Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9    make_builder, ArrayRef, Decimal128Builder, RecordBatch, StringBuilder,
10    TimestampNanosecondBuilder, TimestampSecondBuilder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::StreamExt;
17use oracle::sql_type::OracleType as ColumnType;
18use oracle::{Connector, Row};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
22pub struct OracleConnectionOptions {
23    pub(crate) host: String,
24    pub(crate) port: u16,
25    pub(crate) username: String,
26    pub(crate) password: String,
27    pub(crate) service_name: String,
28    pub(crate) chunk_size: Option<usize>,
29}
30
31impl OracleConnectionOptions {
32    pub fn new(
33        host: impl Into<String>,
34        port: u16,
35        username: impl Into<String>,
36        password: impl Into<String>,
37        service_name: impl Into<String>,
38    ) -> Self {
39        Self {
40            host: host.into(),
41            port,
42            username: username.into(),
43            password: password.into(),
44            service_name: service_name.into(),
45            chunk_size: None,
46        }
47    }
48}
49
50#[derive(Debug)]
51pub struct OraclePool {
52    pool: bb8::Pool<OracleConnectionManager>,
53}
54
55pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
56    let connect_string = format!(
57        "//{}:{}/{}",
58        options.host, options.port, options.service_name
59    );
60    let connector = Connector::new(
61        options.username.clone(),
62        options.password.clone(),
63        connect_string,
64    );
65    let _ = connector
66        .connect()
67        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
68    let manager = OracleConnectionManager::from_connector(connector);
69    let pool = bb8::Pool::builder()
70        .build(manager)
71        .await
72        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
73    Ok(OraclePool { pool })
74}
75
76#[async_trait::async_trait]
77impl Pool for OraclePool {
78    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
79        let conn = self.pool.get_owned().await.map_err(|e| {
80            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
81        })?;
82        Ok(Arc::new(OracleConnection { conn }))
83    }
84}
85
86#[derive(Debug)]
87pub struct OracleConnection {
88    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
89}
90
91#[async_trait::async_trait]
92impl Connection for OracleConnection {
93    async fn infer_schema(
94        &self,
95        sql: &str,
96        transform: Option<Arc<dyn Transform>>,
97    ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
98        let row = self.conn.query_row(sql, &[]).map_err(|e| {
99            DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
100        })?;
101        let remote_schema = Arc::new(build_remote_schema(&row)?);
102        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
103        if let Some(transform) = transform {
104            let batch = rows_to_batch(&[row], &arrow_schema, None)?;
105            let transformed_batch = transform_batch(
106                batch,
107                transform.as_ref(),
108                &arrow_schema,
109                None,
110                Some(&remote_schema),
111            )?;
112            Ok((remote_schema, transformed_batch.schema()))
113        } else {
114            Ok((remote_schema, arrow_schema))
115        }
116    }
117
118    async fn query(
119        &self,
120        conn_options: &ConnectionOptions,
121        sql: &str,
122        table_schema: SchemaRef,
123        projection: Option<&Vec<usize>>,
124    ) -> DFResult<SendableRecordBatchStream> {
125        let projected_schema = project_schema(&table_schema, projection)?;
126        let projection = projection.cloned();
127        let chunk_size = conn_options.chunk_size();
128        let result_set = self.conn.query(sql, &[]).map_err(|e| {
129            DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
130        })?;
131        let stream = futures::stream::iter(result_set)
132            .chunks(chunk_size.unwrap_or(2048))
133            .boxed();
134
135        let stream = stream.map(move |rows| {
136            let rows: Vec<Row> = rows
137                .into_iter()
138                .collect::<Result<Vec<_>, _>>()
139                .map_err(|e| {
140                    DataFusionError::Execution(format!(
141                        "Failed to collect rows from oracle due to {e}",
142                    ))
143                })?;
144            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
145        });
146
147        Ok(Box::pin(RecordBatchStreamAdapter::new(
148            projected_schema,
149            stream,
150        )))
151    }
152}
153
154fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
155    match oracle_type {
156        ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
157        ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
158        ColumnType::Number(precision, scale) => {
159            Ok(RemoteType::Oracle(OracleType::Number(*precision, *scale)))
160        }
161        ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
162        ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
163        _ => Err(DataFusionError::NotImplemented(format!(
164            "Unsupported oracle type: {oracle_type:?}",
165        ))),
166    }
167}
168
169fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
170    let mut remote_fields = vec![];
171    for col in row.column_info() {
172        let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
173        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
174    }
175    Ok(RemoteSchema::new(remote_fields))
176}
177
178macro_rules! handle_primitive_type {
179    ($builder:expr, $field:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
180        let builder = $builder
181            .as_any_mut()
182            .downcast_mut::<$builder_ty>()
183            .unwrap_or_else(|| {
184                panic!(
185                    concat!(
186                        "Failed to downcast builder to ",
187                        stringify!($builder_ty),
188                        " for {:?}"
189                    ),
190                    $field
191                )
192            });
193        let v = $row
194            .get::<usize, Option<$value_ty>>($index)
195            .unwrap_or_else(|e| {
196                panic!(
197                    concat!(
198                        "Failed to get ",
199                        stringify!($value_ty),
200                        " value for {:?}: {:?}"
201                    ),
202                    $field, e
203                )
204            });
205
206        match v {
207            Some(v) => builder.append_value(v),
208            None => builder.append_null(),
209        }
210    }};
211}
212
213fn rows_to_batch(
214    rows: &[Row],
215    table_schema: &SchemaRef,
216    projection: Option<&Vec<usize>>,
217) -> DFResult<RecordBatch> {
218    let projected_schema = project_schema(table_schema, projection)?;
219    let mut array_builders = vec![];
220    for field in table_schema.fields() {
221        let builder = make_builder(field.data_type(), rows.len());
222        array_builders.push(builder);
223    }
224
225    for row in rows {
226        for (idx, field) in table_schema.fields.iter().enumerate() {
227            if !projections_contains(projection, idx) {
228                continue;
229            }
230            let builder = &mut array_builders[idx];
231            let col = row.column_info().get(idx);
232            match field.data_type() {
233                DataType::Utf8 => {
234                    handle_primitive_type!(builder, col, StringBuilder, String, row, idx);
235                }
236                DataType::Decimal128(_precision, scale) => {
237                    let builder = builder
238                        .as_any_mut()
239                        .downcast_mut::<Decimal128Builder>()
240                        .unwrap_or_else(|| {
241                            panic!("Failed to downcast builder to Decimal128Builder for {col:?}")
242                        });
243
244                    let v = row.get::<usize, Option<String>>(idx).unwrap_or_else(|e| {
245                        panic!("Failed to get String value for {col:?}: {e:?}")
246                    });
247
248                    match v {
249                        Some(v) => {
250                            let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
251                                DataFusionError::Execution(format!(
252                                    "Failed to parse BigDecimal from {v:?}: {e:?}",
253                                ))
254                            })?;
255                            let Some(v) = big_decimal_to_i128(&decimal, Some(*scale as u32)) else {
256                                return Err(DataFusionError::Execution(format!(
257                                    "Failed to convert BigDecimal to i128 for {decimal:?}",
258                                )));
259                            };
260                            builder.append_value(v);
261                        }
262                        None => builder.append_null(),
263                    }
264                }
265                DataType::Timestamp(TimeUnit::Second, None) => {
266                    let builder = builder
267                        .as_any_mut()
268                        .downcast_mut::<TimestampSecondBuilder>()
269                        .unwrap_or_else(|| {
270                            panic!(
271                                "Failed to downcast builder to TimestampSecondBuilder for {col:?}"
272                            )
273                        });
274                    let v = row
275                        .get::<usize, Option<chrono::NaiveDateTime>>(idx)
276                        .unwrap_or_else(|e| {
277                            panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
278                        });
279
280                    match v {
281                        Some(v) => {
282                            let t = v.and_utc().timestamp();
283                            builder.append_value(t);
284                        }
285                        None => builder.append_null(),
286                    }
287                }
288                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
289                    let builder = builder
290                                .as_any_mut()
291                                .downcast_mut::<TimestampNanosecondBuilder>()
292                                .unwrap_or_else(|| {
293                                    panic!("Failed to downcast builder to TimestampNanosecondBuilder for {col:?}")
294                                });
295                    let v = row
296                        .get::<usize, Option<chrono::NaiveDateTime>>(idx)
297                        .unwrap_or_else(|e| {
298                            panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
299                        });
300
301                    match v {
302                        Some(v) => {
303                            let t = v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
304                                        DataFusionError::Execution(format!(
305                                        "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
306                                    ))
307                                    })?;
308                            builder.append_value(t);
309                        }
310                        None => builder.append_null(),
311                    }
312                }
313                _ => {
314                    return Err(DataFusionError::NotImplemented(format!(
315                        "Unsupported data type {:?} for col: {:?}",
316                        field.data_type(),
317                        col
318                    )));
319                }
320            }
321        }
322    }
323
324    let projected_columns = array_builders
325        .into_iter()
326        .enumerate()
327        .filter(|(idx, _)| projections_contains(projection, *idx))
328        .map(|(_, mut builder)| builder.finish())
329        .collect::<Vec<ArrayRef>>();
330    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
331}