1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
10 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, TimestampNanosecondBuilder,
11 TimestampSecondBuilder, make_builder,
12};
13use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
14use datafusion::common::{DataFusionError, project_schema};
15use datafusion::execution::SendableRecordBatchStream;
16use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use log::debug;
21use oracle::sql_type::OracleType as ColumnType;
22use oracle::{Connector, Row};
23use std::any::Any;
24use std::sync::Arc;
25
26#[derive(Debug, Clone, With, Getters)]
27pub struct OracleConnectionOptions {
28 pub(crate) host: String,
29 pub(crate) port: u16,
30 pub(crate) username: String,
31 pub(crate) password: String,
32 pub(crate) service_name: String,
33 pub(crate) pool_max_size: usize,
34 pub(crate) stream_chunk_size: usize,
35}
36
37impl OracleConnectionOptions {
38 pub fn new(
39 host: impl Into<String>,
40 port: u16,
41 username: impl Into<String>,
42 password: impl Into<String>,
43 service_name: impl Into<String>,
44 ) -> Self {
45 Self {
46 host: host.into(),
47 port,
48 username: username.into(),
49 password: password.into(),
50 service_name: service_name.into(),
51 pool_max_size: 10,
52 stream_chunk_size: 2048,
53 }
54 }
55}
56
57impl From<OracleConnectionOptions> for ConnectionOptions {
58 fn from(options: OracleConnectionOptions) -> Self {
59 ConnectionOptions::Oracle(options)
60 }
61}
62
63#[derive(Debug)]
64pub struct OraclePool {
65 pool: bb8::Pool<OracleConnectionManager>,
66}
67
68pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
69 let connect_string = format!(
70 "//{}:{}/{}",
71 options.host, options.port, options.service_name
72 );
73 let connector = Connector::new(
74 options.username.clone(),
75 options.password.clone(),
76 connect_string,
77 );
78 let _ = connector
79 .connect()
80 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
81 let manager = OracleConnectionManager::from_connector(connector);
82 let pool = bb8::Pool::builder()
83 .max_size(options.pool_max_size as u32)
84 .build(manager)
85 .await
86 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
87 Ok(OraclePool { pool })
88}
89
90#[async_trait::async_trait]
91impl Pool for OraclePool {
92 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
93 let conn = self.pool.get_owned().await.map_err(|e| {
94 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
95 })?;
96 Ok(Arc::new(OracleConnection { conn }))
97 }
98}
99
100#[derive(Debug)]
101pub struct OracleConnection {
102 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
103}
104
105#[async_trait::async_trait]
106impl Connection for OracleConnection {
107 fn as_any(&self) -> &dyn Any {
108 self
109 }
110
111 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
112 let sql = RemoteDbType::Oracle.query_limit_1(sql);
113 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
114 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
115 })?;
116 let remote_schema = Arc::new(build_remote_schema(&row)?);
117 Ok(remote_schema)
118 }
119
120 async fn query(
121 &self,
122 conn_options: &ConnectionOptions,
123 sql: &str,
124 table_schema: SchemaRef,
125 projection: Option<&Vec<usize>>,
126 unparsed_filters: &[String],
127 limit: Option<usize>,
128 ) -> DFResult<SendableRecordBatchStream> {
129 let projected_schema = project_schema(&table_schema, projection)?;
130
131 let sql = RemoteDbType::Oracle.rewrite_query(sql, unparsed_filters, limit);
132 debug!("[remote-table] executing oracle query: {sql}");
133
134 let projection = projection.cloned();
135 let chunk_size = conn_options.stream_chunk_size();
136 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
137 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
138 })?;
139 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
140
141 let stream = stream.map(move |rows| {
142 let rows: Vec<Row> = rows
143 .into_iter()
144 .collect::<Result<Vec<_>, _>>()
145 .map_err(|e| {
146 DataFusionError::Execution(format!(
147 "Failed to collect rows from oracle due to {e}",
148 ))
149 })?;
150 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
151 });
152
153 Ok(Box::pin(RecordBatchStreamAdapter::new(
154 projected_schema,
155 stream,
156 )))
157 }
158}
159
160fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
161 match oracle_type {
162 ColumnType::Number(precision, scale) => {
163 let precision = if *precision == 0 { 38 } else { *precision };
165 let scale = if *scale == -127 { 0 } else { *scale };
166 Ok(OracleType::Number(precision, scale))
167 }
168 ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
169 ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
170 ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
171 ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
172 ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
173 ColumnType::Char(size) => Ok(OracleType::Char(*size)),
174 ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
175 ColumnType::Long => Ok(OracleType::Long),
176 ColumnType::CLOB => Ok(OracleType::Clob),
177 ColumnType::NCLOB => Ok(OracleType::NClob),
178 ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
179 ColumnType::LongRaw => Ok(OracleType::LongRaw),
180 ColumnType::BLOB => Ok(OracleType::Blob),
181 ColumnType::Date => Ok(OracleType::Date),
182 ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
183 ColumnType::Boolean => Ok(OracleType::Boolean),
184 _ => Err(DataFusionError::NotImplemented(format!(
185 "Unsupported oracle type: {oracle_type:?}",
186 ))),
187 }
188}
189
190fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
191 let mut remote_fields = vec![];
192 for col in row.column_info() {
193 let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
194 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
195 }
196 Ok(RemoteSchema::new(remote_fields))
197}
198
199macro_rules! handle_primitive_type {
200 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
201 let builder = $builder
202 .as_any_mut()
203 .downcast_mut::<$builder_ty>()
204 .unwrap_or_else(|| {
205 panic!(
206 "Failed to downcast builder to {} for {:?} and {:?}",
207 stringify!($builder_ty),
208 $field,
209 $col
210 )
211 });
212 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
213 DataFusionError::Execution(format!(
214 "Failed to get {} value for {:?} and {:?}: {e:?}",
215 stringify!($value_ty),
216 $field,
217 $col
218 ))
219 })?;
220
221 match v {
222 Some(v) => builder.append_value($convert(v)?),
223 None => builder.append_null(),
224 }
225 }};
226}
227
228fn rows_to_batch(
229 rows: &[Row],
230 table_schema: &SchemaRef,
231 projection: Option<&Vec<usize>>,
232) -> DFResult<RecordBatch> {
233 let projected_schema = project_schema(table_schema, projection)?;
234 let mut array_builders = vec![];
235 for field in table_schema.fields() {
236 let builder = make_builder(field.data_type(), rows.len());
237 array_builders.push(builder);
238 }
239
240 for row in rows {
241 for (idx, field) in table_schema.fields.iter().enumerate() {
242 if !projections_contains(projection, idx) {
243 continue;
244 }
245 let builder = &mut array_builders[idx];
246 let col = row.column_info().get(idx);
247 match field.data_type() {
248 DataType::Int16 => {
249 handle_primitive_type!(
250 builder,
251 field,
252 col,
253 Int16Builder,
254 i16,
255 row,
256 idx,
257 just_return
258 );
259 }
260 DataType::Int32 => {
261 handle_primitive_type!(
262 builder,
263 field,
264 col,
265 Int32Builder,
266 i32,
267 row,
268 idx,
269 just_return
270 );
271 }
272 DataType::Int64 => {
273 handle_primitive_type!(
274 builder,
275 field,
276 col,
277 Int64Builder,
278 i64,
279 row,
280 idx,
281 just_return
282 );
283 }
284 DataType::Float32 => {
285 handle_primitive_type!(
286 builder,
287 field,
288 col,
289 Float32Builder,
290 f32,
291 row,
292 idx,
293 just_return
294 );
295 }
296 DataType::Float64 => {
297 handle_primitive_type!(
298 builder,
299 field,
300 col,
301 Float64Builder,
302 f64,
303 row,
304 idx,
305 just_return
306 );
307 }
308 DataType::Utf8 => {
309 handle_primitive_type!(
310 builder,
311 field,
312 col,
313 StringBuilder,
314 String,
315 row,
316 idx,
317 just_return
318 );
319 }
320 DataType::LargeUtf8 => {
321 handle_primitive_type!(
322 builder,
323 field,
324 col,
325 LargeStringBuilder,
326 String,
327 row,
328 idx,
329 just_return
330 );
331 }
332 DataType::Decimal128(_precision, scale) => {
333 handle_primitive_type!(
334 builder,
335 field,
336 col,
337 Decimal128Builder,
338 String,
339 row,
340 idx,
341 |v: String| {
342 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
343 DataFusionError::Execution(format!(
344 "Failed to parse BigDecimal from {v:?}: {e:?}",
345 ))
346 })?;
347 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
348 DataFusionError::Execution(format!(
349 "Failed to convert BigDecimal to i128 for {decimal:?}",
350 ))
351 })
352 }
353 );
354 }
355 DataType::Timestamp(TimeUnit::Second, None) => {
356 handle_primitive_type!(
357 builder,
358 field,
359 col,
360 TimestampSecondBuilder,
361 chrono::NaiveDateTime,
362 row,
363 idx,
364 |v: chrono::NaiveDateTime| {
365 let t = v.and_utc().timestamp();
366 Ok::<_, DataFusionError>(t)
367 }
368 );
369 }
370 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
371 handle_primitive_type!(
372 builder,
373 field,
374 col,
375 TimestampNanosecondBuilder,
376 chrono::NaiveDateTime,
377 row,
378 idx,
379 |v: chrono::NaiveDateTime| {
380 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
381 DataFusionError::Execution(format!(
382 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
383 ))
384 })
385 }
386 );
387 }
388 DataType::Date64 => {
389 handle_primitive_type!(
390 builder,
391 field,
392 col,
393 Date64Builder,
394 chrono::NaiveDateTime,
395 row,
396 idx,
397 |v: chrono::NaiveDateTime| {
398 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
399 }
400 );
401 }
402 DataType::Boolean => {
403 handle_primitive_type!(
404 builder,
405 field,
406 col,
407 BooleanBuilder,
408 bool,
409 row,
410 idx,
411 just_return
412 );
413 }
414 DataType::Binary => {
415 handle_primitive_type!(
416 builder,
417 field,
418 col,
419 BinaryBuilder,
420 Vec<u8>,
421 row,
422 idx,
423 just_return
424 );
425 }
426 DataType::LargeBinary => {
427 handle_primitive_type!(
428 builder,
429 field,
430 col,
431 LargeBinaryBuilder,
432 Vec<u8>,
433 row,
434 idx,
435 just_return
436 );
437 }
438 _ => {
439 return Err(DataFusionError::NotImplemented(format!(
440 "Unsupported data type {:?} for col: {:?}",
441 field.data_type(),
442 col
443 )));
444 }
445 }
446 }
447 }
448
449 let projected_columns = array_builders
450 .into_iter()
451 .enumerate()
452 .filter(|(idx, _)| projections_contains(projection, *idx))
453 .map(|(_, mut builder)| builder.finish())
454 .collect::<Vec<ArrayRef>>();
455 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
456 Ok(RecordBatch::try_new_with_options(
457 projected_schema,
458 projected_columns,
459 &options,
460 )?)
461}