datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10    ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11    Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12    LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13    TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14    UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::prelude::Expr;
21use derive_getters::Getters;
22use derive_with::With;
23use futures::StreamExt;
24use futures::lock::Mutex;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::sync::Arc;
29
30#[derive(Debug, Clone, With, Getters)]
31pub struct MysqlConnectionOptions {
32    pub(crate) host: String,
33    pub(crate) port: u16,
34    pub(crate) username: String,
35    pub(crate) password: String,
36    pub(crate) database: Option<String>,
37    pub(crate) pool_max_size: usize,
38    pub(crate) stream_chunk_size: usize,
39}
40
41impl MysqlConnectionOptions {
42    pub fn new(
43        host: impl Into<String>,
44        port: u16,
45        username: impl Into<String>,
46        password: impl Into<String>,
47    ) -> Self {
48        Self {
49            host: host.into(),
50            port,
51            username: username.into(),
52            password: password.into(),
53            database: None,
54            pool_max_size: 10,
55            stream_chunk_size: 2048,
56        }
57    }
58}
59
60#[derive(Debug)]
61pub struct MysqlPool {
62    pool: mysql_async::Pool,
63}
64
65pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
66    let pool_opts = mysql_async::PoolOpts::new().with_constraints(
67        mysql_async::PoolConstraints::new(0, options.pool_max_size)
68            .expect("Failed to create pool constraints"),
69    );
70    let opts_builder = mysql_async::OptsBuilder::default()
71        .ip_or_hostname(options.host.clone())
72        .tcp_port(options.port)
73        .user(Some(options.username.clone()))
74        .pass(Some(options.password.clone()))
75        .db_name(options.database.clone())
76        .pool_opts(pool_opts);
77    let pool = mysql_async::Pool::new(opts_builder);
78    Ok(MysqlPool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for MysqlPool {
83    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84        let conn = self.pool.get_conn().await.map_err(|e| {
85            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
86        })?;
87        Ok(Arc::new(MysqlConnection {
88            conn: Arc::new(Mutex::new(conn)),
89        }))
90    }
91}
92
93#[derive(Debug)]
94pub struct MysqlConnection {
95    conn: Arc<Mutex<mysql_async::Conn>>,
96}
97
98#[async_trait::async_trait]
99impl Connection for MysqlConnection {
100    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
101        let sql = RemoteDbType::Mysql.query_limit_1(sql)?;
102        let mut conn = self.conn.lock().await;
103        let conn = &mut *conn;
104        let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
105            DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
106        })?;
107        let Some(row) = row else {
108            return Err(DataFusionError::Execution(
109                "No rows returned to infer schema".to_string(),
110            ));
111        };
112        let remote_schema = Arc::new(build_remote_schema(&row)?);
113        Ok(remote_schema)
114    }
115
116    async fn query(
117        &self,
118        conn_options: &ConnectionOptions,
119        sql: &str,
120        table_schema: SchemaRef,
121        projection: Option<&Vec<usize>>,
122        filters: &[Expr],
123        limit: Option<usize>,
124    ) -> DFResult<SendableRecordBatchStream> {
125        let projected_schema = project_schema(&table_schema, projection)?;
126        let sql = RemoteDbType::Mysql.try_rewrite_query(sql, filters, limit)?;
127        let projection = projection.cloned();
128        let chunk_size = conn_options.stream_chunk_size();
129        let conn = Arc::clone(&self.conn);
130        let stream = Box::pin(stream! {
131            let mut conn = conn.lock().await;
132            let mut query_iter = conn
133                .query_iter(sql.clone())
134                .await
135                .map_err(|e| {
136                    DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
137                })?;
138
139            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
140                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
141                })? else {
142                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
143                return;
144            };
145
146            let mut chunked_stream = stream.chunks(chunk_size).boxed();
147
148            while let Some(chunk) = chunked_stream.next().await {
149                let rows = chunk
150                    .into_iter()
151                    .collect::<Result<Vec<_>, _>>()
152                    .map_err(|e| {
153                        DataFusionError::Execution(format!(
154                            "Failed to collect rows from mysql due to {e}",
155                        ))
156                    })?;
157
158                yield Ok::<_, DataFusionError>(rows)
159            }
160        });
161
162        let stream = stream.map(move |rows| {
163            let rows = rows?;
164            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
165        });
166
167        Ok(Box::pin(RecordBatchStreamAdapter::new(
168            projected_schema,
169            stream,
170        )))
171    }
172}
173
174fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
175    let character_set = mysql_col.character_set();
176    let is_utf8_bin_character_set = character_set == 45;
177    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
178    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
179    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
180    let col_length = mysql_col.column_length();
181    match mysql_col.column_type() {
182        ColumnType::MYSQL_TYPE_TINY => {
183            if is_unsigned {
184                Ok(MysqlType::TinyIntUnsigned)
185            } else {
186                Ok(MysqlType::TinyInt)
187            }
188        }
189        ColumnType::MYSQL_TYPE_SHORT => {
190            if is_unsigned {
191                Ok(MysqlType::SmallIntUnsigned)
192            } else {
193                Ok(MysqlType::SmallInt)
194            }
195        }
196        ColumnType::MYSQL_TYPE_INT24 => {
197            if is_unsigned {
198                Ok(MysqlType::MediumIntUnsigned)
199            } else {
200                Ok(MysqlType::MediumInt)
201            }
202        }
203        ColumnType::MYSQL_TYPE_LONG => {
204            if is_unsigned {
205                Ok(MysqlType::IntegerUnsigned)
206            } else {
207                Ok(MysqlType::Integer)
208            }
209        }
210        ColumnType::MYSQL_TYPE_LONGLONG => {
211            if is_unsigned {
212                Ok(MysqlType::BigIntUnsigned)
213            } else {
214                Ok(MysqlType::BigInt)
215            }
216        }
217        ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
218        ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
219        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
220            let precision = (mysql_col.column_length() - 2) as u8;
221            let scale = mysql_col.decimals();
222            Ok(MysqlType::Decimal(precision, scale))
223        }
224        ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
225        ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
226        ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
227        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
228        ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
229        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
230        ColumnType::MYSQL_TYPE_STRING if is_binary => {
231            if is_utf8_bin_character_set {
232                Ok(MysqlType::Char)
233            } else {
234                Ok(MysqlType::Binary)
235            }
236        }
237        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
238        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
239            if is_utf8_bin_character_set {
240                Ok(MysqlType::Varchar)
241            } else {
242                Ok(MysqlType::Varbinary)
243            }
244        }
245        ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
246        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
247        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
248            if is_utf8_bin_character_set {
249                Ok(MysqlType::Text(col_length))
250            } else {
251                Ok(MysqlType::Blob(col_length))
252            }
253        }
254        ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
255        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
256        _ => Err(DataFusionError::NotImplemented(format!(
257            "Unsupported mysql type: {mysql_col:?}",
258        ))),
259    }
260}
261
262fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
263    let mut remote_fields = vec![];
264    for col in row.columns_ref() {
265        remote_fields.push(RemoteField::new(
266            col.name_str().to_string(),
267            RemoteType::Mysql(mysql_type_to_remote_type(col)?),
268            true,
269        ));
270    }
271    Ok(RemoteSchema::new(remote_fields))
272}
273
274macro_rules! handle_primitive_type {
275    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
276        let builder = $builder
277            .as_any_mut()
278            .downcast_mut::<$builder_ty>()
279            .unwrap_or_else(|| {
280                panic!(
281                    "Failed to downcast builder to {} for {:?} and {:?}",
282                    stringify!($builder_ty),
283                    $field,
284                    $col
285                )
286            });
287        let v = $row.get_opt::<$value_ty, usize>($index);
288
289        match v {
290            None => builder.append_null(),
291            Some(Ok(v)) => builder.append_value($convert(v)?),
292            Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
293            Some(Err(e)) => {
294                return Err(DataFusionError::Execution(format!(
295                    "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
296                    stringify!($value_ty),
297                    $field,
298                    $col,
299                )));
300            }
301        }
302    }};
303}
304
305fn rows_to_batch(
306    rows: &[Row],
307    table_schema: &SchemaRef,
308    projection: Option<&Vec<usize>>,
309) -> DFResult<RecordBatch> {
310    let projected_schema = project_schema(table_schema, projection)?;
311    let mut array_builders = vec![];
312    for field in table_schema.fields() {
313        let builder = make_builder(field.data_type(), rows.len());
314        array_builders.push(builder);
315    }
316
317    for row in rows {
318        for (idx, field) in table_schema.fields.iter().enumerate() {
319            if !projections_contains(projection, idx) {
320                continue;
321            }
322            let builder = &mut array_builders[idx];
323            let col = row.columns_ref().get(idx);
324            match field.data_type() {
325                DataType::Int8 => {
326                    handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
327                        Ok::<_, DataFusionError>(v)
328                    });
329                }
330                DataType::Int16 => {
331                    handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
332                        Ok::<_, DataFusionError>(v)
333                    });
334                }
335                DataType::Int32 => {
336                    handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
337                        Ok::<_, DataFusionError>(v)
338                    });
339                }
340                DataType::Int64 => {
341                    handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
342                        Ok::<_, DataFusionError>(v)
343                    });
344                }
345                DataType::UInt8 => {
346                    handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
347                        Ok::<_, DataFusionError>(v)
348                    });
349                }
350                DataType::UInt16 => {
351                    handle_primitive_type!(
352                        builder,
353                        field,
354                        col,
355                        UInt16Builder,
356                        u16,
357                        row,
358                        idx,
359                        |v| { Ok::<_, DataFusionError>(v) }
360                    );
361                }
362                DataType::UInt32 => {
363                    handle_primitive_type!(
364                        builder,
365                        field,
366                        col,
367                        UInt32Builder,
368                        u32,
369                        row,
370                        idx,
371                        |v| { Ok::<_, DataFusionError>(v) }
372                    );
373                }
374                DataType::UInt64 => {
375                    handle_primitive_type!(
376                        builder,
377                        field,
378                        col,
379                        UInt64Builder,
380                        u64,
381                        row,
382                        idx,
383                        |v| { Ok::<_, DataFusionError>(v) }
384                    );
385                }
386                DataType::Float32 => {
387                    handle_primitive_type!(
388                        builder,
389                        field,
390                        col,
391                        Float32Builder,
392                        f32,
393                        row,
394                        idx,
395                        |v| { Ok::<_, DataFusionError>(v) }
396                    );
397                }
398                DataType::Float64 => {
399                    handle_primitive_type!(
400                        builder,
401                        field,
402                        col,
403                        Float64Builder,
404                        f64,
405                        row,
406                        idx,
407                        |v| { Ok::<_, DataFusionError>(v) }
408                    );
409                }
410                DataType::Decimal128(_precision, scale) => {
411                    handle_primitive_type!(
412                        builder,
413                        field,
414                        col,
415                        Decimal128Builder,
416                        BigDecimal,
417                        row,
418                        idx,
419                        |v: BigDecimal| {
420                            big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
421                                DataFusionError::Execution(format!(
422                                    "Failed to convert BigDecimal {v:?} to i128"
423                                ))
424                            })
425                        }
426                    );
427                }
428                DataType::Decimal256(_precision, _scale) => {
429                    handle_primitive_type!(
430                        builder,
431                        field,
432                        col,
433                        Decimal256Builder,
434                        BigDecimal,
435                        row,
436                        idx,
437                        |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
438                    );
439                }
440                DataType::Date32 => {
441                    handle_primitive_type!(
442                        builder,
443                        field,
444                        col,
445                        Date32Builder,
446                        chrono::NaiveDate,
447                        row,
448                        idx,
449                        |v: chrono::NaiveDate| {
450                            Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
451                        }
452                    );
453                }
454                DataType::Timestamp(TimeUnit::Microsecond, None) => {
455                    handle_primitive_type!(
456                        builder,
457                        field,
458                        col,
459                        TimestampMicrosecondBuilder,
460                        time::PrimitiveDateTime,
461                        row,
462                        idx,
463                        |v: time::PrimitiveDateTime| {
464                            let timestamp_micros =
465                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
466                            Ok::<_, DataFusionError>(timestamp_micros)
467                        }
468                    );
469                }
470                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
471                    handle_primitive_type!(
472                        builder,
473                        field,
474                        col,
475                        TimestampNanosecondBuilder,
476                        chrono::NaiveTime,
477                        row,
478                        idx,
479                        |v: chrono::NaiveTime| {
480                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
481                                + i64::from(v.nanosecond());
482                            Ok::<_, DataFusionError>(t)
483                        }
484                    );
485                }
486                DataType::Time32(TimeUnit::Second) => {
487                    handle_primitive_type!(
488                        builder,
489                        field,
490                        col,
491                        Time32SecondBuilder,
492                        chrono::NaiveTime,
493                        row,
494                        idx,
495                        |v: chrono::NaiveTime| {
496                            Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
497                        }
498                    );
499                }
500                DataType::Time64(TimeUnit::Nanosecond) => {
501                    handle_primitive_type!(
502                        builder,
503                        field,
504                        col,
505                        Time64NanosecondBuilder,
506                        chrono::NaiveTime,
507                        row,
508                        idx,
509                        |v: chrono::NaiveTime| {
510                            let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
511                                + i64::from(v.nanosecond());
512                            Ok::<_, DataFusionError>(t)
513                        }
514                    );
515                }
516                DataType::Utf8 => {
517                    handle_primitive_type!(
518                        builder,
519                        field,
520                        col,
521                        StringBuilder,
522                        String,
523                        row,
524                        idx,
525                        |v| { Ok::<_, DataFusionError>(v) }
526                    );
527                }
528                DataType::LargeUtf8 => {
529                    handle_primitive_type!(
530                        builder,
531                        field,
532                        col,
533                        LargeStringBuilder,
534                        String,
535                        row,
536                        idx,
537                        |v| { Ok::<_, DataFusionError>(v) }
538                    );
539                }
540                DataType::Binary => {
541                    handle_primitive_type!(
542                        builder,
543                        field,
544                        col,
545                        BinaryBuilder,
546                        Vec<u8>,
547                        row,
548                        idx,
549                        |v| { Ok::<_, DataFusionError>(v) }
550                    );
551                }
552                DataType::LargeBinary => {
553                    handle_primitive_type!(
554                        builder,
555                        field,
556                        col,
557                        LargeBinaryBuilder,
558                        Vec<u8>,
559                        row,
560                        idx,
561                        |v| { Ok::<_, DataFusionError>(v) }
562                    );
563                }
564                _ => {
565                    return Err(DataFusionError::NotImplemented(format!(
566                        "Unsupported data type {:?} for col: {:?}",
567                        field.data_type(),
568                        col
569                    )));
570                }
571            }
572        }
573    }
574    let projected_columns = array_builders
575        .into_iter()
576        .enumerate()
577        .filter(|(idx, _)| projections_contains(projection, *idx))
578        .map(|(_, mut builder)| builder.finish())
579        .collect::<Vec<ArrayRef>>();
580    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
581}
582
583fn to_decimal_256(decimal: &BigDecimal) -> i256 {
584    let (bigint_value, _) = decimal.as_bigint_and_exponent();
585    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
586
587    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
588    let fill_byte = if is_negative { 0xFF } else { 0x00 };
589
590    if bigint_bytes.len() > 32 {
591        bigint_bytes.truncate(32);
592    } else {
593        bigint_bytes.resize(32, fill_byte);
594    };
595
596    let mut array = [0u8; 32];
597    array.copy_from_slice(&bigint_bytes);
598
599    i256::from_le_bytes(array)
600}