datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13    TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::prelude::Expr;
21use derive_getters::Getters;
22use derive_with::With;
23use futures::StreamExt;
24use futures::lock::Mutex;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::sync::Arc;
29
30#[derive(Debug, Clone, With, Getters)]
31pub struct MysqlConnectionOptions {
32    pub(crate) host: String,
33    pub(crate) port: u16,
34    pub(crate) username: String,
35    pub(crate) password: String,
36    pub(crate) database: Option<String>,
37    pub(crate) pool_max_size: usize,
38    pub(crate) stream_chunk_size: usize,
39}
40
41impl MysqlConnectionOptions {
42    pub fn new(
43        host: impl Into<String>,
44        port: u16,
45        username: impl Into<String>,
46        password: impl Into<String>,
47    ) -> Self {
48        Self {
49            host: host.into(),
50            port,
51            username: username.into(),
52            password: password.into(),
53            database: None,
54            pool_max_size: 10,
55            stream_chunk_size: 2048,
56        }
57    }
58}
59
60#[derive(Debug)]
61pub struct MysqlPool {
62    pool: mysql_async::Pool,
63}
64
65pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
66    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
67        mysql_async::PoolConstraints::new(0, options.pool_max_size)
68            .expect("Failed to create pool constraints"),
69    );
70    let opts_builder = mysql_async::OptsBuilder::default()
71        .ip_or_hostname(options.host.clone())
72        .tcp_port(options.port)
73        .user(Some(options.username.clone()))
74        .pass(Some(options.password.clone()))
75        .db_name(options.database.clone())
76        .pool_opts(pool_opts);
77    let pool = mysql_async::Pool::new(opts_builder);
78    Ok(MysqlPool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for MysqlPool {
83    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84        let conn = self.pool.get_conn().await.map_err(|e| {
85            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
86        })?;
87        Ok(Arc::new(MysqlConnection {
88            conn: Arc::new(Mutex::new(conn)),
89        }))
90    }
91}
92
93#[derive(Debug)]
94pub struct MysqlConnection {
95    conn: Arc<Mutex<mysql_async::Conn>>,
96}
97
98#[async_trait::async_trait]
99impl Connection for MysqlConnection {
100    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
101        let sql = RemoteDbType::Mysql
102            .try_rewrite_query(sql, &[], Some(1))
103            .unwrap_or_else(|| sql.to_string());
104        let mut conn = self.conn.lock().await;
105        let conn = &mut *conn;
106        let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
107            DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
108        })?;
109        let Some(row) = row else {
110            return Err(DataFusionError::Execution(
111                "No rows returned to infer schema".to_string(),
112            ));
113        };
114        let remote_schema = Arc::new(build_remote_schema(&row)?);
115        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
116        Ok((remote_schema, arrow_schema))
117    }
118
119    async fn query(
120        &self,
121        conn_options: &ConnectionOptions,
122        sql: &str,
123        table_schema: SchemaRef,
124        projection: Option<&Vec<usize>>,
125        filters: &[Expr],
126        limit: Option<usize>,
127    ) -> DFResult<SendableRecordBatchStream> {
128        let projected_schema = project_schema(&table_schema, projection)?;
129        let sql = RemoteDbType::Mysql
130            .try_rewrite_query(sql, filters, limit)
131            .unwrap_or_else(|| sql.to_string());
132        let projection = projection.cloned();
133        let chunk_size = conn_options.stream_chunk_size();
134        let conn = Arc::clone(&self.conn);
135        let stream = Box::pin(stream! {
136            let mut conn = conn.lock().await;
137            let mut query_iter = conn
138                .query_iter(sql.clone())
139                .await
140                .map_err(|e| {
141                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
142                })?;
143
144            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
145                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
146                })? else {
147                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
148                return;
149            };
150
151            let mut chunked_stream = stream.chunks(chunk_size).boxed();
152
153            while let Some(chunk) = chunked_stream.next().await {
154                let rows = chunk
155                    .into_iter()
156                    .collect::<Result<Vec<_>, _>>()
157                    .map_err(|e| {
158                        DataFusionError::Execution(format!(
159                            "Failed to collect rows from mysql due to {e}",
160                        ))
161                    })?;
162
163                yield Ok::<_, DataFusionError>(rows)
164            }
165        });
166
167        let stream = stream.map(move |rows| {
168            let rows = rows?;
169            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
170        });
171
172        Ok(Box::pin(RecordBatchStreamAdapter::new(
173            projected_schema,
174            stream,
175        )))
176    }
177}
178
179fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
180    let character_set = mysql_col.character_set();
181    let is_utf8_bin_character_set = character_set == 45;
182    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
183    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
184    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
185    let col_length = mysql_col.column_length();
186    match mysql_col.column_type() {
187        ColumnType::MYSQL_TYPE_TINY => {
188            if is_unsigned {
189                Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
190            } else {
191                Ok(RemoteType::Mysql(MysqlType::TinyInt))
192            }
193        }
194        ColumnType::MYSQL_TYPE_SHORT => {
195            if is_unsigned {
196                Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
197            } else {
198                Ok(RemoteType::Mysql(MysqlType::SmallInt))
199            }
200        }
201        ColumnType::MYSQL_TYPE_INT24 => {
202            if is_unsigned {
203                Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
204            } else {
205                Ok(RemoteType::Mysql(MysqlType::MediumInt))
206            }
207        }
208        ColumnType::MYSQL_TYPE_LONG => {
209            if is_unsigned {
210                Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
211            } else {
212                Ok(RemoteType::Mysql(MysqlType::Integer))
213            }
214        }
215        ColumnType::MYSQL_TYPE_LONGLONG => {
216            if is_unsigned {
217                Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
218            } else {
219                Ok(RemoteType::Mysql(MysqlType::BigInt))
220            }
221        }
222        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
223        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
224        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
225            let precision = (mysql_col.column_length() - 2) as u8;
226            let scale = mysql_col.decimals();
227            Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
228        }
229        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
230        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
231        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
232        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
233        ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
234        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
235        ColumnType::MYSQL_TYPE_STRING if is_binary => {
236            if is_utf8_bin_character_set {
237                Ok(RemoteType::Mysql(MysqlType::Char))
238            } else {
239                Ok(RemoteType::Mysql(MysqlType::Binary))
240            }
241        }
242        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
243            Ok(RemoteType::Mysql(MysqlType::Varchar))
244        }
245        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
246            if is_utf8_bin_character_set {
247                Ok(RemoteType::Mysql(MysqlType::Varchar))
248            } else {
249                Ok(RemoteType::Mysql(MysqlType::Varbinary))
250            }
251        }
252        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
253        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
254            Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
255        }
256        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
257            if is_utf8_bin_character_set {
258                Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
259            } else {
260                Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
261            }
262        }
263        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
264        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
265        _ => Err(DataFusionError::NotImplemented(format!(
266            "Unsupported mysql type: {mysql_col:?}",
267        ))),
268    }
269}
270
271fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
272    let mut remote_fields = vec![];
273    for col in row.columns_ref() {
274        remote_fields.push(RemoteField::new(
275            col.name_str().to_string(),
276            mysql_type_to_remote_type(col)?,
277            true,
278        ));
279    }
280    Ok(RemoteSchema::new(remote_fields))
281}
282
283macro_rules! handle_primitive_type {
284    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
285        let builder = $builder
286            .as_any_mut()
287            .downcast_mut::<$builder_ty>()
288            .unwrap_or_else(|| {
289                panic!(
290                    "Failed to downcast builder to {} for {:?} and {:?}",
291                    stringify!($builder_ty),
292                    $field,
293                    $col
294                )
295            });
296        let v = $row.get_opt::<$value_ty, usize>($index);
297
298        match v {
299            None => builder.append_null(),
300            Some(Ok(v)) => builder.append_value($convert(v)?),
301            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
302            Some(Err(e)) => {
303                return Err(DataFusionError::Execution(format!(
304                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
305                    stringify!($value_ty),
306                    $field,
307                    $col,
308                )));
309            }
310        }
311    }};
312}
313
314fn rows_to_batch(
315    rows: &[Row],
316    table_schema: &SchemaRef,
317    projection: Option<&Vec<usize>>,
318) -> DFResult<RecordBatch> {
319    let projected_schema = project_schema(table_schema, projection)?;
320    let mut array_builders = vec![];
321    for field in table_schema.fields() {
322        let builder = make_builder(field.data_type(), rows.len());
323        array_builders.push(builder);
324    }
325
326    for row in rows {
327        for (idx, field) in table_schema.fields.iter().enumerate() {
328            if !projections_contains(projection, idx) {
329                continue;
330            }
331            let builder = &mut array_builders[idx];
332            let col = row.columns_ref().get(idx);
333            match field.data_type() {
334                DataType::Int8 => {
335                    handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
336                        Ok::<_, DataFusionError>(v)
337                    });
338                }
339                DataType::Int16 => {
340                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
341                        Ok::<_, DataFusionError>(v)
342                    });
343                }
344                DataType::Int32 => {
345                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
346                        Ok::<_, DataFusionError>(v)
347                    });
348                }
349                DataType::Int64 => {
350                    handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
351                        Ok::<_, DataFusionError>(v)
352                    });
353                }
354                DataType::UInt8 => {
355                    handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
356                        Ok::<_, DataFusionError>(v)
357                    });
358                }
359                DataType::UInt16 => {
360                    handle_primitive_type!(
361                        builder,
362                        field,
363                        col,
364                        UInt16Builder,
365                        u16,
366                        row,
367                        idx,
368                        |v| { Ok::<_, DataFusionError>(v) }
369                    );
370                }
371                DataType::UInt32 => {
372                    handle_primitive_type!(
373                        builder,
374                        field,
375                        col,
376                        UInt32Builder,
377                        u32,
378                        row,
379                        idx,
380                        |v| { Ok::<_, DataFusionError>(v) }
381                    );
382                }
383                DataType::UInt64 => {
384                    handle_primitive_type!(
385                        builder,
386                        field,
387                        col,
388                        UInt64Builder,
389                        u64,
390                        row,
391                        idx,
392                        |v| { Ok::<_, DataFusionError>(v) }
393                    );
394                }
395                DataType::Float32 => {
396                    handle_primitive_type!(
397                        builder,
398                        field,
399                        col,
400                        Float32Builder,
401                        f32,
402                        row,
403                        idx,
404                        |v| { Ok::<_, DataFusionError>(v) }
405                    );
406                }
407                DataType::Float64 => {
408                    handle_primitive_type!(
409                        builder,
410                        field,
411                        col,
412                        Float64Builder,
413                        f64,
414                        row,
415                        idx,
416                        |v| { Ok::<_, DataFusionError>(v) }
417                    );
418                }
419                DataType::Decimal128(_precision, scale) => {
420                    handle_primitive_type!(
421                        builder,
422                        field,
423                        col,
424                        Decimal128Builder,
425                        BigDecimal,
426                        row,
427                        idx,
428                        |v: BigDecimal| {
429                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
430                                DataFusionError::Execution(format!(
431                                    "Failed to convert BigDecimal {v:?} to i128"
432                                ))
433                            })
434                        }
435                    );
436                }
437                DataType::Decimal256(_precision, _scale) => {
438                    handle_primitive_type!(
439                        builder,
440                        field,
441                        col,
442                        Decimal256Builder,
443                        BigDecimal,
444                        row,
445                        idx,
446                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
447                    );
448                }
449                DataType::Date32 => {
450                    handle_primitive_type!(
451                        builder,
452                        field,
453                        col,
454                        Date32Builder,
455                        chrono::NaiveDate,
456                        row,
457                        idx,
458                        |v: chrono::NaiveDate| {
459                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
460                        }
461                    );
462                }
463                DataType::Timestamp(TimeUnit::Microsecond, None) => {
464                    handle_primitive_type!(
465                        builder,
466                        field,
467                        col,
468                        TimestampMicrosecondBuilder,
469                        time::PrimitiveDateTime,
470                        row,
471                        idx,
472                        |v: time::PrimitiveDateTime| {
473                            let timestamp_micros =
474                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
475                            Ok::<_, DataFusionError>(timestamp_micros)
476                        }
477                    );
478                }
479                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
480                    handle_primitive_type!(
481                        builder,
482                        field,
483                        col,
484                        TimestampNanosecondBuilder,
485                        chrono::NaiveTime,
486                        row,
487                        idx,
488                        |v: chrono::NaiveTime| {
489                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
490                                + i64::from(v.nanosecond());
491                            Ok::<_, DataFusionError>(t)
492                        }
493                    );
494                }
495                DataType::Time32(TimeUnit::Second) => {
496                    handle_primitive_type!(
497                        builder,
498                        field,
499                        col,
500                        Time32SecondBuilder,
501                        chrono::NaiveTime,
502                        row,
503                        idx,
504                        |v: chrono::NaiveTime| {
505                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
506                        }
507                    );
508                }
509                DataType::Time64(TimeUnit::Nanosecond) => {
510                    handle_primitive_type!(
511                        builder,
512                        field,
513                        col,
514                        Time64NanosecondBuilder,
515                        chrono::NaiveTime,
516                        row,
517                        idx,
518                        |v: chrono::NaiveTime| {
519                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
520                                + i64::from(v.nanosecond());
521                            Ok::<_, DataFusionError>(t)
522                        }
523                    );
524                }
525                DataType::Utf8 => {
526                    handle_primitive_type!(
527                        builder,
528                        field,
529                        col,
530                        StringBuilder,
531                        String,
532                        row,
533                        idx,
534                        |v| { Ok::<_, DataFusionError>(v) }
535                    );
536                }
537                DataType::LargeUtf8 => {
538                    handle_primitive_type!(
539                        builder,
540                        field,
541                        col,
542                        LargeStringBuilder,
543                        String,
544                        row,
545                        idx,
546                        |v| { Ok::<_, DataFusionError>(v) }
547                    );
548                }
549                DataType::Binary => {
550                    handle_primitive_type!(
551                        builder,
552                        field,
553                        col,
554                        BinaryBuilder,
555                        Vec<u8>,
556                        row,
557                        idx,
558                        |v| { Ok::<_, DataFusionError>(v) }
559                    );
560                }
561                DataType::LargeBinary => {
562                    handle_primitive_type!(
563                        builder,
564                        field,
565                        col,
566                        LargeBinaryBuilder,
567                        Vec<u8>,
568                        row,
569                        idx,
570                        |v| { Ok::<_, DataFusionError>(v) }
571                    );
572                }
573                _ => {
574                    return Err(DataFusionError::NotImplemented(format!(
575                        "Unsupported data type {:?} for col: {:?}",
576                        field.data_type(),
577                        col
578                    )));
579                }
580            }
581        }
582    }
583    let projected_columns = array_builders
584        .into_iter()
585        .enumerate()
586        .filter(|(idx, _)| projections_contains(projection, *idx))
587        .map(|(_, mut builder)| builder.finish())
588        .collect::<Vec<ArrayRef>>();
589    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
590}
591
592fn to_decimal_256(decimal: &BigDecimal) -> i256 {
593    let (bigint_value, _) = decimal.as_bigint_and_exponent();
594    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
595
596    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
597    let fill_byte = if is_negative { 0xFF } else { 0x00 };
598
599    if bigint_bytes.len() > 32 {
600        bigint_bytes.truncate(32);
601    } else {
602        bigint_bytes.resize(32, fill_byte);
603    };
604
605    let mut array = [0u8; 32];
606    array.copy_from_slice(&bigint_bytes);
607
608    i256::from_le_bytes(array)
609}