1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4 Connection, ConnectionOptions, DFResult, Literalize, MysqlConnectionOptions, MysqlType, Pool,
5 PoolState, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
6};
7use async_stream::stream;
8use bigdecimal::{BigDecimal, num_bigint};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
12 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
13 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
14 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
15 UInt32Builder, UInt64Builder, make_builder,
16};
17use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
18use datafusion::common::{DataFusionError, project_schema};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::StreamExt;
22use futures::lock::Mutex;
23use log::debug;
24use mysql_async::consts::{ColumnFlags, ColumnType};
25use mysql_async::prelude::Queryable;
26use mysql_async::{Column, FromValueError, Row, Value};
27use std::any::Any;
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct MysqlPool {
32 pool: mysql_async::Pool,
33}
34
35pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
36 let pool_opts = mysql_async::PoolOpts::new()
37 .with_constraints(
38 mysql_async::PoolConstraints::new(options.pool_min_idle, options.pool_max_size)
39 .expect("Failed to create pool constraints"),
40 )
41 .with_inactive_connection_ttl(options.pool_idle_timeout)
42 .with_ttl_check_interval(options.pool_ttl_check_interval);
43 let opts_builder = mysql_async::OptsBuilder::default()
44 .ip_or_hostname(options.host.clone())
45 .tcp_port(options.port)
46 .user(Some(options.username.clone()))
47 .pass(Some(options.password.clone()))
48 .db_name(options.database.clone())
49 .init(vec!["set time_zone='+00:00'".to_string()])
50 .pool_opts(pool_opts);
51 let pool = mysql_async::Pool::new(opts_builder);
52 Ok(MysqlPool { pool })
53}
54
55#[async_trait::async_trait]
56impl Pool for MysqlPool {
57 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
58 let conn = self.pool.get_conn().await.map_err(|e| {
59 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
60 })?;
61 Ok(Arc::new(MysqlConnection {
62 conn: Arc::new(Mutex::new(conn)),
63 }))
64 }
65
66 async fn state(&self) -> DFResult<PoolState> {
67 use std::sync::atomic::Ordering;
68 let metrics = self.pool.metrics();
69 Ok(PoolState {
70 connections: metrics.connection_count.load(Ordering::SeqCst),
71 idle_connections: metrics.connections_in_pool.load(Ordering::SeqCst),
72 })
73 }
74}
75
76#[derive(Debug)]
77pub struct MysqlConnection {
78 conn: Arc<Mutex<mysql_async::Conn>>,
79}
80
81#[async_trait::async_trait]
82impl Connection for MysqlConnection {
83 fn as_any(&self) -> &dyn Any {
84 self
85 }
86
87 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
88 let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
89 let mut conn = self.conn.lock().await;
90 let conn = &mut *conn;
91 let stmt = conn.prep(&sql).await.map_err(|e| {
92 DataFusionError::Plan(format!("Failed to prepare query {sql} on mysql: {e:?}"))
93 })?;
94 let remote_schema = Arc::new(build_remote_schema(&stmt)?);
95 Ok(remote_schema)
96 }
97
98 async fn query(
99 &self,
100 conn_options: &ConnectionOptions,
101 source: &RemoteSource,
102 table_schema: SchemaRef,
103 projection: Option<&Vec<usize>>,
104 unparsed_filters: &[String],
105 limit: Option<usize>,
106 ) -> DFResult<SendableRecordBatchStream> {
107 let projected_schema = project_schema(&table_schema, projection)?;
108
109 let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
110 debug!("[remote-table] executing mysql query: {sql}");
111
112 let projection = projection.cloned();
113 let chunk_size = conn_options.stream_chunk_size();
114 let conn = Arc::clone(&self.conn);
115 let stream = Box::pin(stream! {
116 let mut conn = conn.lock().await;
117 let mut query_iter = conn
118 .query_iter(sql.clone())
119 .await
120 .map_err(|e| {
121 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
122 })?;
123
124 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
125 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
126 })? else {
127 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
128 return;
129 };
130
131 let mut chunked_stream = stream.chunks(chunk_size).boxed();
132
133 while let Some(chunk) = chunked_stream.next().await {
134 let rows = chunk
135 .into_iter()
136 .collect::<Result<Vec<_>, _>>()
137 .map_err(|e| {
138 DataFusionError::Execution(format!(
139 "Failed to collect rows from mysql due to {e}",
140 ))
141 })?;
142
143 yield Ok::<_, DataFusionError>(rows)
144 }
145 });
146
147 let stream = stream.map(move |rows| {
148 let rows = rows?;
149 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
150 });
151
152 Ok(Box::pin(RecordBatchStreamAdapter::new(
153 projected_schema,
154 stream,
155 )))
156 }
157
158 async fn insert(
159 &self,
160 _conn_options: &ConnectionOptions,
161 _literalizer: Arc<dyn Literalize>,
162 _table: &[String],
163 _remote_schema: RemoteSchemaRef,
164 _input: SendableRecordBatchStream,
165 ) -> DFResult<usize> {
166 Err(DataFusionError::Execution(
167 "Insert operation is not supported for mysql".to_string(),
168 ))
169 }
170}
171
172fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
173 let character_set = mysql_col.character_set();
174 let is_utf8_bin_character_set = character_set == 45;
175 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
176 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
177 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
178 let col_length = mysql_col.column_length();
179 match mysql_col.column_type() {
180 ColumnType::MYSQL_TYPE_TINY => {
181 if is_unsigned {
182 Ok(MysqlType::TinyIntUnsigned)
183 } else {
184 Ok(MysqlType::TinyInt)
185 }
186 }
187 ColumnType::MYSQL_TYPE_SHORT => {
188 if is_unsigned {
189 Ok(MysqlType::SmallIntUnsigned)
190 } else {
191 Ok(MysqlType::SmallInt)
192 }
193 }
194 ColumnType::MYSQL_TYPE_INT24 => {
195 if is_unsigned {
196 Ok(MysqlType::MediumIntUnsigned)
197 } else {
198 Ok(MysqlType::MediumInt)
199 }
200 }
201 ColumnType::MYSQL_TYPE_LONG => {
202 if is_unsigned {
203 Ok(MysqlType::IntegerUnsigned)
204 } else {
205 Ok(MysqlType::Integer)
206 }
207 }
208 ColumnType::MYSQL_TYPE_LONGLONG => {
209 if is_unsigned {
210 Ok(MysqlType::BigIntUnsigned)
211 } else {
212 Ok(MysqlType::BigInt)
213 }
214 }
215 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
216 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
217 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
218 let precision = (mysql_col.column_length() - 2) as u8;
219 let scale = mysql_col.decimals();
220 Ok(MysqlType::Decimal(precision, scale))
221 }
222 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
223 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
224 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
225 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
226 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
227 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
228 ColumnType::MYSQL_TYPE_STRING if is_binary => {
229 if is_utf8_bin_character_set {
230 Ok(MysqlType::Char)
231 } else {
232 Ok(MysqlType::Binary)
233 }
234 }
235 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
236 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
237 if is_utf8_bin_character_set {
238 Ok(MysqlType::Varchar)
239 } else {
240 Ok(MysqlType::Varbinary)
241 }
242 }
243 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
244 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
245 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
246 if is_utf8_bin_character_set {
247 Ok(MysqlType::Text(col_length))
248 } else {
249 Ok(MysqlType::Blob(col_length))
250 }
251 }
252 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
253 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
254 _ => Err(DataFusionError::NotImplemented(format!(
255 "Unsupported mysql type: {mysql_col:?}",
256 ))),
257 }
258}
259
260fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
261 let mut remote_fields = vec![];
262 for col in stmt.columns() {
263 remote_fields.push(RemoteField::new(
264 col.name_str().to_string(),
265 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
266 true,
267 ));
268 }
269 Ok(RemoteSchema::new(remote_fields))
270}
271
272macro_rules! handle_primitive_type {
273 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
274 let builder = $builder
275 .as_any_mut()
276 .downcast_mut::<$builder_ty>()
277 .unwrap_or_else(|| {
278 panic!(
279 "Failed to downcast builder to {} for {:?} and {:?}",
280 stringify!($builder_ty),
281 $field,
282 $col
283 )
284 });
285 let v = $row.get_opt::<$value_ty, usize>($index);
286
287 match v {
288 None => builder.append_null(),
289 Some(Ok(v)) => builder.append_value($convert(v)?),
290 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
291 Some(Err(e)) => {
292 return Err(DataFusionError::Execution(format!(
293 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
294 stringify!($value_ty),
295 $field,
296 $col,
297 )));
298 }
299 }
300 }};
301}
302
303fn rows_to_batch(
304 rows: &[Row],
305 table_schema: &SchemaRef,
306 projection: Option<&Vec<usize>>,
307) -> DFResult<RecordBatch> {
308 let projected_schema = project_schema(table_schema, projection)?;
309 let mut array_builders = vec![];
310 for field in table_schema.fields() {
311 let builder = make_builder(field.data_type(), rows.len());
312 array_builders.push(builder);
313 }
314
315 for row in rows {
316 for (idx, field) in table_schema.fields.iter().enumerate() {
317 if !projections_contains(projection, idx) {
318 continue;
319 }
320 let builder = &mut array_builders[idx];
321 let col = row.columns_ref().get(idx);
322 match field.data_type() {
323 DataType::Int8 => {
324 handle_primitive_type!(
325 builder,
326 field,
327 col,
328 Int8Builder,
329 i8,
330 row,
331 idx,
332 just_return
333 );
334 }
335 DataType::Int16 => {
336 handle_primitive_type!(
337 builder,
338 field,
339 col,
340 Int16Builder,
341 i16,
342 row,
343 idx,
344 just_return
345 );
346 }
347 DataType::Int32 => {
348 handle_primitive_type!(
349 builder,
350 field,
351 col,
352 Int32Builder,
353 i32,
354 row,
355 idx,
356 just_return
357 );
358 }
359 DataType::Int64 => {
360 handle_primitive_type!(
361 builder,
362 field,
363 col,
364 Int64Builder,
365 i64,
366 row,
367 idx,
368 just_return
369 );
370 }
371 DataType::UInt8 => {
372 handle_primitive_type!(
373 builder,
374 field,
375 col,
376 UInt8Builder,
377 u8,
378 row,
379 idx,
380 just_return
381 );
382 }
383 DataType::UInt16 => {
384 handle_primitive_type!(
385 builder,
386 field,
387 col,
388 UInt16Builder,
389 u16,
390 row,
391 idx,
392 just_return
393 );
394 }
395 DataType::UInt32 => {
396 handle_primitive_type!(
397 builder,
398 field,
399 col,
400 UInt32Builder,
401 u32,
402 row,
403 idx,
404 just_return
405 );
406 }
407 DataType::UInt64 => {
408 handle_primitive_type!(
409 builder,
410 field,
411 col,
412 UInt64Builder,
413 u64,
414 row,
415 idx,
416 just_return
417 );
418 }
419 DataType::Float32 => {
420 handle_primitive_type!(
421 builder,
422 field,
423 col,
424 Float32Builder,
425 f32,
426 row,
427 idx,
428 just_return
429 );
430 }
431 DataType::Float64 => {
432 handle_primitive_type!(
433 builder,
434 field,
435 col,
436 Float64Builder,
437 f64,
438 row,
439 idx,
440 just_return
441 );
442 }
443 DataType::Decimal128(_precision, scale) => {
444 handle_primitive_type!(
445 builder,
446 field,
447 col,
448 Decimal128Builder,
449 BigDecimal,
450 row,
451 idx,
452 |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
453 );
454 }
455 DataType::Decimal256(_precision, _scale) => {
456 handle_primitive_type!(
457 builder,
458 field,
459 col,
460 Decimal256Builder,
461 BigDecimal,
462 row,
463 idx,
464 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
465 );
466 }
467 DataType::Date32 => {
468 handle_primitive_type!(
469 builder,
470 field,
471 col,
472 Date32Builder,
473 chrono::NaiveDate,
474 row,
475 idx,
476 |v: chrono::NaiveDate| {
477 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
478 }
479 );
480 }
481 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
482 match tz_opt {
483 None => {}
484 Some(tz) => {
485 if !tz.eq_ignore_ascii_case("utc") {
486 return Err(DataFusionError::NotImplemented(format!(
487 "Unsupported data type {:?} for col: {:?}",
488 field.data_type(),
489 col
490 )));
491 }
492 }
493 }
494 handle_primitive_type!(
495 builder,
496 field,
497 col,
498 TimestampMicrosecondBuilder,
499 time::PrimitiveDateTime,
500 row,
501 idx,
502 |v: time::PrimitiveDateTime| {
503 let timestamp_micros =
504 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
505 Ok::<_, DataFusionError>(timestamp_micros)
506 }
507 );
508 }
509 DataType::Time32(TimeUnit::Second) => {
510 handle_primitive_type!(
511 builder,
512 field,
513 col,
514 Time32SecondBuilder,
515 chrono::NaiveTime,
516 row,
517 idx,
518 |v: chrono::NaiveTime| {
519 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
520 }
521 );
522 }
523 DataType::Time64(TimeUnit::Nanosecond) => {
524 handle_primitive_type!(
525 builder,
526 field,
527 col,
528 Time64NanosecondBuilder,
529 chrono::NaiveTime,
530 row,
531 idx,
532 |v: chrono::NaiveTime| {
533 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
534 + i64::from(v.nanosecond());
535 Ok::<_, DataFusionError>(t)
536 }
537 );
538 }
539 DataType::Utf8 => {
540 handle_primitive_type!(
541 builder,
542 field,
543 col,
544 StringBuilder,
545 String,
546 row,
547 idx,
548 just_return
549 );
550 }
551 DataType::LargeUtf8 => {
552 handle_primitive_type!(
553 builder,
554 field,
555 col,
556 LargeStringBuilder,
557 String,
558 row,
559 idx,
560 just_return
561 );
562 }
563 DataType::Binary => {
564 handle_primitive_type!(
565 builder,
566 field,
567 col,
568 BinaryBuilder,
569 Vec<u8>,
570 row,
571 idx,
572 just_return
573 );
574 }
575 DataType::LargeBinary => {
576 handle_primitive_type!(
577 builder,
578 field,
579 col,
580 LargeBinaryBuilder,
581 Vec<u8>,
582 row,
583 idx,
584 just_return
585 );
586 }
587 _ => {
588 return Err(DataFusionError::NotImplemented(format!(
589 "Unsupported data type {:?} for col: {:?}",
590 field.data_type(),
591 col
592 )));
593 }
594 }
595 }
596 }
597 let projected_columns = array_builders
598 .into_iter()
599 .enumerate()
600 .filter(|(idx, _)| projections_contains(projection, *idx))
601 .map(|(_, mut builder)| builder.finish())
602 .collect::<Vec<ArrayRef>>();
603 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
604 Ok(RecordBatch::try_new_with_options(
605 projected_schema,
606 projected_columns,
607 &options,
608 )?)
609}
610
611fn to_decimal_256(decimal: &BigDecimal) -> i256 {
612 let (bigint_value, _) = decimal.as_bigint_and_exponent();
613 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
614
615 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
616 let fill_byte = if is_negative { 0xFF } else { 0x00 };
617
618 if bigint_bytes.len() > 32 {
619 bigint_bytes.truncate(32);
620 } else {
621 bigint_bytes.resize(32, fill_byte);
622 };
623
624 let mut array = [0u8; 32];
625 array.copy_from_slice(&bigint_bytes);
626
627 i256::from_le_bytes(array)
628}