datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4    project_remote_schema, Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5    RemoteType, Transform,
6};
7use async_stream::stream;
8use bigdecimal::num_bigint;
9use chrono::Timelike;
10use datafusion::arrow::array::{
11    make_builder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
12    Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13    LargeBinaryBuilder, LargeStringBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder,
14    TimestampMicrosecondBuilder,
15};
16use datafusion::arrow::datatypes::{i256, Date32Type, SchemaRef};
17use datafusion::common::{project_schema, DataFusionError};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use futures::lock::Mutex;
21use futures::StreamExt;
22use mysql_async::consts::{ColumnFlags, ColumnType};
23use mysql_async::prelude::Queryable;
24use mysql_async::{Column, Row};
25use std::sync::Arc;
26
27#[derive(Debug, Clone, derive_with::With)]
28pub struct MysqlConnectionOptions {
29    pub(crate) host: String,
30    pub(crate) port: u16,
31    pub(crate) username: String,
32    pub(crate) password: String,
33    pub(crate) database: Option<String>,
34}
35
36impl MysqlConnectionOptions {
37    pub fn new(
38        host: impl Into<String>,
39        port: u16,
40        username: impl Into<String>,
41        password: impl Into<String>,
42    ) -> Self {
43        Self {
44            host: host.into(),
45            port,
46            username: username.into(),
47            password: password.into(),
48            database: None,
49        }
50    }
51}
52
53#[derive(Debug)]
54pub struct MysqlPool {
55    pool: mysql_async::Pool,
56}
57
58pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
59    let opts_builder = mysql_async::OptsBuilder::default()
60        .ip_or_hostname(options.host.clone())
61        .tcp_port(options.port)
62        .user(Some(options.username.clone()))
63        .pass(Some(options.password.clone()))
64        .db_name(options.database.clone());
65    let pool = mysql_async::Pool::new(opts_builder);
66    Ok(MysqlPool { pool })
67}
68
69#[async_trait::async_trait]
70impl Pool for MysqlPool {
71    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
72        let conn = self.pool.get_conn().await.map_err(|e| {
73            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
74        })?;
75        Ok(Arc::new(MysqlConnection {
76            conn: Arc::new(Mutex::new(conn)),
77        }))
78    }
79}
80
81#[derive(Debug)]
82pub struct MysqlConnection {
83    conn: Arc<Mutex<mysql_async::Conn>>,
84}
85
86#[async_trait::async_trait]
87impl Connection for MysqlConnection {
88    async fn infer_schema(
89        &self,
90        sql: &str,
91        transform: Option<Arc<dyn Transform>>,
92    ) -> DFResult<(RemoteSchema, SchemaRef)> {
93        let mut conn = self.conn.lock().await;
94        let conn = &mut *conn;
95        let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
96            DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
97        })?;
98        let Some(row) = row else {
99            return Err(DataFusionError::Execution(
100                "No rows returned to infer schema".to_string(),
101            ));
102        };
103        let remote_schema = build_remote_schema(&row)?;
104        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
105        if let Some(transform) = transform {
106            let batch = rows_to_batch(&[row], &remote_schema, arrow_schema.clone(), None)?;
107            let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
108            Ok((remote_schema, transformed_batch.schema()))
109        } else {
110            Ok((remote_schema, arrow_schema))
111        }
112    }
113
114    async fn query(
115        &self,
116        sql: String,
117        projection: Option<Vec<usize>>,
118    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
119        let conn = Arc::clone(&self.conn);
120        let mut stream = Box::pin(stream! {
121            let mut conn = conn.lock().await;
122            let mut query_iter = conn
123                .query_iter(sql)
124                .await
125                .map_err(|e| {
126                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
127                })?;
128
129            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
130                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
131                })? else {
132                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
133                return;
134            };
135
136            let mut chunked_stream = stream.chunks(4_000).boxed();
137
138            while let Some(chunk) = chunked_stream.next().await {
139                let rows = chunk
140                    .into_iter()
141                    .collect::<Result<Vec<_>, _>>()
142                    .map_err(|e| {
143                        DataFusionError::Execution(format!(
144                            "Failed to collect rows from mysql due to {e}",
145                        ))
146                    })?;
147
148                yield Ok::<_, DataFusionError>(rows)
149            }
150        });
151
152        let Some(first_chunk) = stream.next().await else {
153            return Err(DataFusionError::Execution(
154                "No data returned from mysql".to_string(),
155            ));
156        };
157        let first_chunk = first_chunk?;
158
159        let Some(first_row) = first_chunk.first() else {
160            return Err(DataFusionError::Execution(
161                "No data returned from mysql".to_string(),
162            ));
163        };
164
165        let remote_schema = build_remote_schema(first_row)?;
166        let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
167        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
168        let first_chunk = rows_to_batch(
169            first_chunk.as_slice(),
170            &remote_schema,
171            arrow_schema.clone(),
172            projection.as_ref(),
173        )?;
174        let schema = first_chunk.schema();
175
176        let mut stream = stream.map(move |rows| {
177            let rows = rows?;
178            let batch = rows_to_batch(
179                rows.as_slice(),
180                &remote_schema,
181                arrow_schema.clone(),
182                projection.as_ref(),
183            )?;
184            Ok::<RecordBatch, DataFusionError>(batch)
185        });
186
187        let output_stream = async_stream::stream! {
188           yield Ok(first_chunk);
189           while let Some(batch) = stream.next().await {
190                yield batch
191           }
192        };
193
194        Ok((
195            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
196            projected_remote_schema,
197        ))
198    }
199}
200
201fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
202    let empty_flags = mysql_col.flags().is_empty();
203    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
204    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
205    let col_length = mysql_col.column_length();
206    match mysql_col.column_type() {
207        ColumnType::MYSQL_TYPE_TINY => Ok(RemoteType::Mysql(MysqlType::TinyInt)),
208        ColumnType::MYSQL_TYPE_SHORT => Ok(RemoteType::Mysql(MysqlType::SmallInt)),
209        ColumnType::MYSQL_TYPE_INT24 => Ok(RemoteType::Mysql(MysqlType::MediumInt)),
210        ColumnType::MYSQL_TYPE_LONG => Ok(RemoteType::Mysql(MysqlType::Integer)),
211        ColumnType::MYSQL_TYPE_LONGLONG => Ok(RemoteType::Mysql(MysqlType::BigInt)),
212        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
213        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
214        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
215            let precision = (mysql_col.column_length() - 2) as u8;
216            let scale = mysql_col.decimals();
217            Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
218        }
219        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
220        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
221        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
222        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
223        ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
224        ColumnType::MYSQL_TYPE_STRING if empty_flags => Ok(RemoteType::Mysql(MysqlType::Char)),
225        ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
226        ColumnType::MYSQL_TYPE_VAR_STRING if empty_flags => {
227            Ok(RemoteType::Mysql(MysqlType::Varchar))
228        }
229        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
230            Ok(RemoteType::Mysql(MysqlType::Varbinary))
231        }
232        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
233        ColumnType::MYSQL_TYPE_BLOB if col_length == 1020 && is_blob && !is_binary => {
234            Ok(RemoteType::Mysql(MysqlType::TinyText))
235        }
236        ColumnType::MYSQL_TYPE_BLOB if col_length == 262140 && is_blob && !is_binary => {
237            Ok(RemoteType::Mysql(MysqlType::Text))
238        }
239        ColumnType::MYSQL_TYPE_BLOB if col_length == 67108860 && is_blob && !is_binary => {
240            Ok(RemoteType::Mysql(MysqlType::MediumText))
241        }
242        ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && !is_binary => {
243            Ok(RemoteType::Mysql(MysqlType::LongText))
244        }
245        ColumnType::MYSQL_TYPE_BLOB if col_length == 255 && is_blob && is_binary => {
246            Ok(RemoteType::Mysql(MysqlType::TinyBlob))
247        }
248        ColumnType::MYSQL_TYPE_BLOB if col_length == 65535 && is_blob && is_binary => {
249            Ok(RemoteType::Mysql(MysqlType::Blob))
250        }
251        ColumnType::MYSQL_TYPE_BLOB if col_length == 16777215 && is_blob && is_binary => {
252            Ok(RemoteType::Mysql(MysqlType::MediumBlob))
253        }
254        ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && is_binary => {
255            Ok(RemoteType::Mysql(MysqlType::LongBlob))
256        }
257        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
258        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
259        _ => Err(DataFusionError::NotImplemented(format!(
260            "Unsupported mysql type: {mysql_col:?}",
261        ))),
262    }
263}
264
265fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
266    let mut remote_fields = vec![];
267    for col in row.columns_ref() {
268        remote_fields.push(RemoteField::new(
269            col.name_str().to_string(),
270            mysql_type_to_remote_type(col)?,
271            true,
272        ));
273    }
274    Ok(RemoteSchema::new(remote_fields))
275}
276
277macro_rules! handle_primitive_type {
278    ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
279        let builder = $builder
280            .as_any_mut()
281            .downcast_mut::<$builder_ty>()
282            .unwrap_or_else(|| {
283                panic!(
284                    concat!(
285                        "Failed to downcast builder to ",
286                        stringify!($builder_ty),
287                        " for {:?}"
288                    ),
289                    $mysql_col
290                )
291            });
292        let v = $row.get::<Option<$value_ty>, usize>($index);
293
294        match v {
295            Some(Some(v)) => builder.append_value(v),
296            _ => builder.append_null(),
297        }
298    }};
299}
300
301fn rows_to_batch(
302    rows: &[Row],
303    remote_schema: &RemoteSchema,
304    arrow_schema: SchemaRef,
305    projection: Option<&Vec<usize>>,
306) -> DFResult<RecordBatch> {
307    let projected_schema = project_schema(&arrow_schema, projection)?;
308    let mut array_builders = vec![];
309    for field in arrow_schema.fields() {
310        let builder = make_builder(field.data_type(), rows.len());
311        array_builders.push(builder);
312    }
313
314    for row in rows {
315        for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
316            if !projections_contains(projection, idx) {
317                continue;
318            }
319            let builder = &mut array_builders[idx];
320            match remote_field.remote_type {
321                RemoteType::Mysql(MysqlType::TinyInt) => {
322                    handle_primitive_type!(builder, remote_field, Int8Builder, i8, row, idx);
323                }
324                RemoteType::Mysql(MysqlType::SmallInt) | RemoteType::Mysql(MysqlType::Year) => {
325                    handle_primitive_type!(builder, remote_field, Int16Builder, i16, row, idx);
326                }
327                RemoteType::Mysql(MysqlType::MediumInt) | RemoteType::Mysql(MysqlType::Integer) => {
328                    handle_primitive_type!(builder, remote_field, Int32Builder, i32, row, idx);
329                }
330                RemoteType::Mysql(MysqlType::BigInt) => {
331                    handle_primitive_type!(builder, remote_field, Int64Builder, i64, row, idx);
332                }
333                RemoteType::Mysql(MysqlType::Float) => {
334                    handle_primitive_type!(builder, remote_field, Float32Builder, f32, row, idx);
335                }
336                RemoteType::Mysql(MysqlType::Double) => {
337                    handle_primitive_type!(builder, remote_field, Float64Builder, f64, row, idx);
338                }
339                RemoteType::Mysql(MysqlType::Decimal(precision, scale)) => {
340                    if precision > 38 {
341                        let builder = builder
342                            .as_any_mut()
343                            .downcast_mut::<Decimal256Builder>()
344                            .unwrap_or_else(|| {
345                                panic!(
346                                    "Failed to downcast builder to Decimal256Builder for {:?}",
347                                    remote_field
348                                )
349                            });
350                        let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
351
352                        match v {
353                            Some(Some(v)) => builder.append_value(to_decimal_256(&v)),
354                            _ => builder.append_null(),
355                        }
356                    } else {
357                        let builder = builder
358                            .as_any_mut()
359                            .downcast_mut::<Decimal128Builder>()
360                            .unwrap_or_else(|| {
361                                panic!(
362                                    "Failed to downcast builder to Decimal128Builder for {:?}",
363                                    remote_field
364                                )
365                            });
366                        let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
367
368                        match v {
369                            Some(Some(v)) => {
370                                let Some(v) = big_decimal_to_i128(&v, Some(scale as u32)) else {
371                                    return Err(DataFusionError::Execution(format!(
372                                        "Failed to convert BigDecimal {v:?} to i128"
373                                    )));
374                                };
375                                builder.append_value(v)
376                            }
377                            _ => builder.append_null(),
378                        }
379                    }
380                }
381                RemoteType::Mysql(MysqlType::Date) => {
382                    // TODO this could also be simplified by macro
383                    let builder = builder
384                        .as_any_mut()
385                        .downcast_mut::<Date32Builder>()
386                        .unwrap_or_else(|| {
387                            panic!(
388                                "Failed to downcast builder to Date32Builder for {:?}",
389                                remote_field
390                            )
391                        });
392                    let v = row.get::<Option<chrono::NaiveDate>, usize>(idx);
393
394                    match v {
395                        Some(Some(v)) => builder.append_value(Date32Type::from_naive_date(v)),
396                        _ => builder.append_null(),
397                    }
398                }
399                RemoteType::Mysql(MysqlType::Datetime)
400                | RemoteType::Mysql(MysqlType::Timestamp) => {
401                    let builder = builder
402                        .as_any_mut()
403                        .downcast_mut::<TimestampMicrosecondBuilder>()
404                        .unwrap_or_else(|| {
405                            panic!(
406                            "Failed to downcast builder to TimestampMicrosecondBuilder for {:?}",
407                            remote_field
408                        )
409                        });
410                    let v = row.get::<Option<time::PrimitiveDateTime>, usize>(idx);
411
412                    match v {
413                        Some(Some(v)) => {
414                            let timestamp_micros =
415                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
416                            builder.append_value(timestamp_micros)
417                        }
418                        _ => builder.append_null(),
419                    }
420                }
421                RemoteType::Mysql(MysqlType::Time) => {
422                    let builder = builder
423                        .as_any_mut()
424                        .downcast_mut::<Time64NanosecondBuilder>()
425                        .unwrap_or_else(|| {
426                            panic!(
427                                "Failed to downcast builder to Time64NanosecondBuilder for {:?}",
428                                remote_field
429                            )
430                        });
431                    let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
432
433                    match v {
434                        Some(Some(v)) => {
435                            builder.append_value(
436                                i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
437                                    + i64::from(v.nanosecond()),
438                            );
439                        }
440                        _ => builder.append_null(),
441                    }
442                }
443                RemoteType::Mysql(MysqlType::Char)
444                | RemoteType::Mysql(MysqlType::Varchar)
445                | RemoteType::Mysql(MysqlType::TinyText)
446                | RemoteType::Mysql(MysqlType::Text)
447                | RemoteType::Mysql(MysqlType::MediumText) => {
448                    handle_primitive_type!(builder, remote_field, StringBuilder, String, row, idx);
449                }
450                RemoteType::Mysql(MysqlType::LongText) | RemoteType::Mysql(MysqlType::Json) => {
451                    handle_primitive_type!(
452                        builder,
453                        remote_field,
454                        LargeStringBuilder,
455                        String,
456                        row,
457                        idx
458                    );
459                }
460                RemoteType::Mysql(MysqlType::Binary)
461                | RemoteType::Mysql(MysqlType::Varbinary)
462                | RemoteType::Mysql(MysqlType::TinyBlob)
463                | RemoteType::Mysql(MysqlType::Blob)
464                | RemoteType::Mysql(MysqlType::MediumBlob) => {
465                    handle_primitive_type!(builder, remote_field, BinaryBuilder, Vec<u8>, row, idx);
466                }
467                RemoteType::Mysql(MysqlType::LongBlob) | RemoteType::Mysql(MysqlType::Geometry) => {
468                    handle_primitive_type!(
469                        builder,
470                        remote_field,
471                        LargeBinaryBuilder,
472                        Vec<u8>,
473                        row,
474                        idx
475                    );
476                }
477                _ => panic!("Invalid mysql type: {:?}", remote_field.remote_type),
478            }
479        }
480    }
481    let projected_columns = array_builders
482        .into_iter()
483        .enumerate()
484        .filter(|(idx, _)| projections_contains(projection, *idx))
485        .map(|(_, mut builder)| builder.finish())
486        .collect::<Vec<ArrayRef>>();
487    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
488}
489
490fn to_decimal_256(decimal: &bigdecimal::BigDecimal) -> i256 {
491    let (bigint_value, _) = decimal.as_bigint_and_exponent();
492    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
493
494    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
495    let fill_byte = if is_negative { 0xFF } else { 0x00 };
496
497    if bigint_bytes.len() > 32 {
498        bigint_bytes.truncate(32);
499    } else {
500        bigint_bytes.resize(32, fill_byte);
501    };
502
503    let mut array = [0u8; 32];
504    array.copy_from_slice(&bigint_bytes);
505
506    i256::from_le_bytes(array)
507}