datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13    TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use mysql_async::consts::{ColumnFlags, ColumnType};
25use mysql_async::prelude::Queryable;
26use mysql_async::{Column, FromValueError, Row, Value};
27use std::sync::Arc;
28
29#[derive(Debug, Clone, With, Getters)]
30pub struct MysqlConnectionOptions {
31    pub(crate) host: String,
32    pub(crate) port: u16,
33    pub(crate) username: String,
34    pub(crate) password: String,
35    pub(crate) database: Option<String>,
36    pub(crate) pool_max_size: usize,
37    pub(crate) stream_chunk_size: usize,
38}
39
40impl MysqlConnectionOptions {
41    pub fn new(
42        host: impl Into<String>,
43        port: u16,
44        username: impl Into<String>,
45        password: impl Into<String>,
46    ) -> Self {
47        Self {
48            host: host.into(),
49            port,
50            username: username.into(),
51            password: password.into(),
52            database: None,
53            pool_max_size: 10,
54            stream_chunk_size: 2048,
55        }
56    }
57}
58
59#[derive(Debug)]
60pub struct MysqlPool {
61    pool: mysql_async::Pool,
62}
63
64pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
65    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
66        mysql_async::PoolConstraints::new(0, options.pool_max_size)
67            .expect("Failed to create pool constraints"),
68    );
69    let opts_builder = mysql_async::OptsBuilder::default()
70        .ip_or_hostname(options.host.clone())
71        .tcp_port(options.port)
72        .user(Some(options.username.clone()))
73        .pass(Some(options.password.clone()))
74        .db_name(options.database.clone())
75        .pool_opts(pool_opts);
76    let pool = mysql_async::Pool::new(opts_builder);
77    Ok(MysqlPool { pool })
78}
79
80#[async_trait::async_trait]
81impl Pool for MysqlPool {
82    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
83        let conn = self.pool.get_conn().await.map_err(|e| {
84            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
85        })?;
86        Ok(Arc::new(MysqlConnection {
87            conn: Arc::new(Mutex::new(conn)),
88        }))
89    }
90}
91
92#[derive(Debug)]
93pub struct MysqlConnection {
94    conn: Arc<Mutex<mysql_async::Conn>>,
95}
96
97#[async_trait::async_trait]
98impl Connection for MysqlConnection {
99    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
100        let sql = try_limit1_query(sql).unwrap_or_else(|| sql.to_string());
101        let mut conn = self.conn.lock().await;
102        let conn = &mut *conn;
103        let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
104            DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
105        })?;
106        let Some(row) = row else {
107            return Err(DataFusionError::Execution(
108                "No rows returned to infer schema".to_string(),
109            ));
110        };
111        let remote_schema = Arc::new(build_remote_schema(&row)?);
112        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
113        Ok((remote_schema, arrow_schema))
114    }
115
116    async fn query(
117        &self,
118        conn_options: &ConnectionOptions,
119        sql: &str,
120        table_schema: SchemaRef,
121        projection: Option<&Vec<usize>>,
122    ) -> DFResult<SendableRecordBatchStream> {
123        let projected_schema = project_schema(&table_schema, projection)?;
124        let sql = sql.to_string();
125        let projection = projection.cloned();
126        let chunk_size = conn_options.stream_chunk_size();
127        let conn = Arc::clone(&self.conn);
128        let stream = Box::pin(stream! {
129            let mut conn = conn.lock().await;
130            let mut query_iter = conn
131                .query_iter(sql)
132                .await
133                .map_err(|e| {
134                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
135                })?;
136
137            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
138                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
139                })? else {
140                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
141                return;
142            };
143
144            let mut chunked_stream = stream.chunks(chunk_size).boxed();
145
146            while let Some(chunk) = chunked_stream.next().await {
147                let rows = chunk
148                    .into_iter()
149                    .collect::<Result<Vec<_>, _>>()
150                    .map_err(|e| {
151                        DataFusionError::Execution(format!(
152                            "Failed to collect rows from mysql due to {e}",
153                        ))
154                    })?;
155
156                yield Ok::<_, DataFusionError>(rows)
157            }
158        });
159
160        let stream = stream.map(move |rows| {
161            let rows = rows?;
162            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
163        });
164
165        Ok(Box::pin(RecordBatchStreamAdapter::new(
166            projected_schema,
167            stream,
168        )))
169    }
170}
171
172fn try_limit1_query(sql: &str) -> Option<String> {
173    if sql.trim()[0..6].eq_ignore_ascii_case("select") {
174        Some(format!("SELECT * FROM ({sql}) as __subquery LIMIT 1"))
175    } else {
176        None
177    }
178}
179
180fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
181    let character_set = mysql_col.character_set();
182    let is_utf8_bin_character_set = character_set == 45;
183    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
184    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
185    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
186    let col_length = mysql_col.column_length();
187    match mysql_col.column_type() {
188        ColumnType::MYSQL_TYPE_TINY => {
189            if is_unsigned {
190                Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
191            } else {
192                Ok(RemoteType::Mysql(MysqlType::TinyInt))
193            }
194        }
195        ColumnType::MYSQL_TYPE_SHORT => {
196            if is_unsigned {
197                Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
198            } else {
199                Ok(RemoteType::Mysql(MysqlType::SmallInt))
200            }
201        }
202        ColumnType::MYSQL_TYPE_INT24 => {
203            if is_unsigned {
204                Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
205            } else {
206                Ok(RemoteType::Mysql(MysqlType::MediumInt))
207            }
208        }
209        ColumnType::MYSQL_TYPE_LONG => {
210            if is_unsigned {
211                Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
212            } else {
213                Ok(RemoteType::Mysql(MysqlType::Integer))
214            }
215        }
216        ColumnType::MYSQL_TYPE_LONGLONG => {
217            if is_unsigned {
218                Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
219            } else {
220                Ok(RemoteType::Mysql(MysqlType::BigInt))
221            }
222        }
223        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
224        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
225        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
226            let precision = (mysql_col.column_length() - 2) as u8;
227            let scale = mysql_col.decimals();
228            Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
229        }
230        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
231        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
232        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
233        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
234        ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
235        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
236        ColumnType::MYSQL_TYPE_STRING if is_binary => {
237            if is_utf8_bin_character_set {
238                Ok(RemoteType::Mysql(MysqlType::Char))
239            } else {
240                Ok(RemoteType::Mysql(MysqlType::Binary))
241            }
242        }
243        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
244            Ok(RemoteType::Mysql(MysqlType::Varchar))
245        }
246        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
247            if is_utf8_bin_character_set {
248                Ok(RemoteType::Mysql(MysqlType::Varchar))
249            } else {
250                Ok(RemoteType::Mysql(MysqlType::Varbinary))
251            }
252        }
253        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
254        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
255            Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
256        }
257        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
258            if is_utf8_bin_character_set {
259                Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
260            } else {
261                Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
262            }
263        }
264        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
265        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
266        _ => Err(DataFusionError::NotImplemented(format!(
267            "Unsupported mysql type: {mysql_col:?}",
268        ))),
269    }
270}
271
272fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
273    let mut remote_fields = vec![];
274    for col in row.columns_ref() {
275        remote_fields.push(RemoteField::new(
276            col.name_str().to_string(),
277            mysql_type_to_remote_type(col)?,
278            true,
279        ));
280    }
281    Ok(RemoteSchema::new(remote_fields))
282}
283
284macro_rules! handle_primitive_type {
285    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
286        let builder = $builder
287            .as_any_mut()
288            .downcast_mut::<$builder_ty>()
289            .unwrap_or_else(|| {
290                panic!(
291                    "Failed to downcast builder to {} for {:?} and {:?}",
292                    stringify!($builder_ty),
293                    $field,
294                    $col
295                )
296            });
297        let v = $row.get_opt::<$value_ty, usize>($index);
298
299        match v {
300            None => builder.append_null(),
301            Some(Ok(v)) => builder.append_value($convert(v)?),
302            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
303            Some(Err(e)) => {
304                return Err(DataFusionError::Execution(format!(
305                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
306                    stringify!($value_ty),
307                    $field,
308                    $col,
309                )));
310            }
311        }
312    }};
313}
314
315fn rows_to_batch(
316    rows: &[Row],
317    table_schema: &SchemaRef,
318    projection: Option<&Vec<usize>>,
319) -> DFResult<RecordBatch> {
320    let projected_schema = project_schema(table_schema, projection)?;
321    let mut array_builders = vec![];
322    for field in table_schema.fields() {
323        let builder = make_builder(field.data_type(), rows.len());
324        array_builders.push(builder);
325    }
326
327    for row in rows {
328        for (idx, field) in table_schema.fields.iter().enumerate() {
329            if !projections_contains(projection, idx) {
330                continue;
331            }
332            let builder = &mut array_builders[idx];
333            let col = row.columns_ref().get(idx);
334            match field.data_type() {
335                DataType::Int8 => {
336                    handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
337                        Ok::<_, DataFusionError>(v)
338                    });
339                }
340                DataType::Int16 => {
341                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
342                        Ok::<_, DataFusionError>(v)
343                    });
344                }
345                DataType::Int32 => {
346                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
347                        Ok::<_, DataFusionError>(v)
348                    });
349                }
350                DataType::Int64 => {
351                    handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
352                        Ok::<_, DataFusionError>(v)
353                    });
354                }
355                DataType::UInt8 => {
356                    handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
357                        Ok::<_, DataFusionError>(v)
358                    });
359                }
360                DataType::UInt16 => {
361                    handle_primitive_type!(
362                        builder,
363                        field,
364                        col,
365                        UInt16Builder,
366                        u16,
367                        row,
368                        idx,
369                        |v| { Ok::<_, DataFusionError>(v) }
370                    );
371                }
372                DataType::UInt32 => {
373                    handle_primitive_type!(
374                        builder,
375                        field,
376                        col,
377                        UInt32Builder,
378                        u32,
379                        row,
380                        idx,
381                        |v| { Ok::<_, DataFusionError>(v) }
382                    );
383                }
384                DataType::UInt64 => {
385                    handle_primitive_type!(
386                        builder,
387                        field,
388                        col,
389                        UInt64Builder,
390                        u64,
391                        row,
392                        idx,
393                        |v| { Ok::<_, DataFusionError>(v) }
394                    );
395                }
396                DataType::Float32 => {
397                    handle_primitive_type!(
398                        builder,
399                        field,
400                        col,
401                        Float32Builder,
402                        f32,
403                        row,
404                        idx,
405                        |v| { Ok::<_, DataFusionError>(v) }
406                    );
407                }
408                DataType::Float64 => {
409                    handle_primitive_type!(
410                        builder,
411                        field,
412                        col,
413                        Float64Builder,
414                        f64,
415                        row,
416                        idx,
417                        |v| { Ok::<_, DataFusionError>(v) }
418                    );
419                }
420                DataType::Decimal128(_precision, scale) => {
421                    handle_primitive_type!(
422                        builder,
423                        field,
424                        col,
425                        Decimal128Builder,
426                        BigDecimal,
427                        row,
428                        idx,
429                        |v: BigDecimal| {
430                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
431                                DataFusionError::Execution(format!(
432                                    "Failed to convert BigDecimal {v:?} to i128"
433                                ))
434                            })
435                        }
436                    );
437                }
438                DataType::Decimal256(_precision, _scale) => {
439                    handle_primitive_type!(
440                        builder,
441                        field,
442                        col,
443                        Decimal256Builder,
444                        BigDecimal,
445                        row,
446                        idx,
447                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
448                    );
449                }
450                DataType::Date32 => {
451                    handle_primitive_type!(
452                        builder,
453                        field,
454                        col,
455                        Date32Builder,
456                        chrono::NaiveDate,
457                        row,
458                        idx,
459                        |v: chrono::NaiveDate| {
460                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
461                        }
462                    );
463                }
464                DataType::Timestamp(TimeUnit::Microsecond, None) => {
465                    handle_primitive_type!(
466                        builder,
467                        field,
468                        col,
469                        TimestampMicrosecondBuilder,
470                        time::PrimitiveDateTime,
471                        row,
472                        idx,
473                        |v: time::PrimitiveDateTime| {
474                            let timestamp_micros =
475                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
476                            Ok::<_, DataFusionError>(timestamp_micros)
477                        }
478                    );
479                }
480                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
481                    handle_primitive_type!(
482                        builder,
483                        field,
484                        col,
485                        TimestampNanosecondBuilder,
486                        chrono::NaiveTime,
487                        row,
488                        idx,
489                        |v: chrono::NaiveTime| {
490                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
491                                + i64::from(v.nanosecond());
492                            Ok::<_, DataFusionError>(t)
493                        }
494                    );
495                }
496                DataType::Time32(TimeUnit::Second) => {
497                    handle_primitive_type!(
498                        builder,
499                        field,
500                        col,
501                        Time32SecondBuilder,
502                        chrono::NaiveTime,
503                        row,
504                        idx,
505                        |v: chrono::NaiveTime| {
506                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
507                        }
508                    );
509                }
510                DataType::Time64(TimeUnit::Nanosecond) => {
511                    handle_primitive_type!(
512                        builder,
513                        field,
514                        col,
515                        Time64NanosecondBuilder,
516                        chrono::NaiveTime,
517                        row,
518                        idx,
519                        |v: chrono::NaiveTime| {
520                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
521                                + i64::from(v.nanosecond());
522                            Ok::<_, DataFusionError>(t)
523                        }
524                    );
525                }
526                DataType::Utf8 => {
527                    handle_primitive_type!(
528                        builder,
529                        field,
530                        col,
531                        StringBuilder,
532                        String,
533                        row,
534                        idx,
535                        |v| { Ok::<_, DataFusionError>(v) }
536                    );
537                }
538                DataType::LargeUtf8 => {
539                    handle_primitive_type!(
540                        builder,
541                        field,
542                        col,
543                        LargeStringBuilder,
544                        String,
545                        row,
546                        idx,
547                        |v| { Ok::<_, DataFusionError>(v) }
548                    );
549                }
550                DataType::Binary => {
551                    handle_primitive_type!(
552                        builder,
553                        field,
554                        col,
555                        BinaryBuilder,
556                        Vec<u8>,
557                        row,
558                        idx,
559                        |v| { Ok::<_, DataFusionError>(v) }
560                    );
561                }
562                DataType::LargeBinary => {
563                    handle_primitive_type!(
564                        builder,
565                        field,
566                        col,
567                        LargeBinaryBuilder,
568                        Vec<u8>,
569                        row,
570                        idx,
571                        |v| { Ok::<_, DataFusionError>(v) }
572                    );
573                }
574                _ => {
575                    return Err(DataFusionError::NotImplemented(format!(
576                        "Unsupported data type {:?} for col: {:?}",
577                        field.data_type(),
578                        col
579                    )));
580                }
581            }
582        }
583    }
584    let projected_columns = array_builders
585        .into_iter()
586        .enumerate()
587        .filter(|(idx, _)| projections_contains(projection, *idx))
588        .map(|(_, mut builder)| builder.finish())
589        .collect::<Vec<ArrayRef>>();
590    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
591}
592
593fn to_decimal_256(decimal: &BigDecimal) -> i256 {
594    let (bigint_value, _) = decimal.as_bigint_and_exponent();
595    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
596
597    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
598    let fill_byte = if is_negative { 0xFF } else { 0x00 };
599
600    if bigint_bytes.len() > 32 {
601        bigint_bytes.truncate(32);
602    } else {
603        bigint_bytes.resize(32, fill_byte);
604    };
605
606    let mut array = [0u8; 32];
607    array.copy_from_slice(&bigint_bytes);
608
609    i256::from_le_bytes(array)
610}