datafusion_remote_table/connection/
oracle.rs1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4 Connection, ConnectionOptions, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5 RemoteSchemaRef, RemoteType, Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9 make_builder, ArrayRef, Decimal128Builder, RecordBatch, StringBuilder,
10 TimestampNanosecondBuilder, TimestampSecondBuilder,
11};
12use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::StreamExt;
17use oracle::sql_type::OracleType as ColumnType;
18use oracle::{Connector, Row};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
22pub struct OracleConnectionOptions {
23 pub(crate) host: String,
24 pub(crate) port: u16,
25 pub(crate) username: String,
26 pub(crate) password: String,
27 pub(crate) service_name: String,
28 pub(crate) chunk_size: Option<usize>,
29}
30
31impl OracleConnectionOptions {
32 pub fn new(
33 host: impl Into<String>,
34 port: u16,
35 username: impl Into<String>,
36 password: impl Into<String>,
37 service_name: impl Into<String>,
38 ) -> Self {
39 Self {
40 host: host.into(),
41 port,
42 username: username.into(),
43 password: password.into(),
44 service_name: service_name.into(),
45 chunk_size: None,
46 }
47 }
48}
49
50#[derive(Debug)]
51pub struct OraclePool {
52 pool: bb8::Pool<OracleConnectionManager>,
53}
54
55pub(crate) async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
56 let connect_string = format!(
57 "//{}:{}/{}",
58 options.host, options.port, options.service_name
59 );
60 let connector = Connector::new(
61 options.username.clone(),
62 options.password.clone(),
63 connect_string,
64 );
65 let _ = connector
66 .connect()
67 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
68 let manager = OracleConnectionManager::from_connector(connector);
69 let pool = bb8::Pool::builder()
70 .build(manager)
71 .await
72 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
73 Ok(OraclePool { pool })
74}
75
76#[async_trait::async_trait]
77impl Pool for OraclePool {
78 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
79 let conn = self.pool.get_owned().await.map_err(|e| {
80 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
81 })?;
82 Ok(Arc::new(OracleConnection { conn }))
83 }
84}
85
86#[derive(Debug)]
87pub struct OracleConnection {
88 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
89}
90
91#[async_trait::async_trait]
92impl Connection for OracleConnection {
93 async fn infer_schema(
94 &self,
95 sql: &str,
96 transform: Option<Arc<dyn Transform>>,
97 ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
98 let row = self.conn.query_row(sql, &[]).map_err(|e| {
99 DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
100 })?;
101 let remote_schema = Arc::new(build_remote_schema(&row)?);
102 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
103 if let Some(transform) = transform {
104 let batch = rows_to_batch(&[row], &arrow_schema, None)?;
105 let transformed_batch = transform_batch(
106 batch,
107 transform.as_ref(),
108 &arrow_schema,
109 None,
110 Some(&remote_schema),
111 )?;
112 Ok((remote_schema, transformed_batch.schema()))
113 } else {
114 Ok((remote_schema, arrow_schema))
115 }
116 }
117
118 async fn query(
119 &self,
120 conn_options: &ConnectionOptions,
121 sql: &str,
122 table_schema: SchemaRef,
123 projection: Option<&Vec<usize>>,
124 ) -> DFResult<SendableRecordBatchStream> {
125 let projected_schema = project_schema(&table_schema, projection)?;
126 let projection = projection.cloned();
127 let chunk_size = conn_options.chunk_size();
128 let result_set = self.conn.query(sql, &[]).map_err(|e| {
129 DataFusionError::Execution(format!("Failed to execute query on oracle: {e:?}"))
130 })?;
131 let stream = futures::stream::iter(result_set)
132 .chunks(chunk_size.unwrap_or(2048))
133 .boxed();
134
135 let stream = stream.map(move |rows| {
136 let rows: Vec<Row> = rows
137 .into_iter()
138 .collect::<Result<Vec<_>, _>>()
139 .map_err(|e| {
140 DataFusionError::Execution(format!(
141 "Failed to collect rows from oracle due to {e}",
142 ))
143 })?;
144 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
145 });
146
147 Ok(Box::pin(RecordBatchStreamAdapter::new(
148 projected_schema,
149 stream,
150 )))
151 }
152}
153
154fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
155 match oracle_type {
156 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
157 ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
158 ColumnType::Number(precision, scale) => {
159 Ok(RemoteType::Oracle(OracleType::Number(*precision, *scale)))
160 }
161 ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
162 ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
163 _ => Err(DataFusionError::NotImplemented(format!(
164 "Unsupported oracle type: {oracle_type:?}",
165 ))),
166 }
167}
168
169fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
170 let mut remote_fields = vec![];
171 for col in row.column_info() {
172 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
173 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
174 }
175 Ok(RemoteSchema::new(remote_fields))
176}
177
178macro_rules! handle_primitive_type {
179 ($builder:expr, $field:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
180 let builder = $builder
181 .as_any_mut()
182 .downcast_mut::<$builder_ty>()
183 .unwrap_or_else(|| {
184 panic!(
185 concat!(
186 "Failed to downcast builder to ",
187 stringify!($builder_ty),
188 " for {:?}"
189 ),
190 $field
191 )
192 });
193 let v = $row
194 .get::<usize, Option<$value_ty>>($index)
195 .unwrap_or_else(|e| {
196 panic!(
197 concat!(
198 "Failed to get ",
199 stringify!($value_ty),
200 " value for {:?}: {:?}"
201 ),
202 $field, e
203 )
204 });
205
206 match v {
207 Some(v) => builder.append_value(v),
208 None => builder.append_null(),
209 }
210 }};
211}
212
213fn rows_to_batch(
214 rows: &[Row],
215 table_schema: &SchemaRef,
216 projection: Option<&Vec<usize>>,
217) -> DFResult<RecordBatch> {
218 let projected_schema = project_schema(table_schema, projection)?;
219 let mut array_builders = vec![];
220 for field in table_schema.fields() {
221 let builder = make_builder(field.data_type(), rows.len());
222 array_builders.push(builder);
223 }
224
225 for row in rows {
226 for (idx, field) in table_schema.fields.iter().enumerate() {
227 if !projections_contains(projection, idx) {
228 continue;
229 }
230 let builder = &mut array_builders[idx];
231 let col = row.column_info().get(idx);
232 match field.data_type() {
233 DataType::Utf8 => {
234 handle_primitive_type!(builder, col, StringBuilder, String, row, idx);
235 }
236 DataType::Decimal128(_precision, scale) => {
237 let builder = builder
238 .as_any_mut()
239 .downcast_mut::<Decimal128Builder>()
240 .unwrap_or_else(|| {
241 panic!("Failed to downcast builder to Decimal128Builder for {col:?}")
242 });
243
244 let v = row.get::<usize, Option<String>>(idx).unwrap_or_else(|e| {
245 panic!("Failed to get String value for {col:?}: {e:?}")
246 });
247
248 match v {
249 Some(v) => {
250 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
251 DataFusionError::Execution(format!(
252 "Failed to parse BigDecimal from {v:?}: {e:?}",
253 ))
254 })?;
255 let Some(v) = big_decimal_to_i128(&decimal, Some(*scale as u32)) else {
256 return Err(DataFusionError::Execution(format!(
257 "Failed to convert BigDecimal to i128 for {decimal:?}",
258 )));
259 };
260 builder.append_value(v);
261 }
262 None => builder.append_null(),
263 }
264 }
265 DataType::Timestamp(TimeUnit::Second, None) => {
266 let builder = builder
267 .as_any_mut()
268 .downcast_mut::<TimestampSecondBuilder>()
269 .unwrap_or_else(|| {
270 panic!(
271 "Failed to downcast builder to TimestampSecondBuilder for {col:?}"
272 )
273 });
274 let v = row
275 .get::<usize, Option<chrono::NaiveDateTime>>(idx)
276 .unwrap_or_else(|e| {
277 panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
278 });
279
280 match v {
281 Some(v) => {
282 let t = v.and_utc().timestamp();
283 builder.append_value(t);
284 }
285 None => builder.append_null(),
286 }
287 }
288 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
289 let builder = builder
290 .as_any_mut()
291 .downcast_mut::<TimestampNanosecondBuilder>()
292 .unwrap_or_else(|| {
293 panic!("Failed to downcast builder to TimestampNanosecondBuilder for {col:?}")
294 });
295 let v = row
296 .get::<usize, Option<chrono::NaiveDateTime>>(idx)
297 .unwrap_or_else(|e| {
298 panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
299 });
300
301 match v {
302 Some(v) => {
303 let t = v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
304 DataFusionError::Execution(format!(
305 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
306 ))
307 })?;
308 builder.append_value(t);
309 }
310 None => builder.append_null(),
311 }
312 }
313 _ => {
314 return Err(DataFusionError::NotImplemented(format!(
315 "Unsupported data type {:?} for col: {:?}",
316 field.data_type(),
317 col
318 )));
319 }
320 }
321 }
322 }
323
324 let projected_columns = array_builders
325 .into_iter()
326 .enumerate()
327 .filter(|(idx, _)| projections_contains(projection, *idx))
328 .map(|(_, mut builder)| builder.finish())
329 .collect::<Vec<ArrayRef>>();
330 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
331}