datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13    TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::limit::LimitStream;
20use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
21use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
22use derive_getters::Getters;
23use derive_with::With;
24use futures::StreamExt;
25use futures::lock::Mutex;
26use mysql_async::consts::{ColumnFlags, ColumnType};
27use mysql_async::prelude::Queryable;
28use mysql_async::{Column, FromValueError, Row, Value};
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33    pub(crate) host: String,
34    pub(crate) port: u16,
35    pub(crate) username: String,
36    pub(crate) password: String,
37    pub(crate) database: Option<String>,
38    pub(crate) pool_max_size: usize,
39    pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43    pub fn new(
44        host: impl Into<String>,
45        port: u16,
46        username: impl Into<String>,
47        password: impl Into<String>,
48    ) -> Self {
49        Self {
50            host: host.into(),
51            port,
52            username: username.into(),
53            password: password.into(),
54            database: None,
55            pool_max_size: 10,
56            stream_chunk_size: 2048,
57        }
58    }
59}
60
61#[derive(Debug)]
62pub struct MysqlPool {
63    pool: mysql_async::Pool,
64}
65
66pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
67    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
68        mysql_async::PoolConstraints::new(0, options.pool_max_size)
69            .expect("Failed to create pool constraints"),
70    );
71    let opts_builder = mysql_async::OptsBuilder::default()
72        .ip_or_hostname(options.host.clone())
73        .tcp_port(options.port)
74        .user(Some(options.username.clone()))
75        .pass(Some(options.password.clone()))
76        .db_name(options.database.clone())
77        .pool_opts(pool_opts);
78    let pool = mysql_async::Pool::new(opts_builder);
79    Ok(MysqlPool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for MysqlPool {
84    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85        let conn = self.pool.get_conn().await.map_err(|e| {
86            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
87        })?;
88        Ok(Arc::new(MysqlConnection {
89            conn: Arc::new(Mutex::new(conn)),
90        }))
91    }
92}
93
94#[derive(Debug)]
95pub struct MysqlConnection {
96    conn: Arc<Mutex<mysql_async::Conn>>,
97}
98
99#[async_trait::async_trait]
100impl Connection for MysqlConnection {
101    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
102        let sql = try_limit_query(sql, 1).unwrap_or_else(|| sql.to_string());
103        let mut conn = self.conn.lock().await;
104        let conn = &mut *conn;
105        let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
106            DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
107        })?;
108        let Some(row) = row else {
109            return Err(DataFusionError::Execution(
110                "No rows returned to infer schema".to_string(),
111            ));
112        };
113        let remote_schema = Arc::new(build_remote_schema(&row)?);
114        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
115        Ok((remote_schema, arrow_schema))
116    }
117
118    async fn query(
119        &self,
120        conn_options: &ConnectionOptions,
121        sql: &str,
122        table_schema: SchemaRef,
123        projection: Option<&Vec<usize>>,
124        limit: Option<usize>,
125    ) -> DFResult<SendableRecordBatchStream> {
126        let projected_schema = project_schema(&table_schema, projection)?;
127        let (sql, limit_stream) = match limit {
128            Some(limit) => {
129                if let Some(limited_sql) = try_limit_query(sql, limit) {
130                    (limited_sql, false)
131                } else {
132                    (sql.to_string(), true)
133                }
134            }
135            None => (sql.to_string(), false),
136        };
137        let projection = projection.cloned();
138        let chunk_size = conn_options.stream_chunk_size();
139        let conn = Arc::clone(&self.conn);
140        let stream = Box::pin(stream! {
141            let mut conn = conn.lock().await;
142            let mut query_iter = conn
143                .query_iter(sql)
144                .await
145                .map_err(|e| {
146                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
147                })?;
148
149            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
150                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
151                })? else {
152                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
153                return;
154            };
155
156            let mut chunked_stream = stream.chunks(chunk_size).boxed();
157
158            while let Some(chunk) = chunked_stream.next().await {
159                let rows = chunk
160                    .into_iter()
161                    .collect::<Result<Vec<_>, _>>()
162                    .map_err(|e| {
163                        DataFusionError::Execution(format!(
164                            "Failed to collect rows from mysql due to {e}",
165                        ))
166                    })?;
167
168                yield Ok::<_, DataFusionError>(rows)
169            }
170        });
171
172        let stream = stream.map(move |rows| {
173            let rows = rows?;
174            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
175        });
176
177        let sendable_stream = Box::pin(RecordBatchStreamAdapter::new(projected_schema, stream));
178
179        if limit_stream {
180            let metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
181            Ok(Box::pin(LimitStream::new(
182                sendable_stream,
183                0,
184                limit,
185                metrics,
186            )))
187        } else {
188            Ok(sendable_stream)
189        }
190    }
191}
192
193fn try_limit_query(sql: &str, limit: usize) -> Option<String> {
194    if sql.trim()[0..6].eq_ignore_ascii_case("select") {
195        Some(format!("SELECT * FROM ({sql}) as __subquery LIMIT {limit}"))
196    } else {
197        None
198    }
199}
200
201fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
202    let character_set = mysql_col.character_set();
203    let is_utf8_bin_character_set = character_set == 45;
204    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
205    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
206    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
207    let col_length = mysql_col.column_length();
208    match mysql_col.column_type() {
209        ColumnType::MYSQL_TYPE_TINY => {
210            if is_unsigned {
211                Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
212            } else {
213                Ok(RemoteType::Mysql(MysqlType::TinyInt))
214            }
215        }
216        ColumnType::MYSQL_TYPE_SHORT => {
217            if is_unsigned {
218                Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
219            } else {
220                Ok(RemoteType::Mysql(MysqlType::SmallInt))
221            }
222        }
223        ColumnType::MYSQL_TYPE_INT24 => {
224            if is_unsigned {
225                Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
226            } else {
227                Ok(RemoteType::Mysql(MysqlType::MediumInt))
228            }
229        }
230        ColumnType::MYSQL_TYPE_LONG => {
231            if is_unsigned {
232                Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
233            } else {
234                Ok(RemoteType::Mysql(MysqlType::Integer))
235            }
236        }
237        ColumnType::MYSQL_TYPE_LONGLONG => {
238            if is_unsigned {
239                Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
240            } else {
241                Ok(RemoteType::Mysql(MysqlType::BigInt))
242            }
243        }
244        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
245        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
246        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
247            let precision = (mysql_col.column_length() - 2) as u8;
248            let scale = mysql_col.decimals();
249            Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
250        }
251        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
252        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
253        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
254        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
255        ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
256        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
257        ColumnType::MYSQL_TYPE_STRING if is_binary => {
258            if is_utf8_bin_character_set {
259                Ok(RemoteType::Mysql(MysqlType::Char))
260            } else {
261                Ok(RemoteType::Mysql(MysqlType::Binary))
262            }
263        }
264        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
265            Ok(RemoteType::Mysql(MysqlType::Varchar))
266        }
267        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
268            if is_utf8_bin_character_set {
269                Ok(RemoteType::Mysql(MysqlType::Varchar))
270            } else {
271                Ok(RemoteType::Mysql(MysqlType::Varbinary))
272            }
273        }
274        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
275        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
276            Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
277        }
278        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
279            if is_utf8_bin_character_set {
280                Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
281            } else {
282                Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
283            }
284        }
285        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
286        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
287        _ => Err(DataFusionError::NotImplemented(format!(
288            "Unsupported mysql type: {mysql_col:?}",
289        ))),
290    }
291}
292
293fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
294    let mut remote_fields = vec![];
295    for col in row.columns_ref() {
296        remote_fields.push(RemoteField::new(
297            col.name_str().to_string(),
298            mysql_type_to_remote_type(col)?,
299            true,
300        ));
301    }
302    Ok(RemoteSchema::new(remote_fields))
303}
304
305macro_rules! handle_primitive_type {
306    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
307        let builder = $builder
308            .as_any_mut()
309            .downcast_mut::<$builder_ty>()
310            .unwrap_or_else(|| {
311                panic!(
312                    "Failed to downcast builder to {} for {:?} and {:?}",
313                    stringify!($builder_ty),
314                    $field,
315                    $col
316                )
317            });
318        let v = $row.get_opt::<$value_ty, usize>($index);
319
320        match v {
321            None => builder.append_null(),
322            Some(Ok(v)) => builder.append_value($convert(v)?),
323            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
324            Some(Err(e)) => {
325                return Err(DataFusionError::Execution(format!(
326                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
327                    stringify!($value_ty),
328                    $field,
329                    $col,
330                )));
331            }
332        }
333    }};
334}
335
336fn rows_to_batch(
337    rows: &[Row],
338    table_schema: &SchemaRef,
339    projection: Option<&Vec<usize>>,
340) -> DFResult<RecordBatch> {
341    let projected_schema = project_schema(table_schema, projection)?;
342    let mut array_builders = vec![];
343    for field in table_schema.fields() {
344        let builder = make_builder(field.data_type(), rows.len());
345        array_builders.push(builder);
346    }
347
348    for row in rows {
349        for (idx, field) in table_schema.fields.iter().enumerate() {
350            if !projections_contains(projection, idx) {
351                continue;
352            }
353            let builder = &mut array_builders[idx];
354            let col = row.columns_ref().get(idx);
355            match field.data_type() {
356                DataType::Int8 => {
357                    handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
358                        Ok::<_, DataFusionError>(v)
359                    });
360                }
361                DataType::Int16 => {
362                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
363                        Ok::<_, DataFusionError>(v)
364                    });
365                }
366                DataType::Int32 => {
367                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
368                        Ok::<_, DataFusionError>(v)
369                    });
370                }
371                DataType::Int64 => {
372                    handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
373                        Ok::<_, DataFusionError>(v)
374                    });
375                }
376                DataType::UInt8 => {
377                    handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
378                        Ok::<_, DataFusionError>(v)
379                    });
380                }
381                DataType::UInt16 => {
382                    handle_primitive_type!(
383                        builder,
384                        field,
385                        col,
386                        UInt16Builder,
387                        u16,
388                        row,
389                        idx,
390                        |v| { Ok::<_, DataFusionError>(v) }
391                    );
392                }
393                DataType::UInt32 => {
394                    handle_primitive_type!(
395                        builder,
396                        field,
397                        col,
398                        UInt32Builder,
399                        u32,
400                        row,
401                        idx,
402                        |v| { Ok::<_, DataFusionError>(v) }
403                    );
404                }
405                DataType::UInt64 => {
406                    handle_primitive_type!(
407                        builder,
408                        field,
409                        col,
410                        UInt64Builder,
411                        u64,
412                        row,
413                        idx,
414                        |v| { Ok::<_, DataFusionError>(v) }
415                    );
416                }
417                DataType::Float32 => {
418                    handle_primitive_type!(
419                        builder,
420                        field,
421                        col,
422                        Float32Builder,
423                        f32,
424                        row,
425                        idx,
426                        |v| { Ok::<_, DataFusionError>(v) }
427                    );
428                }
429                DataType::Float64 => {
430                    handle_primitive_type!(
431                        builder,
432                        field,
433                        col,
434                        Float64Builder,
435                        f64,
436                        row,
437                        idx,
438                        |v| { Ok::<_, DataFusionError>(v) }
439                    );
440                }
441                DataType::Decimal128(_precision, scale) => {
442                    handle_primitive_type!(
443                        builder,
444                        field,
445                        col,
446                        Decimal128Builder,
447                        BigDecimal,
448                        row,
449                        idx,
450                        |v: BigDecimal| {
451                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
452                                DataFusionError::Execution(format!(
453                                    "Failed to convert BigDecimal {v:?} to i128"
454                                ))
455                            })
456                        }
457                    );
458                }
459                DataType::Decimal256(_precision, _scale) => {
460                    handle_primitive_type!(
461                        builder,
462                        field,
463                        col,
464                        Decimal256Builder,
465                        BigDecimal,
466                        row,
467                        idx,
468                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
469                    );
470                }
471                DataType::Date32 => {
472                    handle_primitive_type!(
473                        builder,
474                        field,
475                        col,
476                        Date32Builder,
477                        chrono::NaiveDate,
478                        row,
479                        idx,
480                        |v: chrono::NaiveDate| {
481                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
482                        }
483                    );
484                }
485                DataType::Timestamp(TimeUnit::Microsecond, None) => {
486                    handle_primitive_type!(
487                        builder,
488                        field,
489                        col,
490                        TimestampMicrosecondBuilder,
491                        time::PrimitiveDateTime,
492                        row,
493                        idx,
494                        |v: time::PrimitiveDateTime| {
495                            let timestamp_micros =
496                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
497                            Ok::<_, DataFusionError>(timestamp_micros)
498                        }
499                    );
500                }
501                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
502                    handle_primitive_type!(
503                        builder,
504                        field,
505                        col,
506                        TimestampNanosecondBuilder,
507                        chrono::NaiveTime,
508                        row,
509                        idx,
510                        |v: chrono::NaiveTime| {
511                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
512                                + i64::from(v.nanosecond());
513                            Ok::<_, DataFusionError>(t)
514                        }
515                    );
516                }
517                DataType::Time32(TimeUnit::Second) => {
518                    handle_primitive_type!(
519                        builder,
520                        field,
521                        col,
522                        Time32SecondBuilder,
523                        chrono::NaiveTime,
524                        row,
525                        idx,
526                        |v: chrono::NaiveTime| {
527                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
528                        }
529                    );
530                }
531                DataType::Time64(TimeUnit::Nanosecond) => {
532                    handle_primitive_type!(
533                        builder,
534                        field,
535                        col,
536                        Time64NanosecondBuilder,
537                        chrono::NaiveTime,
538                        row,
539                        idx,
540                        |v: chrono::NaiveTime| {
541                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
542                                + i64::from(v.nanosecond());
543                            Ok::<_, DataFusionError>(t)
544                        }
545                    );
546                }
547                DataType::Utf8 => {
548                    handle_primitive_type!(
549                        builder,
550                        field,
551                        col,
552                        StringBuilder,
553                        String,
554                        row,
555                        idx,
556                        |v| { Ok::<_, DataFusionError>(v) }
557                    );
558                }
559                DataType::LargeUtf8 => {
560                    handle_primitive_type!(
561                        builder,
562                        field,
563                        col,
564                        LargeStringBuilder,
565                        String,
566                        row,
567                        idx,
568                        |v| { Ok::<_, DataFusionError>(v) }
569                    );
570                }
571                DataType::Binary => {
572                    handle_primitive_type!(
573                        builder,
574                        field,
575                        col,
576                        BinaryBuilder,
577                        Vec<u8>,
578                        row,
579                        idx,
580                        |v| { Ok::<_, DataFusionError>(v) }
581                    );
582                }
583                DataType::LargeBinary => {
584                    handle_primitive_type!(
585                        builder,
586                        field,
587                        col,
588                        LargeBinaryBuilder,
589                        Vec<u8>,
590                        row,
591                        idx,
592                        |v| { Ok::<_, DataFusionError>(v) }
593                    );
594                }
595                _ => {
596                    return Err(DataFusionError::NotImplemented(format!(
597                        "Unsupported data type {:?} for col: {:?}",
598                        field.data_type(),
599                        col
600                    )));
601                }
602            }
603        }
604    }
605    let projected_columns = array_builders
606        .into_iter()
607        .enumerate()
608        .filter(|(idx, _)| projections_contains(projection, *idx))
609        .map(|(_, mut builder)| builder.finish())
610        .collect::<Vec<ArrayRef>>();
611    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
612}
613
614fn to_decimal_256(decimal: &BigDecimal) -> i256 {
615    let (bigint_value, _) = decimal.as_bigint_and_exponent();
616    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
617
618    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
619    let fill_byte = if is_negative { 0xFF } else { 0x00 };
620
621    if bigint_bytes.len() > 32 {
622        bigint_bytes.truncate(32);
623    } else {
624        bigint_bytes.resize(32, fill_byte);
625    };
626
627    let mut array = [0u8; 32];
628    array.copy_from_slice(&bigint_bytes);
629
630    i256::from_le_bytes(array)
631}