datafusion_remote_table/connection/
mysql.rs

1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4    project_remote_schema, Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5    RemoteType, Transform,
6};
7use async_stream::stream;
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    make_builder, ArrayRef, BinaryBuilder, Date32Builder, Float32Builder, Float64Builder,
11    Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeBinaryBuilder, LargeStringBuilder,
12    RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder,
13};
14use datafusion::arrow::datatypes::{Date32Type, SchemaRef};
15use datafusion::common::{project_schema, DataFusionError};
16use datafusion::execution::SendableRecordBatchStream;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use futures::lock::Mutex;
19use futures::StreamExt;
20use mysql_async::consts::{ColumnFlags, ColumnType};
21use mysql_async::prelude::Queryable;
22use mysql_async::{Column, Row};
23use std::sync::Arc;
24
25#[derive(Debug, Clone, derive_with::With)]
26pub struct MysqlConnectionOptions {
27    pub(crate) host: String,
28    pub(crate) port: u16,
29    pub(crate) username: String,
30    pub(crate) password: String,
31    pub(crate) database: Option<String>,
32}
33
34impl MysqlConnectionOptions {
35    pub fn new(
36        host: impl Into<String>,
37        port: u16,
38        username: impl Into<String>,
39        password: impl Into<String>,
40    ) -> Self {
41        Self {
42            host: host.into(),
43            port,
44            username: username.into(),
45            password: password.into(),
46            database: None,
47        }
48    }
49}
50
51#[derive(Debug)]
52pub struct MysqlPool {
53    pool: mysql_async::Pool,
54}
55
56pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
57    let opts_builder = mysql_async::OptsBuilder::default()
58        .ip_or_hostname(options.host.clone())
59        .tcp_port(options.port)
60        .user(Some(options.username.clone()))
61        .pass(Some(options.password.clone()))
62        .db_name(options.database.clone());
63    let pool = mysql_async::Pool::new(opts_builder);
64    Ok(MysqlPool { pool })
65}
66
67#[async_trait::async_trait]
68impl Pool for MysqlPool {
69    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
70        let conn = self.pool.get_conn().await.map_err(|e| {
71            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
72        })?;
73        Ok(Arc::new(MysqlConnection {
74            conn: Arc::new(Mutex::new(conn)),
75        }))
76    }
77}
78
79#[derive(Debug)]
80pub struct MysqlConnection {
81    conn: Arc<Mutex<mysql_async::Conn>>,
82}
83
84#[async_trait::async_trait]
85impl Connection for MysqlConnection {
86    async fn infer_schema(
87        &self,
88        sql: &str,
89        transform: Option<Arc<dyn Transform>>,
90    ) -> DFResult<(RemoteSchema, SchemaRef)> {
91        let mut conn = self.conn.lock().await;
92        let conn = &mut *conn;
93        let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
94            DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
95        })?;
96        let Some(row) = row else {
97            return Err(DataFusionError::Execution(
98                "No rows returned to infer schema".to_string(),
99            ));
100        };
101        let remote_schema = build_remote_schema(&row)?;
102        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
103        if let Some(transform) = transform {
104            let batch = rows_to_batch(&[row], &remote_schema, arrow_schema.clone(), None)?;
105            let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
106            Ok((remote_schema, transformed_batch.schema()))
107        } else {
108            Ok((remote_schema, arrow_schema))
109        }
110    }
111
112    async fn query(
113        &self,
114        sql: String,
115        projection: Option<Vec<usize>>,
116    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
117        let conn = Arc::clone(&self.conn);
118        let mut stream = Box::pin(stream! {
119            let mut conn = conn.lock().await;
120            let mut query_iter = conn
121                .query_iter(sql)
122                .await
123                .map_err(|e| {
124                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
125                })?;
126
127            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
128                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
129                })? else {
130                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
131                return;
132            };
133
134            let mut chunked_stream = stream.chunks(4_000).boxed();
135
136            while let Some(chunk) = chunked_stream.next().await {
137                let rows = chunk
138                    .into_iter()
139                    .collect::<Result<Vec<_>, _>>()
140                    .map_err(|e| {
141                        DataFusionError::Execution(format!(
142                            "Failed to collect rows from mysql due to {e}",
143                        ))
144                    })?;
145
146                yield Ok::<_, DataFusionError>(rows)
147            }
148        });
149
150        let Some(first_chunk) = stream.next().await else {
151            return Err(DataFusionError::Execution(
152                "No data returned from mysql".to_string(),
153            ));
154        };
155        let first_chunk = first_chunk?;
156
157        let Some(first_row) = first_chunk.first() else {
158            return Err(DataFusionError::Execution(
159                "No data returned from mysql".to_string(),
160            ));
161        };
162
163        let remote_schema = build_remote_schema(first_row)?;
164        let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
165        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
166        let first_chunk = rows_to_batch(
167            first_chunk.as_slice(),
168            &remote_schema,
169            arrow_schema.clone(),
170            projection.as_ref(),
171        )?;
172        let schema = first_chunk.schema();
173
174        let mut stream = stream.map(move |rows| {
175            let rows = rows?;
176            let batch = rows_to_batch(
177                rows.as_slice(),
178                &remote_schema,
179                arrow_schema.clone(),
180                projection.as_ref(),
181            )?;
182            Ok::<RecordBatch, DataFusionError>(batch)
183        });
184
185        let output_stream = async_stream::stream! {
186           yield Ok(first_chunk);
187           while let Some(batch) = stream.next().await {
188                yield batch
189           }
190        };
191
192        Ok((
193            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
194            projected_remote_schema,
195        ))
196    }
197}
198
199fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
200    let empty_flags = mysql_col.flags().is_empty();
201    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
202    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
203    let col_length = mysql_col.column_length();
204    match mysql_col.column_type() {
205        ColumnType::MYSQL_TYPE_TINY => Ok(RemoteType::Mysql(MysqlType::TinyInt)),
206        ColumnType::MYSQL_TYPE_SHORT => Ok(RemoteType::Mysql(MysqlType::SmallInt)),
207        ColumnType::MYSQL_TYPE_INT24 => Ok(RemoteType::Mysql(MysqlType::MediumInt)),
208        ColumnType::MYSQL_TYPE_LONG => Ok(RemoteType::Mysql(MysqlType::Integer)),
209        ColumnType::MYSQL_TYPE_LONGLONG => Ok(RemoteType::Mysql(MysqlType::BigInt)),
210        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
211        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
212        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
213        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
214        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
215        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
216        ColumnType::MYSQL_TYPE_STRING if empty_flags => Ok(RemoteType::Mysql(MysqlType::Char)),
217        ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
218        ColumnType::MYSQL_TYPE_VAR_STRING if empty_flags => {
219            Ok(RemoteType::Mysql(MysqlType::Varchar))
220        }
221        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
222            Ok(RemoteType::Mysql(MysqlType::Varbinary))
223        }
224        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
225        ColumnType::MYSQL_TYPE_BLOB if col_length == 1020 && is_blob && !is_binary => {
226            Ok(RemoteType::Mysql(MysqlType::TinyText))
227        }
228        ColumnType::MYSQL_TYPE_BLOB if col_length == 262140 && is_blob && !is_binary => {
229            Ok(RemoteType::Mysql(MysqlType::Text))
230        }
231        ColumnType::MYSQL_TYPE_BLOB if col_length == 67108860 && is_blob && !is_binary => {
232            Ok(RemoteType::Mysql(MysqlType::MediumText))
233        }
234        ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && !is_binary => {
235            Ok(RemoteType::Mysql(MysqlType::LongText))
236        }
237        ColumnType::MYSQL_TYPE_BLOB if col_length == 255 && is_blob && is_binary => {
238            Ok(RemoteType::Mysql(MysqlType::TinyBlob))
239        }
240        ColumnType::MYSQL_TYPE_BLOB if col_length == 65535 && is_blob && is_binary => {
241            Ok(RemoteType::Mysql(MysqlType::Blob))
242        }
243        ColumnType::MYSQL_TYPE_BLOB if col_length == 16777215 && is_blob && is_binary => {
244            Ok(RemoteType::Mysql(MysqlType::MediumBlob))
245        }
246        ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && is_binary => {
247            Ok(RemoteType::Mysql(MysqlType::LongBlob))
248        }
249        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
250        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
251        _ => Err(DataFusionError::NotImplemented(format!(
252            "Unsupported mysql type: {mysql_col:?}",
253        ))),
254    }
255}
256
257fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
258    let mut remote_fields = vec![];
259    for col in row.columns_ref() {
260        remote_fields.push(RemoteField::new(
261            col.name_str().to_string(),
262            mysql_type_to_remote_type(col)?,
263            true,
264        ));
265    }
266    Ok(RemoteSchema::new(remote_fields))
267}
268
269macro_rules! handle_primitive_type {
270    ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
271        let builder = $builder
272            .as_any_mut()
273            .downcast_mut::<$builder_ty>()
274            .unwrap_or_else(|| {
275                panic!(
276                    concat!(
277                        "Failed to downcast builder to ",
278                        stringify!($builder_ty),
279                        " for {:?}"
280                    ),
281                    $mysql_col
282                )
283            });
284        let v = $row.get::<Option<$value_ty>, usize>($index);
285
286        match v {
287            Some(Some(v)) => builder.append_value(v),
288            _ => builder.append_null(),
289        }
290    }};
291}
292
293fn rows_to_batch(
294    rows: &[Row],
295    remote_schema: &RemoteSchema,
296    arrow_schema: SchemaRef,
297    projection: Option<&Vec<usize>>,
298) -> DFResult<RecordBatch> {
299    let projected_schema = project_schema(&arrow_schema, projection)?;
300    let mut array_builders = vec![];
301    for field in arrow_schema.fields() {
302        let builder = make_builder(field.data_type(), rows.len());
303        array_builders.push(builder);
304    }
305
306    for row in rows {
307        for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
308            if !projections_contains(projection, idx) {
309                continue;
310            }
311            let builder = &mut array_builders[idx];
312            match remote_field.remote_type {
313                RemoteType::Mysql(MysqlType::TinyInt) => {
314                    handle_primitive_type!(builder, remote_field, Int8Builder, i8, row, idx);
315                }
316                RemoteType::Mysql(MysqlType::SmallInt) => {
317                    handle_primitive_type!(builder, remote_field, Int16Builder, i16, row, idx);
318                }
319                RemoteType::Mysql(MysqlType::MediumInt) | RemoteType::Mysql(MysqlType::Integer) => {
320                    handle_primitive_type!(builder, remote_field, Int32Builder, i32, row, idx);
321                }
322                RemoteType::Mysql(MysqlType::BigInt) => {
323                    handle_primitive_type!(builder, remote_field, Int64Builder, i64, row, idx);
324                }
325                RemoteType::Mysql(MysqlType::Float) => {
326                    handle_primitive_type!(builder, remote_field, Float32Builder, f32, row, idx);
327                }
328                RemoteType::Mysql(MysqlType::Double) => {
329                    handle_primitive_type!(builder, remote_field, Float64Builder, f64, row, idx);
330                }
331                RemoteType::Mysql(MysqlType::Date) => {
332                    let builder = builder
333                        .as_any_mut()
334                        .downcast_mut::<Date32Builder>()
335                        .unwrap_or_else(|| {
336                            panic!(
337                                "Failed to downcast builder to Date32Builder for {:?}",
338                                remote_field
339                            )
340                        });
341                    let v = row.get::<Option<chrono::NaiveDate>, usize>(idx);
342
343                    match v {
344                        Some(Some(v)) => builder.append_value(Date32Type::from_naive_date(v)),
345                        _ => builder.append_null(),
346                    }
347                }
348                RemoteType::Mysql(MysqlType::Datetime)
349                | RemoteType::Mysql(MysqlType::Timestamp) => {
350                    let builder = builder
351                        .as_any_mut()
352                        .downcast_mut::<TimestampMicrosecondBuilder>()
353                        .unwrap_or_else(|| {
354                            panic!(
355                            "Failed to downcast builder to TimestampMicrosecondBuilder for {:?}",
356                            remote_field
357                        )
358                        });
359                    let v = row.get::<Option<time::PrimitiveDateTime>, usize>(idx);
360
361                    match v {
362                        Some(Some(v)) => {
363                            let timestamp_micros =
364                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
365                            builder.append_value(timestamp_micros)
366                        }
367                        _ => builder.append_null(),
368                    }
369                }
370                RemoteType::Mysql(MysqlType::Time) => {
371                    let builder = builder
372                        .as_any_mut()
373                        .downcast_mut::<Time64NanosecondBuilder>()
374                        .unwrap_or_else(|| {
375                            panic!(
376                                "Failed to downcast builder to Time64NanosecondBuilder for {:?}",
377                                remote_field
378                            )
379                        });
380                    let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
381
382                    match v {
383                        Some(Some(v)) => {
384                            builder.append_value(
385                                i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
386                                    + i64::from(v.nanosecond()),
387                            );
388                        }
389                        _ => builder.append_null(),
390                    }
391                }
392                RemoteType::Mysql(MysqlType::Char)
393                | RemoteType::Mysql(MysqlType::Varchar)
394                | RemoteType::Mysql(MysqlType::TinyText)
395                | RemoteType::Mysql(MysqlType::Text)
396                | RemoteType::Mysql(MysqlType::MediumText) => {
397                    handle_primitive_type!(builder, remote_field, StringBuilder, String, row, idx);
398                }
399                RemoteType::Mysql(MysqlType::LongText) | RemoteType::Mysql(MysqlType::Json) => {
400                    handle_primitive_type!(
401                        builder,
402                        remote_field,
403                        LargeStringBuilder,
404                        String,
405                        row,
406                        idx
407                    );
408                }
409                RemoteType::Mysql(MysqlType::Binary)
410                | RemoteType::Mysql(MysqlType::Varbinary)
411                | RemoteType::Mysql(MysqlType::TinyBlob)
412                | RemoteType::Mysql(MysqlType::Blob)
413                | RemoteType::Mysql(MysqlType::MediumBlob) => {
414                    handle_primitive_type!(builder, remote_field, BinaryBuilder, Vec<u8>, row, idx);
415                }
416                RemoteType::Mysql(MysqlType::LongBlob) | RemoteType::Mysql(MysqlType::Geometry) => {
417                    handle_primitive_type!(
418                        builder,
419                        remote_field,
420                        LargeBinaryBuilder,
421                        Vec<u8>,
422                        row,
423                        idx
424                    );
425                }
426                _ => panic!("Invalid mysql type: {:?}", remote_field.remote_type),
427            }
428        }
429    }
430    let projected_columns = array_builders
431        .into_iter()
432        .enumerate()
433        .filter(|(idx, _)| projections_contains(projection, *idx))
434        .map(|(_, mut builder)| builder.finish())
435        .collect::<Vec<ArrayRef>>();
436    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
437}