1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10 RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::prelude::Expr;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use oracle::sql_type::OracleType as ColumnType;
21use oracle::{Connector, Row};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct OracleConnectionOptions {
26 pub(crate) host: String,
27 pub(crate) port: u16,
28 pub(crate) username: String,
29 pub(crate) password: String,
30 pub(crate) service_name: String,
31 pub(crate) pool_max_size: usize,
32 pub(crate) stream_chunk_size: usize,
33}
34
35impl OracleConnectionOptions {
36 pub fn new(
37 host: impl Into<String>,
38 port: u16,
39 username: impl Into<String>,
40 password: impl Into<String>,
41 service_name: impl Into<String>,
42 ) -> Self {
43 Self {
44 host: host.into(),
45 port,
46 username: username.into(),
47 password: password.into(),
48 service_name: service_name.into(),
49 pool_max_size: 10,
50 stream_chunk_size: 2048,
51 }
52 }
53}
54
55#[derive(Debug)]
56pub struct OraclePool {
57 pool: bb8::Pool<OracleConnectionManager>,
58}
59
60pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
61 let connect_string = format!(
62 "//{}:{}/{}",
63 options.host, options.port, options.service_name
64 );
65 let connector = Connector::new(
66 options.username.clone(),
67 options.password.clone(),
68 connect_string,
69 );
70 let _ = connector
71 .connect()
72 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
73 let manager = OracleConnectionManager::from_connector(connector);
74 let pool = bb8::Pool::builder()
75 .max_size(options.pool_max_size as u32)
76 .build(manager)
77 .await
78 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
79 Ok(OraclePool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for OraclePool {
84 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85 let conn = self.pool.get_owned().await.map_err(|e| {
86 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
87 })?;
88 Ok(Arc::new(OracleConnection { conn }))
89 }
90}
91
92#[derive(Debug)]
93pub struct OracleConnection {
94 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
95}
96
97#[async_trait::async_trait]
98impl Connection for OracleConnection {
99 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
100 let sql = RemoteDbType::Oracle.query_limit_1(sql)?;
101 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
102 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
103 })?;
104 let remote_schema = Arc::new(build_remote_schema(&row)?);
105 Ok(remote_schema)
106 }
107
108 async fn query(
109 &self,
110 conn_options: &ConnectionOptions,
111 sql: &str,
112 table_schema: SchemaRef,
113 projection: Option<&Vec<usize>>,
114 filters: &[Expr],
115 limit: Option<usize>,
116 ) -> DFResult<SendableRecordBatchStream> {
117 let projected_schema = project_schema(&table_schema, projection)?;
118 let sql = RemoteDbType::Oracle.try_rewrite_query(sql, filters, limit)?;
119 let projection = projection.cloned();
120 let chunk_size = conn_options.stream_chunk_size();
121 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
122 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
123 })?;
124 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
125
126 let stream = stream.map(move |rows| {
127 let rows: Vec<Row> = rows
128 .into_iter()
129 .collect::<Result<Vec<_>, _>>()
130 .map_err(|e| {
131 DataFusionError::Execution(format!(
132 "Failed to collect rows from oracle due to {e}",
133 ))
134 })?;
135 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
136 });
137
138 Ok(Box::pin(RecordBatchStreamAdapter::new(
139 projected_schema,
140 stream,
141 )))
142 }
143}
144
145fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<OracleType> {
146 match oracle_type {
147 ColumnType::Number(precision, scale) => {
148 let precision = if *precision == 0 { 38 } else { *precision };
150 let scale = if *scale == -127 { 0 } else { *scale };
151 Ok(OracleType::Number(precision, scale))
152 }
153 ColumnType::BinaryFloat => Ok(OracleType::BinaryFloat),
154 ColumnType::BinaryDouble => Ok(OracleType::BinaryDouble),
155 ColumnType::Float(precision) => Ok(OracleType::Float(*precision)),
156 ColumnType::Varchar2(size) => Ok(OracleType::Varchar2(*size)),
157 ColumnType::NVarchar2(size) => Ok(OracleType::NVarchar2(*size)),
158 ColumnType::Char(size) => Ok(OracleType::Char(*size)),
159 ColumnType::NChar(size) => Ok(OracleType::NChar(*size)),
160 ColumnType::Long => Ok(OracleType::Long),
161 ColumnType::CLOB => Ok(OracleType::Clob),
162 ColumnType::NCLOB => Ok(OracleType::NClob),
163 ColumnType::Raw(size) => Ok(OracleType::Raw(*size)),
164 ColumnType::LongRaw => Ok(OracleType::LongRaw),
165 ColumnType::BLOB => Ok(OracleType::Blob),
166 ColumnType::Date => Ok(OracleType::Date),
167 ColumnType::Timestamp(_) => Ok(OracleType::Timestamp),
168 ColumnType::Boolean => Ok(OracleType::Boolean),
169 _ => Err(DataFusionError::NotImplemented(format!(
170 "Unsupported oracle type: {oracle_type:?}",
171 ))),
172 }
173}
174
175fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
176 let mut remote_fields = vec![];
177 for col in row.column_info() {
178 let remote_type = RemoteType::Oracle(oracle_type_to_remote_type(col.oracle_type())?);
179 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
180 }
181 Ok(RemoteSchema::new(remote_fields))
182}
183
184macro_rules! handle_primitive_type {
185 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
186 let builder = $builder
187 .as_any_mut()
188 .downcast_mut::<$builder_ty>()
189 .unwrap_or_else(|| {
190 panic!(
191 "Failed to downcast builder to {} for {:?} and {:?}",
192 stringify!($builder_ty),
193 $field,
194 $col
195 )
196 });
197 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
198 DataFusionError::Execution(format!(
199 "Failed to get {} value for {:?} and {:?}: {e:?}",
200 stringify!($value_ty),
201 $field,
202 $col
203 ))
204 })?;
205
206 match v {
207 Some(v) => builder.append_value($convert(v)?),
208 None => builder.append_null(),
209 }
210 }};
211}
212
213fn rows_to_batch(
214 rows: &[Row],
215 table_schema: &SchemaRef,
216 projection: Option<&Vec<usize>>,
217) -> DFResult<RecordBatch> {
218 let projected_schema = project_schema(table_schema, projection)?;
219 let mut array_builders = vec![];
220 for field in table_schema.fields() {
221 let builder = make_builder(field.data_type(), rows.len());
222 array_builders.push(builder);
223 }
224
225 for row in rows {
226 for (idx, field) in table_schema.fields.iter().enumerate() {
227 if !projections_contains(projection, idx) {
228 continue;
229 }
230 let builder = &mut array_builders[idx];
231 let col = row.column_info().get(idx);
232 match field.data_type() {
233 DataType::Int16 => {
234 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
235 Ok::<_, DataFusionError>(v)
236 });
237 }
238 DataType::Int32 => {
239 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
240 Ok::<_, DataFusionError>(v)
241 });
242 }
243 DataType::Float32 => {
244 handle_primitive_type!(
245 builder,
246 field,
247 col,
248 Float32Builder,
249 f32,
250 row,
251 idx,
252 |v| { Ok::<_, DataFusionError>(v) }
253 );
254 }
255 DataType::Float64 => {
256 handle_primitive_type!(
257 builder,
258 field,
259 col,
260 Float64Builder,
261 f64,
262 row,
263 idx,
264 |v| { Ok::<_, DataFusionError>(v) }
265 );
266 }
267 DataType::Utf8 => {
268 handle_primitive_type!(
269 builder,
270 field,
271 col,
272 StringBuilder,
273 String,
274 row,
275 idx,
276 |v| { Ok::<_, DataFusionError>(v) }
277 );
278 }
279 DataType::LargeUtf8 => {
280 handle_primitive_type!(
281 builder,
282 field,
283 col,
284 LargeStringBuilder,
285 String,
286 row,
287 idx,
288 |v| { Ok::<_, DataFusionError>(v) }
289 );
290 }
291 DataType::Decimal128(_precision, scale) => {
292 handle_primitive_type!(
293 builder,
294 field,
295 col,
296 Decimal128Builder,
297 String,
298 row,
299 idx,
300 |v: String| {
301 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
302 DataFusionError::Execution(format!(
303 "Failed to parse BigDecimal from {v:?}: {e:?}",
304 ))
305 })?;
306 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
307 DataFusionError::Execution(format!(
308 "Failed to convert BigDecimal to i128 for {decimal:?}",
309 ))
310 })
311 }
312 );
313 }
314 DataType::Timestamp(TimeUnit::Second, None) => {
315 handle_primitive_type!(
316 builder,
317 field,
318 col,
319 TimestampSecondBuilder,
320 chrono::NaiveDateTime,
321 row,
322 idx,
323 |v: chrono::NaiveDateTime| {
324 let t = v.and_utc().timestamp();
325 Ok::<_, DataFusionError>(t)
326 }
327 );
328 }
329 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
330 handle_primitive_type!(
331 builder,
332 field,
333 col,
334 TimestampNanosecondBuilder,
335 chrono::NaiveDateTime,
336 row,
337 idx,
338 |v: chrono::NaiveDateTime| {
339 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
340 DataFusionError::Execution(format!(
341 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
342 ))
343 })
344 }
345 );
346 }
347 DataType::Date64 => {
348 handle_primitive_type!(
349 builder,
350 field,
351 col,
352 Date64Builder,
353 chrono::NaiveDateTime,
354 row,
355 idx,
356 |v: chrono::NaiveDateTime| {
357 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
358 }
359 );
360 }
361 DataType::Boolean => {
362 handle_primitive_type!(
363 builder,
364 field,
365 col,
366 BooleanBuilder,
367 bool,
368 row,
369 idx,
370 |v| { Ok::<_, DataFusionError>(v) }
371 );
372 }
373 DataType::Binary => {
374 handle_primitive_type!(
375 builder,
376 field,
377 col,
378 BinaryBuilder,
379 Vec<u8>,
380 row,
381 idx,
382 |v| { Ok::<_, DataFusionError>(v) }
383 );
384 }
385 DataType::LargeBinary => {
386 handle_primitive_type!(
387 builder,
388 field,
389 col,
390 LargeBinaryBuilder,
391 Vec<u8>,
392 row,
393 idx,
394 |v| { Ok::<_, DataFusionError>(v) }
395 );
396 }
397 _ => {
398 return Err(DataFusionError::NotImplemented(format!(
399 "Unsupported data type {:?} for col: {:?}",
400 field.data_type(),
401 col
402 )));
403 }
404 }
405 }
406 }
407
408 let projected_columns = array_builders
409 .into_iter()
410 .enumerate()
411 .filter(|(idx, _)| projections_contains(projection, *idx))
412 .map(|(_, mut builder)| builder.finish())
413 .collect::<Vec<ArrayRef>>();
414 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
415}