datafusion_remote_table/connection/
oracle.rs

1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4    project_remote_schema, Connection, DFResult, OracleType, Pool, RemoteField, RemoteSchema,
5    RemoteType, Transform,
6};
7use bb8_oracle::OracleConnectionManager;
8use datafusion::arrow::array::{make_builder, ArrayRef, RecordBatch, StringBuilder};
9use datafusion::arrow::datatypes::SchemaRef;
10use datafusion::common::{project_schema, DataFusionError};
11use datafusion::execution::SendableRecordBatchStream;
12use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
13use futures::StreamExt;
14use oracle::sql_type::OracleType as ColumnType;
15use oracle::{Connector, Row};
16use std::sync::Arc;
17
18#[derive(Debug, Clone)]
19pub struct OracleConnectionOptions {
20    pub host: String,
21    pub port: u16,
22    pub username: String,
23    pub password: String,
24    pub service_name: String,
25}
26
27impl OracleConnectionOptions {
28    pub fn new(
29        host: impl Into<String>,
30        port: u16,
31        username: impl Into<String>,
32        password: impl Into<String>,
33        service_name: impl Into<String>,
34    ) -> Self {
35        Self {
36            host: host.into(),
37            port,
38            username: username.into(),
39            password: password.into(),
40            service_name: service_name.into(),
41        }
42    }
43}
44
45#[derive(Debug)]
46pub struct OraclePool {
47    pool: bb8::Pool<OracleConnectionManager>,
48}
49
50pub async fn connect_oracle(options: &OracleConnectionOptions) -> DFResult<OraclePool> {
51    let connect_string = format!(
52        "//{}:{}/{}",
53        options.host, options.port, options.service_name
54    );
55    let connector = Connector::new(
56        options.username.clone(),
57        options.password.clone(),
58        connect_string,
59    );
60    let _ = connector
61        .connect()
62        .map_err(|e| DataFusionError::Internal(format!("Failed to connect to oracle: {e:?}")))?;
63    let manager = OracleConnectionManager::from_connector(connector);
64    let pool = bb8::Pool::builder()
65        .build(manager)
66        .await
67        .map_err(|e| DataFusionError::Internal(format!("Failed to create oracle pool: {:?}", e)))?;
68    Ok(OraclePool { pool })
69}
70
71#[async_trait::async_trait]
72impl Pool for OraclePool {
73    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
74        let conn = self.pool.get_owned().await.map_err(|e| {
75            DataFusionError::Execution(format!("Failed to get oracle connection due to {e:?}"))
76        })?;
77        Ok(Arc::new(OracleConnection { conn }))
78    }
79}
80
81#[derive(Debug)]
82pub struct OracleConnection {
83    conn: bb8::PooledConnection<'static, OracleConnectionManager>,
84}
85
86#[async_trait::async_trait]
87impl Connection for OracleConnection {
88    async fn infer_schema(
89        &self,
90        sql: &str,
91        transform: Option<Arc<dyn Transform>>,
92    ) -> DFResult<(RemoteSchema, SchemaRef)> {
93        let row = self.conn.query_row(sql, &[]).map_err(|e| {
94            DataFusionError::Execution(format!("Failed to query one row to infer schema: {e:?}"))
95        })?;
96        let remote_schema = build_remote_schema(&row)?;
97        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
98        if let Some(transform) = transform {
99            let batch = rows_to_batch(&[row], arrow_schema, None)?;
100            let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
101            Ok((remote_schema, transformed_batch.schema()))
102        } else {
103            Ok((remote_schema, arrow_schema))
104        }
105    }
106
107    async fn query(
108        &self,
109        sql: String,
110        projection: Option<Vec<usize>>,
111    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
112        let result_set = self.conn.query(&sql, &[]).unwrap();
113        let mut stream = futures::stream::iter(result_set).chunks(2000).boxed();
114
115        let Some(first_chunk) = stream.next().await else {
116            return Err(DataFusionError::Execution(
117                "No data returned from oracle".to_string(),
118            ));
119        };
120        let first_chunk: Vec<Row> = first_chunk
121            .into_iter()
122            .collect::<Result<Vec<_>, _>>()
123            .map_err(|e| {
124                DataFusionError::Execution(
125                    format!("Failed to collect rows from oracle due to {e}",),
126                )
127            })?;
128        let Some(first_row) = first_chunk.first() else {
129            return Err(DataFusionError::Execution(
130                "No data returned from oracle".to_string(),
131            ));
132        };
133
134        let remote_schema = build_remote_schema(first_row)?;
135        let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
136        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
137        let first_chunk = rows_to_batch(
138            first_chunk.as_slice(),
139            arrow_schema.clone(),
140            projection.as_ref(),
141        )?;
142        let schema = first_chunk.schema();
143
144        let mut stream = stream.map(move |rows| {
145            let rows: Vec<Row> = rows
146                .into_iter()
147                .collect::<Result<Vec<_>, _>>()
148                .map_err(|e| {
149                    DataFusionError::Execution(format!(
150                        "Failed to collect rows from oracle due to {e}",
151                    ))
152                })?;
153            let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
154            Ok::<RecordBatch, DataFusionError>(batch)
155        });
156
157        let output_stream = async_stream::stream! {
158           yield Ok(first_chunk);
159           while let Some(batch) = stream.next().await {
160                yield batch
161           }
162        };
163
164        Ok((
165            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
166            projected_remote_schema,
167        ))
168    }
169}
170
171fn oracle_type_to_remote_type(oracle_type: &ColumnType) -> DFResult<RemoteType> {
172    match oracle_type {
173        ColumnType::Varchar2(size) => Ok(RemoteType::Oracle(OracleType::Varchar2(*size))),
174        _ => Err(DataFusionError::NotImplemented(format!(
175            "Unsupported oracle type: {oracle_type:?}",
176        ))),
177    }
178}
179
180fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
181    let mut remote_fields = vec![];
182    for col in row.column_info() {
183        let remote_type = oracle_type_to_remote_type(col.oracle_type())?;
184        remote_fields.push(RemoteField::new(col.name(), remote_type, col.nullable()));
185    }
186    Ok(RemoteSchema::new(remote_fields))
187}
188
189fn rows_to_batch(
190    rows: &[Row],
191    arrow_schema: SchemaRef,
192    projection: Option<&Vec<usize>>,
193) -> DFResult<RecordBatch> {
194    let projected_schema = project_schema(&arrow_schema, projection)?;
195    let mut array_builders = vec![];
196    for field in arrow_schema.fields() {
197        let builder = make_builder(field.data_type(), rows.len());
198        array_builders.push(builder);
199    }
200
201    for row in rows {
202        for (i, col) in row.column_info().iter().enumerate() {
203            let builder = &mut array_builders[i];
204            match col.oracle_type() {
205                ColumnType::Varchar2(_size) => {
206                    let builder = builder
207                        .as_any_mut()
208                        .downcast_mut::<StringBuilder>()
209                        .unwrap();
210                    let v = row.get::<usize, Option<String>>(i).unwrap();
211
212                    match v {
213                        Some(v) => builder.append_value(v),
214                        None => builder.append_null(),
215                    }
216                }
217                _ => {
218                    return Err(DataFusionError::NotImplemented(format!(
219                        "Unsupported oracle type: {col:?}",
220                    )))
221                }
222            }
223        }
224    }
225
226    let projected_columns = array_builders
227        .into_iter()
228        .enumerate()
229        .filter(|(idx, _)| projections_contains(projection, *idx))
230        .map(|(_, mut builder)| builder.finish())
231        .collect::<Vec<ArrayRef>>();
232    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
233}