datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13    Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33    pub(crate) host: String,
34    pub(crate) port: u16,
35    pub(crate) username: String,
36    pub(crate) password: String,
37    pub(crate) database: Option<String>,
38    pub(crate) pool_max_size: usize,
39    pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43    pub fn new(
44        host: impl Into<String>,
45        port: u16,
46        username: impl Into<String>,
47        password: impl Into<String>,
48    ) -> Self {
49        Self {
50            host: host.into(),
51            port,
52            username: username.into(),
53            password: password.into(),
54            database: None,
55            pool_max_size: 10,
56            stream_chunk_size: 2048,
57        }
58    }
59}
60
61impl From<MysqlConnectionOptions> for ConnectionOptions {
62    fn from(options: MysqlConnectionOptions) -> Self {
63        ConnectionOptions::Mysql(options)
64    }
65}
66
67#[derive(Debug)]
68pub struct MysqlPool {
69    pool: mysql_async::Pool,
70}
71
72pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
73    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
74        mysql_async::PoolConstraints::new(0, options.pool_max_size)
75            .expect("Failed to create pool constraints"),
76    );
77    let opts_builder = mysql_async::OptsBuilder::default()
78        .ip_or_hostname(options.host.clone())
79        .tcp_port(options.port)
80        .user(Some(options.username.clone()))
81        .pass(Some(options.password.clone()))
82        .db_name(options.database.clone())
83        .init(vec!["set time_zone='+00:00'".to_string()])
84        .pool_opts(pool_opts);
85    let pool = mysql_async::Pool::new(opts_builder);
86    Ok(MysqlPool { pool })
87}
88
89#[async_trait::async_trait]
90impl Pool for MysqlPool {
91    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
92        let conn = self.pool.get_conn().await.map_err(|e| {
93            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
94        })?;
95        Ok(Arc::new(MysqlConnection {
96            conn: Arc::new(Mutex::new(conn)),
97        }))
98    }
99}
100
101#[derive(Debug)]
102pub struct MysqlConnection {
103    conn: Arc<Mutex<mysql_async::Conn>>,
104}
105
106#[async_trait::async_trait]
107impl Connection for MysqlConnection {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111
112    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
113        let sql = RemoteDbType::Mysql.query_limit_1(sql);
114        let mut conn = self.conn.lock().await;
115        let conn = &mut *conn;
116        let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
117            DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
118        })?;
119        let Some(row) = row else {
120            return Err(DataFusionError::Execution(
121                "No rows returned to infer schema".to_string(),
122            ));
123        };
124        let remote_schema = Arc::new(build_remote_schema(&row)?);
125        Ok(remote_schema)
126    }
127
128    async fn query(
129        &self,
130        conn_options: &ConnectionOptions,
131        sql: &str,
132        table_schema: SchemaRef,
133        projection: Option<&Vec<usize>>,
134        unparsed_filters: &[String],
135        limit: Option<usize>,
136    ) -> DFResult<SendableRecordBatchStream> {
137        let projected_schema = project_schema(&table_schema, projection)?;
138
139        let sql = RemoteDbType::Mysql.rewrite_query(sql, unparsed_filters, limit);
140        debug!("[remote-table] executing mysql query: {sql}");
141
142        let projection = projection.cloned();
143        let chunk_size = conn_options.stream_chunk_size();
144        let conn = Arc::clone(&self.conn);
145        let stream = Box::pin(stream! {
146            let mut conn = conn.lock().await;
147            let mut query_iter = conn
148                .query_iter(sql.clone())
149                .await
150                .map_err(|e| {
151                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
152                })?;
153
154            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
155                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
156                })? else {
157                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
158                return;
159            };
160
161            let mut chunked_stream = stream.chunks(chunk_size).boxed();
162
163            while let Some(chunk) = chunked_stream.next().await {
164                let rows = chunk
165                    .into_iter()
166                    .collect::<Result<Vec<_>, _>>()
167                    .map_err(|e| {
168                        DataFusionError::Execution(format!(
169                            "Failed to collect rows from mysql due to {e}",
170                        ))
171                    })?;
172
173                yield Ok::<_, DataFusionError>(rows)
174            }
175        });
176
177        let stream = stream.map(move |rows| {
178            let rows = rows?;
179            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
180        });
181
182        Ok(Box::pin(RecordBatchStreamAdapter::new(
183            projected_schema,
184            stream,
185        )))
186    }
187}
188
189fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
190    let character_set = mysql_col.character_set();
191    let is_utf8_bin_character_set = character_set == 45;
192    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
193    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
194    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
195    let col_length = mysql_col.column_length();
196    match mysql_col.column_type() {
197        ColumnType::MYSQL_TYPE_TINY => {
198            if is_unsigned {
199                Ok(MysqlType::TinyIntUnsigned)
200            } else {
201                Ok(MysqlType::TinyInt)
202            }
203        }
204        ColumnType::MYSQL_TYPE_SHORT => {
205            if is_unsigned {
206                Ok(MysqlType::SmallIntUnsigned)
207            } else {
208                Ok(MysqlType::SmallInt)
209            }
210        }
211        ColumnType::MYSQL_TYPE_INT24 => {
212            if is_unsigned {
213                Ok(MysqlType::MediumIntUnsigned)
214            } else {
215                Ok(MysqlType::MediumInt)
216            }
217        }
218        ColumnType::MYSQL_TYPE_LONG => {
219            if is_unsigned {
220                Ok(MysqlType::IntegerUnsigned)
221            } else {
222                Ok(MysqlType::Integer)
223            }
224        }
225        ColumnType::MYSQL_TYPE_LONGLONG => {
226            if is_unsigned {
227                Ok(MysqlType::BigIntUnsigned)
228            } else {
229                Ok(MysqlType::BigInt)
230            }
231        }
232        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
233        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
234        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
235            let precision = (mysql_col.column_length() - 2) as u8;
236            let scale = mysql_col.decimals();
237            Ok(MysqlType::Decimal(precision, scale))
238        }
239        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
240        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
241        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
242        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
243        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
244        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
245        ColumnType::MYSQL_TYPE_STRING if is_binary => {
246            if is_utf8_bin_character_set {
247                Ok(MysqlType::Char)
248            } else {
249                Ok(MysqlType::Binary)
250            }
251        }
252        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
253        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
254            if is_utf8_bin_character_set {
255                Ok(MysqlType::Varchar)
256            } else {
257                Ok(MysqlType::Varbinary)
258            }
259        }
260        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
261        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
262        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
263            if is_utf8_bin_character_set {
264                Ok(MysqlType::Text(col_length))
265            } else {
266                Ok(MysqlType::Blob(col_length))
267            }
268        }
269        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
270        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
271        _ => Err(DataFusionError::NotImplemented(format!(
272            "Unsupported mysql type: {mysql_col:?}",
273        ))),
274    }
275}
276
277fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
278    let mut remote_fields = vec![];
279    for col in row.columns_ref() {
280        remote_fields.push(RemoteField::new(
281            col.name_str().to_string(),
282            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
283            true,
284        ));
285    }
286    Ok(RemoteSchema::new(remote_fields))
287}
288
289macro_rules! handle_primitive_type {
290    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
291        let builder = $builder
292            .as_any_mut()
293            .downcast_mut::<$builder_ty>()
294            .unwrap_or_else(|| {
295                panic!(
296                    "Failed to downcast builder to {} for {:?} and {:?}",
297                    stringify!($builder_ty),
298                    $field,
299                    $col
300                )
301            });
302        let v = $row.get_opt::<$value_ty, usize>($index);
303
304        match v {
305            None => builder.append_null(),
306            Some(Ok(v)) => builder.append_value($convert(v)?),
307            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
308            Some(Err(e)) => {
309                return Err(DataFusionError::Execution(format!(
310                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
311                    stringify!($value_ty),
312                    $field,
313                    $col,
314                )));
315            }
316        }
317    }};
318}
319
320fn rows_to_batch(
321    rows: &[Row],
322    table_schema: &SchemaRef,
323    projection: Option<&Vec<usize>>,
324) -> DFResult<RecordBatch> {
325    let projected_schema = project_schema(table_schema, projection)?;
326    let mut array_builders = vec![];
327    for field in table_schema.fields() {
328        let builder = make_builder(field.data_type(), rows.len());
329        array_builders.push(builder);
330    }
331
332    for row in rows {
333        for (idx, field) in table_schema.fields.iter().enumerate() {
334            if !projections_contains(projection, idx) {
335                continue;
336            }
337            let builder = &mut array_builders[idx];
338            let col = row.columns_ref().get(idx);
339            match field.data_type() {
340                DataType::Int8 => {
341                    handle_primitive_type!(
342                        builder,
343                        field,
344                        col,
345                        Int8Builder,
346                        i8,
347                        row,
348                        idx,
349                        just_return
350                    );
351                }
352                DataType::Int16 => {
353                    handle_primitive_type!(
354                        builder,
355                        field,
356                        col,
357                        Int16Builder,
358                        i16,
359                        row,
360                        idx,
361                        just_return
362                    );
363                }
364                DataType::Int32 => {
365                    handle_primitive_type!(
366                        builder,
367                        field,
368                        col,
369                        Int32Builder,
370                        i32,
371                        row,
372                        idx,
373                        just_return
374                    );
375                }
376                DataType::Int64 => {
377                    handle_primitive_type!(
378                        builder,
379                        field,
380                        col,
381                        Int64Builder,
382                        i64,
383                        row,
384                        idx,
385                        just_return
386                    );
387                }
388                DataType::UInt8 => {
389                    handle_primitive_type!(
390                        builder,
391                        field,
392                        col,
393                        UInt8Builder,
394                        u8,
395                        row,
396                        idx,
397                        just_return
398                    );
399                }
400                DataType::UInt16 => {
401                    handle_primitive_type!(
402                        builder,
403                        field,
404                        col,
405                        UInt16Builder,
406                        u16,
407                        row,
408                        idx,
409                        just_return
410                    );
411                }
412                DataType::UInt32 => {
413                    handle_primitive_type!(
414                        builder,
415                        field,
416                        col,
417                        UInt32Builder,
418                        u32,
419                        row,
420                        idx,
421                        just_return
422                    );
423                }
424                DataType::UInt64 => {
425                    handle_primitive_type!(
426                        builder,
427                        field,
428                        col,
429                        UInt64Builder,
430                        u64,
431                        row,
432                        idx,
433                        just_return
434                    );
435                }
436                DataType::Float32 => {
437                    handle_primitive_type!(
438                        builder,
439                        field,
440                        col,
441                        Float32Builder,
442                        f32,
443                        row,
444                        idx,
445                        just_return
446                    );
447                }
448                DataType::Float64 => {
449                    handle_primitive_type!(
450                        builder,
451                        field,
452                        col,
453                        Float64Builder,
454                        f64,
455                        row,
456                        idx,
457                        just_return
458                    );
459                }
460                DataType::Decimal128(_precision, scale) => {
461                    handle_primitive_type!(
462                        builder,
463                        field,
464                        col,
465                        Decimal128Builder,
466                        BigDecimal,
467                        row,
468                        idx,
469                        |v: BigDecimal| {
470                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
471                                DataFusionError::Execution(format!(
472                                    "Failed to convert BigDecimal {v:?} to i128"
473                                ))
474                            })
475                        }
476                    );
477                }
478                DataType::Decimal256(_precision, _scale) => {
479                    handle_primitive_type!(
480                        builder,
481                        field,
482                        col,
483                        Decimal256Builder,
484                        BigDecimal,
485                        row,
486                        idx,
487                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
488                    );
489                }
490                DataType::Date32 => {
491                    handle_primitive_type!(
492                        builder,
493                        field,
494                        col,
495                        Date32Builder,
496                        chrono::NaiveDate,
497                        row,
498                        idx,
499                        |v: chrono::NaiveDate| {
500                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
501                        }
502                    );
503                }
504                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
505                    match tz_opt {
506                        None => {}
507                        Some(tz) => {
508                            if !tz.eq_ignore_ascii_case("utc") {
509                                return Err(DataFusionError::NotImplemented(format!(
510                                    "Unsupported data type {:?} for col: {:?}",
511                                    field.data_type(),
512                                    col
513                                )));
514                            }
515                        }
516                    }
517                    handle_primitive_type!(
518                        builder,
519                        field,
520                        col,
521                        TimestampMicrosecondBuilder,
522                        time::PrimitiveDateTime,
523                        row,
524                        idx,
525                        |v: time::PrimitiveDateTime| {
526                            let timestamp_micros =
527                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
528                            Ok::<_, DataFusionError>(timestamp_micros)
529                        }
530                    );
531                }
532                DataType::Time32(TimeUnit::Second) => {
533                    handle_primitive_type!(
534                        builder,
535                        field,
536                        col,
537                        Time32SecondBuilder,
538                        chrono::NaiveTime,
539                        row,
540                        idx,
541                        |v: chrono::NaiveTime| {
542                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
543                        }
544                    );
545                }
546                DataType::Time64(TimeUnit::Nanosecond) => {
547                    handle_primitive_type!(
548                        builder,
549                        field,
550                        col,
551                        Time64NanosecondBuilder,
552                        chrono::NaiveTime,
553                        row,
554                        idx,
555                        |v: chrono::NaiveTime| {
556                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
557                                + i64::from(v.nanosecond());
558                            Ok::<_, DataFusionError>(t)
559                        }
560                    );
561                }
562                DataType::Utf8 => {
563                    handle_primitive_type!(
564                        builder,
565                        field,
566                        col,
567                        StringBuilder,
568                        String,
569                        row,
570                        idx,
571                        just_return
572                    );
573                }
574                DataType::LargeUtf8 => {
575                    handle_primitive_type!(
576                        builder,
577                        field,
578                        col,
579                        LargeStringBuilder,
580                        String,
581                        row,
582                        idx,
583                        just_return
584                    );
585                }
586                DataType::Binary => {
587                    handle_primitive_type!(
588                        builder,
589                        field,
590                        col,
591                        BinaryBuilder,
592                        Vec<u8>,
593                        row,
594                        idx,
595                        just_return
596                    );
597                }
598                DataType::LargeBinary => {
599                    handle_primitive_type!(
600                        builder,
601                        field,
602                        col,
603                        LargeBinaryBuilder,
604                        Vec<u8>,
605                        row,
606                        idx,
607                        just_return
608                    );
609                }
610                _ => {
611                    return Err(DataFusionError::NotImplemented(format!(
612                        "Unsupported data type {:?} for col: {:?}",
613                        field.data_type(),
614                        col
615                    )));
616                }
617            }
618        }
619    }
620    let projected_columns = array_builders
621        .into_iter()
622        .enumerate()
623        .filter(|(idx, _)| projections_contains(projection, *idx))
624        .map(|(_, mut builder)| builder.finish())
625        .collect::<Vec<ArrayRef>>();
626    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
627    Ok(RecordBatch::try_new_with_options(
628        projected_schema,
629        projected_columns,
630        &options,
631    )?)
632}
633
634fn to_decimal_256(decimal: &BigDecimal) -> i256 {
635    let (bigint_value, _) = decimal.as_bigint_and_exponent();
636    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
637
638    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
639    let fill_byte = if is_negative { 0xFF } else { 0x00 };
640
641    if bigint_bytes.len() > 32 {
642        bigint_bytes.truncate(32);
643    } else {
644        bigint_bytes.resize(32, fill_byte);
645    };
646
647    let mut array = [0u8; 32];
648    array.copy_from_slice(&bigint_bytes);
649
650    i256::from_le_bytes(array)
651}