datafusion_remote_table/connection/
oracle.rs1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4 Connection, DFResult, OracleType, Pool, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteType,
5 Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9 make_builder, ArrayRef, Decimal128Builder, RecordBatch, StringBuilder,
10 TimestampNanosecondBuilder, TimestampSecondBuilder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::StreamExt;
17use oracle::sql_type::OracleType as ColumnType;
18use oracle::{Connector, Row};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
22pub struct OracleConnectionOptions {
23 pub host: String,
24 pub port: u16,
25 pub username: String,
26 pub password: String,
27 pub service_name: String,
28}
29
30impl OracleConnectionOptions {
31 pub fn new(
32 host: impl Into<String>,
33 port: u16,
34 username: impl Into<String>,
35 password: impl Into<String>,
36 service_name: impl Into<String>,
37 ) -> Self {
38 Self {
39 host: host.into(),
40 port,
41 username: username.into(),
42 password: password.into(),
43 service_name: service_name.into(),
44 }
45 }
46}
47
48#[derive(Debug)]
49pub struct OraclePool {
50 pool: bb8::Pool<OracleConnectionManager>,
51}
52
53pub async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
54 let connect_string = format!(
55 "//{}:{}/{}",
56 options.host, options.port, options.service_name
57 );
58 let connector = Connector::new(
59 options.username.clone(),
60 options.password.clone(),
61 connect_string,
62 );
63 let _ = connector
64 .connect()
65 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
66 let manager = OracleConnectionManager::from_connector(connector);
67 let pool = bb8::Pool::builder()
68 .build(manager)
69 .await
70 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
71 Ok(OraclePool { pool })
72}
73
74#[async_trait::async_trait]
75impl Pool for OraclePool {
76 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
77 let conn = self.pool.get_owned().await.map_err(|e| {
78 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
79 })?;
80 Ok(Arc::new(OracleConnection { conn }))
81 }
82}
83
84#[derive(Debug)]
85pub struct OracleConnection {
86 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
87}
88
89#[async_trait::async_trait]
90impl Connection for OracleConnection {
91 async fn infer_schema(
92 &self,
93 sql: &str,
94 transform: Option<Arc<dyn Transform>>,
95 ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
96 let row = self.conn.query_row(sql, &[]).map_err(|e| {
97 DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
98 })?;
99 let remote_schema = Arc::new(build_remote_schema(&row)?);
100 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
101 if let Some(transform) = transform {
102 let batch = rows_to_batch(&[row], &arrow_schema, None)?;
103 let transformed_batch = transform_batch(
104 batch,
105 transform.as_ref(),
106 &arrow_schema,
107 None,
108 Some(&remote_schema),
109 )?;
110 Ok((remote_schema, transformed_batch.schema()))
111 } else {
112 Ok((remote_schema, arrow_schema))
113 }
114 }
115
116 async fn query(
117 &self,
118 sql: String,
119 table_schema: SchemaRef,
120 projection: Option<Vec<usize>>,
121 ) -> DFResult<SendableRecordBatchStream> {
122 let projected_schema = project_schema(&table_schema, projection.as_ref())?;
123 let result_set = self.conn.query(&sql, &[]).map_err(|e| {
124 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
125 })?;
126 let stream = futures::stream::iter(result_set).chunks(2000).boxed();
127
128 let stream = stream.map(move |rows| {
129 let rows: Vec<Row> = rows
130 .into_iter()
131 .collect::<Result<Vec<_>, _>>()
132 .map_err(|e| {
133 DataFusionError::Execution(format!(
134 "Failed to collect rows from oracle due to {e}",
135 ))
136 })?;
137 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
138 });
139
140 Ok(Box::pin(RecordBatchStreamAdapter::new(
141 projected_schema,
142 stream,
143 )))
144 }
145}
146
147fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
148 match oracle_type {
149 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
150 ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
151 ColumnType::Number(precision, scale) => {
152 Ok(RemoteType::Oracle(OracleType::Number(*precision, *scale)))
153 }
154 ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
155 ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
156 _ => Err(DataFusionError::NotImplemented(format!(
157 "Unsupported oracle type: {oracle_type:?}",
158 ))),
159 }
160}
161
162fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
163 let mut remote_fields = vec![];
164 for col in row.column_info() {
165 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
166 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
167 }
168 Ok(RemoteSchema::new(remote_fields))
169}
170
171macro_rules! handle_primitive_type {
172 ($builder:expr, $field:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
173 let builder = $builder
174 .as_any_mut()
175 .downcast_mut::<$builder_ty>()
176 .unwrap_or_else(|| {
177 panic!(
178 concat!(
179 "Failed to downcast builder to ",
180 stringify!($builder_ty),
181 " for {:?}"
182 ),
183 $field
184 )
185 });
186 let v = $row
187 .get::<usize, Option<$value_ty>>($index)
188 .unwrap_or_else(|e| {
189 panic!(
190 concat!(
191 "Failed to get ",
192 stringify!($value_ty),
193 " value for {:?}: {:?}"
194 ),
195 $field, e
196 )
197 });
198
199 match v {
200 Some(v) => builder.append_value(v),
201 None => builder.append_null(),
202 }
203 }};
204}
205
206fn rows_to_batch(
207 rows: &[Row],
208 table_schema: &SchemaRef,
209 projection: Option<&Vec<usize>>,
210) -> DFResult<RecordBatch> {
211 let projected_schema = project_schema(table_schema, projection)?;
212 let mut array_builders = vec![];
213 for field in table_schema.fields() {
214 let builder = make_builder(field.data_type(), rows.len());
215 array_builders.push(builder);
216 }
217
218 for row in rows {
219 for (idx, field) in table_schema.fields.iter().enumerate() {
220 if !projections_contains(projection, idx) {
221 continue;
222 }
223 let builder = &mut array_builders[idx];
224 let col = row.column_info().get(idx);
225 match field.data_type() {
226 DataType::Utf8 => {
227 handle_primitive_type!(builder, col, StringBuilder, String, row, idx);
228 }
229 DataType::Decimal128(_precision, scale) => {
230 let builder = builder
231 .as_any_mut()
232 .downcast_mut::<Decimal128Builder>()
233 .unwrap_or_else(|| {
234 panic!("Failed to downcast builder to Decimal128Builder for {col:?}")
235 });
236
237 let v = row.get::<usize, Option<String>>(idx).unwrap_or_else(|e| {
238 panic!("Failed to get String value for {col:?}: {e:?}")
239 });
240
241 match v {
242 Some(v) => {
243 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
244 DataFusionError::Execution(format!(
245 "Failed to parse BigDecimal from {v:?}: {e:?}",
246 ))
247 })?;
248 let Some(v) = big_decimal_to_i128(&decimal, Some(*scale as u32)) else {
249 return Err(DataFusionError::Execution(format!(
250 "Failed to convert BigDecimal to i128 for {decimal:?}",
251 )));
252 };
253 builder.append_value(v);
254 }
255 None => builder.append_null(),
256 }
257 }
258 DataType::Timestamp(TimeUnit::Second, None) => {
259 let builder = builder
260 .as_any_mut()
261 .downcast_mut::<TimestampSecondBuilder>()
262 .unwrap_or_else(|| {
263 panic!(
264 "Failed to downcast builder to TimestampSecondBuilder for {col:?}"
265 )
266 });
267 let v = row
268 .get::<usize, Option<chrono::NaiveDateTime>>(idx)
269 .unwrap_or_else(|e| {
270 panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
271 });
272
273 match v {
274 Some(v) => {
275 let t = v.and_utc().timestamp();
276 builder.append_value(t);
277 }
278 None => builder.append_null(),
279 }
280 }
281 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
282 let builder = builder
283 .as_any_mut()
284 .downcast_mut::<TimestampNanosecondBuilder>()
285 .unwrap_or_else(|| {
286 panic!("Failed to downcast builder to TimestampNanosecondBuilder for {col:?}")
287 });
288 let v = row
289 .get::<usize, Option<chrono::NaiveDateTime>>(idx)
290 .unwrap_or_else(|e| {
291 panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
292 });
293
294 match v {
295 Some(v) => {
296 let t = v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
297 DataFusionError::Execution(format!(
298 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
299 ))
300 })?;
301 builder.append_value(t);
302 }
303 None => builder.append_null(),
304 }
305 }
306 _ => {
307 return Err(DataFusionError::NotImplemented(format!(
308 "Unsupported data type {:?} for col: {:?}",
309 field.data_type(),
310 col
311 )));
312 }
313 }
314 }
315 }
316
317 let projected_columns = array_builders
318 .into_iter()
319 .enumerate()
320 .filter(|(idx, _)| projections_contains(projection, *idx))
321 .map(|(_, mut builder)| builder.finish())
322 .collect::<Vec<ArrayRef>>();
323 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
324}