datafusion_remote_table/connection/
oracle.rs1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4 project_remote_schema, Connection, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5 RemoteType, Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{
9 make_builder, ArrayRef, Decimal128Builder, RecordBatch, StringBuilder,
10 TimestampNanosecondBuilder, TimestampSecondBuilder,
11};
12use datafusion::arrow::datatypes::SchemaRef;
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::StreamExt;
17use oracle::sql_type::OracleType as ColumnType;
18use oracle::{Connector, Row};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
22pub struct OracleConnectionOptions {
23 pub host: String,
24 pub port: u16,
25 pub username: String,
26 pub password: String,
27 pub service_name: String,
28}
29
30impl OracleConnectionOptions {
31 pub fn new(
32 host: impl Into<String>,
33 port: u16,
34 username: impl Into<String>,
35 password: impl Into<String>,
36 service_name: impl Into<String>,
37 ) -> Self {
38 Self {
39 host: host.into(),
40 port,
41 username: username.into(),
42 password: password.into(),
43 service_name: service_name.into(),
44 }
45 }
46}
47
48#[derive(Debug)]
49pub struct OraclePool {
50 pool: bb8::Pool<OracleConnectionManager>,
51}
52
53pub async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
54 let connect_string = format!(
55 "//{}:{}/{}",
56 options.host, options.port, options.service_name
57 );
58 let connector = Connector::new(
59 options.username.clone(),
60 options.password.clone(),
61 connect_string,
62 );
63 let _ = connector
64 .connect()
65 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
66 let manager = OracleConnectionManager::from_connector(connector);
67 let pool = bb8::Pool::builder()
68 .build(manager)
69 .await
70 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
71 Ok(OraclePool { pool })
72}
73
74#[async_trait::async_trait]
75impl Pool for OraclePool {
76 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
77 let conn = self.pool.get_owned().await.map_err(|e| {
78 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
79 })?;
80 Ok(Arc::new(OracleConnection { conn }))
81 }
82}
83
84#[derive(Debug)]
85pub struct OracleConnection {
86 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
87}
88
89#[async_trait::async_trait]
90impl Connection for OracleConnection {
91 async fn infer_schema(
92 &self,
93 sql: &str,
94 transform: Option<Arc<dyn Transform>>,
95 ) -> DFResult<(RemoteSchema, SchemaRef)> {
96 let row = self.conn.query_row(sql, &[]).map_err(|e| {
97 DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
98 })?;
99 let remote_schema = build_remote_schema(&row)?;
100 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
101 if let Some(transform) = transform {
102 let batch = rows_to_batch(&[row], arrow_schema, None)?;
103 let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
104 Ok((remote_schema, transformed_batch.schema()))
105 } else {
106 Ok((remote_schema, arrow_schema))
107 }
108 }
109
110 async fn query(
111 &self,
112 sql: String,
113 projection: Option<Vec<usize>>,
114 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
115 let result_set = self.conn.query(&sql, &[]).unwrap();
116 let mut stream = futures::stream::iter(result_set).chunks(2000).boxed();
117
118 let Some(first_chunk) = stream.next().await else {
119 return Err(DataFusionError::Execution(
120 "No data returned from oracle".to_string(),
121 ));
122 };
123 let first_chunk: Vec<Row> = first_chunk
124 .into_iter()
125 .collect::<Result<Vec<_>, _>>()
126 .map_err(|e| {
127 DataFusionError::Execution(
128 format!("Failed to collect rows from oracle due to {e}",),
129 )
130 })?;
131 let Some(first_row) = first_chunk.first() else {
132 return Err(DataFusionError::Execution(
133 "No data returned from oracle".to_string(),
134 ));
135 };
136
137 let remote_schema = build_remote_schema(first_row)?;
138 let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
139 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
140 let first_chunk = rows_to_batch(
141 first_chunk.as_slice(),
142 arrow_schema.clone(),
143 projection.as_ref(),
144 )?;
145 let schema = first_chunk.schema();
146
147 let mut stream = stream.map(move |rows| {
148 let rows: Vec<Row> = rows
149 .into_iter()
150 .collect::<Result<Vec<_>, _>>()
151 .map_err(|e| {
152 DataFusionError::Execution(format!(
153 "Failed to collect rows from oracle due to {e}",
154 ))
155 })?;
156 let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
157 Ok::<RecordBatch, DataFusionError>(batch)
158 });
159
160 let output_stream = async_stream::stream! {
161 yield Ok(first_chunk);
162 while let Some(batch) = stream.next().await {
163 yield batch
164 }
165 };
166
167 Ok((
168 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
169 projected_remote_schema,
170 ))
171 }
172}
173
174fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
175 match oracle_type {
176 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
177 ColumnType::Char(size) => Ok(RemoteType::Oracle(OracleType::Char(*size))),
178 ColumnType::Number(precision, scale) => {
179 Ok(RemoteType::Oracle(OracleType::Number(*precision, *scale)))
180 }
181 ColumnType::Date => Ok(RemoteType::Oracle(OracleType::Date)),
182 ColumnType::Timestamp(_) => Ok(RemoteType::Oracle(OracleType::Timestamp)),
183 _ => Err(DataFusionError::NotImplemented(format!(
184 "Unsupported oracle type: {oracle_type:?}",
185 ))),
186 }
187}
188
189fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
190 let mut remote_fields = vec![];
191 for col in row.column_info() {
192 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
193 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
194 }
195 Ok(RemoteSchema::new(remote_fields))
196}
197
198macro_rules! handle_primitive_type {
199 ($builder:expr, $field:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
200 let builder = $builder
201 .as_any_mut()
202 .downcast_mut::<$builder_ty>()
203 .unwrap_or_else(|| {
204 panic!(
205 concat!(
206 "Failed to downcast builder to ",
207 stringify!($builder_ty),
208 " for {:?}"
209 ),
210 $field
211 )
212 });
213 let v = $row
214 .get::<usize, Option<$value_ty>>($index)
215 .unwrap_or_else(|e| {
216 panic!(
217 concat!(
218 "Failed to get ",
219 stringify!($value_ty),
220 " value for {:?}: {:?}"
221 ),
222 $field, e
223 )
224 });
225
226 match v {
227 Some(v) => builder.append_value(v),
228 None => builder.append_null(),
229 }
230 }};
231}
232
233fn rows_to_batch(
234 rows: &[Row],
235 arrow_schema: SchemaRef,
236 projection: Option<&Vec<usize>>,
237) -> DFResult<RecordBatch> {
238 let projected_schema = project_schema(&arrow_schema, projection)?;
239 let mut array_builders = vec![];
240 for field in arrow_schema.fields() {
241 let builder = make_builder(field.data_type(), rows.len());
242 array_builders.push(builder);
243 }
244
245 for row in rows {
246 for (i, col) in row.column_info().iter().enumerate() {
247 let builder = &mut array_builders[i];
248 match col.oracle_type() {
249 ColumnType::Varchar2(_size) | ColumnType::Char(_size) => {
250 handle_primitive_type!(builder, col, StringBuilder, String, row, i);
251 }
252 ColumnType::Number(_precision, scale) => {
253 let builder = builder
254 .as_any_mut()
255 .downcast_mut::<Decimal128Builder>()
256 .unwrap_or_else(|| {
257 panic!("Failed to downcast builder to Decimal128Builder for {col:?}")
258 });
259
260 let v = row.get::<usize, Option<String>>(i).unwrap_or_else(|e| {
261 panic!("Failed to get String value for {col:?}: {e:?}")
262 });
263
264 match v {
265 Some(v) => {
266 let decimal = v.parse::<bigdecimal::BigDecimal>().map_err(|e| {
267 DataFusionError::Execution(format!(
268 "Failed to parse BigDecimal from {v:?}: {e:?}",
269 ))
270 })?;
271 let Some(v) = big_decimal_to_i128(&decimal, Some(*scale as u32)) else {
272 return Err(DataFusionError::Execution(format!(
273 "Failed to convert BigDecimal to i128 for {decimal:?}",
274 )));
275 };
276 builder.append_value(v);
277 }
278 None => builder.append_null(),
279 }
280 }
281 ColumnType::Date => {
282 let builder = builder
283 .as_any_mut()
284 .downcast_mut::<TimestampSecondBuilder>()
285 .unwrap_or_else(|| {
286 panic!(
287 "Failed to downcast builder to TimestampSecondBuilder for {col:?}"
288 )
289 });
290 let v = row
291 .get::<usize, Option<chrono::NaiveDateTime>>(i)
292 .unwrap_or_else(|e| {
293 panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
294 });
295
296 match v {
297 Some(v) => {
298 let t = v.and_utc().timestamp();
299 builder.append_value(t);
300 }
301 None => builder.append_null(),
302 }
303 }
304 ColumnType::Timestamp(_) => {
305 let builder = builder
306 .as_any_mut()
307 .downcast_mut::<TimestampNanosecondBuilder>()
308 .unwrap_or_else(|| {
309 panic!("Failed to downcast builder to TimestampNanosecondBuilder for {col:?}")
310 });
311 let v = row
312 .get::<usize, Option<chrono::NaiveDateTime>>(i)
313 .unwrap_or_else(|e| {
314 panic!("Failed to get chrono::NaiveDateTime value for {col:?}: {e:?}")
315 });
316
317 match v {
318 Some(v) => {
319 let t = v.and_utc().timestamp_nanos_opt().ok_or_else(|| {
320 DataFusionError::Execution(format!(
321 "Failed to convert chrono::NaiveDateTime {v} to nanos timestamp"
322 ))
323 })?;
324 builder.append_value(t);
325 }
326 None => builder.append_null(),
327 }
328 }
329 _ => {
330 return Err(DataFusionError::NotImplemented(format!(
331 "Unsupported oracle type: {col:?}",
332 )))
333 }
334 }
335 }
336 }
337
338 let projected_columns = array_builders
339 .into_iter()
340 .enumerate()
341 .filter(|(idx, _)| projections_contains(projection, *idx))
342 .map(|(_, mut builder)| builder.finish())
343 .collect::<Vec<ArrayRef>>();
344 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
345}