1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use bb8_oracle::OracleConnectionManager;
7use datafusion::arrow::array::{
8 ArrayRef, BinaryBuilder, BooleanBuilder, Date64Builder, Decimal128Builder, Float32Builder,
9 Float64Builder, Int16Builder, Int32Builder, LargeBinaryBuilder, LargeStringBuilder,
10 RecordBatch, StringBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, make_builder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{DataFusionError, project_schema};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use derive_getters::Getters;
17use derive_with::With;
18use futures::StreamExt;
19use oracle::sql_type::OracleType as ColumnType;
20use oracle::{Connector, Row};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, With, Getters)]
24pub struct OracleConnectionOptions {
25 pub(crate) host: String,
26 pub(crate) port: u16,
27 pub(crate) username: String,
28 pub(crate) password: String,
29 pub(crate) service_name: String,
30 pub(crate) pool_max_size: usize,
31 pub(crate) stream_chunk_size: usize,
32}
33
34impl OracleConnectionOptions {
35 pub fn new(
36 host: impl Into<String>,
37 port: u16,
38 username: impl Into<String>,
39 password: impl Into<String>,
40 service_name: impl Into<String>,
41 ) -> Self {
42 Self {
43 host: host.into(),
44 port,
45 username: username.into(),
46 password: password.into(),
47 service_name: service_name.into(),
48 pool_max_size: 10,
49 stream_chunk_size: 2048,
50 }
51 }
52}
53
54#[derive(Debug)]
55pub struct OraclePool {
56 pool: bb8::Pool<OracleConnectionManager>,
57}
58
59pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
60 let connect_string = format!(
61 "//{}:{}/{}",
62 options.host, options.port, options.service_name
63 );
64 let connector = Connector::new(
65 options.username.clone(),
66 options.password.clone(),
67 connect_string,
68 );
69 let _ = connector
70 .connect()
71 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
72 let manager = OracleConnectionManager::from_connector(connector);
73 let pool = bb8::Pool::builder()
74 .max_size(options.pool_max_size as u32)
75 .build(manager)
76 .await
77 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
78 Ok(OraclePool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for OraclePool {
83 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84 let conn = self.pool.get_owned().await.map_err(|e| {
85 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
86 })?;
87 Ok(Arc::new(OracleConnection { conn }))
88 }
89}
90
91#[derive(Debug)]
92pub struct OracleConnection {
93 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
94}
95
96#[async_trait::async_trait]
97impl Connection for OracleConnection {
98 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
99 let sql = try_limit1_query(sql).unwrap_or_else(|| sql.to_string());
100 let row = self.conn.query_row(&sql, &[]).map_err(|e| {
101 DataFusionError::Execution(format!("Failed to execute query {sql} on oracle: {e:?}"))
102 })?;
103 let remote_schema = Arc::new(build_remote_schema(&row)?);
104 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
105 Ok((remote_schema, arrow_schema))
106 }
107
108 async fn query(
109 &self,
110 conn_options: &ConnectionOptions,
111 sql: &str,
112 table_schema: SchemaRef,
113 projection: Option<&Vec<usize>>,
114 ) -> DFResult<SendableRecordBatchStream> {
115 let projected_schema = project_schema(&table_schema, projection)?;
116 let projection = projection.cloned();
117 let chunk_size = conn_options.stream_chunk_size();
118 let result_set = self.conn.query(sql, &[]).map_err(|e| {
119 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
120 })?;
121 let stream = futures::stream::iter(result_set).chunks(chunk_size).boxed();
122
123 let stream = stream.map(move |rows| {
124 let rows: Vec<Row> = rows
125 .into_iter()
126 .collect::<Result<Vec<_>, _>>()
127 .map_err(|e| {
128 DataFusionError::Execution(format!(
129 "Failed to collect rows from oracle due to {e}",
130 ))
131 })?;
132 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
133 });
134
135 Ok(Box::pin(RecordBatchStreamAdapter::new(
136 projected_schema,
137 stream,
138 )))
139 }
140}
141
142fn try_limit1_query(sql: &str) -> Option<String> {
143 if sql.trim()[0..6].eq_ignore_ascii_case("select") {
144 Some(format!("SELECT * FROM ({sql}) WHERE ROWNUM <= 1"))
145 } else {
146 None
147 }
148}
149
150fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
151 match oracle_type {
152 ColumnType::Number(precision, scale) => {
153 let precision = if *precision == 0 { 38 } else { *precision };
155 let scale = if *scale == -127 { 0 } else { *scale };
156 Ok(RemoteType::Oracle(OracleType::Number(precision, scale)))
157 }
158 ColumnType::BinaryFloat => Ok(RemoteType::Oracle(OracleType::BinaryFloat)),
159 ColumnType::BinaryDouble => Ok(RemoteType::Oracle(OracleType::BinaryDouble)),
160 ColumnType::Float(precision) => Ok(RemoteType::Oracle(OracleType::Float(*precision))),
161 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
162 ColumnType::NVarchar2(size) => Ok(RemoteType::Oracle(OracleType::NVarchar2(*size))),
163 ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
164 ColumnType::NChar(size) => Ok(RemoteType::Oracle(OracleType::NChar(*size))),
165 ColumnType::Long => Ok(RemoteType::Oracle(OracleType::Long)),
166 ColumnType::CLOB => Ok(RemoteType::Oracle(OracleType::Clob)),
167 ColumnType::NCLOB => Ok(RemoteType::Oracle(OracleType::NClob)),
168 ColumnType::Raw(size) => Ok(RemoteType::Oracle(OracleType::Raw(*size))),
169 ColumnType::LongRaw => Ok(RemoteType::Oracle(OracleType::LongRaw)),
170 ColumnType::BLOB => Ok(RemoteType::Oracle(OracleType::Blob)),
171 ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
172 ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
173 ColumnType::Boolean => Ok(RemoteType::Oracle(OracleType::Boolean)),
174 _ => Err(DataFusionError::NotImplemented(format!(
175 "Unsupported oracle type: {oracle_type:?}",
176 ))),
177 }
178}
179
180fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
181 let mut remote_fields = vec![];
182 for col in row.column_info() {
183 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
184 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
185 }
186 Ok(RemoteSchema::new(remote_fields))
187}
188
189macro_rules! handle_primitive_type {
190 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
191 let builder = $builder
192 .as_any_mut()
193 .downcast_mut::<$builder_ty>()
194 .unwrap_or_else(|| {
195 panic!(
196 "Failed to downcast builder to {} for {:?} and {:?}",
197 stringify!($builder_ty),
198 $field,
199 $col
200 )
201 });
202 let v = $row.get::<usize, Option<$value_ty>>($index).map_err(|e| {
203 DataFusionError::Execution(format!(
204 "Failed to get {} value for {:?} and {:?}: {e:?}",
205 stringify!($value_ty),
206 $field,
207 $col
208 ))
209 })?;
210
211 match v {
212 Some(v) => builder.append_value($convert(v)?),
213 None => builder.append_null(),
214 }
215 }};
216}
217
218fn rows_to_batch(
219 rows: &[Row],
220 table_schema: &SchemaRef,
221 projection: Option<&Vec<usize>>,
222) -> DFResult<RecordBatch> {
223 let projected_schema = project_schema(table_schema, projection)?;
224 let mut array_builders = vec![];
225 for field in table_schema.fields() {
226 let builder = make_builder(field.data_type(), rows.len());
227 array_builders.push(builder);
228 }
229
230 for row in rows {
231 for (idx, field) in table_schema.fields.iter().enumerate() {
232 if !projections_contains(projection, idx) {
233 continue;
234 }
235 let builder = &mut array_builders[idx];
236 let col = row.column_info().get(idx);
237 match field.data_type() {
238 DataType::Int16 => {
239 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
240 Ok::<_, DataFusionError>(v)
241 });
242 }
243 DataType::Int32 => {
244 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
245 Ok::<_, DataFusionError>(v)
246 });
247 }
248 DataType::Float32 => {
249 handle_primitive_type!(
250 builder,
251 field,
252 col,
253 Float32Builder,
254 f32,
255 row,
256 idx,
257 |v| { Ok::<_, DataFusionError>(v) }
258 );
259 }
260 DataType::Float64 => {
261 handle_primitive_type!(
262 builder,
263 field,
264 col,
265 Float64Builder,
266 f64,
267 row,
268 idx,
269 |v| { Ok::<_, DataFusionError>(v) }
270 );
271 }
272 DataType::Utf8 => {
273 handle_primitive_type!(
274 builder,
275 field,
276 col,
277 StringBuilder,
278 String,
279 row,
280 idx,
281 |v| { Ok::<_, DataFusionError>(v) }
282 );
283 }
284 DataType::LargeUtf8 => {
285 handle_primitive_type!(
286 builder,
287 field,
288 col,
289 LargeStringBuilder,
290 String,
291 row,
292 idx,
293 |v| { Ok::<_, DataFusionError>(v) }
294 );
295 }
296 DataType::Decimal128(_precision, scale) => {
297 handle_primitive_type!(
298 builder,
299 field,
300 col,
301 Decimal128Builder,
302 String,
303 row,
304 idx,
305 |v: String| {
306 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
307 DataFusionError::Execution(format!(
308 "Failed to parse BigDecimal from {v:?}: {e:?}",
309 ))
310 })?;
311 big_decimal_to_i128(&decimal, Some(*scale as i32)).ok_or_else(|| {
312 DataFusionError::Execution(format!(
313 "Failed to convert BigDecimal to i128 for {decimal:?}",
314 ))
315 })
316 }
317 );
318 }
319 DataType::Timestamp(TimeUnit::Second, None) => {
320 handle_primitive_type!(
321 builder,
322 field,
323 col,
324 TimestampSecondBuilder,
325 chrono::NaiveDateTime,
326 row,
327 idx,
328 |v: chrono::NaiveDateTime| {
329 let t = v.and_utc().timestamp();
330 Ok::<_, DataFusionError>(t)
331 }
332 );
333 }
334 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
335 handle_primitive_type!(
336 builder,
337 field,
338 col,
339 TimestampNanosecondBuilder,
340 chrono::NaiveDateTime,
341 row,
342 idx,
343 |v: chrono::NaiveDateTime| {
344 v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
345 DataFusionError::Execution(format!(
346 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
347 ))
348 })
349 }
350 );
351 }
352 DataType::Date64 => {
353 handle_primitive_type!(
354 builder,
355 field,
356 col,
357 Date64Builder,
358 chrono::NaiveDateTime,
359 row,
360 idx,
361 |v: chrono::NaiveDateTime| {
362 Ok::<_, DataFusionError>(v.and_utc().timestamp_millis())
363 }
364 );
365 }
366 DataType::Boolean => {
367 handle_primitive_type!(
368 builder,
369 field,
370 col,
371 BooleanBuilder,
372 bool,
373 row,
374 idx,
375 |v| { Ok::<_, DataFusionError>(v) }
376 );
377 }
378 DataType::Binary => {
379 handle_primitive_type!(
380 builder,
381 field,
382 col,
383 BinaryBuilder,
384 Vec<u8>,
385 row,
386 idx,
387 |v| { Ok::<_, DataFusionError>(v) }
388 );
389 }
390 DataType::LargeBinary => {
391 handle_primitive_type!(
392 builder,
393 field,
394 col,
395 LargeBinaryBuilder,
396 Vec<u8>,
397 row,
398 idx,
399 |v| { Ok::<_, DataFusionError>(v) }
400 );
401 }
402 _ => {
403 return Err(DataFusionError::NotImplemented(format!(
404 "Unsupported data type {:?} for col: {:?}",
405 field.data_type(),
406 col
407 )));
408 }
409 }
410 }
411 }
412
413 let projected_columns = array_builders
414 .into_iter()
415 .enumerate()
416 .filter(|(idx, _)| projections_contains(projection, *idx))
417 .map(|(_, mut builder)| builder.finish())
418 .collect::<Vec<ArrayRef>>();
419 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
420}