datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4    Connection, ConnectionOptions, DFResult, Literalize, MysqlConnectionOptions, MysqlType, Pool,
5    PoolState, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
6};
7use arrow::array::{
8    ArrayRef, BinaryBuilder, BinaryViewBuilder, Date32Builder, Decimal128Builder,
9    Decimal256Builder, Float32Builder, Float64Builder, Int8Builder, Int16Builder, Int32Builder,
10    Int64Builder, LargeBinaryBuilder, LargeStringBuilder, RecordBatch, RecordBatchOptions,
11    StringBuilder, StringViewBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
12    TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder, UInt32Builder, UInt64Builder,
13    make_builder,
14};
15use arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
16use async_stream::stream;
17use bigdecimal::{BigDecimal, num_bigint};
18use chrono::Timelike;
19use datafusion_common::{DataFusionError, project_schema};
20use datafusion_execution::SendableRecordBatchStream;
21use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug)]
32pub struct MysqlPool {
33    pool: mysql_async::Pool,
34}
35
36pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
37    let pool_opts = mysql_async::PoolOpts::new()
38        .with_constraints(
39            mysql_async::PoolConstraints::new(options.pool_min_idle, options.pool_max_size)
40                .expect("Failed to create pool constraints"),
41        )
42        .with_inactive_connection_ttl(options.pool_idle_timeout)
43        .with_ttl_check_interval(options.pool_ttl_check_interval);
44    let opts_builder = mysql_async::OptsBuilder::default()
45        .ip_or_hostname(options.host.clone())
46        .tcp_port(options.port)
47        .user(Some(options.username.clone()))
48        .pass(Some(options.password.clone()))
49        .db_name(options.database.clone())
50        .init(vec!["set time_zone='+00:00'".to_string()])
51        .pool_opts(pool_opts);
52    let pool = mysql_async::Pool::new(opts_builder);
53    Ok(MysqlPool { pool })
54}
55
56#[async_trait::async_trait]
57impl Pool for MysqlPool {
58    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
59        let conn = self.pool.get_conn().await.map_err(|e| {
60            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
61        })?;
62        Ok(Arc::new(MysqlConnection {
63            conn: Arc::new(Mutex::new(conn)),
64        }))
65    }
66
67    async fn state(&self) -> DFResult<PoolState> {
68        use std::sync::atomic::Ordering;
69        let metrics = self.pool.metrics();
70        Ok(PoolState {
71            connections: metrics.connection_count.load(Ordering::SeqCst),
72            idle_connections: metrics.connections_in_pool.load(Ordering::SeqCst),
73        })
74    }
75}
76
77#[derive(Debug)]
78pub struct MysqlConnection {
79    conn: Arc<Mutex<mysql_async::Conn>>,
80}
81
82#[async_trait::async_trait]
83impl Connection for MysqlConnection {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
89        let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
90        let mut conn = self.conn.lock().await;
91        let conn = &mut *conn;
92        let stmt = conn.prep(&sql).await.map_err(|e| {
93            DataFusionError::Plan(format!("Failed to prepare query {sql} on mysql: {e:?}"))
94        })?;
95        let remote_schema = Arc::new(build_remote_schema(&stmt)?);
96        Ok(remote_schema)
97    }
98
99    async fn query(
100        &self,
101        conn_options: &ConnectionOptions,
102        source: &RemoteSource,
103        table_schema: SchemaRef,
104        projection: Option<&Vec<usize>>,
105        unparsed_filters: &[String],
106        limit: Option<usize>,
107    ) -> DFResult<SendableRecordBatchStream> {
108        let projected_schema = project_schema(&table_schema, projection)?;
109
110        let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
111        debug!("[remote-table] executing mysql query: {sql}");
112
113        let projection = projection.cloned();
114        let chunk_size = conn_options.stream_chunk_size();
115        let conn = Arc::clone(&self.conn);
116        let stream = Box::pin(stream! {
117            let mut conn = conn.lock().await;
118            let mut query_iter = conn
119                .query_iter(sql.clone())
120                .await
121                .map_err(|e| {
122                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
123                })?;
124
125            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
126                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
127                })? else {
128                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
129                return;
130            };
131
132            let mut chunked_stream = stream.chunks(chunk_size).boxed();
133
134            while let Some(chunk) = chunked_stream.next().await {
135                let rows = chunk
136                    .into_iter()
137                    .collect::<Result<Vec<_>, _>>()
138                    .map_err(|e| {
139                        DataFusionError::Execution(format!(
140                            "Failed to collect rows from mysql due to {e}",
141                        ))
142                    })?;
143
144                yield Ok::<_, DataFusionError>(rows)
145            }
146        });
147
148        let stream = stream.map(move |rows| {
149            let rows = rows?;
150            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
151        });
152
153        Ok(Box::pin(RecordBatchStreamAdapter::new(
154            projected_schema,
155            stream,
156        )))
157    }
158
159    async fn insert(
160        &self,
161        _conn_options: &ConnectionOptions,
162        _literalizer: Arc<dyn Literalize>,
163        _table: &[String],
164        _remote_schema: RemoteSchemaRef,
165        _batch: RecordBatch,
166    ) -> DFResult<usize> {
167        Err(DataFusionError::Execution(
168            "Insert operation is not supported for mysql".to_string(),
169        ))
170    }
171}
172
173fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
174    let character_set = mysql_col.character_set();
175    let is_utf8_bin_character_set = character_set == 45;
176    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
177    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
178    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
179    let col_length = mysql_col.column_length();
180    match mysql_col.column_type() {
181        ColumnType::MYSQL_TYPE_TINY => {
182            if is_unsigned {
183                Ok(MysqlType::TinyIntUnsigned)
184            } else {
185                Ok(MysqlType::TinyInt)
186            }
187        }
188        ColumnType::MYSQL_TYPE_SHORT => {
189            if is_unsigned {
190                Ok(MysqlType::SmallIntUnsigned)
191            } else {
192                Ok(MysqlType::SmallInt)
193            }
194        }
195        ColumnType::MYSQL_TYPE_INT24 => {
196            if is_unsigned {
197                Ok(MysqlType::MediumIntUnsigned)
198            } else {
199                Ok(MysqlType::MediumInt)
200            }
201        }
202        ColumnType::MYSQL_TYPE_LONG => {
203            if is_unsigned {
204                Ok(MysqlType::IntegerUnsigned)
205            } else {
206                Ok(MysqlType::Integer)
207            }
208        }
209        ColumnType::MYSQL_TYPE_LONGLONG => {
210            if is_unsigned {
211                Ok(MysqlType::BigIntUnsigned)
212            } else {
213                Ok(MysqlType::BigInt)
214            }
215        }
216        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
217        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
218        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
219            let precision = (mysql_col.column_length() - 2) as u8;
220            let scale = mysql_col.decimals();
221            Ok(MysqlType::Decimal(precision, scale))
222        }
223        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
224        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
225        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
226        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
227        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
228        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
229        ColumnType::MYSQL_TYPE_STRING if is_binary => {
230            if is_utf8_bin_character_set {
231                Ok(MysqlType::Char)
232            } else {
233                Ok(MysqlType::Binary)
234            }
235        }
236        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
237        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
238            if is_utf8_bin_character_set {
239                Ok(MysqlType::Varchar)
240            } else {
241                Ok(MysqlType::Varbinary)
242            }
243        }
244        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
245        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
246        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
247            if is_utf8_bin_character_set {
248                Ok(MysqlType::Text(col_length))
249            } else {
250                Ok(MysqlType::Blob(col_length))
251            }
252        }
253        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
254        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
255        _ => Err(DataFusionError::NotImplemented(format!(
256            "Unsupported mysql type: {mysql_col:?}",
257        ))),
258    }
259}
260
261fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
262    let mut remote_fields = vec![];
263    for col in stmt.columns() {
264        remote_fields.push(RemoteField::new(
265            col.name_str().to_string(),
266            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
267            true,
268        ));
269    }
270    Ok(RemoteSchema::new(remote_fields))
271}
272
273macro_rules! handle_primitive_type {
274    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
275        let builder = $builder
276            .as_any_mut()
277            .downcast_mut::<$builder_ty>()
278            .unwrap_or_else(|| {
279                panic!(
280                    "Failed to downcast builder to {} for {:?} and {:?}",
281                    stringify!($builder_ty),
282                    $field,
283                    $col
284                )
285            });
286        let v = $row.get_opt::<$value_ty, usize>($index);
287
288        match v {
289            None => builder.append_null(),
290            Some(Ok(v)) => builder.append_value($convert(v)?),
291            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
292            Some(Err(e)) => {
293                return Err(DataFusionError::Execution(format!(
294                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
295                    stringify!($value_ty),
296                    $field,
297                    $col,
298                )));
299            }
300        }
301    }};
302}
303
304fn rows_to_batch(
305    rows: &[Row],
306    table_schema: &SchemaRef,
307    projection: Option<&Vec<usize>>,
308) -> DFResult<RecordBatch> {
309    let projected_schema = project_schema(table_schema, projection)?;
310    let mut array_builders = vec![];
311    for field in table_schema.fields() {
312        let builder = make_builder(field.data_type(), rows.len());
313        array_builders.push(builder);
314    }
315
316    for row in rows {
317        for (idx, field) in table_schema.fields.iter().enumerate() {
318            if !projections_contains(projection, idx) {
319                continue;
320            }
321            let builder = &mut array_builders[idx];
322            let col = row.columns_ref().get(idx);
323            match field.data_type() {
324                DataType::Int8 => {
325                    handle_primitive_type!(
326                        builder,
327                        field,
328                        col,
329                        Int8Builder,
330                        i8,
331                        row,
332                        idx,
333                        just_return
334                    );
335                }
336                DataType::Int16 => {
337                    handle_primitive_type!(
338                        builder,
339                        field,
340                        col,
341                        Int16Builder,
342                        i16,
343                        row,
344                        idx,
345                        just_return
346                    );
347                }
348                DataType::Int32 => {
349                    handle_primitive_type!(
350                        builder,
351                        field,
352                        col,
353                        Int32Builder,
354                        i32,
355                        row,
356                        idx,
357                        just_return
358                    );
359                }
360                DataType::Int64 => {
361                    handle_primitive_type!(
362                        builder,
363                        field,
364                        col,
365                        Int64Builder,
366                        i64,
367                        row,
368                        idx,
369                        just_return
370                    );
371                }
372                DataType::UInt8 => {
373                    handle_primitive_type!(
374                        builder,
375                        field,
376                        col,
377                        UInt8Builder,
378                        u8,
379                        row,
380                        idx,
381                        just_return
382                    );
383                }
384                DataType::UInt16 => {
385                    handle_primitive_type!(
386                        builder,
387                        field,
388                        col,
389                        UInt16Builder,
390                        u16,
391                        row,
392                        idx,
393                        just_return
394                    );
395                }
396                DataType::UInt32 => {
397                    handle_primitive_type!(
398                        builder,
399                        field,
400                        col,
401                        UInt32Builder,
402                        u32,
403                        row,
404                        idx,
405                        just_return
406                    );
407                }
408                DataType::UInt64 => {
409                    handle_primitive_type!(
410                        builder,
411                        field,
412                        col,
413                        UInt64Builder,
414                        u64,
415                        row,
416                        idx,
417                        just_return
418                    );
419                }
420                DataType::Float32 => {
421                    handle_primitive_type!(
422                        builder,
423                        field,
424                        col,
425                        Float32Builder,
426                        f32,
427                        row,
428                        idx,
429                        just_return
430                    );
431                }
432                DataType::Float64 => {
433                    handle_primitive_type!(
434                        builder,
435                        field,
436                        col,
437                        Float64Builder,
438                        f64,
439                        row,
440                        idx,
441                        just_return
442                    );
443                }
444                DataType::Decimal128(_precision, scale) => {
445                    handle_primitive_type!(
446                        builder,
447                        field,
448                        col,
449                        Decimal128Builder,
450                        BigDecimal,
451                        row,
452                        idx,
453                        |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
454                    );
455                }
456                DataType::Decimal256(_precision, _scale) => {
457                    handle_primitive_type!(
458                        builder,
459                        field,
460                        col,
461                        Decimal256Builder,
462                        BigDecimal,
463                        row,
464                        idx,
465                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
466                    );
467                }
468                DataType::Date32 => {
469                    handle_primitive_type!(
470                        builder,
471                        field,
472                        col,
473                        Date32Builder,
474                        chrono::NaiveDate,
475                        row,
476                        idx,
477                        |v: chrono::NaiveDate| {
478                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
479                        }
480                    );
481                }
482                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
483                    match tz_opt {
484                        None => {}
485                        Some(tz) => {
486                            if !tz.eq_ignore_ascii_case("utc") {
487                                return Err(DataFusionError::NotImplemented(format!(
488                                    "Unsupported data type {:?} for col: {:?}",
489                                    field.data_type(),
490                                    col
491                                )));
492                            }
493                        }
494                    }
495                    handle_primitive_type!(
496                        builder,
497                        field,
498                        col,
499                        TimestampMicrosecondBuilder,
500                        time::PrimitiveDateTime,
501                        row,
502                        idx,
503                        |v: time::PrimitiveDateTime| {
504                            let timestamp_micros =
505                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
506                            Ok::<_, DataFusionError>(timestamp_micros)
507                        }
508                    );
509                }
510                DataType::Time32(TimeUnit::Second) => {
511                    handle_primitive_type!(
512                        builder,
513                        field,
514                        col,
515                        Time32SecondBuilder,
516                        chrono::NaiveTime,
517                        row,
518                        idx,
519                        |v: chrono::NaiveTime| {
520                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
521                        }
522                    );
523                }
524                DataType::Time64(TimeUnit::Nanosecond) => {
525                    handle_primitive_type!(
526                        builder,
527                        field,
528                        col,
529                        Time64NanosecondBuilder,
530                        chrono::NaiveTime,
531                        row,
532                        idx,
533                        |v: chrono::NaiveTime| {
534                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
535                                + i64::from(v.nanosecond());
536                            Ok::<_, DataFusionError>(t)
537                        }
538                    );
539                }
540                DataType::Utf8 => {
541                    handle_primitive_type!(
542                        builder,
543                        field,
544                        col,
545                        StringBuilder,
546                        String,
547                        row,
548                        idx,
549                        just_return
550                    );
551                }
552                DataType::LargeUtf8 => {
553                    handle_primitive_type!(
554                        builder,
555                        field,
556                        col,
557                        LargeStringBuilder,
558                        String,
559                        row,
560                        idx,
561                        just_return
562                    );
563                }
564                DataType::Utf8View => {
565                    handle_primitive_type!(
566                        builder,
567                        field,
568                        col,
569                        StringViewBuilder,
570                        String,
571                        row,
572                        idx,
573                        just_return
574                    );
575                }
576                DataType::Binary => {
577                    handle_primitive_type!(
578                        builder,
579                        field,
580                        col,
581                        BinaryBuilder,
582                        Vec<u8>,
583                        row,
584                        idx,
585                        just_return
586                    );
587                }
588                DataType::LargeBinary => {
589                    handle_primitive_type!(
590                        builder,
591                        field,
592                        col,
593                        LargeBinaryBuilder,
594                        Vec<u8>,
595                        row,
596                        idx,
597                        just_return
598                    );
599                }
600                DataType::BinaryView => {
601                    handle_primitive_type!(
602                        builder,
603                        field,
604                        col,
605                        BinaryViewBuilder,
606                        Vec<u8>,
607                        row,
608                        idx,
609                        just_return
610                    );
611                }
612                _ => {
613                    return Err(DataFusionError::NotImplemented(format!(
614                        "Unsupported data type {:?} for col: {:?}",
615                        field.data_type(),
616                        col
617                    )));
618                }
619            }
620        }
621    }
622    let projected_columns = array_builders
623        .into_iter()
624        .enumerate()
625        .filter(|(idx, _)| projections_contains(projection, *idx))
626        .map(|(_, mut builder)| builder.finish())
627        .collect::<Vec<ArrayRef>>();
628    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
629    Ok(RecordBatch::try_new_with_options(
630        projected_schema,
631        projected_columns,
632        &options,
633    )?)
634}
635
636fn to_decimal_256(decimal: &BigDecimal) -> i256 {
637    let (bigint_value, _) = decimal.as_bigint_and_exponent();
638    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
639
640    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
641    let fill_byte = if is_negative { 0xFF } else { 0x00 };
642
643    if bigint_bytes.len() > 32 {
644        bigint_bytes.truncate(32);
645    } else {
646        bigint_bytes.resize(32, fill_byte);
647    };
648
649    let mut array = [0u8; 32];
650    array.copy_from_slice(&bigint_bytes);
651
652    i256::from_le_bytes(array)
653}