1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4 Connection, ConnectionOptions, DFResult, Literalize, MysqlConnectionOptions, MysqlType, Pool,
5 PoolState, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
6};
7use arrow::array::{
8 ArrayRef, BinaryBuilder, BinaryViewBuilder, Date32Builder, Decimal128Builder,
9 Decimal256Builder, Float32Builder, Float64Builder, Int8Builder, Int16Builder, Int32Builder,
10 Int64Builder, LargeBinaryBuilder, LargeStringBuilder, RecordBatch, RecordBatchOptions,
11 StringBuilder, StringViewBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
12 TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder, UInt32Builder, UInt64Builder,
13 make_builder,
14};
15use arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
16use async_stream::stream;
17use bigdecimal::{BigDecimal, num_bigint};
18use chrono::Timelike;
19use datafusion_common::{DataFusionError, project_schema};
20use datafusion_execution::SendableRecordBatchStream;
21use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug)]
32pub struct MysqlPool {
33 pool: mysql_async::Pool,
34}
35
36pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
37 let pool_opts = mysql_async::PoolOpts::new()
38 .with_constraints(
39 mysql_async::PoolConstraints::new(options.pool_min_idle, options.pool_max_size)
40 .expect("Failed to create pool constraints"),
41 )
42 .with_inactive_connection_ttl(options.pool_idle_timeout)
43 .with_ttl_check_interval(options.pool_ttl_check_interval);
44 let opts_builder = mysql_async::OptsBuilder::default()
45 .ip_or_hostname(options.host.clone())
46 .tcp_port(options.port)
47 .user(Some(options.username.clone()))
48 .pass(Some(options.password.clone()))
49 .db_name(options.database.clone())
50 .init(vec!["set time_zone='+00:00'".to_string()])
51 .pool_opts(pool_opts);
52 let pool = mysql_async::Pool::new(opts_builder);
53 Ok(MysqlPool { pool })
54}
55
56#[async_trait::async_trait]
57impl Pool for MysqlPool {
58 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
59 let conn = self.pool.get_conn().await.map_err(|e| {
60 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
61 })?;
62 Ok(Arc::new(MysqlConnection {
63 conn: Arc::new(Mutex::new(conn)),
64 }))
65 }
66
67 async fn state(&self) -> DFResult<PoolState> {
68 use std::sync::atomic::Ordering;
69 let metrics = self.pool.metrics();
70 Ok(PoolState {
71 connections: metrics.connection_count.load(Ordering::SeqCst),
72 idle_connections: metrics.connections_in_pool.load(Ordering::SeqCst),
73 })
74 }
75}
76
77#[derive(Debug)]
78pub struct MysqlConnection {
79 conn: Arc<Mutex<mysql_async::Conn>>,
80}
81
82#[async_trait::async_trait]
83impl Connection for MysqlConnection {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
89 let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
90 let mut conn = self.conn.lock().await;
91 let conn = &mut *conn;
92 let stmt = conn.prep(&sql).await.map_err(|e| {
93 DataFusionError::Plan(format!("Failed to prepare query {sql} on mysql: {e:?}"))
94 })?;
95 let remote_schema = Arc::new(build_remote_schema(&stmt)?);
96 Ok(remote_schema)
97 }
98
99 async fn query(
100 &self,
101 conn_options: &ConnectionOptions,
102 source: &RemoteSource,
103 table_schema: SchemaRef,
104 projection: Option<&Vec<usize>>,
105 unparsed_filters: &[String],
106 limit: Option<usize>,
107 ) -> DFResult<SendableRecordBatchStream> {
108 let projected_schema = project_schema(&table_schema, projection)?;
109
110 let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
111 debug!("[remote-table] executing mysql query: {sql}");
112
113 let projection = projection.cloned();
114 let chunk_size = conn_options.stream_chunk_size();
115 let conn = Arc::clone(&self.conn);
116 let stream = Box::pin(stream! {
117 let mut conn = conn.lock().await;
118 let mut query_iter = conn
119 .query_iter(sql.clone())
120 .await
121 .map_err(|e| {
122 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
123 })?;
124
125 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
126 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
127 })? else {
128 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
129 return;
130 };
131
132 let mut chunked_stream = stream.chunks(chunk_size).boxed();
133
134 while let Some(chunk) = chunked_stream.next().await {
135 let rows = chunk
136 .into_iter()
137 .collect::<Result<Vec<_>, _>>()
138 .map_err(|e| {
139 DataFusionError::Execution(format!(
140 "Failed to collect rows from mysql due to {e}",
141 ))
142 })?;
143
144 yield Ok::<_, DataFusionError>(rows)
145 }
146 });
147
148 let stream = stream.map(move |rows| {
149 let rows = rows?;
150 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
151 });
152
153 Ok(Box::pin(RecordBatchStreamAdapter::new(
154 projected_schema,
155 stream,
156 )))
157 }
158
159 async fn insert(
160 &self,
161 _conn_options: &ConnectionOptions,
162 _literalizer: Arc<dyn Literalize>,
163 _table: &[String],
164 _remote_schema: RemoteSchemaRef,
165 _batch: RecordBatch,
166 ) -> DFResult<usize> {
167 Err(DataFusionError::Execution(
168 "Insert operation is not supported for mysql".to_string(),
169 ))
170 }
171}
172
173fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
174 let character_set = mysql_col.character_set();
175 let is_utf8_bin_character_set = character_set == 45;
176 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
177 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
178 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
179 let col_length = mysql_col.column_length();
180 match mysql_col.column_type() {
181 ColumnType::MYSQL_TYPE_TINY => {
182 if is_unsigned {
183 Ok(MysqlType::TinyIntUnsigned)
184 } else {
185 Ok(MysqlType::TinyInt)
186 }
187 }
188 ColumnType::MYSQL_TYPE_SHORT => {
189 if is_unsigned {
190 Ok(MysqlType::SmallIntUnsigned)
191 } else {
192 Ok(MysqlType::SmallInt)
193 }
194 }
195 ColumnType::MYSQL_TYPE_INT24 => {
196 if is_unsigned {
197 Ok(MysqlType::MediumIntUnsigned)
198 } else {
199 Ok(MysqlType::MediumInt)
200 }
201 }
202 ColumnType::MYSQL_TYPE_LONG => {
203 if is_unsigned {
204 Ok(MysqlType::IntegerUnsigned)
205 } else {
206 Ok(MysqlType::Integer)
207 }
208 }
209 ColumnType::MYSQL_TYPE_LONGLONG => {
210 if is_unsigned {
211 Ok(MysqlType::BigIntUnsigned)
212 } else {
213 Ok(MysqlType::BigInt)
214 }
215 }
216 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
217 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
218 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
219 let precision = (mysql_col.column_length() - 2) as u8;
220 let scale = mysql_col.decimals();
221 Ok(MysqlType::Decimal(precision, scale))
222 }
223 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
224 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
225 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
226 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
227 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
228 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
229 ColumnType::MYSQL_TYPE_STRING if is_binary => {
230 if is_utf8_bin_character_set {
231 Ok(MysqlType::Char)
232 } else {
233 Ok(MysqlType::Binary)
234 }
235 }
236 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
237 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
238 if is_utf8_bin_character_set {
239 Ok(MysqlType::Varchar)
240 } else {
241 Ok(MysqlType::Varbinary)
242 }
243 }
244 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
245 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
246 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
247 if is_utf8_bin_character_set {
248 Ok(MysqlType::Text(col_length))
249 } else {
250 Ok(MysqlType::Blob(col_length))
251 }
252 }
253 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
254 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
255 _ => Err(DataFusionError::NotImplemented(format!(
256 "Unsupported mysql type: {mysql_col:?}",
257 ))),
258 }
259}
260
261fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
262 let mut remote_fields = vec![];
263 for col in stmt.columns() {
264 remote_fields.push(RemoteField::new(
265 col.name_str().to_string(),
266 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
267 true,
268 ));
269 }
270 Ok(RemoteSchema::new(remote_fields))
271}
272
273macro_rules! handle_primitive_type {
274 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
275 let builder = $builder
276 .as_any_mut()
277 .downcast_mut::<$builder_ty>()
278 .unwrap_or_else(|| {
279 panic!(
280 "Failed to downcast builder to {} for {:?} and {:?}",
281 stringify!($builder_ty),
282 $field,
283 $col
284 )
285 });
286 let v = $row.get_opt::<$value_ty, usize>($index);
287
288 match v {
289 None => builder.append_null(),
290 Some(Ok(v)) => builder.append_value($convert(v)?),
291 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
292 Some(Err(e)) => {
293 return Err(DataFusionError::Execution(format!(
294 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
295 stringify!($value_ty),
296 $field,
297 $col,
298 )));
299 }
300 }
301 }};
302}
303
304fn rows_to_batch(
305 rows: &[Row],
306 table_schema: &SchemaRef,
307 projection: Option<&Vec<usize>>,
308) -> DFResult<RecordBatch> {
309 let projected_schema = project_schema(table_schema, projection)?;
310 let mut array_builders = vec![];
311 for field in table_schema.fields() {
312 let builder = make_builder(field.data_type(), rows.len());
313 array_builders.push(builder);
314 }
315
316 for row in rows {
317 for (idx, field) in table_schema.fields.iter().enumerate() {
318 if !projections_contains(projection, idx) {
319 continue;
320 }
321 let builder = &mut array_builders[idx];
322 let col = row.columns_ref().get(idx);
323 match field.data_type() {
324 DataType::Int8 => {
325 handle_primitive_type!(
326 builder,
327 field,
328 col,
329 Int8Builder,
330 i8,
331 row,
332 idx,
333 just_return
334 );
335 }
336 DataType::Int16 => {
337 handle_primitive_type!(
338 builder,
339 field,
340 col,
341 Int16Builder,
342 i16,
343 row,
344 idx,
345 just_return
346 );
347 }
348 DataType::Int32 => {
349 handle_primitive_type!(
350 builder,
351 field,
352 col,
353 Int32Builder,
354 i32,
355 row,
356 idx,
357 just_return
358 );
359 }
360 DataType::Int64 => {
361 handle_primitive_type!(
362 builder,
363 field,
364 col,
365 Int64Builder,
366 i64,
367 row,
368 idx,
369 just_return
370 );
371 }
372 DataType::UInt8 => {
373 handle_primitive_type!(
374 builder,
375 field,
376 col,
377 UInt8Builder,
378 u8,
379 row,
380 idx,
381 just_return
382 );
383 }
384 DataType::UInt16 => {
385 handle_primitive_type!(
386 builder,
387 field,
388 col,
389 UInt16Builder,
390 u16,
391 row,
392 idx,
393 just_return
394 );
395 }
396 DataType::UInt32 => {
397 handle_primitive_type!(
398 builder,
399 field,
400 col,
401 UInt32Builder,
402 u32,
403 row,
404 idx,
405 just_return
406 );
407 }
408 DataType::UInt64 => {
409 handle_primitive_type!(
410 builder,
411 field,
412 col,
413 UInt64Builder,
414 u64,
415 row,
416 idx,
417 just_return
418 );
419 }
420 DataType::Float32 => {
421 handle_primitive_type!(
422 builder,
423 field,
424 col,
425 Float32Builder,
426 f32,
427 row,
428 idx,
429 just_return
430 );
431 }
432 DataType::Float64 => {
433 handle_primitive_type!(
434 builder,
435 field,
436 col,
437 Float64Builder,
438 f64,
439 row,
440 idx,
441 just_return
442 );
443 }
444 DataType::Decimal128(_precision, scale) => {
445 handle_primitive_type!(
446 builder,
447 field,
448 col,
449 Decimal128Builder,
450 BigDecimal,
451 row,
452 idx,
453 |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
454 );
455 }
456 DataType::Decimal256(_precision, _scale) => {
457 handle_primitive_type!(
458 builder,
459 field,
460 col,
461 Decimal256Builder,
462 BigDecimal,
463 row,
464 idx,
465 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
466 );
467 }
468 DataType::Date32 => {
469 handle_primitive_type!(
470 builder,
471 field,
472 col,
473 Date32Builder,
474 chrono::NaiveDate,
475 row,
476 idx,
477 |v: chrono::NaiveDate| {
478 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
479 }
480 );
481 }
482 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
483 match tz_opt {
484 None => {}
485 Some(tz) => {
486 if !tz.eq_ignore_ascii_case("utc") {
487 return Err(DataFusionError::NotImplemented(format!(
488 "Unsupported data type {:?} for col: {:?}",
489 field.data_type(),
490 col
491 )));
492 }
493 }
494 }
495 handle_primitive_type!(
496 builder,
497 field,
498 col,
499 TimestampMicrosecondBuilder,
500 time::PrimitiveDateTime,
501 row,
502 idx,
503 |v: time::PrimitiveDateTime| {
504 let timestamp_micros =
505 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
506 Ok::<_, DataFusionError>(timestamp_micros)
507 }
508 );
509 }
510 DataType::Time32(TimeUnit::Second) => {
511 handle_primitive_type!(
512 builder,
513 field,
514 col,
515 Time32SecondBuilder,
516 chrono::NaiveTime,
517 row,
518 idx,
519 |v: chrono::NaiveTime| {
520 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
521 }
522 );
523 }
524 DataType::Time64(TimeUnit::Nanosecond) => {
525 handle_primitive_type!(
526 builder,
527 field,
528 col,
529 Time64NanosecondBuilder,
530 chrono::NaiveTime,
531 row,
532 idx,
533 |v: chrono::NaiveTime| {
534 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
535 + i64::from(v.nanosecond());
536 Ok::<_, DataFusionError>(t)
537 }
538 );
539 }
540 DataType::Utf8 => {
541 handle_primitive_type!(
542 builder,
543 field,
544 col,
545 StringBuilder,
546 String,
547 row,
548 idx,
549 just_return
550 );
551 }
552 DataType::LargeUtf8 => {
553 handle_primitive_type!(
554 builder,
555 field,
556 col,
557 LargeStringBuilder,
558 String,
559 row,
560 idx,
561 just_return
562 );
563 }
564 DataType::Utf8View => {
565 handle_primitive_type!(
566 builder,
567 field,
568 col,
569 StringViewBuilder,
570 String,
571 row,
572 idx,
573 just_return
574 );
575 }
576 DataType::Binary => {
577 handle_primitive_type!(
578 builder,
579 field,
580 col,
581 BinaryBuilder,
582 Vec<u8>,
583 row,
584 idx,
585 just_return
586 );
587 }
588 DataType::LargeBinary => {
589 handle_primitive_type!(
590 builder,
591 field,
592 col,
593 LargeBinaryBuilder,
594 Vec<u8>,
595 row,
596 idx,
597 just_return
598 );
599 }
600 DataType::BinaryView => {
601 handle_primitive_type!(
602 builder,
603 field,
604 col,
605 BinaryViewBuilder,
606 Vec<u8>,
607 row,
608 idx,
609 just_return
610 );
611 }
612 _ => {
613 return Err(DataFusionError::NotImplemented(format!(
614 "Unsupported data type {:?} for col: {:?}",
615 field.data_type(),
616 col
617 )));
618 }
619 }
620 }
621 }
622 let projected_columns = array_builders
623 .into_iter()
624 .enumerate()
625 .filter(|(idx, _)| projections_contains(projection, *idx))
626 .map(|(_, mut builder)| builder.finish())
627 .collect::<Vec<ArrayRef>>();
628 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
629 Ok(RecordBatch::try_new_with_options(
630 projected_schema,
631 projected_columns,
632 &options,
633 )?)
634}
635
636fn to_decimal_256(decimal: &BigDecimal) -> i256 {
637 let (bigint_value, _) = decimal.as_bigint_and_exponent();
638 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
639
640 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
641 let fill_byte = if is_negative { 0xFF } else { 0x00 };
642
643 if bigint_bytes.len() > 32 {
644 bigint_bytes.truncate(32);
645 } else {
646 bigint_bytes.resize(32, fill_byte);
647 };
648
649 let mut array = [0u8; 32];
650 array.copy_from_slice(&bigint_bytes);
651
652 i256::from_le_bytes(array)
653}