1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10 RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion::prelude::Expr;
17use derive_getters::Getters;
18use derive_with::With;
19use futures::StreamExt;
20use oracle::sql_type::OracleType as ColumnType;
21use oracle::{Connector, Row};
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct OracleConnectionOptions {
26 pub(crate) host: String,
27 pub(crate) port: u16,
28 pub(crate) username: String,
29 pub(crate) password: String,
30 pub(crate) service_name: String,
31 pub(crate) pool_max_size: usize,
32 pub(crate) stream_chunk_size: usize,
33}
34
35impl OracleConnectionOptions {
36 pub fn new(
37 host: impl Into<String>,
38 port: u16,
39 username: impl Into<String>,
40 password: impl Into<String>,
41 service_name: impl Into<String>,
42 ) -> Self {
43 Self {
44 host: host.into(),
45 port,
46 username: username.into(),
47 password: password.into(),
48 service_name: service_name.into(),
49 pool_max_size: 10,
50 stream_chunk_size: 2048,
51 }
52 }
53}
54
55#[derive(Debug)]
56pub struct OraclePool {
57 pool: bb8::Pool<OracleConnectionManager>,
58}
59
60pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
61 let connect_string = format!(
62 "//{}:{}/{}",
63 options.host, options.port, options.service_name
64 );
65 let connector = Connector::new(
66 options.username.clone(),
67 options.password.clone(),
68 connect_string,
69 );
70 let _ = connector
71 .connect()
72 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
73 let manager = OracleConnectionManager::from_connector(connector);
74 let pool = bb8::Pool::builder()
75 .max_size(options.pool_max_size as u32)
76 .build(manager)
77 .await
78 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
79 Ok(OraclePool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for OraclePool {
84 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85 let conn = self.pool.get_owned().await.map_err(|e| {
86 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
87 })?;
88 Ok(Arc::new(OracleConnection { conn }))
89 }
90}
91
92#[derive(Debug)]
93pub struct OracleConnection {
94 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
95}
96
97#[async_trait::async_trait]
98impl Connection for OracleConnection {
99 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
100 let sql = RemoteDbType::Oracle
101 .try_rewrite_query(sql, &[], Some(1))
102 .unwrap_or_else(|| sql.to_string());
103 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
104 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
105 })?;
106 let remote_schema = Arc::new(build_remote_schema(&row)?);
107 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
108 Ok((remote_schema, arrow_schema))
109 }
110
111 async fn query(
112 &self,
113 conn_options: &ConnectionOptions,
114 sql: &str,
115 table_schema: SchemaRef,
116 projection: Option<&Vec<usize>>,
117 filters: &[Expr],
118 limit: Option<usize>,
119 ) -> DFResult<SendableRecordBatchStream> {
120 let projected_schema = project_schema(&table_schema, projection)?;
121 let sql = RemoteDbType::Oracle
122 .try_rewrite_query(sql, filters, limit)
123 .unwrap_or_else(|| sql.to_string());
124 let projection = projection.cloned();
125 let chunk_size = conn_options.stream_chunk_size();
126 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
127 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
128 })?;
129 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
130
131 let stream = stream.map(move |rows| {
132 let rows: Vec<Row> = rows
133 .into_iter()
134 .collect::<Result<Vec<_>, _>>()
135 .map_err(|e| {
136 DataFusionError::Execution(format!(
137 "Failed to collect rows from oracle due to {e}",
138 ))
139 })?;
140 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
141 });
142
143 Ok(Box::pin(RecordBatchStreamAdapter::new(
144 projected_schema,
145 stream,
146 )))
147 }
148}
149
150fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
151 match oracle_type {
152 ColumnType::Number(precision, scale) => {
153 let precision = if *precision == 0 { 38 } else { *precision };
155 let scale = if *scale == -127 { 0 } else { *scale };
156 Ok(RemoteType::Oracle(OracleType::Number(precision, scale)))
157 }
158 ColumnType::BinaryFloat => Ok(RemoteType::Oracle(OracleType::BinaryFloat)),
159 ColumnType::BinaryDouble => Ok(RemoteType::Oracle(OracleType::BinaryDouble)),
160 ColumnType::Float(precision) => Ok(RemoteType::Oracle(OracleType::Float(*precision))),
161 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
162 ColumnType::NVarchar2(size) => Ok(RemoteType::Oracle(OracleType::NVarchar2(*size))),
163 ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
164 ColumnType::NChar(size) => Ok(RemoteType::Oracle(OracleType::NChar(*size))),
165 ColumnType::Long => Ok(RemoteType::Oracle(OracleType::Long)),
166 ColumnType::CLOB => Ok(RemoteType::Oracle(OracleType::Clob)),
167 ColumnType::NCLOB => Ok(RemoteType::Oracle(OracleType::NClob)),
168 ColumnType::Raw(size) => Ok(RemoteType::Oracle(OracleType::Raw(*size))),
169 ColumnType::LongRaw => Ok(RemoteType::Oracle(OracleType::LongRaw)),
170 ColumnType::BLOB => Ok(RemoteType::Oracle(OracleType::Blob)),
171 ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
172 ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
173 ColumnType::Boolean => Ok(RemoteType::Oracle(OracleType::Boolean)),
174 _ => Err(DataFusionError::NotImplemented(format!(
175 "Unsupported oracle type: {oracle_type:?}",
176 ))),
177 }
178}
179
180fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
181 let mut remote_fields = vec![];
182 for col in row.column_info() {
183 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
184 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
185 }
186 Ok(RemoteSchema::new(remote_fields))
187}
188
189macro_rules! handle_primitive_type {
190 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
191 let builder = $builder
192 .as_any_mut()
193 .downcast_mut::<$builder_ty>()
194 .unwrap_or_else(|| {
195 panic!(
196 "Failed to downcast builder to {} for {:?} and {:?}",
197 stringify!($builder_ty),
198 $field,
199 $col
200 )
201 });
202 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
203 DataFusionError::Execution(format!(
204 "Failed to get {} value for {:?} and {:?}: {e:?}",
205 stringify!($value_ty),
206 $field,
207 $col
208 ))
209 })?;
210
211 match v {
212 Some(v) => builder.append_value($convert(v)?),
213 None => builder.append_null(),
214 }
215 }};
216}
217
218fn rows_to_batch(
219 rows: &[Row],
220 table_schema: &SchemaRef,
221 projection: Option<&Vec<usize>>,
222) -> DFResult<RecordBatch> {
223 let projected_schema = project_schema(table_schema, projection)?;
224 let mut array_builders = vec![];
225 for field in table_schema.fields() {
226 let builder = make_builder(field.data_type(), rows.len());
227 array_builders.push(builder);
228 }
229
230 for row in rows {
231 for (idx, field) in table_schema.fields.iter().enumerate() {
232 if !projections_contains(projection, idx) {
233 continue;
234 }
235 let builder = &mut array_builders[idx];
236 let col = row.column_info().get(idx);
237 match field.data_type() {
238 DataType::Int16 => {
239 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
240 Ok::<_, DataFusionError>(v)
241 });
242 }
243 DataType::Int32 => {
244 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
245 Ok::<_, DataFusionError>(v)
246 });
247 }
248 DataType::Float32 => {
249 handle_primitive_type!(
250 builder,
251 field,
252 col,
253 Float32Builder,
254 f32,
255 row,
256 idx,
257 |v| { Ok::<_, DataFusionError>(v) }
258 );
259 }
260 DataType::Float64 => {
261 handle_primitive_type!(
262 builder,
263 field,
264 col,
265 Float64Builder,
266 f64,
267 row,
268 idx,
269 |v| { Ok::<_, DataFusionError>(v) }
270 );
271 }
272 DataType::Utf8 => {
273 handle_primitive_type!(
274 builder,
275 field,
276 col,
277 StringBuilder,
278 String,
279 row,
280 idx,
281 |v| { Ok::<_, DataFusionError>(v) }
282 );
283 }
284 DataType::LargeUtf8 => {
285 handle_primitive_type!(
286 builder,
287 field,
288 col,
289 LargeStringBuilder,
290 String,
291 row,
292 idx,
293 |v| { Ok::<_, DataFusionError>(v) }
294 );
295 }
296 DataType::Decimal128(_precision, scale) => {
297 handle_primitive_type!(
298 builder,
299 field,
300 col,
301 Decimal128Builder,
302 String,
303 row,
304 idx,
305 |v: String| {
306 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
307 DataFusionError::Execution(format!(
308 "Failed to parse BigDecimal from {v:?}: {e:?}",
309 ))
310 })?;
311 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
312 DataFusionError::Execution(format!(
313 "Failed to convert BigDecimal to i128 for {decimal:?}",
314 ))
315 })
316 }
317 );
318 }
319 DataType::Timestamp(TimeUnit::Second, None) => {
320 handle_primitive_type!(
321 builder,
322 field,
323 col,
324 TimestampSecondBuilder,
325 chrono::NaiveDateTime,
326 row,
327 idx,
328 |v: chrono::NaiveDateTime| {
329 let t = v.and_utc().timestamp();
330 Ok::<_, DataFusionError>(t)
331 }
332 );
333 }
334 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
335 handle_primitive_type!(
336 builder,
337 field,
338 col,
339 TimestampNanosecondBuilder,
340 chrono::NaiveDateTime,
341 row,
342 idx,
343 |v: chrono::NaiveDateTime| {
344 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
345 DataFusionError::Execution(format!(
346 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
347 ))
348 })
349 }
350 );
351 }
352 DataType::Date64 => {
353 handle_primitive_type!(
354 builder,
355 field,
356 col,
357 Date64Builder,
358 chrono::NaiveDateTime,
359 row,
360 idx,
361 |v: chrono::NaiveDateTime| {
362 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
363 }
364 );
365 }
366 DataType::Boolean => {
367 handle_primitive_type!(
368 builder,
369 field,
370 col,
371 BooleanBuilder,
372 bool,
373 row,
374 idx,
375 |v| { Ok::<_, DataFusionError>(v) }
376 );
377 }
378 DataType::Binary => {
379 handle_primitive_type!(
380 builder,
381 field,
382 col,
383 BinaryBuilder,
384 Vec<u8>,
385 row,
386 idx,
387 |v| { Ok::<_, DataFusionError>(v) }
388 );
389 }
390 DataType::LargeBinary => {
391 handle_primitive_type!(
392 builder,
393 field,
394 col,
395 LargeBinaryBuilder,
396 Vec<u8>,
397 row,
398 idx,
399 |v| { Ok::<_, DataFusionError>(v) }
400 );
401 }
402 _ => {
403 return Err(DataFusionError::NotImplemented(format!(
404 "Unsupported data type {:?} for col: {:?}",
405 field.data_type(),
406 col
407 )));
408 }
409 }
410 }
411 }
412
413 let projected_columns = array_builders
414 .into_iter()
415 .enumerate()
416 .filter(|(idx, _)| projections_contains(projection, *idx))
417 .map(|(_, mut builder)| builder.finish())
418 .collect::<Vec<ArrayRef>>();
419 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
420}