datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13    Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33    pub(crate) host: String,
34    pub(crate) port: u16,
35    pub(crate) username: String,
36    pub(crate) password: String,
37    pub(crate) database: Option<String>,
38    pub(crate) pool_max_size: usize,
39    pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43    pub fn new(
44        host: impl Into<String>,
45        port: u16,
46        username: impl Into<String>,
47        password: impl Into<String>,
48    ) -> Self {
49        Self {
50            host: host.into(),
51            port,
52            username: username.into(),
53            password: password.into(),
54            database: None,
55            pool_max_size: 10,
56            stream_chunk_size: 2048,
57        }
58    }
59}
60
61impl From<MysqlConnectionOptions> for ConnectionOptions {
62    fn from(options: MysqlConnectionOptions) -> Self {
63        ConnectionOptions::Mysql(options)
64    }
65}
66
67#[derive(Debug)]
68pub struct MysqlPool {
69    pool: mysql_async::Pool,
70}
71
72pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
73    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
74        mysql_async::PoolConstraints::new(0, options.pool_max_size)
75            .expect("Failed to create pool constraints"),
76    );
77    let opts_builder = mysql_async::OptsBuilder::default()
78        .ip_or_hostname(options.host.clone())
79        .tcp_port(options.port)
80        .user(Some(options.username.clone()))
81        .pass(Some(options.password.clone()))
82        .db_name(options.database.clone())
83        .init(vec!["set time_zone='+00:00'".to_string()])
84        .pool_opts(pool_opts);
85    let pool = mysql_async::Pool::new(opts_builder);
86    Ok(MysqlPool { pool })
87}
88
89#[async_trait::async_trait]
90impl Pool for MysqlPool {
91    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
92        let conn = self.pool.get_conn().await.map_err(|e| {
93            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
94        })?;
95        Ok(Arc::new(MysqlConnection {
96            conn: Arc::new(Mutex::new(conn)),
97        }))
98    }
99}
100
101#[derive(Debug)]
102pub struct MysqlConnection {
103    conn: Arc<Mutex<mysql_async::Conn>>,
104}
105
106#[async_trait::async_trait]
107impl Connection for MysqlConnection {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111
112    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
113        let sql = RemoteDbType::Mysql.query_limit_1(sql);
114        let mut conn = self.conn.lock().await;
115        let conn = &mut *conn;
116        let stmt = conn.prep(&sql).await.map_err(|e| {
117            DataFusionError::Execution(format!("Failed to prepare query {sql} on mysql: {e:?}"))
118        })?;
119        let remote_schema = Arc::new(build_remote_schema(&stmt)?);
120        Ok(remote_schema)
121    }
122
123    async fn query(
124        &self,
125        conn_options: &ConnectionOptions,
126        sql: &str,
127        table_schema: SchemaRef,
128        projection: Option<&Vec<usize>>,
129        unparsed_filters: &[String],
130        limit: Option<usize>,
131    ) -> DFResult<SendableRecordBatchStream> {
132        let projected_schema = project_schema(&table_schema, projection)?;
133
134        let sql = RemoteDbType::Mysql.rewrite_query(sql, unparsed_filters, limit);
135        debug!("[remote-table] executing mysql query: {sql}");
136
137        let projection = projection.cloned();
138        let chunk_size = conn_options.stream_chunk_size();
139        let conn = Arc::clone(&self.conn);
140        let stream = Box::pin(stream! {
141            let mut conn = conn.lock().await;
142            let mut query_iter = conn
143                .query_iter(sql.clone())
144                .await
145                .map_err(|e| {
146                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
147                })?;
148
149            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
150                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
151                })? else {
152                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
153                return;
154            };
155
156            let mut chunked_stream = stream.chunks(chunk_size).boxed();
157
158            while let Some(chunk) = chunked_stream.next().await {
159                let rows = chunk
160                    .into_iter()
161                    .collect::<Result<Vec<_>, _>>()
162                    .map_err(|e| {
163                        DataFusionError::Execution(format!(
164                            "Failed to collect rows from mysql due to {e}",
165                        ))
166                    })?;
167
168                yield Ok::<_, DataFusionError>(rows)
169            }
170        });
171
172        let stream = stream.map(move |rows| {
173            let rows = rows?;
174            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
175        });
176
177        Ok(Box::pin(RecordBatchStreamAdapter::new(
178            projected_schema,
179            stream,
180        )))
181    }
182}
183
184fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
185    let character_set = mysql_col.character_set();
186    let is_utf8_bin_character_set = character_set == 45;
187    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
188    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
189    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
190    let col_length = mysql_col.column_length();
191    match mysql_col.column_type() {
192        ColumnType::MYSQL_TYPE_TINY => {
193            if is_unsigned {
194                Ok(MysqlType::TinyIntUnsigned)
195            } else {
196                Ok(MysqlType::TinyInt)
197            }
198        }
199        ColumnType::MYSQL_TYPE_SHORT => {
200            if is_unsigned {
201                Ok(MysqlType::SmallIntUnsigned)
202            } else {
203                Ok(MysqlType::SmallInt)
204            }
205        }
206        ColumnType::MYSQL_TYPE_INT24 => {
207            if is_unsigned {
208                Ok(MysqlType::MediumIntUnsigned)
209            } else {
210                Ok(MysqlType::MediumInt)
211            }
212        }
213        ColumnType::MYSQL_TYPE_LONG => {
214            if is_unsigned {
215                Ok(MysqlType::IntegerUnsigned)
216            } else {
217                Ok(MysqlType::Integer)
218            }
219        }
220        ColumnType::MYSQL_TYPE_LONGLONG => {
221            if is_unsigned {
222                Ok(MysqlType::BigIntUnsigned)
223            } else {
224                Ok(MysqlType::BigInt)
225            }
226        }
227        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
228        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
229        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
230            let precision = (mysql_col.column_length() - 2) as u8;
231            let scale = mysql_col.decimals();
232            Ok(MysqlType::Decimal(precision, scale))
233        }
234        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
235        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
236        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
237        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
238        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
239        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
240        ColumnType::MYSQL_TYPE_STRING if is_binary => {
241            if is_utf8_bin_character_set {
242                Ok(MysqlType::Char)
243            } else {
244                Ok(MysqlType::Binary)
245            }
246        }
247        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
248        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
249            if is_utf8_bin_character_set {
250                Ok(MysqlType::Varchar)
251            } else {
252                Ok(MysqlType::Varbinary)
253            }
254        }
255        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
256        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
257        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
258            if is_utf8_bin_character_set {
259                Ok(MysqlType::Text(col_length))
260            } else {
261                Ok(MysqlType::Blob(col_length))
262            }
263        }
264        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
265        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
266        _ => Err(DataFusionError::NotImplemented(format!(
267            "Unsupported mysql type: {mysql_col:?}",
268        ))),
269    }
270}
271
272fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
273    let mut remote_fields = vec![];
274    for col in stmt.columns() {
275        remote_fields.push(RemoteField::new(
276            col.name_str().to_string(),
277            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
278            true,
279        ));
280    }
281    Ok(RemoteSchema::new(remote_fields))
282}
283
284macro_rules! handle_primitive_type {
285    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
286        let builder = $builder
287            .as_any_mut()
288            .downcast_mut::<$builder_ty>()
289            .unwrap_or_else(|| {
290                panic!(
291                    "Failed to downcast builder to {} for {:?} and {:?}",
292                    stringify!($builder_ty),
293                    $field,
294                    $col
295                )
296            });
297        let v = $row.get_opt::<$value_ty, usize>($index);
298
299        match v {
300            None => builder.append_null(),
301            Some(Ok(v)) => builder.append_value($convert(v)?),
302            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
303            Some(Err(e)) => {
304                return Err(DataFusionError::Execution(format!(
305                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
306                    stringify!($value_ty),
307                    $field,
308                    $col,
309                )));
310            }
311        }
312    }};
313}
314
315fn rows_to_batch(
316    rows: &[Row],
317    table_schema: &SchemaRef,
318    projection: Option<&Vec<usize>>,
319) -> DFResult<RecordBatch> {
320    let projected_schema = project_schema(table_schema, projection)?;
321    let mut array_builders = vec![];
322    for field in table_schema.fields() {
323        let builder = make_builder(field.data_type(), rows.len());
324        array_builders.push(builder);
325    }
326
327    for row in rows {
328        for (idx, field) in table_schema.fields.iter().enumerate() {
329            if !projections_contains(projection, idx) {
330                continue;
331            }
332            let builder = &mut array_builders[idx];
333            let col = row.columns_ref().get(idx);
334            match field.data_type() {
335                DataType::Int8 => {
336                    handle_primitive_type!(
337                        builder,
338                        field,
339                        col,
340                        Int8Builder,
341                        i8,
342                        row,
343                        idx,
344                        just_return
345                    );
346                }
347                DataType::Int16 => {
348                    handle_primitive_type!(
349                        builder,
350                        field,
351                        col,
352                        Int16Builder,
353                        i16,
354                        row,
355                        idx,
356                        just_return
357                    );
358                }
359                DataType::Int32 => {
360                    handle_primitive_type!(
361                        builder,
362                        field,
363                        col,
364                        Int32Builder,
365                        i32,
366                        row,
367                        idx,
368                        just_return
369                    );
370                }
371                DataType::Int64 => {
372                    handle_primitive_type!(
373                        builder,
374                        field,
375                        col,
376                        Int64Builder,
377                        i64,
378                        row,
379                        idx,
380                        just_return
381                    );
382                }
383                DataType::UInt8 => {
384                    handle_primitive_type!(
385                        builder,
386                        field,
387                        col,
388                        UInt8Builder,
389                        u8,
390                        row,
391                        idx,
392                        just_return
393                    );
394                }
395                DataType::UInt16 => {
396                    handle_primitive_type!(
397                        builder,
398                        field,
399                        col,
400                        UInt16Builder,
401                        u16,
402                        row,
403                        idx,
404                        just_return
405                    );
406                }
407                DataType::UInt32 => {
408                    handle_primitive_type!(
409                        builder,
410                        field,
411                        col,
412                        UInt32Builder,
413                        u32,
414                        row,
415                        idx,
416                        just_return
417                    );
418                }
419                DataType::UInt64 => {
420                    handle_primitive_type!(
421                        builder,
422                        field,
423                        col,
424                        UInt64Builder,
425                        u64,
426                        row,
427                        idx,
428                        just_return
429                    );
430                }
431                DataType::Float32 => {
432                    handle_primitive_type!(
433                        builder,
434                        field,
435                        col,
436                        Float32Builder,
437                        f32,
438                        row,
439                        idx,
440                        just_return
441                    );
442                }
443                DataType::Float64 => {
444                    handle_primitive_type!(
445                        builder,
446                        field,
447                        col,
448                        Float64Builder,
449                        f64,
450                        row,
451                        idx,
452                        just_return
453                    );
454                }
455                DataType::Decimal128(_precision, scale) => {
456                    handle_primitive_type!(
457                        builder,
458                        field,
459                        col,
460                        Decimal128Builder,
461                        BigDecimal,
462                        row,
463                        idx,
464                        |v: BigDecimal| {
465                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
466                                DataFusionError::Execution(format!(
467                                    "Failed to convert BigDecimal {v:?} to i128"
468                                ))
469                            })
470                        }
471                    );
472                }
473                DataType::Decimal256(_precision, _scale) => {
474                    handle_primitive_type!(
475                        builder,
476                        field,
477                        col,
478                        Decimal256Builder,
479                        BigDecimal,
480                        row,
481                        idx,
482                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
483                    );
484                }
485                DataType::Date32 => {
486                    handle_primitive_type!(
487                        builder,
488                        field,
489                        col,
490                        Date32Builder,
491                        chrono::NaiveDate,
492                        row,
493                        idx,
494                        |v: chrono::NaiveDate| {
495                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
496                        }
497                    );
498                }
499                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
500                    match tz_opt {
501                        None => {}
502                        Some(tz) => {
503                            if !tz.eq_ignore_ascii_case("utc") {
504                                return Err(DataFusionError::NotImplemented(format!(
505                                    "Unsupported data type {:?} for col: {:?}",
506                                    field.data_type(),
507                                    col
508                                )));
509                            }
510                        }
511                    }
512                    handle_primitive_type!(
513                        builder,
514                        field,
515                        col,
516                        TimestampMicrosecondBuilder,
517                        time::PrimitiveDateTime,
518                        row,
519                        idx,
520                        |v: time::PrimitiveDateTime| {
521                            let timestamp_micros =
522                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
523                            Ok::<_, DataFusionError>(timestamp_micros)
524                        }
525                    );
526                }
527                DataType::Time32(TimeUnit::Second) => {
528                    handle_primitive_type!(
529                        builder,
530                        field,
531                        col,
532                        Time32SecondBuilder,
533                        chrono::NaiveTime,
534                        row,
535                        idx,
536                        |v: chrono::NaiveTime| {
537                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
538                        }
539                    );
540                }
541                DataType::Time64(TimeUnit::Nanosecond) => {
542                    handle_primitive_type!(
543                        builder,
544                        field,
545                        col,
546                        Time64NanosecondBuilder,
547                        chrono::NaiveTime,
548                        row,
549                        idx,
550                        |v: chrono::NaiveTime| {
551                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
552                                + i64::from(v.nanosecond());
553                            Ok::<_, DataFusionError>(t)
554                        }
555                    );
556                }
557                DataType::Utf8 => {
558                    handle_primitive_type!(
559                        builder,
560                        field,
561                        col,
562                        StringBuilder,
563                        String,
564                        row,
565                        idx,
566                        just_return
567                    );
568                }
569                DataType::LargeUtf8 => {
570                    handle_primitive_type!(
571                        builder,
572                        field,
573                        col,
574                        LargeStringBuilder,
575                        String,
576                        row,
577                        idx,
578                        just_return
579                    );
580                }
581                DataType::Binary => {
582                    handle_primitive_type!(
583                        builder,
584                        field,
585                        col,
586                        BinaryBuilder,
587                        Vec<u8>,
588                        row,
589                        idx,
590                        just_return
591                    );
592                }
593                DataType::LargeBinary => {
594                    handle_primitive_type!(
595                        builder,
596                        field,
597                        col,
598                        LargeBinaryBuilder,
599                        Vec<u8>,
600                        row,
601                        idx,
602                        just_return
603                    );
604                }
605                _ => {
606                    return Err(DataFusionError::NotImplemented(format!(
607                        "Unsupported data type {:?} for col: {:?}",
608                        field.data_type(),
609                        col
610                    )));
611                }
612            }
613        }
614    }
615    let projected_columns = array_builders
616        .into_iter()
617        .enumerate()
618        .filter(|(idx, _)| projections_contains(projection, *idx))
619        .map(|(_, mut builder)| builder.finish())
620        .collect::<Vec<ArrayRef>>();
621    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
622    Ok(RecordBatch::try_new_with_options(
623        projected_schema,
624        projected_columns,
625        &options,
626    )?)
627}
628
629fn to_decimal_256(decimal: &BigDecimal) -> i256 {
630    let (bigint_value, _) = decimal.as_bigint_and_exponent();
631    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
632
633    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
634    let fill_byte = if is_negative { 0xFF } else { 0x00 };
635
636    if bigint_bytes.len() > 32 {
637        bigint_bytes.truncate(32);
638    } else {
639        bigint_bytes.resize(32, fill_byte);
640    };
641
642    let mut array = [0u8; 32];
643    array.copy_from_slice(&bigint_bytes);
644
645    i256::from_le_bytes(array)
646}