datafusion_remote_table/connection/
mysql.rs

1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4    Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteType,
5    Transform,
6};
7use async_stream::stream;
8use bigdecimal::num_bigint;
9use chrono::Timelike;
10use datafusion::arrow::array::{
11    make_builder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
12    Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13    LargeBinaryBuilder, LargeStringBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder,
14    TimestampMicrosecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
15};
16use datafusion::arrow::datatypes::{i256, DataType, Date32Type, SchemaRef, TimeUnit};
17use datafusion::common::{project_schema, DataFusionError};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use futures::lock::Mutex;
21use futures::StreamExt;
22use mysql_async::consts::{ColumnFlags, ColumnType};
23use mysql_async::prelude::Queryable;
24use mysql_async::{Column, Row};
25use std::sync::Arc;
26
27#[derive(Debug, Clone, derive_with::With)]
28pub struct MysqlConnectionOptions {
29    pub(crate) host: String,
30    pub(crate) port: u16,
31    pub(crate) username: String,
32    pub(crate) password: String,
33    pub(crate) database: Option<String>,
34}
35
36impl MysqlConnectionOptions {
37    pub fn new(
38        host: impl Into<String>,
39        port: u16,
40        username: impl Into<String>,
41        password: impl Into<String>,
42    ) -> Self {
43        Self {
44            host: host.into(),
45            port,
46            username: username.into(),
47            password: password.into(),
48            database: None,
49        }
50    }
51}
52
53#[derive(Debug)]
54pub struct MysqlPool {
55    pool: mysql_async::Pool,
56}
57
58pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
59    let opts_builder = mysql_async::OptsBuilder::default()
60        .ip_or_hostname(options.host.clone())
61        .tcp_port(options.port)
62        .user(Some(options.username.clone()))
63        .pass(Some(options.password.clone()))
64        .db_name(options.database.clone());
65    let pool = mysql_async::Pool::new(opts_builder);
66    Ok(MysqlPool { pool })
67}
68
69#[async_trait::async_trait]
70impl Pool for MysqlPool {
71    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
72        let conn = self.pool.get_conn().await.map_err(|e| {
73            DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
74        })?;
75        Ok(Arc::new(MysqlConnection {
76            conn: Arc::new(Mutex::new(conn)),
77        }))
78    }
79}
80
81#[derive(Debug)]
82pub struct MysqlConnection {
83    conn: Arc<Mutex<mysql_async::Conn>>,
84}
85
86#[async_trait::async_trait]
87impl Connection for MysqlConnection {
88    async fn infer_schema(
89        &self,
90        sql: &str,
91        transform: Option<Arc<dyn Transform>>,
92    ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
93        let mut conn = self.conn.lock().await;
94        let conn = &mut *conn;
95        let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
96            DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
97        })?;
98        let Some(row) = row else {
99            return Err(DataFusionError::Execution(
100                "No rows returned to infer schema".to_string(),
101            ));
102        };
103        let remote_schema = Arc::new(build_remote_schema(&row)?);
104        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
105        if let Some(transform) = transform {
106            let batch = rows_to_batch(&[row], &arrow_schema, None)?;
107            let transformed_batch = transform_batch(
108                batch,
109                transform.as_ref(),
110                &arrow_schema,
111                None,
112                Some(&remote_schema),
113            )?;
114            Ok((remote_schema, transformed_batch.schema()))
115        } else {
116            Ok((remote_schema, arrow_schema))
117        }
118    }
119
120    async fn query(
121        &self,
122        sql: String,
123        table_schema: SchemaRef,
124        projection: Option<Vec<usize>>,
125    ) -> DFResult<SendableRecordBatchStream> {
126        let projected_schema = project_schema(&table_schema, projection.as_ref())?;
127        let conn = Arc::clone(&self.conn);
128        let stream = Box::pin(stream! {
129            let mut conn = conn.lock().await;
130            let mut query_iter = conn
131                .query_iter(sql)
132                .await
133                .map_err(|e| {
134                    DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
135                })?;
136
137            let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
138                    DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
139                })? else {
140                yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
141                return;
142            };
143
144            let mut chunked_stream = stream.chunks(4_000).boxed();
145
146            while let Some(chunk) = chunked_stream.next().await {
147                let rows = chunk
148                    .into_iter()
149                    .collect::<Result<Vec<_>, _>>()
150                    .map_err(|e| {
151                        DataFusionError::Execution(format!(
152                            "Failed to collect rows from mysql due to {e}",
153                        ))
154                    })?;
155
156                yield Ok::<_, DataFusionError>(rows)
157            }
158        });
159
160        let stream = stream.map(move |rows| {
161            let rows = rows?;
162            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
163        });
164
165        Ok(Box::pin(RecordBatchStreamAdapter::new(
166            projected_schema,
167            stream,
168        )))
169    }
170}
171
172fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
173    let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
174    let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
175    let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
176    let col_length = mysql_col.column_length();
177    match mysql_col.column_type() {
178        ColumnType::MYSQL_TYPE_TINY => {
179            if is_unsigned {
180                Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
181            } else {
182                Ok(RemoteType::Mysql(MysqlType::TinyInt))
183            }
184        }
185        ColumnType::MYSQL_TYPE_SHORT => {
186            if is_unsigned {
187                Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
188            } else {
189                Ok(RemoteType::Mysql(MysqlType::SmallInt))
190            }
191        }
192        ColumnType::MYSQL_TYPE_INT24 => {
193            if is_unsigned {
194                Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
195            } else {
196                Ok(RemoteType::Mysql(MysqlType::MediumInt))
197            }
198        }
199        ColumnType::MYSQL_TYPE_LONG => {
200            if is_unsigned {
201                Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
202            } else {
203                Ok(RemoteType::Mysql(MysqlType::Integer))
204            }
205        }
206        ColumnType::MYSQL_TYPE_LONGLONG => {
207            if is_unsigned {
208                Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
209            } else {
210                Ok(RemoteType::Mysql(MysqlType::BigInt))
211            }
212        }
213        ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
214        ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
215        ColumnType::MYSQL_TYPE_NEWDECIMAL => {
216            let precision = (mysql_col.column_length() - 2) as u8;
217            let scale = mysql_col.decimals();
218            Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
219        }
220        ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
221        ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
222        ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
223        ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
224        ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
225        ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
226        ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
227        ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
228            Ok(RemoteType::Mysql(MysqlType::Varchar))
229        }
230        ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
231            Ok(RemoteType::Mysql(MysqlType::Varbinary))
232        }
233        ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
234        ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
235            Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
236        }
237        ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
238            Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
239        }
240        ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
241        ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
242        _ => Err(DataFusionError::NotImplemented(format!(
243            "Unsupported mysql type: {mysql_col:?}",
244        ))),
245    }
246}
247
248fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
249    let mut remote_fields = vec![];
250    for col in row.columns_ref() {
251        remote_fields.push(RemoteField::new(
252            col.name_str().to_string(),
253            mysql_type_to_remote_type(col)?,
254            true,
255        ));
256    }
257    Ok(RemoteSchema::new(remote_fields))
258}
259
260macro_rules! handle_primitive_type {
261    ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
262        let builder = $builder
263            .as_any_mut()
264            .downcast_mut::<$builder_ty>()
265            .unwrap_or_else(|| {
266                panic!(
267                    concat!(
268                        "Failed to downcast builder to ",
269                        stringify!($builder_ty),
270                        " for {:?}"
271                    ),
272                    $mysql_col
273                )
274            });
275        let v = $row.get::<Option<$value_ty>, usize>($index);
276
277        match v {
278            Some(Some(v)) => builder.append_value(v),
279            _ => builder.append_null(),
280        }
281    }};
282}
283
284fn rows_to_batch(
285    rows: &[Row],
286    table_schema: &SchemaRef,
287    projection: Option<&Vec<usize>>,
288) -> DFResult<RecordBatch> {
289    let projected_schema = project_schema(table_schema, projection)?;
290    let mut array_builders = vec![];
291    for field in table_schema.fields() {
292        let builder = make_builder(field.data_type(), rows.len());
293        array_builders.push(builder);
294    }
295
296    for row in rows {
297        for (idx, field) in table_schema.fields.iter().enumerate() {
298            if !projections_contains(projection, idx) {
299                continue;
300            }
301            let builder = &mut array_builders[idx];
302            let col = row.columns_ref().get(idx);
303            match field.data_type() {
304                DataType::Int8 => {
305                    handle_primitive_type!(builder, col, Int8Builder, i8, row, idx);
306                }
307                DataType::Int16 => {
308                    handle_primitive_type!(builder, col, Int16Builder, i16, row, idx);
309                }
310                DataType::Int32 => {
311                    handle_primitive_type!(builder, col, Int32Builder, i32, row, idx);
312                }
313                DataType::Int64 => {
314                    handle_primitive_type!(builder, col, Int64Builder, i64, row, idx);
315                }
316                DataType::UInt8 => {
317                    handle_primitive_type!(builder, col, UInt8Builder, u8, row, idx);
318                }
319                DataType::UInt16 => {
320                    handle_primitive_type!(builder, col, UInt16Builder, u16, row, idx);
321                }
322                DataType::UInt32 => {
323                    handle_primitive_type!(builder, col, UInt32Builder, u32, row, idx);
324                }
325                DataType::UInt64 => {
326                    handle_primitive_type!(builder, col, UInt64Builder, u64, row, idx);
327                }
328                DataType::Float32 => {
329                    handle_primitive_type!(builder, col, Float32Builder, f32, row, idx);
330                }
331                DataType::Float64 => {
332                    handle_primitive_type!(builder, col, Float64Builder, f64, row, idx);
333                }
334                DataType::Decimal128(_precision, scale) => {
335                    let builder = builder
336                        .as_any_mut()
337                        .downcast_mut::<Decimal128Builder>()
338                        .unwrap_or_else(|| {
339                            panic!("Failed to downcast builder to Decimal128Builder for {field:?}")
340                        });
341                    let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
342
343                    match v {
344                        Some(Some(v)) => {
345                            let Some(v) = big_decimal_to_i128(&v, Some(*scale as u32)) else {
346                                return Err(DataFusionError::Execution(format!(
347                                    "Failed to convert BigDecimal {v:?} to i128"
348                                )));
349                            };
350                            builder.append_value(v)
351                        }
352                        _ => builder.append_null(),
353                    }
354                }
355                DataType::Decimal256(_precision, _scale) => {
356                    let builder = builder
357                        .as_any_mut()
358                        .downcast_mut::<Decimal256Builder>()
359                        .unwrap_or_else(|| {
360                            panic!("Failed to downcast builder to Decimal256Builder for {field:?}")
361                        });
362                    let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
363
364                    match v {
365                        Some(Some(v)) => builder.append_value(to_decimal_256(&v)),
366                        _ => builder.append_null(),
367                    }
368                }
369                DataType::Date32 => {
370                    // TODO this could also be simplified by macro
371                    let builder = builder
372                        .as_any_mut()
373                        .downcast_mut::<Date32Builder>()
374                        .unwrap_or_else(|| {
375                            panic!("Failed to downcast builder to Date32Builder for {field:?}")
376                        });
377                    let v = row.get::<Option<chrono::NaiveDate>, usize>(idx);
378
379                    match v {
380                        Some(Some(v)) => builder.append_value(Date32Type::from_naive_date(v)),
381                        _ => builder.append_null(),
382                    }
383                }
384                DataType::Timestamp(TimeUnit::Microsecond, None) => {
385                    let builder = builder
386                                .as_any_mut()
387                                .downcast_mut::<TimestampMicrosecondBuilder>()
388                                .unwrap_or_else(|| {
389                                    panic!("Failed to downcast builder to TimestampMicrosecondBuilder for {field:?}")
390                                });
391                    let v = row.get::<Option<time::PrimitiveDateTime>, usize>(idx);
392
393                    match v {
394                        Some(Some(v)) => {
395                            let timestamp_micros =
396                                (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
397                            builder.append_value(timestamp_micros)
398                        }
399                        _ => builder.append_null(),
400                    }
401                }
402                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
403                    let builder = builder
404                                .as_any_mut()
405                                .downcast_mut::<Time64NanosecondBuilder>()
406                                .unwrap_or_else(|| {
407                                    panic!("Failed to downcast builder to Time64NanosecondBuilder for {field:?}")
408                                });
409                    let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
410
411                    match v {
412                        Some(Some(v)) => {
413                            builder.append_value(
414                                i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
415                                    + i64::from(v.nanosecond()),
416                            );
417                        }
418                        _ => builder.append_null(),
419                    }
420                }
421                DataType::Time64(TimeUnit::Nanosecond) => {
422                    let builder = builder
423                        .as_any_mut()
424                        .downcast_mut::<Time64NanosecondBuilder>()
425                        .unwrap_or_else(|| {
426                            panic!("Failed to downcast builder to Time64NanosecondBuilder for {field:?}")
427                        });
428                    let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
429
430                    match v {
431                        Some(Some(v)) => {
432                            builder.append_value(
433                                i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
434                                    + i64::from(v.nanosecond()),
435                            );
436                        }
437                        _ => builder.append_null(),
438                    }
439                }
440                DataType::Utf8 => {
441                    handle_primitive_type!(builder, col, StringBuilder, String, row, idx);
442                }
443                DataType::LargeUtf8 => {
444                    handle_primitive_type!(builder, col, LargeStringBuilder, String, row, idx);
445                }
446                DataType::Binary => {
447                    handle_primitive_type!(builder, col, BinaryBuilder, Vec<u8>, row, idx);
448                }
449                DataType::LargeBinary => {
450                    handle_primitive_type!(builder, col, LargeBinaryBuilder, Vec<u8>, row, idx);
451                }
452                _ => {
453                    return Err(DataFusionError::NotImplemented(format!(
454                        "Unsupported data type {:?} for col: {:?}",
455                        field.data_type(),
456                        col
457                    )));
458                }
459            }
460        }
461    }
462    let projected_columns = array_builders
463        .into_iter()
464        .enumerate()
465        .filter(|(idx, _)| projections_contains(projection, *idx))
466        .map(|(_, mut builder)| builder.finish())
467        .collect::<Vec<ArrayRef>>();
468    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
469}
470
471fn to_decimal_256(decimal: &bigdecimal::BigDecimal) -> i256 {
472    let (bigint_value, _) = decimal.as_bigint_and_exponent();
473    let mut bigint_bytes = bigint_value.to_signed_bytes_le();
474
475    let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
476    let fill_byte = if is_negative { 0xFF } else { 0x00 };
477
478    if bigint_bytes.len() > 32 {
479        bigint_bytes.truncate(32);
480    } else {
481        bigint_bytes.resize(32, fill_byte);
482    };
483
484    let mut array = [0u8; 32];
485    array.copy_from_slice(&bigint_bytes);
486
487    i256::from_le_bytes(array)
488}