datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5    RemoteSchemaRef, RemoteSource, RemoteType, Unparse,
6};
7use async_stream::stream;
8use bigdecimal::{BigDecimal, num_bigint};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
12    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
13    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
14    Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
15    UInt32Builder, UInt64Builder, make_builder,
16};
17use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
18use datafusion::common::{DataFusionError, project_schema};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use derive_getters::Getters;
22use derive_with::With;
23use futures::StreamExt;
24use futures::lock::Mutex;
25use log::debug;
26use mysql_async::consts::{ColumnFlags, ColumnType};
27use mysql_async::prelude::Queryable;
28use mysql_async::{Column, FromValueError, Row, Value};
29use std::any::Any;
30use std::sync::Arc;
31
32#[derive(Debug, Clone, With, Getters)]
33pub struct MysqlConnectionOptions {
34    pub(crate) host: String,
35    pub(crate) port: u16,
36    pub(crate) username: String,
37    pub(crate) password: String,
38    pub(crate) database: Option<String>,
39    pub(crate) pool_max_size: usize,
40    pub(crate) stream_chunk_size: usize,
41}
42
43impl MysqlConnectionOptions {
44    pub fn new(
45        host: impl Into<String>,
46        port: u16,
47        username: impl Into<String>,
48        password: impl Into<String>,
49    ) -> Self {
50        Self {
51            host: host.into(),
52            port,
53            username: username.into(),
54            password: password.into(),
55            database: None,
56            pool_max_size: 10,
57            stream_chunk_size: 2048,
58        }
59    }
60}
61
62impl From<MysqlConnectionOptions> for ConnectionOptions {
63    fn from(options: MysqlConnectionOptions) -> Self {
64        ConnectionOptions::Mysql(options)
65    }
66}
67
68#[derive(Debug)]
69pub struct MysqlPool {
70    pool: mysql_async::Pool,
71}
72
73pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
74    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
75        mysql_async::PoolConstraints::new(0, options.pool_max_size)
76            .expect("Failed to create pool constraints"),
77    );
78    let opts_builder = mysql_async::OptsBuilder::default()
79        .ip_or_hostname(options.host.clone())
80        .tcp_port(options.port)
81        .user(Some(options.username.clone()))
82        .pass(Some(options.password.clone()))
83        .db_name(options.database.clone())
84        .init(vec!["set time_zone='+00:00'".to_string()])
85        .pool_opts(pool_opts);
86    let pool = mysql_async::Pool::new(opts_builder);
87    Ok(MysqlPool { pool })
88}
89
90#[async_trait::async_trait]
91impl Pool for MysqlPool {
92    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
93        let conn = self.pool.get_conn().await.map_err(|e| {
94            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
95        })?;
96        Ok(Arc::new(MysqlConnection {
97            conn: Arc::new(Mutex::new(conn)),
98        }))
99    }
100}
101
102#[derive(Debug)]
103pub struct MysqlConnection {
104    conn: Arc<Mutex<mysql_async::Conn>>,
105}
106
107#[async_trait::async_trait]
108impl Connection for MysqlConnection {
109    fn as_any(&self) -> &dyn Any {
110        self
111    }
112
113    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
114        let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
115        let mut conn = self.conn.lock().await;
116        let conn = &mut *conn;
117        let stmt = conn.prep(&sql).await.map_err(|e| {
118            DataFusionError::Execution(format!("Failed to prepare query {sql} on mysql: {e:?}"))
119        })?;
120        let remote_schema = Arc::new(build_remote_schema(&stmt)?);
121        Ok(remote_schema)
122    }
123
124    async fn query(
125        &self,
126        conn_options: &ConnectionOptions,
127        source: &RemoteSource,
128        table_schema: SchemaRef,
129        projection: Option<&Vec<usize>>,
130        unparsed_filters: &[String],
131        limit: Option<usize>,
132    ) -> DFResult<SendableRecordBatchStream> {
133        let projected_schema = project_schema(&table_schema, projection)?;
134
135        let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
136        debug!("[remote-table] executing mysql query: {sql}");
137
138        let projection = projection.cloned();
139        let chunk_size = conn_options.stream_chunk_size();
140        let conn = Arc::clone(&self.conn);
141        let stream = Box::pin(stream! {
142            let mut conn = conn.lock().await;
143            let mut query_iter = conn
144                .query_iter(sql.clone())
145                .await
146                .map_err(|e| {
147                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
148                })?;
149
150            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
151                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
152                })? else {
153                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
154                return;
155            };
156
157            let mut chunked_stream = stream.chunks(chunk_size).boxed();
158
159            while let Some(chunk) = chunked_stream.next().await {
160                let rows = chunk
161                    .into_iter()
162                    .collect::<Result<Vec<_>, _>>()
163                    .map_err(|e| {
164                        DataFusionError::Execution(format!(
165                            "Failed to collect rows from mysql due to {e}",
166                        ))
167                    })?;
168
169                yield Ok::<_, DataFusionError>(rows)
170            }
171        });
172
173        let stream = stream.map(move |rows| {
174            let rows = rows?;
175            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
176        });
177
178        Ok(Box::pin(RecordBatchStreamAdapter::new(
179            projected_schema,
180            stream,
181        )))
182    }
183
184    async fn insert(
185        &self,
186        _conn_options: &ConnectionOptions,
187        _unparser: Arc<dyn Unparse>,
188        _table: &[String],
189        _remote_schema: RemoteSchemaRef,
190        _input: SendableRecordBatchStream,
191    ) -> DFResult<usize> {
192        Err(DataFusionError::Execution(
193            "Insert operation is not supported for mysql".to_string(),
194        ))
195    }
196}
197
198fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
199    let character_set = mysql_col.character_set();
200    let is_utf8_bin_character_set = character_set == 45;
201    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
202    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
203    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
204    let col_length = mysql_col.column_length();
205    match mysql_col.column_type() {
206        ColumnType::MYSQL_TYPE_TINY => {
207            if is_unsigned {
208                Ok(MysqlType::TinyIntUnsigned)
209            } else {
210                Ok(MysqlType::TinyInt)
211            }
212        }
213        ColumnType::MYSQL_TYPE_SHORT => {
214            if is_unsigned {
215                Ok(MysqlType::SmallIntUnsigned)
216            } else {
217                Ok(MysqlType::SmallInt)
218            }
219        }
220        ColumnType::MYSQL_TYPE_INT24 => {
221            if is_unsigned {
222                Ok(MysqlType::MediumIntUnsigned)
223            } else {
224                Ok(MysqlType::MediumInt)
225            }
226        }
227        ColumnType::MYSQL_TYPE_LONG => {
228            if is_unsigned {
229                Ok(MysqlType::IntegerUnsigned)
230            } else {
231                Ok(MysqlType::Integer)
232            }
233        }
234        ColumnType::MYSQL_TYPE_LONGLONG => {
235            if is_unsigned {
236                Ok(MysqlType::BigIntUnsigned)
237            } else {
238                Ok(MysqlType::BigInt)
239            }
240        }
241        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
242        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
243        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
244            let precision = (mysql_col.column_length() - 2) as u8;
245            let scale = mysql_col.decimals();
246            Ok(MysqlType::Decimal(precision, scale))
247        }
248        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
249        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
250        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
251        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
252        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
253        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
254        ColumnType::MYSQL_TYPE_STRING if is_binary => {
255            if is_utf8_bin_character_set {
256                Ok(MysqlType::Char)
257            } else {
258                Ok(MysqlType::Binary)
259            }
260        }
261        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
262        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
263            if is_utf8_bin_character_set {
264                Ok(MysqlType::Varchar)
265            } else {
266                Ok(MysqlType::Varbinary)
267            }
268        }
269        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
270        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
271        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
272            if is_utf8_bin_character_set {
273                Ok(MysqlType::Text(col_length))
274            } else {
275                Ok(MysqlType::Blob(col_length))
276            }
277        }
278        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
279        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
280        _ => Err(DataFusionError::NotImplemented(format!(
281            "Unsupported mysql type: {mysql_col:?}",
282        ))),
283    }
284}
285
286fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
287    let mut remote_fields = vec![];
288    for col in stmt.columns() {
289        remote_fields.push(RemoteField::new(
290            col.name_str().to_string(),
291            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
292            true,
293        ));
294    }
295    Ok(RemoteSchema::new(remote_fields))
296}
297
298macro_rules! handle_primitive_type {
299    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
300        let builder = $builder
301            .as_any_mut()
302            .downcast_mut::<$builder_ty>()
303            .unwrap_or_else(|| {
304                panic!(
305                    "Failed to downcast builder to {} for {:?} and {:?}",
306                    stringify!($builder_ty),
307                    $field,
308                    $col
309                )
310            });
311        let v = $row.get_opt::<$value_ty, usize>($index);
312
313        match v {
314            None => builder.append_null(),
315            Some(Ok(v)) => builder.append_value($convert(v)?),
316            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
317            Some(Err(e)) => {
318                return Err(DataFusionError::Execution(format!(
319                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
320                    stringify!($value_ty),
321                    $field,
322                    $col,
323                )));
324            }
325        }
326    }};
327}
328
329fn rows_to_batch(
330    rows: &[Row],
331    table_schema: &SchemaRef,
332    projection: Option<&Vec<usize>>,
333) -> DFResult<RecordBatch> {
334    let projected_schema = project_schema(table_schema, projection)?;
335    let mut array_builders = vec![];
336    for field in table_schema.fields() {
337        let builder = make_builder(field.data_type(), rows.len());
338        array_builders.push(builder);
339    }
340
341    for row in rows {
342        for (idx, field) in table_schema.fields.iter().enumerate() {
343            if !projections_contains(projection, idx) {
344                continue;
345            }
346            let builder = &mut array_builders[idx];
347            let col = row.columns_ref().get(idx);
348            match field.data_type() {
349                DataType::Int8 => {
350                    handle_primitive_type!(
351                        builder,
352                        field,
353                        col,
354                        Int8Builder,
355                        i8,
356                        row,
357                        idx,
358                        just_return
359                    );
360                }
361                DataType::Int16 => {
362                    handle_primitive_type!(
363                        builder,
364                        field,
365                        col,
366                        Int16Builder,
367                        i16,
368                        row,
369                        idx,
370                        just_return
371                    );
372                }
373                DataType::Int32 => {
374                    handle_primitive_type!(
375                        builder,
376                        field,
377                        col,
378                        Int32Builder,
379                        i32,
380                        row,
381                        idx,
382                        just_return
383                    );
384                }
385                DataType::Int64 => {
386                    handle_primitive_type!(
387                        builder,
388                        field,
389                        col,
390                        Int64Builder,
391                        i64,
392                        row,
393                        idx,
394                        just_return
395                    );
396                }
397                DataType::UInt8 => {
398                    handle_primitive_type!(
399                        builder,
400                        field,
401                        col,
402                        UInt8Builder,
403                        u8,
404                        row,
405                        idx,
406                        just_return
407                    );
408                }
409                DataType::UInt16 => {
410                    handle_primitive_type!(
411                        builder,
412                        field,
413                        col,
414                        UInt16Builder,
415                        u16,
416                        row,
417                        idx,
418                        just_return
419                    );
420                }
421                DataType::UInt32 => {
422                    handle_primitive_type!(
423                        builder,
424                        field,
425                        col,
426                        UInt32Builder,
427                        u32,
428                        row,
429                        idx,
430                        just_return
431                    );
432                }
433                DataType::UInt64 => {
434                    handle_primitive_type!(
435                        builder,
436                        field,
437                        col,
438                        UInt64Builder,
439                        u64,
440                        row,
441                        idx,
442                        just_return
443                    );
444                }
445                DataType::Float32 => {
446                    handle_primitive_type!(
447                        builder,
448                        field,
449                        col,
450                        Float32Builder,
451                        f32,
452                        row,
453                        idx,
454                        just_return
455                    );
456                }
457                DataType::Float64 => {
458                    handle_primitive_type!(
459                        builder,
460                        field,
461                        col,
462                        Float64Builder,
463                        f64,
464                        row,
465                        idx,
466                        just_return
467                    );
468                }
469                DataType::Decimal128(_precision, scale) => {
470                    handle_primitive_type!(
471                        builder,
472                        field,
473                        col,
474                        Decimal128Builder,
475                        BigDecimal,
476                        row,
477                        idx,
478                        |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
479                    );
480                }
481                DataType::Decimal256(_precision, _scale) => {
482                    handle_primitive_type!(
483                        builder,
484                        field,
485                        col,
486                        Decimal256Builder,
487                        BigDecimal,
488                        row,
489                        idx,
490                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
491                    );
492                }
493                DataType::Date32 => {
494                    handle_primitive_type!(
495                        builder,
496                        field,
497                        col,
498                        Date32Builder,
499                        chrono::NaiveDate,
500                        row,
501                        idx,
502                        |v: chrono::NaiveDate| {
503                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
504                        }
505                    );
506                }
507                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
508                    match tz_opt {
509                        None => {}
510                        Some(tz) => {
511                            if !tz.eq_ignore_ascii_case("utc") {
512                                return Err(DataFusionError::NotImplemented(format!(
513                                    "Unsupported data type {:?} for col: {:?}",
514                                    field.data_type(),
515                                    col
516                                )));
517                            }
518                        }
519                    }
520                    handle_primitive_type!(
521                        builder,
522                        field,
523                        col,
524                        TimestampMicrosecondBuilder,
525                        time::PrimitiveDateTime,
526                        row,
527                        idx,
528                        |v: time::PrimitiveDateTime| {
529                            let timestamp_micros =
530                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
531                            Ok::<_, DataFusionError>(timestamp_micros)
532                        }
533                    );
534                }
535                DataType::Time32(TimeUnit::Second) => {
536                    handle_primitive_type!(
537                        builder,
538                        field,
539                        col,
540                        Time32SecondBuilder,
541                        chrono::NaiveTime,
542                        row,
543                        idx,
544                        |v: chrono::NaiveTime| {
545                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
546                        }
547                    );
548                }
549                DataType::Time64(TimeUnit::Nanosecond) => {
550                    handle_primitive_type!(
551                        builder,
552                        field,
553                        col,
554                        Time64NanosecondBuilder,
555                        chrono::NaiveTime,
556                        row,
557                        idx,
558                        |v: chrono::NaiveTime| {
559                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
560                                + i64::from(v.nanosecond());
561                            Ok::<_, DataFusionError>(t)
562                        }
563                    );
564                }
565                DataType::Utf8 => {
566                    handle_primitive_type!(
567                        builder,
568                        field,
569                        col,
570                        StringBuilder,
571                        String,
572                        row,
573                        idx,
574                        just_return
575                    );
576                }
577                DataType::LargeUtf8 => {
578                    handle_primitive_type!(
579                        builder,
580                        field,
581                        col,
582                        LargeStringBuilder,
583                        String,
584                        row,
585                        idx,
586                        just_return
587                    );
588                }
589                DataType::Binary => {
590                    handle_primitive_type!(
591                        builder,
592                        field,
593                        col,
594                        BinaryBuilder,
595                        Vec<u8>,
596                        row,
597                        idx,
598                        just_return
599                    );
600                }
601                DataType::LargeBinary => {
602                    handle_primitive_type!(
603                        builder,
604                        field,
605                        col,
606                        LargeBinaryBuilder,
607                        Vec<u8>,
608                        row,
609                        idx,
610                        just_return
611                    );
612                }
613                _ => {
614                    return Err(DataFusionError::NotImplemented(format!(
615                        "Unsupported data type {:?} for col: {:?}",
616                        field.data_type(),
617                        col
618                    )));
619                }
620            }
621        }
622    }
623    let projected_columns = array_builders
624        .into_iter()
625        .enumerate()
626        .filter(|(idx, _)| projections_contains(projection, *idx))
627        .map(|(_, mut builder)| builder.finish())
628        .collect::<Vec<ArrayRef>>();
629    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
630    Ok(RecordBatch::try_new_with_options(
631        projected_schema,
632        projected_columns,
633        &options,
634    )?)
635}
636
637fn to_decimal_256(decimal: &BigDecimal) -> i256 {
638    let (bigint_value, _) = decimal.as_bigint_and_exponent();
639    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
640
641    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
642    let fill_byte = if is_negative { 0xFF } else { 0x00 };
643
644    if bigint_bytes.len() > 32 {
645        bigint_bytes.truncate(32);
646    } else {
647        bigint_bytes.resize(32, fill_byte);
648    };
649
650    let mut array = [0u8; 32];
651    array.copy_from_slice(&bigint_bytes);
652
653    i256::from_le_bytes(array)
654}