datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4    Connection, ConnectionOptions, DFResult, Literalize, MysqlConnectionOptions, MysqlType, Pool,
5    PoolState, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
6};
7use async_stream::stream;
8use bigdecimal::{BigDecimal, num_bigint};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
12    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
13    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
14    Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
15    UInt32Builder, UInt64Builder, make_builder,
16};
17use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
18use datafusion::common::{DataFusionError, project_schema};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::StreamExt;
22use futures::lock::Mutex;
23use log::debug;
24use mysql_async::consts::{ColumnFlags, ColumnType};
25use mysql_async::prelude::Queryable;
26use mysql_async::{Column, FromValueError, Row, Value};
27use std::any::Any;
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct MysqlPool {
32    pool: mysql_async::Pool,
33}
34
35pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
36    let pool_opts = mysql_async::PoolOpts::new()
37        .with_constraints(
38            mysql_async::PoolConstraints::new(options.pool_min_idle, options.pool_max_size)
39                .expect("Failed to create pool constraints"),
40        )
41        .with_inactive_connection_ttl(options.pool_idle_timeout)
42        .with_ttl_check_interval(options.pool_ttl_check_interval);
43    let opts_builder = mysql_async::OptsBuilder::default()
44        .ip_or_hostname(options.host.clone())
45        .tcp_port(options.port)
46        .user(Some(options.username.clone()))
47        .pass(Some(options.password.clone()))
48        .db_name(options.database.clone())
49        .init(vec!["set time_zone='+00:00'".to_string()])
50        .pool_opts(pool_opts);
51    let pool = mysql_async::Pool::new(opts_builder);
52    Ok(MysqlPool { pool })
53}
54
55#[async_trait::async_trait]
56impl Pool for MysqlPool {
57    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
58        let conn = self.pool.get_conn().await.map_err(|e| {
59            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
60        })?;
61        Ok(Arc::new(MysqlConnection {
62            conn: Arc::new(Mutex::new(conn)),
63        }))
64    }
65
66    async fn state(&self) -> DFResult<PoolState> {
67        use std::sync::atomic::Ordering;
68        let metrics = self.pool.metrics();
69        Ok(PoolState {
70            connections: metrics.connection_count.load(Ordering::SeqCst),
71            idle_connections: metrics.connections_in_pool.load(Ordering::SeqCst),
72        })
73    }
74}
75
76#[derive(Debug)]
77pub struct MysqlConnection {
78    conn: Arc<Mutex<mysql_async::Conn>>,
79}
80
81#[async_trait::async_trait]
82impl Connection for MysqlConnection {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
88        let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
89        let mut conn = self.conn.lock().await;
90        let conn = &mut *conn;
91        let stmt = conn.prep(&sql).await.map_err(|e| {
92            DataFusionError::Plan(format!("Failed to prepare query {sql} on mysql: {e:?}"))
93        })?;
94        let remote_schema = Arc::new(build_remote_schema(&stmt)?);
95        Ok(remote_schema)
96    }
97
98    async fn query(
99        &self,
100        conn_options: &ConnectionOptions,
101        source: &RemoteSource,
102        table_schema: SchemaRef,
103        projection: Option<&Vec<usize>>,
104        unparsed_filters: &[String],
105        limit: Option<usize>,
106    ) -> DFResult<SendableRecordBatchStream> {
107        let projected_schema = project_schema(&table_schema, projection)?;
108
109        let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
110        debug!("[remote-table] executing mysql query: {sql}");
111
112        let projection = projection.cloned();
113        let chunk_size = conn_options.stream_chunk_size();
114        let conn = Arc::clone(&self.conn);
115        let stream = Box::pin(stream! {
116            let mut conn = conn.lock().await;
117            let mut query_iter = conn
118                .query_iter(sql.clone())
119                .await
120                .map_err(|e| {
121                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
122                })?;
123
124            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
125                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
126                })? else {
127                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
128                return;
129            };
130
131            let mut chunked_stream = stream.chunks(chunk_size).boxed();
132
133            while let Some(chunk) = chunked_stream.next().await {
134                let rows = chunk
135                    .into_iter()
136                    .collect::<Result<Vec<_>, _>>()
137                    .map_err(|e| {
138                        DataFusionError::Execution(format!(
139                            "Failed to collect rows from mysql due to {e}",
140                        ))
141                    })?;
142
143                yield Ok::<_, DataFusionError>(rows)
144            }
145        });
146
147        let stream = stream.map(move |rows| {
148            let rows = rows?;
149            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
150        });
151
152        Ok(Box::pin(RecordBatchStreamAdapter::new(
153            projected_schema,
154            stream,
155        )))
156    }
157
158    async fn insert(
159        &self,
160        _conn_options: &ConnectionOptions,
161        _literalizer: Arc<dyn Literalize>,
162        _table: &[String],
163        _remote_schema: RemoteSchemaRef,
164        _input: SendableRecordBatchStream,
165    ) -> DFResult<usize> {
166        Err(DataFusionError::Execution(
167            "Insert operation is not supported for mysql".to_string(),
168        ))
169    }
170}
171
172fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
173    let character_set = mysql_col.character_set();
174    let is_utf8_bin_character_set = character_set == 45;
175    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
176    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
177    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
178    let col_length = mysql_col.column_length();
179    match mysql_col.column_type() {
180        ColumnType::MYSQL_TYPE_TINY => {
181            if is_unsigned {
182                Ok(MysqlType::TinyIntUnsigned)
183            } else {
184                Ok(MysqlType::TinyInt)
185            }
186        }
187        ColumnType::MYSQL_TYPE_SHORT => {
188            if is_unsigned {
189                Ok(MysqlType::SmallIntUnsigned)
190            } else {
191                Ok(MysqlType::SmallInt)
192            }
193        }
194        ColumnType::MYSQL_TYPE_INT24 => {
195            if is_unsigned {
196                Ok(MysqlType::MediumIntUnsigned)
197            } else {
198                Ok(MysqlType::MediumInt)
199            }
200        }
201        ColumnType::MYSQL_TYPE_LONG => {
202            if is_unsigned {
203                Ok(MysqlType::IntegerUnsigned)
204            } else {
205                Ok(MysqlType::Integer)
206            }
207        }
208        ColumnType::MYSQL_TYPE_LONGLONG => {
209            if is_unsigned {
210                Ok(MysqlType::BigIntUnsigned)
211            } else {
212                Ok(MysqlType::BigInt)
213            }
214        }
215        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
216        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
217        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
218            let precision = (mysql_col.column_length() - 2) as u8;
219            let scale = mysql_col.decimals();
220            Ok(MysqlType::Decimal(precision, scale))
221        }
222        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
223        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
224        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
225        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
226        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
227        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
228        ColumnType::MYSQL_TYPE_STRING if is_binary => {
229            if is_utf8_bin_character_set {
230                Ok(MysqlType::Char)
231            } else {
232                Ok(MysqlType::Binary)
233            }
234        }
235        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
236        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
237            if is_utf8_bin_character_set {
238                Ok(MysqlType::Varchar)
239            } else {
240                Ok(MysqlType::Varbinary)
241            }
242        }
243        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
244        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
245        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
246            if is_utf8_bin_character_set {
247                Ok(MysqlType::Text(col_length))
248            } else {
249                Ok(MysqlType::Blob(col_length))
250            }
251        }
252        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
253        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
254        _ => Err(DataFusionError::NotImplemented(format!(
255            "Unsupported mysql type: {mysql_col:?}",
256        ))),
257    }
258}
259
260fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
261    let mut remote_fields = vec![];
262    for col in stmt.columns() {
263        remote_fields.push(RemoteField::new(
264            col.name_str().to_string(),
265            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
266            true,
267        ));
268    }
269    Ok(RemoteSchema::new(remote_fields))
270}
271
272macro_rules! handle_primitive_type {
273    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
274        let builder = $builder
275            .as_any_mut()
276            .downcast_mut::<$builder_ty>()
277            .unwrap_or_else(|| {
278                panic!(
279                    "Failed to downcast builder to {} for {:?} and {:?}",
280                    stringify!($builder_ty),
281                    $field,
282                    $col
283                )
284            });
285        let v = $row.get_opt::<$value_ty, usize>($index);
286
287        match v {
288            None => builder.append_null(),
289            Some(Ok(v)) => builder.append_value($convert(v)?),
290            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
291            Some(Err(e)) => {
292                return Err(DataFusionError::Execution(format!(
293                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
294                    stringify!($value_ty),
295                    $field,
296                    $col,
297                )));
298            }
299        }
300    }};
301}
302
303fn rows_to_batch(
304    rows: &[Row],
305    table_schema: &SchemaRef,
306    projection: Option<&Vec<usize>>,
307) -> DFResult<RecordBatch> {
308    let projected_schema = project_schema(table_schema, projection)?;
309    let mut array_builders = vec![];
310    for field in table_schema.fields() {
311        let builder = make_builder(field.data_type(), rows.len());
312        array_builders.push(builder);
313    }
314
315    for row in rows {
316        for (idx, field) in table_schema.fields.iter().enumerate() {
317            if !projections_contains(projection, idx) {
318                continue;
319            }
320            let builder = &mut array_builders[idx];
321            let col = row.columns_ref().get(idx);
322            match field.data_type() {
323                DataType::Int8 => {
324                    handle_primitive_type!(
325                        builder,
326                        field,
327                        col,
328                        Int8Builder,
329                        i8,
330                        row,
331                        idx,
332                        just_return
333                    );
334                }
335                DataType::Int16 => {
336                    handle_primitive_type!(
337                        builder,
338                        field,
339                        col,
340                        Int16Builder,
341                        i16,
342                        row,
343                        idx,
344                        just_return
345                    );
346                }
347                DataType::Int32 => {
348                    handle_primitive_type!(
349                        builder,
350                        field,
351                        col,
352                        Int32Builder,
353                        i32,
354                        row,
355                        idx,
356                        just_return
357                    );
358                }
359                DataType::Int64 => {
360                    handle_primitive_type!(
361                        builder,
362                        field,
363                        col,
364                        Int64Builder,
365                        i64,
366                        row,
367                        idx,
368                        just_return
369                    );
370                }
371                DataType::UInt8 => {
372                    handle_primitive_type!(
373                        builder,
374                        field,
375                        col,
376                        UInt8Builder,
377                        u8,
378                        row,
379                        idx,
380                        just_return
381                    );
382                }
383                DataType::UInt16 => {
384                    handle_primitive_type!(
385                        builder,
386                        field,
387                        col,
388                        UInt16Builder,
389                        u16,
390                        row,
391                        idx,
392                        just_return
393                    );
394                }
395                DataType::UInt32 => {
396                    handle_primitive_type!(
397                        builder,
398                        field,
399                        col,
400                        UInt32Builder,
401                        u32,
402                        row,
403                        idx,
404                        just_return
405                    );
406                }
407                DataType::UInt64 => {
408                    handle_primitive_type!(
409                        builder,
410                        field,
411                        col,
412                        UInt64Builder,
413                        u64,
414                        row,
415                        idx,
416                        just_return
417                    );
418                }
419                DataType::Float32 => {
420                    handle_primitive_type!(
421                        builder,
422                        field,
423                        col,
424                        Float32Builder,
425                        f32,
426                        row,
427                        idx,
428                        just_return
429                    );
430                }
431                DataType::Float64 => {
432                    handle_primitive_type!(
433                        builder,
434                        field,
435                        col,
436                        Float64Builder,
437                        f64,
438                        row,
439                        idx,
440                        just_return
441                    );
442                }
443                DataType::Decimal128(_precision, scale) => {
444                    handle_primitive_type!(
445                        builder,
446                        field,
447                        col,
448                        Decimal128Builder,
449                        BigDecimal,
450                        row,
451                        idx,
452                        |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
453                    );
454                }
455                DataType::Decimal256(_precision, _scale) => {
456                    handle_primitive_type!(
457                        builder,
458                        field,
459                        col,
460                        Decimal256Builder,
461                        BigDecimal,
462                        row,
463                        idx,
464                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
465                    );
466                }
467                DataType::Date32 => {
468                    handle_primitive_type!(
469                        builder,
470                        field,
471                        col,
472                        Date32Builder,
473                        chrono::NaiveDate,
474                        row,
475                        idx,
476                        |v: chrono::NaiveDate| {
477                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
478                        }
479                    );
480                }
481                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
482                    match tz_opt {
483                        None => {}
484                        Some(tz) => {
485                            if !tz.eq_ignore_ascii_case("utc") {
486                                return Err(DataFusionError::NotImplemented(format!(
487                                    "Unsupported data type {:?} for col: {:?}",
488                                    field.data_type(),
489                                    col
490                                )));
491                            }
492                        }
493                    }
494                    handle_primitive_type!(
495                        builder,
496                        field,
497                        col,
498                        TimestampMicrosecondBuilder,
499                        time::PrimitiveDateTime,
500                        row,
501                        idx,
502                        |v: time::PrimitiveDateTime| {
503                            let timestamp_micros =
504                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
505                            Ok::<_, DataFusionError>(timestamp_micros)
506                        }
507                    );
508                }
509                DataType::Time32(TimeUnit::Second) => {
510                    handle_primitive_type!(
511                        builder,
512                        field,
513                        col,
514                        Time32SecondBuilder,
515                        chrono::NaiveTime,
516                        row,
517                        idx,
518                        |v: chrono::NaiveTime| {
519                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
520                        }
521                    );
522                }
523                DataType::Time64(TimeUnit::Nanosecond) => {
524                    handle_primitive_type!(
525                        builder,
526                        field,
527                        col,
528                        Time64NanosecondBuilder,
529                        chrono::NaiveTime,
530                        row,
531                        idx,
532                        |v: chrono::NaiveTime| {
533                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
534                                + i64::from(v.nanosecond());
535                            Ok::<_, DataFusionError>(t)
536                        }
537                    );
538                }
539                DataType::Utf8 => {
540                    handle_primitive_type!(
541                        builder,
542                        field,
543                        col,
544                        StringBuilder,
545                        String,
546                        row,
547                        idx,
548                        just_return
549                    );
550                }
551                DataType::LargeUtf8 => {
552                    handle_primitive_type!(
553                        builder,
554                        field,
555                        col,
556                        LargeStringBuilder,
557                        String,
558                        row,
559                        idx,
560                        just_return
561                    );
562                }
563                DataType::Binary => {
564                    handle_primitive_type!(
565                        builder,
566                        field,
567                        col,
568                        BinaryBuilder,
569                        Vec<u8>,
570                        row,
571                        idx,
572                        just_return
573                    );
574                }
575                DataType::LargeBinary => {
576                    handle_primitive_type!(
577                        builder,
578                        field,
579                        col,
580                        LargeBinaryBuilder,
581                        Vec<u8>,
582                        row,
583                        idx,
584                        just_return
585                    );
586                }
587                _ => {
588                    return Err(DataFusionError::NotImplemented(format!(
589                        "Unsupported data type {:?} for col: {:?}",
590                        field.data_type(),
591                        col
592                    )));
593                }
594            }
595        }
596    }
597    let projected_columns = array_builders
598        .into_iter()
599        .enumerate()
600        .filter(|(idx, _)| projections_contains(projection, *idx))
601        .map(|(_, mut builder)| builder.finish())
602        .collect::<Vec<ArrayRef>>();
603    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
604    Ok(RecordBatch::try_new_with_options(
605        projected_schema,
606        projected_columns,
607        &options,
608    )?)
609}
610
611fn to_decimal_256(decimal: &BigDecimal) -> i256 {
612    let (bigint_value, _) = decimal.as_bigint_and_exponent();
613    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
614
615    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
616    let fill_byte = if is_negative { 0xFF } else { 0x00 };
617
618    if bigint_bytes.len() > 32 {
619        bigint_bytes.truncate(32);
620    } else {
621        bigint_bytes.resize(32, fill_byte);
622    };
623
624    let mut array = [0u8; 32];
625    array.copy_from_slice(&bigint_bytes);
626
627    i256::from_le_bytes(array)
628}