datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5    RemoteSchemaRef, RemoteType, Transform,
6};
7use async_stream::stream;
8use bigdecimal::{num_bigint, BigDecimal};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11    make_builder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
12    Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13    LargeBinaryBuilder, LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder,
14    Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder,
15    UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
16};
17use datafusion::arrow::datatypes::{i256, DataType, Date32Type, SchemaRef, TimeUnit};
18use datafusion::common::{project_schema, DataFusionError};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::lock::Mutex;
22use futures::StreamExt;
23use mysql_async::consts::{ColumnFlags, ColumnType};
24use mysql_async::prelude::Queryable;
25use mysql_async::{Column, FromValueError, Row, Value};
26use std::sync::Arc;
27
28#[derive(Debug, Clone, derive_with::With)]
29pub struct MysqlConnectionOptions {
30    pub(crate) host: String,
31    pub(crate) port: u16,
32    pub(crate) username: String,
33    pub(crate) password: String,
34    pub(crate) database: Option<String>,
35    pub(crate) chunk_size: Option<usize>,
36}
37
38impl MysqlConnectionOptions {
39    pub fn new(
40        host: impl Into<String>,
41        port: u16,
42        username: impl Into<String>,
43        password: impl Into<String>,
44    ) -> Self {
45        Self {
46            host: host.into(),
47            port,
48            username: username.into(),
49            password: password.into(),
50            database: None,
51            chunk_size: None,
52        }
53    }
54}
55
56#[derive(Debug)]
57pub struct MysqlPool {
58    pool: mysql_async::Pool,
59}
60
61pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
62    let opts_builder = mysql_async::OptsBuilder::default()
63        .ip_or_hostname(options.host.clone())
64        .tcp_port(options.port)
65        .user(Some(options.username.clone()))
66        .pass(Some(options.password.clone()))
67        .db_name(options.database.clone());
68    let pool = mysql_async::Pool::new(opts_builder);
69    Ok(MysqlPool { pool })
70}
71
72#[async_trait::async_trait]
73impl Pool for MysqlPool {
74    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
75        let conn = self.pool.get_conn().await.map_err(|e| {
76            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
77        })?;
78        Ok(Arc::new(MysqlConnection {
79            conn: Arc::new(Mutex::new(conn)),
80        }))
81    }
82}
83
84#[derive(Debug)]
85pub struct MysqlConnection {
86    conn: Arc<Mutex<mysql_async::Conn>>,
87}
88
89#[async_trait::async_trait]
90impl Connection for MysqlConnection {
91    async fn infer_schema(
92        &self,
93        sql: &str,
94        transform: Option<Arc<dyn Transform>>,
95    ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
96        let mut conn = self.conn.lock().await;
97        let conn = &mut *conn;
98        let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
99            DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
100        })?;
101        let Some(row) = row else {
102            return Err(DataFusionError::Execution(
103                "No rows returned to infer schema".to_string(),
104            ));
105        };
106        let remote_schema = Arc::new(build_remote_schema(&row)?);
107        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
108        if let Some(transform) = transform {
109            let batch = rows_to_batch(&[row], &arrow_schema, None)?;
110            let transformed_batch = transform_batch(
111                batch,
112                transform.as_ref(),
113                &arrow_schema,
114                None,
115                Some(&remote_schema),
116            )?;
117            Ok((remote_schema, transformed_batch.schema()))
118        } else {
119            Ok((remote_schema, arrow_schema))
120        }
121    }
122
123    async fn query(
124        &self,
125        conn_options: &ConnectionOptions,
126        sql: &str,
127        table_schema: SchemaRef,
128        projection: Option<&Vec<usize>>,
129    ) -> DFResult<SendableRecordBatchStream> {
130        let projected_schema = project_schema(&table_schema, projection)?;
131        let sql = sql.to_string();
132        let projection = projection.cloned();
133        let chunk_size = conn_options.chunk_size();
134        let conn = Arc::clone(&self.conn);
135        let stream = Box::pin(stream! {
136            let mut conn = conn.lock().await;
137            let mut query_iter = conn
138                .query_iter(sql)
139                .await
140                .map_err(|e| {
141                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
142                })?;
143
144            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
145                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
146                })? else {
147                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
148                return;
149            };
150
151            let mut chunked_stream = stream.chunks(chunk_size.unwrap_or(2048)).boxed();
152
153            while let Some(chunk) = chunked_stream.next().await {
154                let rows = chunk
155                    .into_iter()
156                    .collect::<Result<Vec<_>, _>>()
157                    .map_err(|e| {
158                        DataFusionError::Execution(format!(
159                            "Failed to collect rows from mysql due to {e}",
160                        ))
161                    })?;
162
163                yield Ok::<_, DataFusionError>(rows)
164            }
165        });
166
167        let stream = stream.map(move |rows| {
168            let rows = rows?;
169            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
170        });
171
172        Ok(Box::pin(RecordBatchStreamAdapter::new(
173            projected_schema,
174            stream,
175        )))
176    }
177}
178
179fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
180    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
181    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
182    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
183    let col_length = mysql_col.column_length();
184    match mysql_col.column_type() {
185        ColumnType::MYSQL_TYPE_TINY => {
186            if is_unsigned {
187                Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
188            } else {
189                Ok(RemoteType::Mysql(MysqlType::TinyInt))
190            }
191        }
192        ColumnType::MYSQL_TYPE_SHORT => {
193            if is_unsigned {
194                Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
195            } else {
196                Ok(RemoteType::Mysql(MysqlType::SmallInt))
197            }
198        }
199        ColumnType::MYSQL_TYPE_INT24 => {
200            if is_unsigned {
201                Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
202            } else {
203                Ok(RemoteType::Mysql(MysqlType::MediumInt))
204            }
205        }
206        ColumnType::MYSQL_TYPE_LONG => {
207            if is_unsigned {
208                Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
209            } else {
210                Ok(RemoteType::Mysql(MysqlType::Integer))
211            }
212        }
213        ColumnType::MYSQL_TYPE_LONGLONG => {
214            if is_unsigned {
215                Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
216            } else {
217                Ok(RemoteType::Mysql(MysqlType::BigInt))
218            }
219        }
220        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
221        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
222        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
223            let precision = (mysql_col.column_length() - 2) as u8;
224            let scale = mysql_col.decimals();
225            Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
226        }
227        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
228        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
229        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
230        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
231        ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
232        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
233        ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
234        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
235            Ok(RemoteType::Mysql(MysqlType::Varchar))
236        }
237        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
238            Ok(RemoteType::Mysql(MysqlType::Varbinary))
239        }
240        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
241        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
242            Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
243        }
244        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
245            Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
246        }
247        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
248        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
249        _ => Err(DataFusionError::NotImplemented(format!(
250            "Unsupported mysql type: {mysql_col:?}",
251        ))),
252    }
253}
254
255fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
256    let mut remote_fields = vec![];
257    for col in row.columns_ref() {
258        remote_fields.push(RemoteField::new(
259            col.name_str().to_string(),
260            mysql_type_to_remote_type(col)?,
261            true,
262        ));
263    }
264    Ok(RemoteSchema::new(remote_fields))
265}
266
267macro_rules! handle_primitive_type {
268    ($builder:expr, $field:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
269        let builder = $builder
270            .as_any_mut()
271            .downcast_mut::<$builder_ty>()
272            .unwrap_or_else(|| {
273                panic!(
274                    concat!(
275                        "Failed to downcast builder to ",
276                        stringify!($builder_ty),
277                        " for {:?} and {:?}"
278                    ),
279                    $field, $mysql_col
280                )
281            });
282        let v = $row.get_opt::<$value_ty, usize>($index);
283
284        match v {
285            None => builder.append_null(),
286            Some(Ok(v)) => builder.append_value($convert(v)?),
287            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
288            Some(Err(e)) => {
289                return Err(DataFusionError::Execution(format!(
290                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
291                    stringify!($value_ty),
292                    $field,
293                    $mysql_col,
294                )))
295            }
296        }
297    }};
298}
299
300fn rows_to_batch(
301    rows: &[Row],
302    table_schema: &SchemaRef,
303    projection: Option<&Vec<usize>>,
304) -> DFResult<RecordBatch> {
305    let projected_schema = project_schema(table_schema, projection)?;
306    let mut array_builders = vec![];
307    for field in table_schema.fields() {
308        let builder = make_builder(field.data_type(), rows.len());
309        array_builders.push(builder);
310    }
311
312    for row in rows {
313        for (idx, field) in table_schema.fields.iter().enumerate() {
314            if !projections_contains(projection, idx) {
315                continue;
316            }
317            let builder = &mut array_builders[idx];
318            let col = row.columns_ref().get(idx);
319            match field.data_type() {
320                DataType::Int8 => {
321                    handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
322                        Ok::<_, DataFusionError>(v)
323                    });
324                }
325                DataType::Int16 => {
326                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
327                        Ok::<_, DataFusionError>(v)
328                    });
329                }
330                DataType::Int32 => {
331                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
332                        Ok::<_, DataFusionError>(v)
333                    });
334                }
335                DataType::Int64 => {
336                    handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
337                        Ok::<_, DataFusionError>(v)
338                    });
339                }
340                DataType::UInt8 => {
341                    handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
342                        Ok::<_, DataFusionError>(v)
343                    });
344                }
345                DataType::UInt16 => {
346                    handle_primitive_type!(
347                        builder,
348                        field,
349                        col,
350                        UInt16Builder,
351                        u16,
352                        row,
353                        idx,
354                        |v| { Ok::<_, DataFusionError>(v) }
355                    );
356                }
357                DataType::UInt32 => {
358                    handle_primitive_type!(
359                        builder,
360                        field,
361                        col,
362                        UInt32Builder,
363                        u32,
364                        row,
365                        idx,
366                        |v| { Ok::<_, DataFusionError>(v) }
367                    );
368                }
369                DataType::UInt64 => {
370                    handle_primitive_type!(
371                        builder,
372                        field,
373                        col,
374                        UInt64Builder,
375                        u64,
376                        row,
377                        idx,
378                        |v| { Ok::<_, DataFusionError>(v) }
379                    );
380                }
381                DataType::Float32 => {
382                    handle_primitive_type!(
383                        builder,
384                        field,
385                        col,
386                        Float32Builder,
387                        f32,
388                        row,
389                        idx,
390                        |v| { Ok::<_, DataFusionError>(v) }
391                    );
392                }
393                DataType::Float64 => {
394                    handle_primitive_type!(
395                        builder,
396                        field,
397                        col,
398                        Float64Builder,
399                        f64,
400                        row,
401                        idx,
402                        |v| { Ok::<_, DataFusionError>(v) }
403                    );
404                }
405                DataType::Decimal128(_precision, scale) => {
406                    handle_primitive_type!(
407                        builder,
408                        field,
409                        col,
410                        Decimal128Builder,
411                        BigDecimal,
412                        row,
413                        idx,
414                        |v: BigDecimal| {
415                            big_decimal_to_i128(&v, Some(*scale as u32)).ok_or_else(|| {
416                                DataFusionError::Execution(format!(
417                                    "Failed to convert BigDecimal {v:?} to i128"
418                                ))
419                            })
420                        }
421                    );
422                }
423                DataType::Decimal256(_precision, _scale) => {
424                    handle_primitive_type!(
425                        builder,
426                        field,
427                        col,
428                        Decimal256Builder,
429                        BigDecimal,
430                        row,
431                        idx,
432                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
433                    );
434                }
435                DataType::Date32 => {
436                    handle_primitive_type!(
437                        builder,
438                        field,
439                        col,
440                        Date32Builder,
441                        chrono::NaiveDate,
442                        row,
443                        idx,
444                        |v: chrono::NaiveDate| {
445                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
446                        }
447                    );
448                }
449                DataType::Timestamp(TimeUnit::Microsecond, None) => {
450                    handle_primitive_type!(
451                        builder,
452                        field,
453                        col,
454                        TimestampMicrosecondBuilder,
455                        time::PrimitiveDateTime,
456                        row,
457                        idx,
458                        |v: time::PrimitiveDateTime| {
459                            let timestamp_micros =
460                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
461                            Ok::<_, DataFusionError>(timestamp_micros)
462                        }
463                    );
464                }
465                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
466                    handle_primitive_type!(
467                        builder,
468                        field,
469                        col,
470                        TimestampNanosecondBuilder,
471                        chrono::NaiveTime,
472                        row,
473                        idx,
474                        |v: chrono::NaiveTime| {
475                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
476                                + i64::from(v.nanosecond());
477                            Ok::<_, DataFusionError>(t)
478                        }
479                    );
480                }
481                DataType::Time32(TimeUnit::Second) => {
482                    handle_primitive_type!(
483                        builder,
484                        field,
485                        col,
486                        Time32SecondBuilder,
487                        chrono::NaiveTime,
488                        row,
489                        idx,
490                        |v: chrono::NaiveTime| {
491                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
492                        }
493                    );
494                }
495                DataType::Time64(TimeUnit::Nanosecond) => {
496                    handle_primitive_type!(
497                        builder,
498                        field,
499                        col,
500                        Time64NanosecondBuilder,
501                        chrono::NaiveTime,
502                        row,
503                        idx,
504                        |v: chrono::NaiveTime| {
505                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
506                                + i64::from(v.nanosecond());
507                            Ok::<_, DataFusionError>(t)
508                        }
509                    );
510                }
511                DataType::Utf8 => {
512                    handle_primitive_type!(
513                        builder,
514                        field,
515                        col,
516                        StringBuilder,
517                        String,
518                        row,
519                        idx,
520                        |v| { Ok::<_, DataFusionError>(v) }
521                    );
522                }
523                DataType::LargeUtf8 => {
524                    handle_primitive_type!(
525                        builder,
526                        field,
527                        col,
528                        LargeStringBuilder,
529                        String,
530                        row,
531                        idx,
532                        |v| { Ok::<_, DataFusionError>(v) }
533                    );
534                }
535                DataType::Binary => {
536                    handle_primitive_type!(
537                        builder,
538                        field,
539                        col,
540                        BinaryBuilder,
541                        Vec<u8>,
542                        row,
543                        idx,
544                        |v| { Ok::<_, DataFusionError>(v) }
545                    );
546                }
547                DataType::LargeBinary => {
548                    handle_primitive_type!(
549                        builder,
550                        field,
551                        col,
552                        LargeBinaryBuilder,
553                        Vec<u8>,
554                        row,
555                        idx,
556                        |v| { Ok::<_, DataFusionError>(v) }
557                    );
558                }
559                _ => {
560                    return Err(DataFusionError::NotImplemented(format!(
561                        "Unsupported data type {:?} for col: {:?}",
562                        field.data_type(),
563                        col
564                    )));
565                }
566            }
567        }
568    }
569    let projected_columns = array_builders
570        .into_iter()
571        .enumerate()
572        .filter(|(idx, _)| projections_contains(projection, *idx))
573        .map(|(_, mut builder)| builder.finish())
574        .collect::<Vec<ArrayRef>>();
575    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
576}
577
578fn to_decimal_256(decimal: &BigDecimal) -> i256 {
579    let (bigint_value, _) = decimal.as_bigint_and_exponent();
580    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
581
582    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
583    let fill_byte = if is_negative { 0xFF } else { 0x00 };
584
585    if bigint_bytes.len() > 32 {
586        bigint_bytes.truncate(32);
587    } else {
588        bigint_bytes.resize(32, fill_byte);
589    };
590
591    let mut array = [0u8; 32];
592    array.copy_from_slice(&bigint_bytes);
593
594    i256::from_le_bytes(array)
595}