1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
10 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, TimestampNanosecondBuilder,
11 TimestampSecondBuilder, make_builder,
12};
13use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
14use datafusion::common::{DataFusionError, project_schema};
15use datafusion::execution::SendableRecordBatchStream;
16use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use log::debug;
21use oracle::sql_type::OracleType as ColumnType;
22use oracle::{Connector, Row};
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct OracleConnectionOptions {
27 pub(crate) host: String,
28 pub(crate) port: u16,
29 pub(crate) username: String,
30 pub(crate) password: String,
31 pub(crate) service_name: String,
32 pub(crate) pool_max_size: usize,
33 pub(crate) stream_chunk_size: usize,
34}
35
36impl OracleConnectionOptions {
37 pub fn new(
38 host: impl Into<String>,
39 port: u16,
40 username: impl Into<String>,
41 password: impl Into<String>,
42 service_name: impl Into<String>,
43 ) -> Self {
44 Self {
45 host: host.into(),
46 port,
47 username: username.into(),
48 password: password.into(),
49 service_name: service_name.into(),
50 pool_max_size: 10,
51 stream_chunk_size: 2048,
52 }
53 }
54}
55
56#[derive(Debug)]
57pub struct OraclePool {
58 pool: bb8::Pool<OracleConnectionManager>,
59}
60
61pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
62 let connect_string = format!(
63 "//{}:{}/{}",
64 options.host, options.port, options.service_name
65 );
66 let connector = Connector::new(
67 options.username.clone(),
68 options.password.clone(),
69 connect_string,
70 );
71 let _ = connector
72 .connect()
73 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
74 let manager = OracleConnectionManager::from_connector(connector);
75 let pool = bb8::Pool::builder()
76 .max_size(options.pool_max_size as u32)
77 .build(manager)
78 .await
79 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
80 Ok(OraclePool { pool })
81}
82
83#[async_trait::async_trait]
84impl Pool for OraclePool {
85 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
86 let conn = self.pool.get_owned().await.map_err(|e| {
87 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
88 })?;
89 Ok(Arc::new(OracleConnection { conn }))
90 }
91}
92
93#[derive(Debug)]
94pub struct OracleConnection {
95 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
96}
97
98#[async_trait::async_trait]
99impl Connection for OracleConnection {
100 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
101 let sql = RemoteDbType::Oracle.query_limit_1(sql);
102 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
103 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
104 })?;
105 let remote_schema = Arc::new(build_remote_schema(&row)?);
106 Ok(remote_schema)
107 }
108
109 async fn query(
110 &self,
111 conn_options: &ConnectionOptions,
112 sql: &str,
113 table_schema: SchemaRef,
114 projection: Option<&Vec<usize>>,
115 unparsed_filters: &[String],
116 limit: Option<usize>,
117 ) -> DFResult<SendableRecordBatchStream> {
118 let projected_schema = project_schema(&table_schema, projection)?;
119
120 let sql = RemoteDbType::Oracle.rewrite_query(sql, unparsed_filters, limit);
121 debug!("[remote-table] executing oracle query: {sql}");
122
123 let projection = projection.cloned();
124 let chunk_size = conn_options.stream_chunk_size();
125 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
126 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
127 })?;
128 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
129
130 let stream = stream.map(move |rows| {
131 let rows: Vec<Row> = rows
132 .into_iter()
133 .collect::<Result<Vec<_>, _>>()
134 .map_err(|e| {
135 DataFusionError::Execution(format!(
136 "Failed to collect rows from oracle due to {e}",
137 ))
138 })?;
139 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
140 });
141
142 Ok(Box::pin(RecordBatchStreamAdapter::new(
143 projected_schema,
144 stream,
145 )))
146 }
147}
148
149fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
150 match oracle_type {
151 ColumnType::Number(precision, scale) => {
152 let precision = if *precision == 0 { 38 } else { *precision };
154 let scale = if *scale == -127 { 0 } else { *scale };
155 Ok(OracleType::Number(precision, scale))
156 }
157 ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
158 ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
159 ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
160 ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
161 ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
162 ColumnType::Char(size) => Ok(OracleType::Char(*size)),
163 ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
164 ColumnType::Long => Ok(OracleType::Long),
165 ColumnType::CLOB => Ok(OracleType::Clob),
166 ColumnType::NCLOB => Ok(OracleType::NClob),
167 ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
168 ColumnType::LongRaw => Ok(OracleType::LongRaw),
169 ColumnType::BLOB => Ok(OracleType::Blob),
170 ColumnType::Date => Ok(OracleType::Date),
171 ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
172 ColumnType::Boolean => Ok(OracleType::Boolean),
173 _ => Err(DataFusionError::NotImplemented(format!(
174 "Unsupported oracle type: {oracle_type:?}",
175 ))),
176 }
177}
178
179fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
180 let mut remote_fields = vec![];
181 for col in row.column_info() {
182 let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
183 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
184 }
185 Ok(RemoteSchema::new(remote_fields))
186}
187
188macro_rules! handle_primitive_type {
189 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
190 let builder = $builder
191 .as_any_mut()
192 .downcast_mut::<$builder_ty>()
193 .unwrap_or_else(|| {
194 panic!(
195 "Failed to downcast builder to {} for {:?} and {:?}",
196 stringify!($builder_ty),
197 $field,
198 $col
199 )
200 });
201 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
202 DataFusionError::Execution(format!(
203 "Failed to get {} value for {:?} and {:?}: {e:?}",
204 stringify!($value_ty),
205 $field,
206 $col
207 ))
208 })?;
209
210 match v {
211 Some(v) => builder.append_value($convert(v)?),
212 None => builder.append_null(),
213 }
214 }};
215}
216
217fn rows_to_batch(
218 rows: &[Row],
219 table_schema: &SchemaRef,
220 projection: Option<&Vec<usize>>,
221) -> DFResult<RecordBatch> {
222 let projected_schema = project_schema(table_schema, projection)?;
223 let mut array_builders = vec![];
224 for field in table_schema.fields() {
225 let builder = make_builder(field.data_type(), rows.len());
226 array_builders.push(builder);
227 }
228
229 for row in rows {
230 for (idx, field) in table_schema.fields.iter().enumerate() {
231 if !projections_contains(projection, idx) {
232 continue;
233 }
234 let builder = &mut array_builders[idx];
235 let col = row.column_info().get(idx);
236 match field.data_type() {
237 DataType::Int16 => {
238 handle_primitive_type!(
239 builder,
240 field,
241 col,
242 Int16Builder,
243 i16,
244 row,
245 idx,
246 just_return
247 );
248 }
249 DataType::Int32 => {
250 handle_primitive_type!(
251 builder,
252 field,
253 col,
254 Int32Builder,
255 i32,
256 row,
257 idx,
258 just_return
259 );
260 }
261 DataType::Int64 => {
262 handle_primitive_type!(
263 builder,
264 field,
265 col,
266 Int64Builder,
267 i64,
268 row,
269 idx,
270 just_return
271 );
272 }
273 DataType::Float32 => {
274 handle_primitive_type!(
275 builder,
276 field,
277 col,
278 Float32Builder,
279 f32,
280 row,
281 idx,
282 just_return
283 );
284 }
285 DataType::Float64 => {
286 handle_primitive_type!(
287 builder,
288 field,
289 col,
290 Float64Builder,
291 f64,
292 row,
293 idx,
294 just_return
295 );
296 }
297 DataType::Utf8 => {
298 handle_primitive_type!(
299 builder,
300 field,
301 col,
302 StringBuilder,
303 String,
304 row,
305 idx,
306 just_return
307 );
308 }
309 DataType::LargeUtf8 => {
310 handle_primitive_type!(
311 builder,
312 field,
313 col,
314 LargeStringBuilder,
315 String,
316 row,
317 idx,
318 just_return
319 );
320 }
321 DataType::Decimal128(_precision, scale) => {
322 handle_primitive_type!(
323 builder,
324 field,
325 col,
326 Decimal128Builder,
327 String,
328 row,
329 idx,
330 |v: String| {
331 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
332 DataFusionError::Execution(format!(
333 "Failed to parse BigDecimal from {v:?}: {e:?}",
334 ))
335 })?;
336 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
337 DataFusionError::Execution(format!(
338 "Failed to convert BigDecimal to i128 for {decimal:?}",
339 ))
340 })
341 }
342 );
343 }
344 DataType::Timestamp(TimeUnit::Second, None) => {
345 handle_primitive_type!(
346 builder,
347 field,
348 col,
349 TimestampSecondBuilder,
350 chrono::NaiveDateTime,
351 row,
352 idx,
353 |v: chrono::NaiveDateTime| {
354 let t = v.and_utc().timestamp();
355 Ok::<_, DataFusionError>(t)
356 }
357 );
358 }
359 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
360 handle_primitive_type!(
361 builder,
362 field,
363 col,
364 TimestampNanosecondBuilder,
365 chrono::NaiveDateTime,
366 row,
367 idx,
368 |v: chrono::NaiveDateTime| {
369 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
370 DataFusionError::Execution(format!(
371 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
372 ))
373 })
374 }
375 );
376 }
377 DataType::Date64 => {
378 handle_primitive_type!(
379 builder,
380 field,
381 col,
382 Date64Builder,
383 chrono::NaiveDateTime,
384 row,
385 idx,
386 |v: chrono::NaiveDateTime| {
387 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
388 }
389 );
390 }
391 DataType::Boolean => {
392 handle_primitive_type!(
393 builder,
394 field,
395 col,
396 BooleanBuilder,
397 bool,
398 row,
399 idx,
400 just_return
401 );
402 }
403 DataType::Binary => {
404 handle_primitive_type!(
405 builder,
406 field,
407 col,
408 BinaryBuilder,
409 Vec<u8>,
410 row,
411 idx,
412 just_return
413 );
414 }
415 DataType::LargeBinary => {
416 handle_primitive_type!(
417 builder,
418 field,
419 col,
420 LargeBinaryBuilder,
421 Vec<u8>,
422 row,
423 idx,
424 just_return
425 );
426 }
427 _ => {
428 return Err(DataFusionError::NotImplemented(format!(
429 "Unsupported data type {:?} for col: {:?}",
430 field.data_type(),
431 col
432 )));
433 }
434 }
435 }
436 }
437
438 let projected_columns = array_builders
439 .into_iter()
440 .enumerate()
441 .filter(|(idx, _)| projections_contains(projection, *idx))
442 .map(|(_, mut builder)| builder.finish())
443 .collect::<Vec<ArrayRef>>();
444 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
445 Ok(RecordBatch::try_new_with_options(
446 projected_schema,
447 projected_columns,
448 &options,
449 )?)
450}