datafusion_remote_table/connection/
oracle.rs1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4 project_remote_schema, Connection, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5 RemoteType, Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{make_builder, ArrayRef, RecordBatch, StringBuilder};
9use datafusion::arrow::datatypes::SchemaRef;
10use datafusion::common::{project_schema, DataFusionError};
11use datafusion::execution::SendableRecordBatchStream;
12use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
13use futures::StreamExt;
14use oracle::sql_type::OracleType as ColumnType;
15use oracle::{Connector, Row};
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
19pub struct OracleConnectionOptions {
20 pub host: String,
21 pub port: u16,
22 pub username: String,
23 pub password: String,
24 pub service_name: String,
25}
26
27impl OracleConnectionOptions {
28 pub fn new(
29 host: impl Into<String>,
30 port: u16,
31 username: impl Into<String>,
32 password: impl Into<String>,
33 service_name: impl Into<String>,
34 ) -> Self {
35 Self {
36 host: host.into(),
37 port,
38 username: username.into(),
39 password: password.into(),
40 service_name: service_name.into(),
41 }
42 }
43}
44
45#[derive(Debug)]
46pub struct OraclePool {
47 pool: bb8::Pool<OracleConnectionManager>,
48}
49
50pub async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
51 let connect_string = format!(
52 "//{}:{}/{}",
53 options.host, options.port, options.service_name
54 );
55 let connector = Connector::new(
56 options.username.clone(),
57 options.password.clone(),
58 connect_string,
59 );
60 let _ = connector
61 .connect()
62 .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
63 let manager = OracleConnectionManager::from_connector(connector);
64 let pool = bb8::Pool::builder()
65 .build(manager)
66 .await
67 .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
68 Ok(OraclePool { pool })
69}
70
71#[async_trait::async_trait]
72impl Pool for OraclePool {
73 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
74 let conn = self.pool.get_owned().await.map_err(|e| {
75 DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
76 })?;
77 Ok(Arc::new(OracleConnection { conn }))
78 }
79}
80
81#[derive(Debug)]
82pub struct OracleConnection {
83 conn: bb8::PooledConnection<'static, OracleConnectionManager>,
84}
85
86#[async_trait::async_trait]
87impl Connection for OracleConnection {
88 async fn infer_schema(
89 &self,
90 sql: &str,
91 transform: Option<Arc<dyn Transform>>,
92 ) -> DFResult<(RemoteSchema, SchemaRef)> {
93 let row = self.conn.query_row(sql, &[]).map_err(|e| {
94 DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
95 })?;
96 let remote_schema = build_remote_schema(&row)?;
97 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
98 if let Some(transform) = transform {
99 let batch = rows_to_batch(&[row], arrow_schema, None)?;
100 let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
101 Ok((remote_schema, transformed_batch.schema()))
102 } else {
103 Ok((remote_schema, arrow_schema))
104 }
105 }
106
107 async fn query(
108 &self,
109 sql: String,
110 projection: Option<Vec<usize>>,
111 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
112 let result_set = self.conn.query(&sql, &[]).unwrap();
113 let mut stream = futures::stream::iter(result_set).chunks(2000).boxed();
114
115 let Some(first_chunk) = stream.next().await else {
116 return Err(DataFusionError::Execution(
117 "No data returned from oracle".to_string(),
118 ));
119 };
120 let first_chunk: Vec<Row> = first_chunk
121 .into_iter()
122 .collect::<Result<Vec<_>, _>>()
123 .map_err(|e| {
124 DataFusionError::Execution(
125 format!("Failed to collect rows from oracle due to {e}",),
126 )
127 })?;
128 let Some(first_row) = first_chunk.first() else {
129 return Err(DataFusionError::Execution(
130 "No data returned from oracle".to_string(),
131 ));
132 };
133
134 let remote_schema = build_remote_schema(first_row)?;
135 let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
136 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
137 let first_chunk = rows_to_batch(
138 first_chunk.as_slice(),
139 arrow_schema.clone(),
140 projection.as_ref(),
141 )?;
142 let schema = first_chunk.schema();
143
144 let mut stream = stream.map(move |rows| {
145 let rows: Vec<Row> = rows
146 .into_iter()
147 .collect::<Result<Vec<_>, _>>()
148 .map_err(|e| {
149 DataFusionError::Execution(format!(
150 "Failed to collect rows from oracle due to {e}",
151 ))
152 })?;
153 let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
154 Ok::<RecordBatch, DataFusionError>(batch)
155 });
156
157 let output_stream = async_stream::stream! {
158 yield Ok(first_chunk);
159 while let Some(batch) = stream.next().await {
160 yield batch
161 }
162 };
163
164 Ok((
165 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
166 projected_remote_schema,
167 ))
168 }
169}
170
171fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
172 match oracle_type {
173 ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
174 _ => Err(DataFusionError::NotImplemented(format!(
175 "Unsupported oracle type: {oracle_type:?}",
176 ))),
177 }
178}
179
180fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
181 let mut remote_fields = vec![];
182 for col in row.column_info() {
183 let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
184 remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
185 }
186 Ok(RemoteSchema::new(remote_fields))
187}
188
189fn rows_to_batch(
190 rows: &[Row],
191 arrow_schema: SchemaRef,
192 projection: Option<&Vec<usize>>,
193) -> DFResult<RecordBatch> {
194 let projected_schema = project_schema(&arrow_schema, projection)?;
195 let mut array_builders = vec![];
196 for field in arrow_schema.fields() {
197 let builder = make_builder(field.data_type(), rows.len());
198 array_builders.push(builder);
199 }
200
201 for row in rows {
202 for (i, col) in row.column_info().iter().enumerate() {
203 let builder = &mut array_builders[i];
204 match col.oracle_type() {
205 ColumnType::Varchar2(_size) => {
206 let builder = builder
207 .as_any_mut()
208 .downcast_mut::<StringBuilder>()
209 .unwrap();
210 let v = row.get::<usize, Option<String>>(i).unwrap();
211
212 match v {
213 Some(v) => builder.append_value(v),
214 None => builder.append_null(),
215 }
216 }
217 _ => {
218 return Err(DataFusionError::NotImplemented(format!(
219 "Unsupported oracle type: {col:?}",
220 )))
221 }
222 }
223 }
224 }
225
226 let projected_columns = array_builders
227 .into_iter()
228 .enumerate()
229 .filter(|(idx, _)| projections_contains(projection, *idx))
230 .map(|(_, mut builder)| builder.finish())
231 .collect::<Vec<ArrayRef>>();
232 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
233}