datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13    Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::sync::Arc;
29
30#[derive(Debug, Clone, With, Getters)]
31pub struct MysqlConnectionOptions {
32    pub(crate) host: String,
33    pub(crate) port: u16,
34    pub(crate) username: String,
35    pub(crate) password: String,
36    pub(crate) database: Option<String>,
37    pub(crate) pool_max_size: usize,
38    pub(crate) stream_chunk_size: usize,
39}
40
41impl MysqlConnectionOptions {
42    pub fn new(
43        host: impl Into<String>,
44        port: u16,
45        username: impl Into<String>,
46        password: impl Into<String>,
47    ) -> Self {
48        Self {
49            host: host.into(),
50            port,
51            username: username.into(),
52            password: password.into(),
53            database: None,
54            pool_max_size: 10,
55            stream_chunk_size: 2048,
56        }
57    }
58}
59
60#[derive(Debug)]
61pub struct MysqlPool {
62    pool: mysql_async::Pool,
63}
64
65pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
66    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
67        mysql_async::PoolConstraints::new(0, options.pool_max_size)
68            .expect("Failed to create pool constraints"),
69    );
70    let opts_builder = mysql_async::OptsBuilder::default()
71        .ip_or_hostname(options.host.clone())
72        .tcp_port(options.port)
73        .user(Some(options.username.clone()))
74        .pass(Some(options.password.clone()))
75        .db_name(options.database.clone())
76        .init(vec!["set time_zone='+00:00'".to_string()])
77        .pool_opts(pool_opts);
78    let pool = mysql_async::Pool::new(opts_builder);
79    Ok(MysqlPool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for MysqlPool {
84    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85        let conn = self.pool.get_conn().await.map_err(|e| {
86            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
87        })?;
88        Ok(Arc::new(MysqlConnection {
89            conn: Arc::new(Mutex::new(conn)),
90        }))
91    }
92}
93
94#[derive(Debug)]
95pub struct MysqlConnection {
96    conn: Arc<Mutex<mysql_async::Conn>>,
97}
98
99#[async_trait::async_trait]
100impl Connection for MysqlConnection {
101    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
102        let sql = RemoteDbType::Mysql.query_limit_1(sql);
103        let mut conn = self.conn.lock().await;
104        let conn = &mut *conn;
105        let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
106            DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
107        })?;
108        let Some(row) = row else {
109            return Err(DataFusionError::Execution(
110                "No rows returned to infer schema".to_string(),
111            ));
112        };
113        let remote_schema = Arc::new(build_remote_schema(&row)?);
114        Ok(remote_schema)
115    }
116
117    async fn query(
118        &self,
119        conn_options: &ConnectionOptions,
120        sql: &str,
121        table_schema: SchemaRef,
122        projection: Option<&Vec<usize>>,
123        unparsed_filters: &[String],
124        limit: Option<usize>,
125    ) -> DFResult<SendableRecordBatchStream> {
126        let projected_schema = project_schema(&table_schema, projection)?;
127
128        let sql = RemoteDbType::Mysql.rewrite_query(sql, unparsed_filters, limit);
129        debug!("[remote-table] executing mysql query: {sql}");
130
131        let projection = projection.cloned();
132        let chunk_size = conn_options.stream_chunk_size();
133        let conn = Arc::clone(&self.conn);
134        let stream = Box::pin(stream! {
135            let mut conn = conn.lock().await;
136            let mut query_iter = conn
137                .query_iter(sql.clone())
138                .await
139                .map_err(|e| {
140                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
141                })?;
142
143            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
144                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
145                })? else {
146                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
147                return;
148            };
149
150            let mut chunked_stream = stream.chunks(chunk_size).boxed();
151
152            while let Some(chunk) = chunked_stream.next().await {
153                let rows = chunk
154                    .into_iter()
155                    .collect::<Result<Vec<_>, _>>()
156                    .map_err(|e| {
157                        DataFusionError::Execution(format!(
158                            "Failed to collect rows from mysql due to {e}",
159                        ))
160                    })?;
161
162                yield Ok::<_, DataFusionError>(rows)
163            }
164        });
165
166        let stream = stream.map(move |rows| {
167            let rows = rows?;
168            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
169        });
170
171        Ok(Box::pin(RecordBatchStreamAdapter::new(
172            projected_schema,
173            stream,
174        )))
175    }
176}
177
178fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
179    let character_set = mysql_col.character_set();
180    let is_utf8_bin_character_set = character_set == 45;
181    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
182    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
183    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
184    let col_length = mysql_col.column_length();
185    match mysql_col.column_type() {
186        ColumnType::MYSQL_TYPE_TINY => {
187            if is_unsigned {
188                Ok(MysqlType::TinyIntUnsigned)
189            } else {
190                Ok(MysqlType::TinyInt)
191            }
192        }
193        ColumnType::MYSQL_TYPE_SHORT => {
194            if is_unsigned {
195                Ok(MysqlType::SmallIntUnsigned)
196            } else {
197                Ok(MysqlType::SmallInt)
198            }
199        }
200        ColumnType::MYSQL_TYPE_INT24 => {
201            if is_unsigned {
202                Ok(MysqlType::MediumIntUnsigned)
203            } else {
204                Ok(MysqlType::MediumInt)
205            }
206        }
207        ColumnType::MYSQL_TYPE_LONG => {
208            if is_unsigned {
209                Ok(MysqlType::IntegerUnsigned)
210            } else {
211                Ok(MysqlType::Integer)
212            }
213        }
214        ColumnType::MYSQL_TYPE_LONGLONG => {
215            if is_unsigned {
216                Ok(MysqlType::BigIntUnsigned)
217            } else {
218                Ok(MysqlType::BigInt)
219            }
220        }
221        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
222        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
223        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
224            let precision = (mysql_col.column_length() - 2) as u8;
225            let scale = mysql_col.decimals();
226            Ok(MysqlType::Decimal(precision, scale))
227        }
228        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
229        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
230        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
231        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
232        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
233        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
234        ColumnType::MYSQL_TYPE_STRING if is_binary => {
235            if is_utf8_bin_character_set {
236                Ok(MysqlType::Char)
237            } else {
238                Ok(MysqlType::Binary)
239            }
240        }
241        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
242        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
243            if is_utf8_bin_character_set {
244                Ok(MysqlType::Varchar)
245            } else {
246                Ok(MysqlType::Varbinary)
247            }
248        }
249        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
250        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
251        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
252            if is_utf8_bin_character_set {
253                Ok(MysqlType::Text(col_length))
254            } else {
255                Ok(MysqlType::Blob(col_length))
256            }
257        }
258        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
259        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
260        _ => Err(DataFusionError::NotImplemented(format!(
261            "Unsupported mysql type: {mysql_col:?}",
262        ))),
263    }
264}
265
266fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
267    let mut remote_fields = vec![];
268    for col in row.columns_ref() {
269        remote_fields.push(RemoteField::new(
270            col.name_str().to_string(),
271            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
272            true,
273        ));
274    }
275    Ok(RemoteSchema::new(remote_fields))
276}
277
278macro_rules! handle_primitive_type {
279    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
280        let builder = $builder
281            .as_any_mut()
282            .downcast_mut::<$builder_ty>()
283            .unwrap_or_else(|| {
284                panic!(
285                    "Failed to downcast builder to {} for {:?} and {:?}",
286                    stringify!($builder_ty),
287                    $field,
288                    $col
289                )
290            });
291        let v = $row.get_opt::<$value_ty, usize>($index);
292
293        match v {
294            None => builder.append_null(),
295            Some(Ok(v)) => builder.append_value($convert(v)?),
296            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
297            Some(Err(e)) => {
298                return Err(DataFusionError::Execution(format!(
299                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
300                    stringify!($value_ty),
301                    $field,
302                    $col,
303                )));
304            }
305        }
306    }};
307}
308
309fn rows_to_batch(
310    rows: &[Row],
311    table_schema: &SchemaRef,
312    projection: Option<&Vec<usize>>,
313) -> DFResult<RecordBatch> {
314    let projected_schema = project_schema(table_schema, projection)?;
315    let mut array_builders = vec![];
316    for field in table_schema.fields() {
317        let builder = make_builder(field.data_type(), rows.len());
318        array_builders.push(builder);
319    }
320
321    for row in rows {
322        for (idx, field) in table_schema.fields.iter().enumerate() {
323            if !projections_contains(projection, idx) {
324                continue;
325            }
326            let builder = &mut array_builders[idx];
327            let col = row.columns_ref().get(idx);
328            match field.data_type() {
329                DataType::Int8 => {
330                    handle_primitive_type!(
331                        builder,
332                        field,
333                        col,
334                        Int8Builder,
335                        i8,
336                        row,
337                        idx,
338                        just_return
339                    );
340                }
341                DataType::Int16 => {
342                    handle_primitive_type!(
343                        builder,
344                        field,
345                        col,
346                        Int16Builder,
347                        i16,
348                        row,
349                        idx,
350                        just_return
351                    );
352                }
353                DataType::Int32 => {
354                    handle_primitive_type!(
355                        builder,
356                        field,
357                        col,
358                        Int32Builder,
359                        i32,
360                        row,
361                        idx,
362                        just_return
363                    );
364                }
365                DataType::Int64 => {
366                    handle_primitive_type!(
367                        builder,
368                        field,
369                        col,
370                        Int64Builder,
371                        i64,
372                        row,
373                        idx,
374                        just_return
375                    );
376                }
377                DataType::UInt8 => {
378                    handle_primitive_type!(
379                        builder,
380                        field,
381                        col,
382                        UInt8Builder,
383                        u8,
384                        row,
385                        idx,
386                        just_return
387                    );
388                }
389                DataType::UInt16 => {
390                    handle_primitive_type!(
391                        builder,
392                        field,
393                        col,
394                        UInt16Builder,
395                        u16,
396                        row,
397                        idx,
398                        just_return
399                    );
400                }
401                DataType::UInt32 => {
402                    handle_primitive_type!(
403                        builder,
404                        field,
405                        col,
406                        UInt32Builder,
407                        u32,
408                        row,
409                        idx,
410                        just_return
411                    );
412                }
413                DataType::UInt64 => {
414                    handle_primitive_type!(
415                        builder,
416                        field,
417                        col,
418                        UInt64Builder,
419                        u64,
420                        row,
421                        idx,
422                        just_return
423                    );
424                }
425                DataType::Float32 => {
426                    handle_primitive_type!(
427                        builder,
428                        field,
429                        col,
430                        Float32Builder,
431                        f32,
432                        row,
433                        idx,
434                        just_return
435                    );
436                }
437                DataType::Float64 => {
438                    handle_primitive_type!(
439                        builder,
440                        field,
441                        col,
442                        Float64Builder,
443                        f64,
444                        row,
445                        idx,
446                        just_return
447                    );
448                }
449                DataType::Decimal128(_precision, scale) => {
450                    handle_primitive_type!(
451                        builder,
452                        field,
453                        col,
454                        Decimal128Builder,
455                        BigDecimal,
456                        row,
457                        idx,
458                        |v: BigDecimal| {
459                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
460                                DataFusionError::Execution(format!(
461                                    "Failed to convert BigDecimal {v:?} to i128"
462                                ))
463                            })
464                        }
465                    );
466                }
467                DataType::Decimal256(_precision, _scale) => {
468                    handle_primitive_type!(
469                        builder,
470                        field,
471                        col,
472                        Decimal256Builder,
473                        BigDecimal,
474                        row,
475                        idx,
476                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
477                    );
478                }
479                DataType::Date32 => {
480                    handle_primitive_type!(
481                        builder,
482                        field,
483                        col,
484                        Date32Builder,
485                        chrono::NaiveDate,
486                        row,
487                        idx,
488                        |v: chrono::NaiveDate| {
489                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
490                        }
491                    );
492                }
493                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
494                    match tz_opt {
495                        None => {}
496                        Some(tz) => {
497                            if !tz.eq_ignore_ascii_case("utc") {
498                                return Err(DataFusionError::NotImplemented(format!(
499                                    "Unsupported data type {:?} for col: {:?}",
500                                    field.data_type(),
501                                    col
502                                )));
503                            }
504                        }
505                    }
506                    handle_primitive_type!(
507                        builder,
508                        field,
509                        col,
510                        TimestampMicrosecondBuilder,
511                        time::PrimitiveDateTime,
512                        row,
513                        idx,
514                        |v: time::PrimitiveDateTime| {
515                            let timestamp_micros =
516                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
517                            Ok::<_, DataFusionError>(timestamp_micros)
518                        }
519                    );
520                }
521                DataType::Time32(TimeUnit::Second) => {
522                    handle_primitive_type!(
523                        builder,
524                        field,
525                        col,
526                        Time32SecondBuilder,
527                        chrono::NaiveTime,
528                        row,
529                        idx,
530                        |v: chrono::NaiveTime| {
531                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
532                        }
533                    );
534                }
535                DataType::Time64(TimeUnit::Nanosecond) => {
536                    handle_primitive_type!(
537                        builder,
538                        field,
539                        col,
540                        Time64NanosecondBuilder,
541                        chrono::NaiveTime,
542                        row,
543                        idx,
544                        |v: chrono::NaiveTime| {
545                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
546                                + i64::from(v.nanosecond());
547                            Ok::<_, DataFusionError>(t)
548                        }
549                    );
550                }
551                DataType::Utf8 => {
552                    handle_primitive_type!(
553                        builder,
554                        field,
555                        col,
556                        StringBuilder,
557                        String,
558                        row,
559                        idx,
560                        just_return
561                    );
562                }
563                DataType::LargeUtf8 => {
564                    handle_primitive_type!(
565                        builder,
566                        field,
567                        col,
568                        LargeStringBuilder,
569                        String,
570                        row,
571                        idx,
572                        just_return
573                    );
574                }
575                DataType::Binary => {
576                    handle_primitive_type!(
577                        builder,
578                        field,
579                        col,
580                        BinaryBuilder,
581                        Vec<u8>,
582                        row,
583                        idx,
584                        just_return
585                    );
586                }
587                DataType::LargeBinary => {
588                    handle_primitive_type!(
589                        builder,
590                        field,
591                        col,
592                        LargeBinaryBuilder,
593                        Vec<u8>,
594                        row,
595                        idx,
596                        just_return
597                    );
598                }
599                _ => {
600                    return Err(DataFusionError::NotImplemented(format!(
601                        "Unsupported data type {:?} for col: {:?}",
602                        field.data_type(),
603                        col
604                    )));
605                }
606            }
607        }
608    }
609    let projected_columns = array_builders
610        .into_iter()
611        .enumerate()
612        .filter(|(idx, _)| projections_contains(projection, *idx))
613        .map(|(_, mut builder)| builder.finish())
614        .collect::<Vec<ArrayRef>>();
615    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
616    Ok(RecordBatch::try_new_with_options(
617        projected_schema,
618        projected_columns,
619        &options,
620    )?)
621}
622
623fn to_decimal_256(decimal: &BigDecimal) -> i256 {
624    let (bigint_value, _) = decimal.as_bigint_and_exponent();
625    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
626
627    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
628    let fill_byte = if is_negative { 0xFF } else { 0x00 };
629
630    if bigint_bytes.len() > 32 {
631        bigint_bytes.truncate(32);
632    } else {
633        bigint_bytes.resize(32, fill_byte);
634    };
635
636    let mut array = [0u8; 32];
637    array.copy_from_slice(&bigint_bytes);
638
639    i256::from_le_bytes(array)
640}