1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10 RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use derive_getters::Getters;
17use derive_with::With;
18use futures::StreamExt;
19use oracle::sql_type::OracleType as ColumnType;
20use oracle::{Connector, Row};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, With, Getters)]
24pub struct OracleConnectionOptions {
25 pub(crate) host: String,
26 pub(crate) port: u16,
27 pub(crate) username: String,
28 pub(crate) password: String,
29 pub(crate) service_name: String,
30 pub(crate) pool_max_size: usize,
31 pub(crate) stream_chunk_size: usize,
32}
33
34impl OracleConnectionOptions {
35 pub fn new(
36 host: impl Into<String>,
37 port: u16,
38 username: impl Into<String>,
39 password: impl Into<String>,
40 service_name: impl Into<String>,
41 ) -> Self {
42 Self {
43 host: host.into(),
44 port,
45 username: username.into(),
46 password: password.into(),
47 service_name: service_name.into(),
48 pool_max_size: 10,
49 stream_chunk_size: 2048,
50 }
51 }
52}
53
54#[derive(Debug)]
55pub struct OraclePool {
56 pool: bb8::Pool<OracleConnectionManager>,
57}
58
59pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
60 let connect_string = format!(
61 "//{}:{}/{}",
62 options.host, options.port, options.service_name
63 );
64 let connector = Connector::new(
65 options.username.clone(),
66 options.password.clone(),
67 connect_string,
68 );
69 let _ = connector
70 .connect()
71 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
72 let manager = OracleConnectionManager::from_connector(connector);
73 let pool = bb8::Pool::builder()
74 .max_size(options.pool_max_size as u32)
75 .build(manager)
76 .await
77 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
78 Ok(OraclePool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for OraclePool {
83 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84 let conn = self.pool.get_owned().await.map_err(|e| {
85 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
86 })?;
87 Ok(Arc::new(OracleConnection { conn }))
88 }
89}
90
91#[derive(Debug)]
92pub struct OracleConnection {
93 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
94}
95
96#[async_trait::async_trait]
97impl Connection for OracleConnection {
98 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
99 let sql = RemoteDbType::Oracle.query_limit_1(sql)?;
100 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
101 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
102 })?;
103 let remote_schema = Arc::new(build_remote_schema(&row)?);
104 Ok(remote_schema)
105 }
106
107 async fn query(
108 &self,
109 conn_options: &ConnectionOptions,
110 sql: &str,
111 table_schema: SchemaRef,
112 projection: Option<&Vec<usize>>,
113 unparsed_filters: &[String],
114 limit: Option<usize>,
115 ) -> DFResult<SendableRecordBatchStream> {
116 let projected_schema = project_schema(&table_schema, projection)?;
117 let sql = RemoteDbType::Oracle.try_rewrite_query(sql, unparsed_filters, limit)?;
118 let projection = projection.cloned();
119 let chunk_size = conn_options.stream_chunk_size();
120 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
121 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
122 })?;
123 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
124
125 let stream = stream.map(move |rows| {
126 let rows: Vec<Row> = rows
127 .into_iter()
128 .collect::<Result<Vec<_>, _>>()
129 .map_err(|e| {
130 DataFusionError::Execution(format!(
131 "Failed to collect rows from oracle due to {e}",
132 ))
133 })?;
134 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
135 });
136
137 Ok(Box::pin(RecordBatchStreamAdapter::new(
138 projected_schema,
139 stream,
140 )))
141 }
142}
143
144fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
145 match oracle_type {
146 ColumnType::Number(precision, scale) => {
147 let precision = if *precision == 0 { 38 } else { *precision };
149 let scale = if *scale == -127 { 0 } else { *scale };
150 Ok(OracleType::Number(precision, scale))
151 }
152 ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
153 ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
154 ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
155 ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
156 ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
157 ColumnType::Char(size) => Ok(OracleType::Char(*size)),
158 ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
159 ColumnType::Long => Ok(OracleType::Long),
160 ColumnType::CLOB => Ok(OracleType::Clob),
161 ColumnType::NCLOB => Ok(OracleType::NClob),
162 ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
163 ColumnType::LongRaw => Ok(OracleType::LongRaw),
164 ColumnType::BLOB => Ok(OracleType::Blob),
165 ColumnType::Date => Ok(OracleType::Date),
166 ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
167 ColumnType::Boolean => Ok(OracleType::Boolean),
168 _ => Err(DataFusionError::NotImplemented(format!(
169 "Unsupported oracle type: {oracle_type:?}",
170 ))),
171 }
172}
173
174fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
175 let mut remote_fields = vec![];
176 for col in row.column_info() {
177 let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
178 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
179 }
180 Ok(RemoteSchema::new(remote_fields))
181}
182
183macro_rules! handle_primitive_type {
184 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
185 let builder = $builder
186 .as_any_mut()
187 .downcast_mut::<$builder_ty>()
188 .unwrap_or_else(|| {
189 panic!(
190 "Failed to downcast builder to {} for {:?} and {:?}",
191 stringify!($builder_ty),
192 $field,
193 $col
194 )
195 });
196 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
197 DataFusionError::Execution(format!(
198 "Failed to get {} value for {:?} and {:?}: {e:?}",
199 stringify!($value_ty),
200 $field,
201 $col
202 ))
203 })?;
204
205 match v {
206 Some(v) => builder.append_value($convert(v)?),
207 None => builder.append_null(),
208 }
209 }};
210}
211
212fn rows_to_batch(
213 rows: &[Row],
214 table_schema: &SchemaRef,
215 projection: Option<&Vec<usize>>,
216) -> DFResult<RecordBatch> {
217 let projected_schema = project_schema(table_schema, projection)?;
218 let mut array_builders = vec![];
219 for field in table_schema.fields() {
220 let builder = make_builder(field.data_type(), rows.len());
221 array_builders.push(builder);
222 }
223
224 for row in rows {
225 for (idx, field) in table_schema.fields.iter().enumerate() {
226 if !projections_contains(projection, idx) {
227 continue;
228 }
229 let builder = &mut array_builders[idx];
230 let col = row.column_info().get(idx);
231 match field.data_type() {
232 DataType::Int16 => {
233 handle_primitive_type!(
234 builder,
235 field,
236 col,
237 Int16Builder,
238 i16,
239 row,
240 idx,
241 just_return
242 );
243 }
244 DataType::Int32 => {
245 handle_primitive_type!(
246 builder,
247 field,
248 col,
249 Int32Builder,
250 i32,
251 row,
252 idx,
253 just_return
254 );
255 }
256 DataType::Float32 => {
257 handle_primitive_type!(
258 builder,
259 field,
260 col,
261 Float32Builder,
262 f32,
263 row,
264 idx,
265 just_return
266 );
267 }
268 DataType::Float64 => {
269 handle_primitive_type!(
270 builder,
271 field,
272 col,
273 Float64Builder,
274 f64,
275 row,
276 idx,
277 just_return
278 );
279 }
280 DataType::Utf8 => {
281 handle_primitive_type!(
282 builder,
283 field,
284 col,
285 StringBuilder,
286 String,
287 row,
288 idx,
289 just_return
290 );
291 }
292 DataType::LargeUtf8 => {
293 handle_primitive_type!(
294 builder,
295 field,
296 col,
297 LargeStringBuilder,
298 String,
299 row,
300 idx,
301 just_return
302 );
303 }
304 DataType::Decimal128(_precision, scale) => {
305 handle_primitive_type!(
306 builder,
307 field,
308 col,
309 Decimal128Builder,
310 String,
311 row,
312 idx,
313 |v: String| {
314 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
315 DataFusionError::Execution(format!(
316 "Failed to parse BigDecimal from {v:?}: {e:?}",
317 ))
318 })?;
319 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
320 DataFusionError::Execution(format!(
321 "Failed to convert BigDecimal to i128 for {decimal:?}",
322 ))
323 })
324 }
325 );
326 }
327 DataType::Timestamp(TimeUnit::Second, None) => {
328 handle_primitive_type!(
329 builder,
330 field,
331 col,
332 TimestampSecondBuilder,
333 chrono::NaiveDateTime,
334 row,
335 idx,
336 |v: chrono::NaiveDateTime| {
337 let t = v.and_utc().timestamp();
338 Ok::<_, DataFusionError>(t)
339 }
340 );
341 }
342 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
343 handle_primitive_type!(
344 builder,
345 field,
346 col,
347 TimestampNanosecondBuilder,
348 chrono::NaiveDateTime,
349 row,
350 idx,
351 |v: chrono::NaiveDateTime| {
352 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
353 DataFusionError::Execution(format!(
354 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
355 ))
356 })
357 }
358 );
359 }
360 DataType::Date64 => {
361 handle_primitive_type!(
362 builder,
363 field,
364 col,
365 Date64Builder,
366 chrono::NaiveDateTime,
367 row,
368 idx,
369 |v: chrono::NaiveDateTime| {
370 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
371 }
372 );
373 }
374 DataType::Boolean => {
375 handle_primitive_type!(
376 builder,
377 field,
378 col,
379 BooleanBuilder,
380 bool,
381 row,
382 idx,
383 just_return
384 );
385 }
386 DataType::Binary => {
387 handle_primitive_type!(
388 builder,
389 field,
390 col,
391 BinaryBuilder,
392 Vec<u8>,
393 row,
394 idx,
395 just_return
396 );
397 }
398 DataType::LargeBinary => {
399 handle_primitive_type!(
400 builder,
401 field,
402 col,
403 LargeBinaryBuilder,
404 Vec<u8>,
405 row,
406 idx,
407 just_return
408 );
409 }
410 _ => {
411 return Err(DataFusionError::NotImplemented(format!(
412 "Unsupported data type {:?} for col: {:?}",
413 field.data_type(),
414 col
415 )));
416 }
417 }
418 }
419 }
420
421 let projected_columns = array_builders
422 .into_iter()
423 .enumerate()
424 .filter(|(idx, _)| projections_contains(projection, *idx))
425 .map(|(_, mut builder)| builder.finish())
426 .collect::<Vec<ArrayRef>>();
427 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
428}