datafusion_remote_table/connection/
mysql.rs

1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4    project_remote_schema, Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5    RemoteType, Transform,
6};
7use async_stream::stream;
8use datafusion::arrow::array::{
9    make_builder, ArrayRef, Float32Builder, Float64Builder, Int16Builder, Int32Builder,
10    Int64Builder, Int8Builder, RecordBatch,
11};
12use datafusion::arrow::datatypes::SchemaRef;
13use datafusion::common::{project_schema, DataFusionError};
14use datafusion::execution::SendableRecordBatchStream;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::lock::Mutex;
17use futures::StreamExt;
18use mysql_async::consts::ColumnType;
19use mysql_async::prelude::Queryable;
20use mysql_async::{Column, Row};
21use std::sync::Arc;
22
23#[derive(Debug, Clone, derive_with::With)]
24pub struct MysqlConnectionOptions {
25    pub(crate) host: String,
26    pub(crate) port: u16,
27    pub(crate) username: String,
28    pub(crate) password: String,
29    pub(crate) database: Option<String>,
30}
31
32impl MysqlConnectionOptions {
33    pub fn new(
34        host: impl Into<String>,
35        port: u16,
36        username: impl Into<String>,
37        password: impl Into<String>,
38    ) -> Self {
39        Self {
40            host: host.into(),
41            port,
42            username: username.into(),
43            password: password.into(),
44            database: None,
45        }
46    }
47}
48
49#[derive(Debug)]
50pub struct MysqlPool {
51    pool: mysql_async::Pool,
52}
53
54pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
55    let opts_builder = mysql_async::OptsBuilder::default()
56        .ip_or_hostname(options.host.clone())
57        .tcp_port(options.port)
58        .user(Some(options.username.clone()))
59        .pass(Some(options.password.clone()))
60        .db_name(options.database.clone());
61    let pool = mysql_async::Pool::new(opts_builder);
62    Ok(MysqlPool { pool })
63}
64
65#[async_trait::async_trait]
66impl Pool for MysqlPool {
67    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
68        let conn = self.pool.get_conn().await.map_err(|e| {
69            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
70        })?;
71        Ok(Arc::new(MysqlConnection {
72            conn: Arc::new(Mutex::new(conn)),
73        }))
74    }
75}
76
77#[derive(Debug)]
78pub struct MysqlConnection {
79    conn: Arc<Mutex<mysql_async::Conn>>,
80}
81
82#[async_trait::async_trait]
83impl Connection for MysqlConnection {
84    async fn infer_schema(
85        &self,
86        sql: &str,
87        transform: Option<Arc<dyn Transform>>,
88    ) -> DFResult<(RemoteSchema, SchemaRef)> {
89        let mut conn = self.conn.lock().await;
90        let conn = &mut *conn;
91        let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
92            DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
93        })?;
94        let Some(row) = row else {
95            return Err(DataFusionError::Execution(
96                "No rows returned to infer schema".to_string(),
97            ));
98        };
99        let remote_schema = build_remote_schema(&row)?;
100        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
101        if let Some(transform) = transform {
102            let batch = rows_to_batch(&[row], arrow_schema.clone(), None)?;
103            let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
104            Ok((remote_schema, transformed_batch.schema()))
105        } else {
106            Ok((remote_schema, arrow_schema))
107        }
108    }
109
110    async fn query(
111        &self,
112        sql: String,
113        projection: Option<Vec<usize>>,
114    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
115        let conn = Arc::clone(&self.conn);
116        let mut stream = Box::pin(stream! {
117            let mut conn = conn.lock().await;
118            let mut query_iter = conn
119                .query_iter(sql)
120                .await
121                .map_err(|e| {
122                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
123                })?;
124
125            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
126                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
127                })? else {
128                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
129                return;
130            };
131
132            let mut chunked_stream = stream.chunks(4_000).boxed();
133
134            while let Some(chunk) = chunked_stream.next().await {
135                let rows = chunk
136                    .into_iter()
137                    .collect::<Result<Vec<_>, _>>()
138                    .map_err(|e| {
139                        DataFusionError::Execution(format!(
140                            "Failed to collect rows from mysql due to {e}",
141                        ))
142                    })?;
143
144                yield Ok::<_, DataFusionError>(rows)
145            }
146        });
147
148        let Some(first_chunk) = stream.next().await else {
149            return Err(DataFusionError::Execution(
150                "No data returned from mysql".to_string(),
151            ));
152        };
153        let first_chunk = first_chunk?;
154
155        let Some(first_row) = first_chunk.first() else {
156            return Err(DataFusionError::Execution(
157                "No data returned from mysql".to_string(),
158            ));
159        };
160
161        let remote_schema = build_remote_schema(first_row)?;
162        let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
163        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
164        let first_chunk = rows_to_batch(
165            first_chunk.as_slice(),
166            arrow_schema.clone(),
167            projection.as_ref(),
168        )?;
169        let schema = first_chunk.schema();
170
171        let mut stream = stream.map(move |rows| {
172            let rows = rows?;
173            let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
174            Ok::<RecordBatch, DataFusionError>(batch)
175        });
176
177        let output_stream = async_stream::stream! {
178           yield Ok(first_chunk);
179           while let Some(batch) = stream.next().await {
180                yield batch
181           }
182        };
183
184        Ok((
185            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
186            projected_remote_schema,
187        ))
188    }
189}
190
191fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
192    match mysql_col.column_type() {
193        ColumnType::MYSQL_TYPE_TINY => Ok(RemoteType::Mysql(MysqlType::TinyInt)),
194        ColumnType::MYSQL_TYPE_SHORT => Ok(RemoteType::Mysql(MysqlType::SmallInt)),
195        ColumnType::MYSQL_TYPE_LONG => Ok(RemoteType::Mysql(MysqlType::Integer)),
196        ColumnType::MYSQL_TYPE_LONGLONG => Ok(RemoteType::Mysql(MysqlType::BigInt)),
197        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
198        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
199        _ => Err(DataFusionError::NotImplemented(format!(
200            "Unsupported mysql type: {mysql_col:?}",
201        ))),
202    }
203}
204
205fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
206    let mut remote_fields = vec![];
207    for col in row.columns_ref() {
208        remote_fields.push(RemoteField::new(
209            col.name_str().to_string(),
210            mysql_type_to_remote_type(col)?,
211            true,
212        ));
213    }
214    Ok(RemoteSchema::new(remote_fields))
215}
216
217macro_rules! handle_primitive_type {
218    ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
219        let builder = $builder
220            .as_any_mut()
221            .downcast_mut::<$builder_ty>()
222            .expect(concat!(
223                "Failed to downcast builder to ",
224                stringify!($builder_ty),
225                " for ",
226                stringify!($mysql_col)
227            ));
228        let v = $row.get::<$value_ty, usize>($index);
229
230        match v {
231            Some(v) => builder.append_value(v),
232            None => builder.append_null(),
233        }
234    }};
235}
236
237fn rows_to_batch(
238    rows: &[Row],
239    arrow_schema: SchemaRef,
240    projection: Option<&Vec<usize>>,
241) -> DFResult<RecordBatch> {
242    let projected_schema = project_schema(&arrow_schema, projection)?;
243    let mut array_builders = vec![];
244    for field in arrow_schema.fields() {
245        let builder = make_builder(field.data_type(), rows.len());
246        array_builders.push(builder);
247    }
248
249    for row in rows {
250        for (idx, col) in row.columns_ref().iter().enumerate() {
251            if !projections_contains(projection, idx) {
252                continue;
253            }
254            let builder = &mut array_builders[idx];
255            match col.column_type() {
256                ColumnType::MYSQL_TYPE_TINY => {
257                    handle_primitive_type!(builder, col, Int8Builder, i8, row, idx);
258                }
259                ColumnType::MYSQL_TYPE_SHORT => {
260                    handle_primitive_type!(builder, col, Int16Builder, i16, row, idx);
261                }
262                ColumnType::MYSQL_TYPE_LONG => {
263                    handle_primitive_type!(builder, col, Int32Builder, i32, row, idx);
264                }
265                ColumnType::MYSQL_TYPE_LONGLONG => {
266                    handle_primitive_type!(builder, col, Int64Builder, i64, row, idx);
267                }
268                ColumnType::MYSQL_TYPE_FLOAT => {
269                    handle_primitive_type!(builder, col, Float32Builder, f32, row, idx);
270                }
271                ColumnType::MYSQL_TYPE_DOUBLE => {
272                    handle_primitive_type!(builder, col, Float64Builder, f64, row, idx);
273                }
274                _ => {
275                    return Err(DataFusionError::NotImplemented(format!(
276                        "Unsupported mysql type: {col:?}",
277                    )));
278                }
279            }
280        }
281    }
282    let projected_columns = array_builders
283        .into_iter()
284        .enumerate()
285        .filter(|(idx, _)| projections_contains(projection, *idx))
286        .map(|(_, mut builder)| builder.finish())
287        .collect::<Vec<ArrayRef>>();
288    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
289}