1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10 RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::prelude::Expr;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use oracle::sql_type::OracleType as ColumnType;
21use oracle::{Connector, Row};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct OracleConnectionOptions {
26 pub(crate) host: String,
27 pub(crate) port: u16,
28 pub(crate) username: String,
29 pub(crate) password: String,
30 pub(crate) service_name: String,
31 pub(crate) pool_max_size: usize,
32 pub(crate) stream_chunk_size: usize,
33}
34
35impl OracleConnectionOptions {
36 pub fn new(
37 host: impl Into<String>,
38 port: u16,
39 username: impl Into<String>,
40 password: impl Into<String>,
41 service_name: impl Into<String>,
42 ) -> Self {
43 Self {
44 host: host.into(),
45 port,
46 username: username.into(),
47 password: password.into(),
48 service_name: service_name.into(),
49 pool_max_size: 10,
50 stream_chunk_size: 2048,
51 }
52 }
53}
54
55#[derive(Debug)]
56pub struct OraclePool {
57 pool: bb8::Pool<OracleConnectionManager>,
58}
59
60pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
61 let connect_string = format!(
62 "//{}:{}/{}",
63 options.host, options.port, options.service_name
64 );
65 let connector = Connector::new(
66 options.username.clone(),
67 options.password.clone(),
68 connect_string,
69 );
70 let _ = connector
71 .connect()
72 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
73 let manager = OracleConnectionManager::from_connector(connector);
74 let pool = bb8::Pool::builder()
75 .max_size(options.pool_max_size as u32)
76 .build(manager)
77 .await
78 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
79 Ok(OraclePool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for OraclePool {
84 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85 let conn = self.pool.get_owned().await.map_err(|e| {
86 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
87 })?;
88 Ok(Arc::new(OracleConnection { conn }))
89 }
90}
91
92#[derive(Debug)]
93pub struct OracleConnection {
94 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
95}
96
97#[async_trait::async_trait]
98impl Connection for OracleConnection {
99 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
100 let sql = RemoteDbType::Oracle.query_limit_1(sql)?;
101 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
102 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
103 })?;
104 let remote_schema = Arc::new(build_remote_schema(&row)?);
105 Ok(remote_schema)
106 }
107
108 async fn query(
109 &self,
110 conn_options: &ConnectionOptions,
111 sql: &str,
112 table_schema: SchemaRef,
113 projection: Option<&Vec<usize>>,
114 filters: &[Expr],
115 limit: Option<usize>,
116 ) -> DFResult<SendableRecordBatchStream> {
117 let projected_schema = project_schema(&table_schema, projection)?;
118 let sql = RemoteDbType::Oracle.try_rewrite_query(sql, filters, limit)?;
119 let projection = projection.cloned();
120 let chunk_size = conn_options.stream_chunk_size();
121 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
122 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
123 })?;
124 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
125
126 let stream = stream.map(move |rows| {
127 let rows: Vec<Row> = rows
128 .into_iter()
129 .collect::<Result<Vec<_>, _>>()
130 .map_err(|e| {
131 DataFusionError::Execution(format!(
132 "Failed to collect rows from oracle due to {e}",
133 ))
134 })?;
135 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
136 });
137
138 Ok(Box::pin(RecordBatchStreamAdapter::new(
139 projected_schema,
140 stream,
141 )))
142 }
143}
144
145fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
146 match oracle_type {
147 ColumnType::Number(precision, scale) => {
148 let precision = if *precision == 0 { 38 } else { *precision };
150 let scale = if *scale == -127 { 0 } else { *scale };
151 Ok(OracleType::Number(precision, scale))
152 }
153 ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
154 ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
155 ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
156 ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
157 ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
158 ColumnType::Char(size) => Ok(OracleType::Char(*size)),
159 ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
160 ColumnType::Long => Ok(OracleType::Long),
161 ColumnType::CLOB => Ok(OracleType::Clob),
162 ColumnType::NCLOB => Ok(OracleType::NClob),
163 ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
164 ColumnType::LongRaw => Ok(OracleType::LongRaw),
165 ColumnType::BLOB => Ok(OracleType::Blob),
166 ColumnType::Date => Ok(OracleType::Date),
167 ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
168 ColumnType::Boolean => Ok(OracleType::Boolean),
169 _ => Err(DataFusionError::NotImplemented(format!(
170 "Unsupported oracle type: {oracle_type:?}",
171 ))),
172 }
173}
174
175fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
176 let mut remote_fields = vec![];
177 for col in row.column_info() {
178 let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
179 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
180 }
181 Ok(RemoteSchema::new(remote_fields))
182}
183
184macro_rules! handle_primitive_type {
185 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
186 let builder = $builder
187 .as_any_mut()
188 .downcast_mut::<$builder_ty>()
189 .unwrap_or_else(|| {
190 panic!(
191 "Failed to downcast builder to {} for {:?} and {:?}",
192 stringify!($builder_ty),
193 $field,
194 $col
195 )
196 });
197 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
198 DataFusionError::Execution(format!(
199 "Failed to get {} value for {:?} and {:?}: {e:?}",
200 stringify!($value_ty),
201 $field,
202 $col
203 ))
204 })?;
205
206 match v {
207 Some(v) => builder.append_value($convert(v)?),
208 None => builder.append_null(),
209 }
210 }};
211}
212
213fn rows_to_batch(
214 rows: &[Row],
215 table_schema: &SchemaRef,
216 projection: Option<&Vec<usize>>,
217) -> DFResult<RecordBatch> {
218 let projected_schema = project_schema(table_schema, projection)?;
219 let mut array_builders = vec![];
220 for field in table_schema.fields() {
221 let builder = make_builder(field.data_type(), rows.len());
222 array_builders.push(builder);
223 }
224
225 for row in rows {
226 for (idx, field) in table_schema.fields.iter().enumerate() {
227 if !projections_contains(projection, idx) {
228 continue;
229 }
230 let builder = &mut array_builders[idx];
231 let col = row.column_info().get(idx);
232 match field.data_type() {
233 DataType::Int16 => {
234 handle_primitive_type!(
235 builder,
236 field,
237 col,
238 Int16Builder,
239 i16,
240 row,
241 idx,
242 just_return
243 );
244 }
245 DataType::Int32 => {
246 handle_primitive_type!(
247 builder,
248 field,
249 col,
250 Int32Builder,
251 i32,
252 row,
253 idx,
254 just_return
255 );
256 }
257 DataType::Float32 => {
258 handle_primitive_type!(
259 builder,
260 field,
261 col,
262 Float32Builder,
263 f32,
264 row,
265 idx,
266 just_return
267 );
268 }
269 DataType::Float64 => {
270 handle_primitive_type!(
271 builder,
272 field,
273 col,
274 Float64Builder,
275 f64,
276 row,
277 idx,
278 just_return
279 );
280 }
281 DataType::Utf8 => {
282 handle_primitive_type!(
283 builder,
284 field,
285 col,
286 StringBuilder,
287 String,
288 row,
289 idx,
290 just_return
291 );
292 }
293 DataType::LargeUtf8 => {
294 handle_primitive_type!(
295 builder,
296 field,
297 col,
298 LargeStringBuilder,
299 String,
300 row,
301 idx,
302 just_return
303 );
304 }
305 DataType::Decimal128(_precision, scale) => {
306 handle_primitive_type!(
307 builder,
308 field,
309 col,
310 Decimal128Builder,
311 String,
312 row,
313 idx,
314 |v: String| {
315 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
316 DataFusionError::Execution(format!(
317 "Failed to parse BigDecimal from {v:?}: {e:?}",
318 ))
319 })?;
320 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
321 DataFusionError::Execution(format!(
322 "Failed to convert BigDecimal to i128 for {decimal:?}",
323 ))
324 })
325 }
326 );
327 }
328 DataType::Timestamp(TimeUnit::Second, None) => {
329 handle_primitive_type!(
330 builder,
331 field,
332 col,
333 TimestampSecondBuilder,
334 chrono::NaiveDateTime,
335 row,
336 idx,
337 |v: chrono::NaiveDateTime| {
338 let t = v.and_utc().timestamp();
339 Ok::<_, DataFusionError>(t)
340 }
341 );
342 }
343 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
344 handle_primitive_type!(
345 builder,
346 field,
347 col,
348 TimestampNanosecondBuilder,
349 chrono::NaiveDateTime,
350 row,
351 idx,
352 |v: chrono::NaiveDateTime| {
353 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
354 DataFusionError::Execution(format!(
355 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
356 ))
357 })
358 }
359 );
360 }
361 DataType::Date64 => {
362 handle_primitive_type!(
363 builder,
364 field,
365 col,
366 Date64Builder,
367 chrono::NaiveDateTime,
368 row,
369 idx,
370 |v: chrono::NaiveDateTime| {
371 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
372 }
373 );
374 }
375 DataType::Boolean => {
376 handle_primitive_type!(
377 builder,
378 field,
379 col,
380 BooleanBuilder,
381 bool,
382 row,
383 idx,
384 just_return
385 );
386 }
387 DataType::Binary => {
388 handle_primitive_type!(
389 builder,
390 field,
391 col,
392 BinaryBuilder,
393 Vec<u8>,
394 row,
395 idx,
396 just_return
397 );
398 }
399 DataType::LargeBinary => {
400 handle_primitive_type!(
401 builder,
402 field,
403 col,
404 LargeBinaryBuilder,
405 Vec<u8>,
406 row,
407 idx,
408 just_return
409 );
410 }
411 _ => {
412 return Err(DataFusionError::NotImplemented(format!(
413 "Unsupported data type {:?} for col: {:?}",
414 field.data_type(),
415 col
416 )));
417 }
418 }
419 }
420 }
421
422 let projected_columns = array_builders
423 .into_iter()
424 .enumerate()
425 .filter(|(idx, _)| projections_contains(projection, *idx))
426 .map(|(_, mut builder)| builder.finish())
427 .collect::<Vec<ArrayRef>>();
428 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
429}