1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType, TableSource, Unparse,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33 pub(crate) host: String,
34 pub(crate) port: u16,
35 pub(crate) username: String,
36 pub(crate) password: String,
37 pub(crate) database: Option<String>,
38 pub(crate) pool_max_size: usize,
39 pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43 pub fn new(
44 host: impl Into<String>,
45 port: u16,
46 username: impl Into<String>,
47 password: impl Into<String>,
48 ) -> Self {
49 Self {
50 host: host.into(),
51 port,
52 username: username.into(),
53 password: password.into(),
54 database: None,
55 pool_max_size: 10,
56 stream_chunk_size: 2048,
57 }
58 }
59}
60
61impl From<MysqlConnectionOptions> for ConnectionOptions {
62 fn from(options: MysqlConnectionOptions) -> Self {
63 ConnectionOptions::Mysql(options)
64 }
65}
66
67#[derive(Debug)]
68pub struct MysqlPool {
69 pool: mysql_async::Pool,
70}
71
72pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
73 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
74 mysql_async::PoolConstraints::new(0, options.pool_max_size)
75 .expect("Failed to create pool constraints"),
76 );
77 let opts_builder = mysql_async::OptsBuilder::default()
78 .ip_or_hostname(options.host.clone())
79 .tcp_port(options.port)
80 .user(Some(options.username.clone()))
81 .pass(Some(options.password.clone()))
82 .db_name(options.database.clone())
83 .init(vec!["set time_zone='+00:00'".to_string()])
84 .pool_opts(pool_opts);
85 let pool = mysql_async::Pool::new(opts_builder);
86 Ok(MysqlPool { pool })
87}
88
89#[async_trait::async_trait]
90impl Pool for MysqlPool {
91 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
92 let conn = self.pool.get_conn().await.map_err(|e| {
93 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
94 })?;
95 Ok(Arc::new(MysqlConnection {
96 conn: Arc::new(Mutex::new(conn)),
97 }))
98 }
99}
100
101#[derive(Debug)]
102pub struct MysqlConnection {
103 conn: Arc<Mutex<mysql_async::Conn>>,
104}
105
106#[async_trait::async_trait]
107impl Connection for MysqlConnection {
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111
112 async fn infer_schema(&self, source: &TableSource) -> DFResult<RemoteSchemaRef> {
113 match source {
114 TableSource::Table(_table) => Err(DataFusionError::Execution(
115 "Mysql does not support infer schema for table".to_string(),
116 )),
117 TableSource::Query(_query) => {
118 let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
119 let mut conn = self.conn.lock().await;
120 let conn = &mut *conn;
121 let stmt = conn.prep(&sql).await.map_err(|e| {
122 DataFusionError::Execution(format!(
123 "Failed to prepare query {sql} on mysql: {e:?}"
124 ))
125 })?;
126 let remote_schema = Arc::new(build_remote_schema(&stmt)?);
127 Ok(remote_schema)
128 }
129 }
130 }
131
132 async fn query(
133 &self,
134 conn_options: &ConnectionOptions,
135 source: &TableSource,
136 table_schema: SchemaRef,
137 projection: Option<&Vec<usize>>,
138 unparsed_filters: &[String],
139 limit: Option<usize>,
140 ) -> DFResult<SendableRecordBatchStream> {
141 let projected_schema = project_schema(&table_schema, projection)?;
142
143 let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
144 debug!("[remote-table] executing mysql query: {sql}");
145
146 let projection = projection.cloned();
147 let chunk_size = conn_options.stream_chunk_size();
148 let conn = Arc::clone(&self.conn);
149 let stream = Box::pin(stream! {
150 let mut conn = conn.lock().await;
151 let mut query_iter = conn
152 .query_iter(sql.clone())
153 .await
154 .map_err(|e| {
155 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
156 })?;
157
158 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
159 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
160 })? else {
161 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
162 return;
163 };
164
165 let mut chunked_stream = stream.chunks(chunk_size).boxed();
166
167 while let Some(chunk) = chunked_stream.next().await {
168 let rows = chunk
169 .into_iter()
170 .collect::<Result<Vec<_>, _>>()
171 .map_err(|e| {
172 DataFusionError::Execution(format!(
173 "Failed to collect rows from mysql due to {e}",
174 ))
175 })?;
176
177 yield Ok::<_, DataFusionError>(rows)
178 }
179 });
180
181 let stream = stream.map(move |rows| {
182 let rows = rows?;
183 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
184 });
185
186 Ok(Box::pin(RecordBatchStreamAdapter::new(
187 projected_schema,
188 stream,
189 )))
190 }
191
192 async fn insert(
193 &self,
194 _conn_options: &ConnectionOptions,
195 _unparser: Arc<dyn Unparse>,
196 _table: &[String],
197 _remote_schema: RemoteSchemaRef,
198 _input: SendableRecordBatchStream,
199 ) -> DFResult<usize> {
200 Err(DataFusionError::Execution(
201 "Insert operation is not supported for mysql".to_string(),
202 ))
203 }
204}
205
206fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
207 let character_set = mysql_col.character_set();
208 let is_utf8_bin_character_set = character_set == 45;
209 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
210 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
211 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
212 let col_length = mysql_col.column_length();
213 match mysql_col.column_type() {
214 ColumnType::MYSQL_TYPE_TINY => {
215 if is_unsigned {
216 Ok(MysqlType::TinyIntUnsigned)
217 } else {
218 Ok(MysqlType::TinyInt)
219 }
220 }
221 ColumnType::MYSQL_TYPE_SHORT => {
222 if is_unsigned {
223 Ok(MysqlType::SmallIntUnsigned)
224 } else {
225 Ok(MysqlType::SmallInt)
226 }
227 }
228 ColumnType::MYSQL_TYPE_INT24 => {
229 if is_unsigned {
230 Ok(MysqlType::MediumIntUnsigned)
231 } else {
232 Ok(MysqlType::MediumInt)
233 }
234 }
235 ColumnType::MYSQL_TYPE_LONG => {
236 if is_unsigned {
237 Ok(MysqlType::IntegerUnsigned)
238 } else {
239 Ok(MysqlType::Integer)
240 }
241 }
242 ColumnType::MYSQL_TYPE_LONGLONG => {
243 if is_unsigned {
244 Ok(MysqlType::BigIntUnsigned)
245 } else {
246 Ok(MysqlType::BigInt)
247 }
248 }
249 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
250 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
251 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
252 let precision = (mysql_col.column_length() - 2) as u8;
253 let scale = mysql_col.decimals();
254 Ok(MysqlType::Decimal(precision, scale))
255 }
256 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
257 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
258 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
259 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
260 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
261 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
262 ColumnType::MYSQL_TYPE_STRING if is_binary => {
263 if is_utf8_bin_character_set {
264 Ok(MysqlType::Char)
265 } else {
266 Ok(MysqlType::Binary)
267 }
268 }
269 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
270 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
271 if is_utf8_bin_character_set {
272 Ok(MysqlType::Varchar)
273 } else {
274 Ok(MysqlType::Varbinary)
275 }
276 }
277 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
278 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
279 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
280 if is_utf8_bin_character_set {
281 Ok(MysqlType::Text(col_length))
282 } else {
283 Ok(MysqlType::Blob(col_length))
284 }
285 }
286 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
287 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
288 _ => Err(DataFusionError::NotImplemented(format!(
289 "Unsupported mysql type: {mysql_col:?}",
290 ))),
291 }
292}
293
294fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
295 let mut remote_fields = vec![];
296 for col in stmt.columns() {
297 remote_fields.push(RemoteField::new(
298 col.name_str().to_string(),
299 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
300 true,
301 ));
302 }
303 Ok(RemoteSchema::new(remote_fields))
304}
305
306macro_rules! handle_primitive_type {
307 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
308 let builder = $builder
309 .as_any_mut()
310 .downcast_mut::<$builder_ty>()
311 .unwrap_or_else(|| {
312 panic!(
313 "Failed to downcast builder to {} for {:?} and {:?}",
314 stringify!($builder_ty),
315 $field,
316 $col
317 )
318 });
319 let v = $row.get_opt::<$value_ty, usize>($index);
320
321 match v {
322 None => builder.append_null(),
323 Some(Ok(v)) => builder.append_value($convert(v)?),
324 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
325 Some(Err(e)) => {
326 return Err(DataFusionError::Execution(format!(
327 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
328 stringify!($value_ty),
329 $field,
330 $col,
331 )));
332 }
333 }
334 }};
335}
336
337fn rows_to_batch(
338 rows: &[Row],
339 table_schema: &SchemaRef,
340 projection: Option<&Vec<usize>>,
341) -> DFResult<RecordBatch> {
342 let projected_schema = project_schema(table_schema, projection)?;
343 let mut array_builders = vec![];
344 for field in table_schema.fields() {
345 let builder = make_builder(field.data_type(), rows.len());
346 array_builders.push(builder);
347 }
348
349 for row in rows {
350 for (idx, field) in table_schema.fields.iter().enumerate() {
351 if !projections_contains(projection, idx) {
352 continue;
353 }
354 let builder = &mut array_builders[idx];
355 let col = row.columns_ref().get(idx);
356 match field.data_type() {
357 DataType::Int8 => {
358 handle_primitive_type!(
359 builder,
360 field,
361 col,
362 Int8Builder,
363 i8,
364 row,
365 idx,
366 just_return
367 );
368 }
369 DataType::Int16 => {
370 handle_primitive_type!(
371 builder,
372 field,
373 col,
374 Int16Builder,
375 i16,
376 row,
377 idx,
378 just_return
379 );
380 }
381 DataType::Int32 => {
382 handle_primitive_type!(
383 builder,
384 field,
385 col,
386 Int32Builder,
387 i32,
388 row,
389 idx,
390 just_return
391 );
392 }
393 DataType::Int64 => {
394 handle_primitive_type!(
395 builder,
396 field,
397 col,
398 Int64Builder,
399 i64,
400 row,
401 idx,
402 just_return
403 );
404 }
405 DataType::UInt8 => {
406 handle_primitive_type!(
407 builder,
408 field,
409 col,
410 UInt8Builder,
411 u8,
412 row,
413 idx,
414 just_return
415 );
416 }
417 DataType::UInt16 => {
418 handle_primitive_type!(
419 builder,
420 field,
421 col,
422 UInt16Builder,
423 u16,
424 row,
425 idx,
426 just_return
427 );
428 }
429 DataType::UInt32 => {
430 handle_primitive_type!(
431 builder,
432 field,
433 col,
434 UInt32Builder,
435 u32,
436 row,
437 idx,
438 just_return
439 );
440 }
441 DataType::UInt64 => {
442 handle_primitive_type!(
443 builder,
444 field,
445 col,
446 UInt64Builder,
447 u64,
448 row,
449 idx,
450 just_return
451 );
452 }
453 DataType::Float32 => {
454 handle_primitive_type!(
455 builder,
456 field,
457 col,
458 Float32Builder,
459 f32,
460 row,
461 idx,
462 just_return
463 );
464 }
465 DataType::Float64 => {
466 handle_primitive_type!(
467 builder,
468 field,
469 col,
470 Float64Builder,
471 f64,
472 row,
473 idx,
474 just_return
475 );
476 }
477 DataType::Decimal128(_precision, scale) => {
478 handle_primitive_type!(
479 builder,
480 field,
481 col,
482 Decimal128Builder,
483 BigDecimal,
484 row,
485 idx,
486 |v: BigDecimal| {
487 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
488 DataFusionError::Execution(format!(
489 "Failed to convert BigDecimal {v:?} to i128"
490 ))
491 })
492 }
493 );
494 }
495 DataType::Decimal256(_precision, _scale) => {
496 handle_primitive_type!(
497 builder,
498 field,
499 col,
500 Decimal256Builder,
501 BigDecimal,
502 row,
503 idx,
504 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
505 );
506 }
507 DataType::Date32 => {
508 handle_primitive_type!(
509 builder,
510 field,
511 col,
512 Date32Builder,
513 chrono::NaiveDate,
514 row,
515 idx,
516 |v: chrono::NaiveDate| {
517 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
518 }
519 );
520 }
521 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
522 match tz_opt {
523 None => {}
524 Some(tz) => {
525 if !tz.eq_ignore_ascii_case("utc") {
526 return Err(DataFusionError::NotImplemented(format!(
527 "Unsupported data type {:?} for col: {:?}",
528 field.data_type(),
529 col
530 )));
531 }
532 }
533 }
534 handle_primitive_type!(
535 builder,
536 field,
537 col,
538 TimestampMicrosecondBuilder,
539 time::PrimitiveDateTime,
540 row,
541 idx,
542 |v: time::PrimitiveDateTime| {
543 let timestamp_micros =
544 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
545 Ok::<_, DataFusionError>(timestamp_micros)
546 }
547 );
548 }
549 DataType::Time32(TimeUnit::Second) => {
550 handle_primitive_type!(
551 builder,
552 field,
553 col,
554 Time32SecondBuilder,
555 chrono::NaiveTime,
556 row,
557 idx,
558 |v: chrono::NaiveTime| {
559 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
560 }
561 );
562 }
563 DataType::Time64(TimeUnit::Nanosecond) => {
564 handle_primitive_type!(
565 builder,
566 field,
567 col,
568 Time64NanosecondBuilder,
569 chrono::NaiveTime,
570 row,
571 idx,
572 |v: chrono::NaiveTime| {
573 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
574 + i64::from(v.nanosecond());
575 Ok::<_, DataFusionError>(t)
576 }
577 );
578 }
579 DataType::Utf8 => {
580 handle_primitive_type!(
581 builder,
582 field,
583 col,
584 StringBuilder,
585 String,
586 row,
587 idx,
588 just_return
589 );
590 }
591 DataType::LargeUtf8 => {
592 handle_primitive_type!(
593 builder,
594 field,
595 col,
596 LargeStringBuilder,
597 String,
598 row,
599 idx,
600 just_return
601 );
602 }
603 DataType::Binary => {
604 handle_primitive_type!(
605 builder,
606 field,
607 col,
608 BinaryBuilder,
609 Vec<u8>,
610 row,
611 idx,
612 just_return
613 );
614 }
615 DataType::LargeBinary => {
616 handle_primitive_type!(
617 builder,
618 field,
619 col,
620 LargeBinaryBuilder,
621 Vec<u8>,
622 row,
623 idx,
624 just_return
625 );
626 }
627 _ => {
628 return Err(DataFusionError::NotImplemented(format!(
629 "Unsupported data type {:?} for col: {:?}",
630 field.data_type(),
631 col
632 )));
633 }
634 }
635 }
636 }
637 let projected_columns = array_builders
638 .into_iter()
639 .enumerate()
640 .filter(|(idx, _)| projections_contains(projection, *idx))
641 .map(|(_, mut builder)| builder.finish())
642 .collect::<Vec<ArrayRef>>();
643 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
644 Ok(RecordBatch::try_new_with_options(
645 projected_schema,
646 projected_columns,
647 &options,
648 )?)
649}
650
651fn to_decimal_256(decimal: &BigDecimal) -> i256 {
652 let (bigint_value, _) = decimal.as_bigint_and_exponent();
653 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
654
655 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
656 let fill_byte = if is_negative { 0xFF } else { 0x00 };
657
658 if bigint_bytes.len() > 32 {
659 bigint_bytes.truncate(32);
660 } else {
661 bigint_bytes.resize(32, fill_byte);
662 };
663
664 let mut array = [0u8; 32];
665 array.copy_from_slice(&bigint_bytes);
666
667 i256::from_le_bytes(array)
668}