datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType, TableSource, Unparse,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13    Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33    pub(crate) host: String,
34    pub(crate) port: u16,
35    pub(crate) username: String,
36    pub(crate) password: String,
37    pub(crate) database: Option<String>,
38    pub(crate) pool_max_size: usize,
39    pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43    pub fn new(
44        host: impl Into<String>,
45        port: u16,
46        username: impl Into<String>,
47        password: impl Into<String>,
48    ) -> Self {
49        Self {
50            host: host.into(),
51            port,
52            username: username.into(),
53            password: password.into(),
54            database: None,
55            pool_max_size: 10,
56            stream_chunk_size: 2048,
57        }
58    }
59}
60
61impl From<MysqlConnectionOptions> for ConnectionOptions {
62    fn from(options: MysqlConnectionOptions) -> Self {
63        ConnectionOptions::Mysql(options)
64    }
65}
66
67#[derive(Debug)]
68pub struct MysqlPool {
69    pool: mysql_async::Pool,
70}
71
72pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
73    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
74        mysql_async::PoolConstraints::new(0, options.pool_max_size)
75            .expect("Failed to create pool constraints"),
76    );
77    let opts_builder = mysql_async::OptsBuilder::default()
78        .ip_or_hostname(options.host.clone())
79        .tcp_port(options.port)
80        .user(Some(options.username.clone()))
81        .pass(Some(options.password.clone()))
82        .db_name(options.database.clone())
83        .init(vec!["set time_zone='+00:00'".to_string()])
84        .pool_opts(pool_opts);
85    let pool = mysql_async::Pool::new(opts_builder);
86    Ok(MysqlPool { pool })
87}
88
89#[async_trait::async_trait]
90impl Pool for MysqlPool {
91    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
92        let conn = self.pool.get_conn().await.map_err(|e| {
93            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
94        })?;
95        Ok(Arc::new(MysqlConnection {
96            conn: Arc::new(Mutex::new(conn)),
97        }))
98    }
99}
100
101#[derive(Debug)]
102pub struct MysqlConnection {
103    conn: Arc<Mutex<mysql_async::Conn>>,
104}
105
106#[async_trait::async_trait]
107impl Connection for MysqlConnection {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111
112    async fn infer_schema(&self, source: &TableSource) -> DFResult<RemoteSchemaRef> {
113        match source {
114            TableSource::Table(_table) => Err(DataFusionError::Execution(
115                "Mysql does not support infer schema for table".to_string(),
116            )),
117            TableSource::Query(_query) => {
118                let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
119                let mut conn = self.conn.lock().await;
120                let conn = &mut *conn;
121                let stmt = conn.prep(&sql).await.map_err(|e| {
122                    DataFusionError::Execution(format!(
123                        "Failed to prepare query {sql} on mysql: {e:?}"
124                    ))
125                })?;
126                let remote_schema = Arc::new(build_remote_schema(&stmt)?);
127                Ok(remote_schema)
128            }
129        }
130    }
131
132    async fn query(
133        &self,
134        conn_options: &ConnectionOptions,
135        source: &TableSource,
136        table_schema: SchemaRef,
137        projection: Option<&Vec<usize>>,
138        unparsed_filters: &[String],
139        limit: Option<usize>,
140    ) -> DFResult<SendableRecordBatchStream> {
141        let projected_schema = project_schema(&table_schema, projection)?;
142
143        let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
144        debug!("[remote-table] executing mysql query: {sql}");
145
146        let projection = projection.cloned();
147        let chunk_size = conn_options.stream_chunk_size();
148        let conn = Arc::clone(&self.conn);
149        let stream = Box::pin(stream! {
150            let mut conn = conn.lock().await;
151            let mut query_iter = conn
152                .query_iter(sql.clone())
153                .await
154                .map_err(|e| {
155                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
156                })?;
157
158            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
159                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
160                })? else {
161                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
162                return;
163            };
164
165            let mut chunked_stream = stream.chunks(chunk_size).boxed();
166
167            while let Some(chunk) = chunked_stream.next().await {
168                let rows = chunk
169                    .into_iter()
170                    .collect::<Result<Vec<_>, _>>()
171                    .map_err(|e| {
172                        DataFusionError::Execution(format!(
173                            "Failed to collect rows from mysql due to {e}",
174                        ))
175                    })?;
176
177                yield Ok::<_, DataFusionError>(rows)
178            }
179        });
180
181        let stream = stream.map(move |rows| {
182            let rows = rows?;
183            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
184        });
185
186        Ok(Box::pin(RecordBatchStreamAdapter::new(
187            projected_schema,
188            stream,
189        )))
190    }
191
192    async fn insert(
193        &self,
194        _conn_options: &ConnectionOptions,
195        _unparser: Arc<dyn Unparse>,
196        _table: &[String],
197        _remote_schema: RemoteSchemaRef,
198        _input: SendableRecordBatchStream,
199    ) -> DFResult<usize> {
200        Err(DataFusionError::Execution(
201            "Insert operation is not supported for mysql".to_string(),
202        ))
203    }
204}
205
206fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
207    let character_set = mysql_col.character_set();
208    let is_utf8_bin_character_set = character_set == 45;
209    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
210    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
211    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
212    let col_length = mysql_col.column_length();
213    match mysql_col.column_type() {
214        ColumnType::MYSQL_TYPE_TINY => {
215            if is_unsigned {
216                Ok(MysqlType::TinyIntUnsigned)
217            } else {
218                Ok(MysqlType::TinyInt)
219            }
220        }
221        ColumnType::MYSQL_TYPE_SHORT => {
222            if is_unsigned {
223                Ok(MysqlType::SmallIntUnsigned)
224            } else {
225                Ok(MysqlType::SmallInt)
226            }
227        }
228        ColumnType::MYSQL_TYPE_INT24 => {
229            if is_unsigned {
230                Ok(MysqlType::MediumIntUnsigned)
231            } else {
232                Ok(MysqlType::MediumInt)
233            }
234        }
235        ColumnType::MYSQL_TYPE_LONG => {
236            if is_unsigned {
237                Ok(MysqlType::IntegerUnsigned)
238            } else {
239                Ok(MysqlType::Integer)
240            }
241        }
242        ColumnType::MYSQL_TYPE_LONGLONG => {
243            if is_unsigned {
244                Ok(MysqlType::BigIntUnsigned)
245            } else {
246                Ok(MysqlType::BigInt)
247            }
248        }
249        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
250        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
251        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
252            let precision = (mysql_col.column_length() - 2) as u8;
253            let scale = mysql_col.decimals();
254            Ok(MysqlType::Decimal(precision, scale))
255        }
256        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
257        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
258        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
259        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
260        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
261        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
262        ColumnType::MYSQL_TYPE_STRING if is_binary => {
263            if is_utf8_bin_character_set {
264                Ok(MysqlType::Char)
265            } else {
266                Ok(MysqlType::Binary)
267            }
268        }
269        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
270        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
271            if is_utf8_bin_character_set {
272                Ok(MysqlType::Varchar)
273            } else {
274                Ok(MysqlType::Varbinary)
275            }
276        }
277        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
278        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
279        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
280            if is_utf8_bin_character_set {
281                Ok(MysqlType::Text(col_length))
282            } else {
283                Ok(MysqlType::Blob(col_length))
284            }
285        }
286        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
287        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
288        _ => Err(DataFusionError::NotImplemented(format!(
289            "Unsupported mysql type: {mysql_col:?}",
290        ))),
291    }
292}
293
294fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
295    let mut remote_fields = vec![];
296    for col in stmt.columns() {
297        remote_fields.push(RemoteField::new(
298            col.name_str().to_string(),
299            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
300            true,
301        ));
302    }
303    Ok(RemoteSchema::new(remote_fields))
304}
305
306macro_rules! handle_primitive_type {
307    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
308        let builder = $builder
309            .as_any_mut()
310            .downcast_mut::<$builder_ty>()
311            .unwrap_or_else(|| {
312                panic!(
313                    "Failed to downcast builder to {} for {:?} and {:?}",
314                    stringify!($builder_ty),
315                    $field,
316                    $col
317                )
318            });
319        let v = $row.get_opt::<$value_ty, usize>($index);
320
321        match v {
322            None => builder.append_null(),
323            Some(Ok(v)) => builder.append_value($convert(v)?),
324            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
325            Some(Err(e)) => {
326                return Err(DataFusionError::Execution(format!(
327                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
328                    stringify!($value_ty),
329                    $field,
330                    $col,
331                )));
332            }
333        }
334    }};
335}
336
337fn rows_to_batch(
338    rows: &[Row],
339    table_schema: &SchemaRef,
340    projection: Option<&Vec<usize>>,
341) -> DFResult<RecordBatch> {
342    let projected_schema = project_schema(table_schema, projection)?;
343    let mut array_builders = vec![];
344    for field in table_schema.fields() {
345        let builder = make_builder(field.data_type(), rows.len());
346        array_builders.push(builder);
347    }
348
349    for row in rows {
350        for (idx, field) in table_schema.fields.iter().enumerate() {
351            if !projections_contains(projection, idx) {
352                continue;
353            }
354            let builder = &mut array_builders[idx];
355            let col = row.columns_ref().get(idx);
356            match field.data_type() {
357                DataType::Int8 => {
358                    handle_primitive_type!(
359                        builder,
360                        field,
361                        col,
362                        Int8Builder,
363                        i8,
364                        row,
365                        idx,
366                        just_return
367                    );
368                }
369                DataType::Int16 => {
370                    handle_primitive_type!(
371                        builder,
372                        field,
373                        col,
374                        Int16Builder,
375                        i16,
376                        row,
377                        idx,
378                        just_return
379                    );
380                }
381                DataType::Int32 => {
382                    handle_primitive_type!(
383                        builder,
384                        field,
385                        col,
386                        Int32Builder,
387                        i32,
388                        row,
389                        idx,
390                        just_return
391                    );
392                }
393                DataType::Int64 => {
394                    handle_primitive_type!(
395                        builder,
396                        field,
397                        col,
398                        Int64Builder,
399                        i64,
400                        row,
401                        idx,
402                        just_return
403                    );
404                }
405                DataType::UInt8 => {
406                    handle_primitive_type!(
407                        builder,
408                        field,
409                        col,
410                        UInt8Builder,
411                        u8,
412                        row,
413                        idx,
414                        just_return
415                    );
416                }
417                DataType::UInt16 => {
418                    handle_primitive_type!(
419                        builder,
420                        field,
421                        col,
422                        UInt16Builder,
423                        u16,
424                        row,
425                        idx,
426                        just_return
427                    );
428                }
429                DataType::UInt32 => {
430                    handle_primitive_type!(
431                        builder,
432                        field,
433                        col,
434                        UInt32Builder,
435                        u32,
436                        row,
437                        idx,
438                        just_return
439                    );
440                }
441                DataType::UInt64 => {
442                    handle_primitive_type!(
443                        builder,
444                        field,
445                        col,
446                        UInt64Builder,
447                        u64,
448                        row,
449                        idx,
450                        just_return
451                    );
452                }
453                DataType::Float32 => {
454                    handle_primitive_type!(
455                        builder,
456                        field,
457                        col,
458                        Float32Builder,
459                        f32,
460                        row,
461                        idx,
462                        just_return
463                    );
464                }
465                DataType::Float64 => {
466                    handle_primitive_type!(
467                        builder,
468                        field,
469                        col,
470                        Float64Builder,
471                        f64,
472                        row,
473                        idx,
474                        just_return
475                    );
476                }
477                DataType::Decimal128(_precision, scale) => {
478                    handle_primitive_type!(
479                        builder,
480                        field,
481                        col,
482                        Decimal128Builder,
483                        BigDecimal,
484                        row,
485                        idx,
486                        |v: BigDecimal| {
487                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
488                                DataFusionError::Execution(format!(
489                                    "Failed to convert BigDecimal {v:?} to i128"
490                                ))
491                            })
492                        }
493                    );
494                }
495                DataType::Decimal256(_precision, _scale) => {
496                    handle_primitive_type!(
497                        builder,
498                        field,
499                        col,
500                        Decimal256Builder,
501                        BigDecimal,
502                        row,
503                        idx,
504                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
505                    );
506                }
507                DataType::Date32 => {
508                    handle_primitive_type!(
509                        builder,
510                        field,
511                        col,
512                        Date32Builder,
513                        chrono::NaiveDate,
514                        row,
515                        idx,
516                        |v: chrono::NaiveDate| {
517                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
518                        }
519                    );
520                }
521                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
522                    match tz_opt {
523                        None => {}
524                        Some(tz) => {
525                            if !tz.eq_ignore_ascii_case("utc") {
526                                return Err(DataFusionError::NotImplemented(format!(
527                                    "Unsupported data type {:?} for col: {:?}",
528                                    field.data_type(),
529                                    col
530                                )));
531                            }
532                        }
533                    }
534                    handle_primitive_type!(
535                        builder,
536                        field,
537                        col,
538                        TimestampMicrosecondBuilder,
539                        time::PrimitiveDateTime,
540                        row,
541                        idx,
542                        |v: time::PrimitiveDateTime| {
543                            let timestamp_micros =
544                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
545                            Ok::<_, DataFusionError>(timestamp_micros)
546                        }
547                    );
548                }
549                DataType::Time32(TimeUnit::Second) => {
550                    handle_primitive_type!(
551                        builder,
552                        field,
553                        col,
554                        Time32SecondBuilder,
555                        chrono::NaiveTime,
556                        row,
557                        idx,
558                        |v: chrono::NaiveTime| {
559                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
560                        }
561                    );
562                }
563                DataType::Time64(TimeUnit::Nanosecond) => {
564                    handle_primitive_type!(
565                        builder,
566                        field,
567                        col,
568                        Time64NanosecondBuilder,
569                        chrono::NaiveTime,
570                        row,
571                        idx,
572                        |v: chrono::NaiveTime| {
573                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
574                                + i64::from(v.nanosecond());
575                            Ok::<_, DataFusionError>(t)
576                        }
577                    );
578                }
579                DataType::Utf8 => {
580                    handle_primitive_type!(
581                        builder,
582                        field,
583                        col,
584                        StringBuilder,
585                        String,
586                        row,
587                        idx,
588                        just_return
589                    );
590                }
591                DataType::LargeUtf8 => {
592                    handle_primitive_type!(
593                        builder,
594                        field,
595                        col,
596                        LargeStringBuilder,
597                        String,
598                        row,
599                        idx,
600                        just_return
601                    );
602                }
603                DataType::Binary => {
604                    handle_primitive_type!(
605                        builder,
606                        field,
607                        col,
608                        BinaryBuilder,
609                        Vec<u8>,
610                        row,
611                        idx,
612                        just_return
613                    );
614                }
615                DataType::LargeBinary => {
616                    handle_primitive_type!(
617                        builder,
618                        field,
619                        col,
620                        LargeBinaryBuilder,
621                        Vec<u8>,
622                        row,
623                        idx,
624                        just_return
625                    );
626                }
627                _ => {
628                    return Err(DataFusionError::NotImplemented(format!(
629                        "Unsupported data type {:?} for col: {:?}",
630                        field.data_type(),
631                        col
632                    )));
633                }
634            }
635        }
636    }
637    let projected_columns = array_builders
638        .into_iter()
639        .enumerate()
640        .filter(|(idx, _)| projections_contains(projection, *idx))
641        .map(|(_, mut builder)| builder.finish())
642        .collect::<Vec<ArrayRef>>();
643    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
644    Ok(RecordBatch::try_new_with_options(
645        projected_schema,
646        projected_columns,
647        &options,
648    )?)
649}
650
651fn to_decimal_256(decimal: &BigDecimal) -> i256 {
652    let (bigint_value, _) = decimal.as_bigint_and_exponent();
653    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
654
655    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
656    let fill_byte = if is_negative { 0xFF } else { 0x00 };
657
658    if bigint_bytes.len() > 32 {
659        bigint_bytes.truncate(32);
660    } else {
661        bigint_bytes.resize(32, fill_byte);
662    };
663
664    let mut array = [0u8; 32];
665    array.copy_from_slice(&bigint_bytes);
666
667    i256::from_le_bytes(array)
668}