1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10 RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::limit::LimitStream;
16use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use derive_getters::Getters;
19use derive_with::With;
20use futures::StreamExt;
21use oracle::sql_type::OracleType as ColumnType;
22use oracle::{Connector, Row};
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct OracleConnectionOptions {
27 pub(crate) host: String,
28 pub(crate) port: u16,
29 pub(crate) username: String,
30 pub(crate) password: String,
31 pub(crate) service_name: String,
32 pub(crate) pool_max_size: usize,
33 pub(crate) stream_chunk_size: usize,
34}
35
36impl OracleConnectionOptions {
37 pub fn new(
38 host: impl Into<String>,
39 port: u16,
40 username: impl Into<String>,
41 password: impl Into<String>,
42 service_name: impl Into<String>,
43 ) -> Self {
44 Self {
45 host: host.into(),
46 port,
47 username: username.into(),
48 password: password.into(),
49 service_name: service_name.into(),
50 pool_max_size: 10,
51 stream_chunk_size: 2048,
52 }
53 }
54}
55
56#[derive(Debug)]
57pub struct OraclePool {
58 pool: bb8::Pool<OracleConnectionManager>,
59}
60
61pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
62 let connect_string = format!(
63 "//{}:{}/{}",
64 options.host, options.port, options.service_name
65 );
66 let connector = Connector::new(
67 options.username.clone(),
68 options.password.clone(),
69 connect_string,
70 );
71 let _ = connector
72 .connect()
73 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
74 let manager = OracleConnectionManager::from_connector(connector);
75 let pool = bb8::Pool::builder()
76 .max_size(options.pool_max_size as u32)
77 .build(manager)
78 .await
79 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
80 Ok(OraclePool { pool })
81}
82
83#[async_trait::async_trait]
84impl Pool for OraclePool {
85 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
86 let conn = self.pool.get_owned().await.map_err(|e| {
87 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
88 })?;
89 Ok(Arc::new(OracleConnection { conn }))
90 }
91}
92
93#[derive(Debug)]
94pub struct OracleConnection {
95 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
96}
97
98#[async_trait::async_trait]
99impl Connection for OracleConnection {
100 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
101 let sql = try_limit_query(sql, 1).unwrap_or_else(|| sql.to_string());
102 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
103 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
104 })?;
105 let remote_schema = Arc::new(build_remote_schema(&row)?);
106 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
107 Ok((remote_schema, arrow_schema))
108 }
109
110 async fn query(
111 &self,
112 conn_options: &ConnectionOptions,
113 sql: &str,
114 table_schema: SchemaRef,
115 projection: Option<&Vec<usize>>,
116 limit: Option<usize>,
117 ) -> DFResult<SendableRecordBatchStream> {
118 let projected_schema = project_schema(&table_schema, projection)?;
119 let (sql, limit_stream) = match limit {
120 Some(limit) => {
121 if let Some(limited_sql) = try_limit_query(sql, limit) {
122 (limited_sql, false)
123 } else {
124 (sql.to_string(), true)
125 }
126 }
127 None => (sql.to_string(), false),
128 };
129 let projection = projection.cloned();
130 let chunk_size = conn_options.stream_chunk_size();
131 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
132 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
133 })?;
134 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
135
136 let stream = stream.map(move |rows| {
137 let rows: Vec<Row> = rows
138 .into_iter()
139 .collect::<Result<Vec<_>, _>>()
140 .map_err(|e| {
141 DataFusionError::Execution(format!(
142 "Failed to collect rows from oracle due to {e}",
143 ))
144 })?;
145 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
146 });
147
148 let sendable_stream = Box::pin(RecordBatchStreamAdapter::new(projected_schema, stream));
149
150 if limit_stream {
151 let metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
152 Ok(Box::pin(LimitStream::new(
153 sendable_stream,
154 0,
155 limit,
156 metrics,
157 )))
158 } else {
159 Ok(sendable_stream)
160 }
161 }
162}
163
164fn try_limit_query(sql: &str, limit: usize) -> Option<String> {
165 if sql.trim()[0..6].eq_ignore_ascii_case("select") {
166 Some(format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {limit}"))
167 } else {
168 None
169 }
170}
171
172fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
173 match oracle_type {
174 ColumnType::Number(precision, scale) => {
175 let precision = if *precision == 0 { 38 } else { *precision };
177 let scale = if *scale == -127 { 0 } else { *scale };
178 Ok(RemoteType::Oracle(OracleType::Number(precision, scale)))
179 }
180 ColumnType::BinaryFloat => Ok(RemoteType::Oracle(OracleType::BinaryFloat)),
181 ColumnType::BinaryDouble => Ok(RemoteType::Oracle(OracleType::BinaryDouble)),
182 ColumnType::Float(precision) => Ok(RemoteType::Oracle(OracleType::Float(*precision))),
183 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
184 ColumnType::NVarchar2(size) => Ok(RemoteType::Oracle(OracleType::NVarchar2(*size))),
185 ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
186 ColumnType::NChar(size) => Ok(RemoteType::Oracle(OracleType::NChar(*size))),
187 ColumnType::Long => Ok(RemoteType::Oracle(OracleType::Long)),
188 ColumnType::CLOB => Ok(RemoteType::Oracle(OracleType::Clob)),
189 ColumnType::NCLOB => Ok(RemoteType::Oracle(OracleType::NClob)),
190 ColumnType::Raw(size) => Ok(RemoteType::Oracle(OracleType::Raw(*size))),
191 ColumnType::LongRaw => Ok(RemoteType::Oracle(OracleType::LongRaw)),
192 ColumnType::BLOB => Ok(RemoteType::Oracle(OracleType::Blob)),
193 ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
194 ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
195 ColumnType::Boolean => Ok(RemoteType::Oracle(OracleType::Boolean)),
196 _ => Err(DataFusionError::NotImplemented(format!(
197 "Unsupported oracle type: {oracle_type:?}",
198 ))),
199 }
200}
201
202fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
203 let mut remote_fields = vec![];
204 for col in row.column_info() {
205 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
206 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
207 }
208 Ok(RemoteSchema::new(remote_fields))
209}
210
211macro_rules! handle_primitive_type {
212 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
213 let builder = $builder
214 .as_any_mut()
215 .downcast_mut::<$builder_ty>()
216 .unwrap_or_else(|| {
217 panic!(
218 "Failed to downcast builder to {} for {:?} and {:?}",
219 stringify!($builder_ty),
220 $field,
221 $col
222 )
223 });
224 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
225 DataFusionError::Execution(format!(
226 "Failed to get {} value for {:?} and {:?}: {e:?}",
227 stringify!($value_ty),
228 $field,
229 $col
230 ))
231 })?;
232
233 match v {
234 Some(v) => builder.append_value($convert(v)?),
235 None => builder.append_null(),
236 }
237 }};
238}
239
240fn rows_to_batch(
241 rows: &[Row],
242 table_schema: &SchemaRef,
243 projection: Option<&Vec<usize>>,
244) -> DFResult<RecordBatch> {
245 let projected_schema = project_schema(table_schema, projection)?;
246 let mut array_builders = vec![];
247 for field in table_schema.fields() {
248 let builder = make_builder(field.data_type(), rows.len());
249 array_builders.push(builder);
250 }
251
252 for row in rows {
253 for (idx, field) in table_schema.fields.iter().enumerate() {
254 if !projections_contains(projection, idx) {
255 continue;
256 }
257 let builder = &mut array_builders[idx];
258 let col = row.column_info().get(idx);
259 match field.data_type() {
260 DataType::Int16 => {
261 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
262 Ok::<_, DataFusionError>(v)
263 });
264 }
265 DataType::Int32 => {
266 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
267 Ok::<_, DataFusionError>(v)
268 });
269 }
270 DataType::Float32 => {
271 handle_primitive_type!(
272 builder,
273 field,
274 col,
275 Float32Builder,
276 f32,
277 row,
278 idx,
279 |v| { Ok::<_, DataFusionError>(v) }
280 );
281 }
282 DataType::Float64 => {
283 handle_primitive_type!(
284 builder,
285 field,
286 col,
287 Float64Builder,
288 f64,
289 row,
290 idx,
291 |v| { Ok::<_, DataFusionError>(v) }
292 );
293 }
294 DataType::Utf8 => {
295 handle_primitive_type!(
296 builder,
297 field,
298 col,
299 StringBuilder,
300 String,
301 row,
302 idx,
303 |v| { Ok::<_, DataFusionError>(v) }
304 );
305 }
306 DataType::LargeUtf8 => {
307 handle_primitive_type!(
308 builder,
309 field,
310 col,
311 LargeStringBuilder,
312 String,
313 row,
314 idx,
315 |v| { Ok::<_, DataFusionError>(v) }
316 );
317 }
318 DataType::Decimal128(_precision, scale) => {
319 handle_primitive_type!(
320 builder,
321 field,
322 col,
323 Decimal128Builder,
324 String,
325 row,
326 idx,
327 |v: String| {
328 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
329 DataFusionError::Execution(format!(
330 "Failed to parse BigDecimal from {v:?}: {e:?}",
331 ))
332 })?;
333 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
334 DataFusionError::Execution(format!(
335 "Failed to convert BigDecimal to i128 for {decimal:?}",
336 ))
337 })
338 }
339 );
340 }
341 DataType::Timestamp(TimeUnit::Second, None) => {
342 handle_primitive_type!(
343 builder,
344 field,
345 col,
346 TimestampSecondBuilder,
347 chrono::NaiveDateTime,
348 row,
349 idx,
350 |v: chrono::NaiveDateTime| {
351 let t = v.and_utc().timestamp();
352 Ok::<_, DataFusionError>(t)
353 }
354 );
355 }
356 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
357 handle_primitive_type!(
358 builder,
359 field,
360 col,
361 TimestampNanosecondBuilder,
362 chrono::NaiveDateTime,
363 row,
364 idx,
365 |v: chrono::NaiveDateTime| {
366 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
367 DataFusionError::Execution(format!(
368 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
369 ))
370 })
371 }
372 );
373 }
374 DataType::Date64 => {
375 handle_primitive_type!(
376 builder,
377 field,
378 col,
379 Date64Builder,
380 chrono::NaiveDateTime,
381 row,
382 idx,
383 |v: chrono::NaiveDateTime| {
384 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
385 }
386 );
387 }
388 DataType::Boolean => {
389 handle_primitive_type!(
390 builder,
391 field,
392 col,
393 BooleanBuilder,
394 bool,
395 row,
396 idx,
397 |v| { Ok::<_, DataFusionError>(v) }
398 );
399 }
400 DataType::Binary => {
401 handle_primitive_type!(
402 builder,
403 field,
404 col,
405 BinaryBuilder,
406 Vec<u8>,
407 row,
408 idx,
409 |v| { Ok::<_, DataFusionError>(v) }
410 );
411 }
412 DataType::LargeBinary => {
413 handle_primitive_type!(
414 builder,
415 field,
416 col,
417 LargeBinaryBuilder,
418 Vec<u8>,
419 row,
420 idx,
421 |v| { Ok::<_, DataFusionError>(v) }
422 );
423 }
424 _ => {
425 return Err(DataFusionError::NotImplemented(format!(
426 "Unsupported data type {:?} for col: {:?}",
427 field.data_type(),
428 col
429 )));
430 }
431 }
432 }
433 }
434
435 let projected_columns = array_builders
436 .into_iter()
437 .enumerate()
438 .filter(|(idx, _)| projections_contains(projection, *idx))
439 .map(|(_, mut builder)| builder.finish())
440 .collect::<Vec<ArrayRef>>();
441 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
442}