1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4 Connection, ConnectionOptions, DFResult, Literalize, MysqlType, Pool, RemoteField,
5 RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
6};
7use async_stream::stream;
8use bigdecimal::{BigDecimal, num_bigint};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
12 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
13 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
14 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
15 UInt32Builder, UInt64Builder, make_builder,
16};
17use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
18use datafusion::common::{DataFusionError, project_schema};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use derive_getters::Getters;
22use derive_with::With;
23use futures::StreamExt;
24use futures::lock::Mutex;
25use log::debug;
26use mysql_async::consts::{ColumnFlags, ColumnType};
27use mysql_async::prelude::Queryable;
28use mysql_async::{Column, FromValueError, Row, Value};
29use std::any::Any;
30use std::sync::Arc;
31
32#[derive(Debug, Clone, With, Getters)]
33pub struct MysqlConnectionOptions {
34 pub(crate) host: String,
35 pub(crate) port: u16,
36 pub(crate) username: String,
37 pub(crate) password: String,
38 pub(crate) database: Option<String>,
39 pub(crate) pool_max_size: usize,
40 pub(crate) stream_chunk_size: usize,
41}
42
43impl MysqlConnectionOptions {
44 pub fn new(
45 host: impl Into<String>,
46 port: u16,
47 username: impl Into<String>,
48 password: impl Into<String>,
49 ) -> Self {
50 Self {
51 host: host.into(),
52 port,
53 username: username.into(),
54 password: password.into(),
55 database: None,
56 pool_max_size: 10,
57 stream_chunk_size: 2048,
58 }
59 }
60}
61
62impl From<MysqlConnectionOptions> for ConnectionOptions {
63 fn from(options: MysqlConnectionOptions) -> Self {
64 ConnectionOptions::Mysql(options)
65 }
66}
67
68#[derive(Debug)]
69pub struct MysqlPool {
70 pool: mysql_async::Pool,
71}
72
73pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
74 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
75 mysql_async::PoolConstraints::new(0, options.pool_max_size)
76 .expect("Failed to create pool constraints"),
77 );
78 let opts_builder = mysql_async::OptsBuilder::default()
79 .ip_or_hostname(options.host.clone())
80 .tcp_port(options.port)
81 .user(Some(options.username.clone()))
82 .pass(Some(options.password.clone()))
83 .db_name(options.database.clone())
84 .init(vec!["set time_zone='+00:00'".to_string()])
85 .pool_opts(pool_opts);
86 let pool = mysql_async::Pool::new(opts_builder);
87 Ok(MysqlPool { pool })
88}
89
90#[async_trait::async_trait]
91impl Pool for MysqlPool {
92 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
93 let conn = self.pool.get_conn().await.map_err(|e| {
94 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
95 })?;
96 Ok(Arc::new(MysqlConnection {
97 conn: Arc::new(Mutex::new(conn)),
98 }))
99 }
100}
101
102#[derive(Debug)]
103pub struct MysqlConnection {
104 conn: Arc<Mutex<mysql_async::Conn>>,
105}
106
107#[async_trait::async_trait]
108impl Connection for MysqlConnection {
109 fn as_any(&self) -> &dyn Any {
110 self
111 }
112
113 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
114 let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
115 let mut conn = self.conn.lock().await;
116 let conn = &mut *conn;
117 let stmt = conn.prep(&sql).await.map_err(|e| {
118 DataFusionError::Execution(format!("Failed to prepare query {sql} on mysql: {e:?}"))
119 })?;
120 let remote_schema = Arc::new(build_remote_schema(&stmt)?);
121 Ok(remote_schema)
122 }
123
124 async fn query(
125 &self,
126 conn_options: &ConnectionOptions,
127 source: &RemoteSource,
128 table_schema: SchemaRef,
129 projection: Option<&Vec<usize>>,
130 unparsed_filters: &[String],
131 limit: Option<usize>,
132 ) -> DFResult<SendableRecordBatchStream> {
133 let projected_schema = project_schema(&table_schema, projection)?;
134
135 let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
136 debug!("[remote-table] executing mysql query: {sql}");
137
138 let projection = projection.cloned();
139 let chunk_size = conn_options.stream_chunk_size();
140 let conn = Arc::clone(&self.conn);
141 let stream = Box::pin(stream! {
142 let mut conn = conn.lock().await;
143 let mut query_iter = conn
144 .query_iter(sql.clone())
145 .await
146 .map_err(|e| {
147 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
148 })?;
149
150 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
151 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
152 })? else {
153 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
154 return;
155 };
156
157 let mut chunked_stream = stream.chunks(chunk_size).boxed();
158
159 while let Some(chunk) = chunked_stream.next().await {
160 let rows = chunk
161 .into_iter()
162 .collect::<Result<Vec<_>, _>>()
163 .map_err(|e| {
164 DataFusionError::Execution(format!(
165 "Failed to collect rows from mysql due to {e}",
166 ))
167 })?;
168
169 yield Ok::<_, DataFusionError>(rows)
170 }
171 });
172
173 let stream = stream.map(move |rows| {
174 let rows = rows?;
175 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
176 });
177
178 Ok(Box::pin(RecordBatchStreamAdapter::new(
179 projected_schema,
180 stream,
181 )))
182 }
183
184 async fn insert(
185 &self,
186 _conn_options: &ConnectionOptions,
187 _literalizer: Arc<dyn Literalize>,
188 _table: &[String],
189 _remote_schema: RemoteSchemaRef,
190 _input: SendableRecordBatchStream,
191 ) -> DFResult<usize> {
192 Err(DataFusionError::Execution(
193 "Insert operation is not supported for mysql".to_string(),
194 ))
195 }
196}
197
198fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
199 let character_set = mysql_col.character_set();
200 let is_utf8_bin_character_set = character_set == 45;
201 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
202 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
203 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
204 let col_length = mysql_col.column_length();
205 match mysql_col.column_type() {
206 ColumnType::MYSQL_TYPE_TINY => {
207 if is_unsigned {
208 Ok(MysqlType::TinyIntUnsigned)
209 } else {
210 Ok(MysqlType::TinyInt)
211 }
212 }
213 ColumnType::MYSQL_TYPE_SHORT => {
214 if is_unsigned {
215 Ok(MysqlType::SmallIntUnsigned)
216 } else {
217 Ok(MysqlType::SmallInt)
218 }
219 }
220 ColumnType::MYSQL_TYPE_INT24 => {
221 if is_unsigned {
222 Ok(MysqlType::MediumIntUnsigned)
223 } else {
224 Ok(MysqlType::MediumInt)
225 }
226 }
227 ColumnType::MYSQL_TYPE_LONG => {
228 if is_unsigned {
229 Ok(MysqlType::IntegerUnsigned)
230 } else {
231 Ok(MysqlType::Integer)
232 }
233 }
234 ColumnType::MYSQL_TYPE_LONGLONG => {
235 if is_unsigned {
236 Ok(MysqlType::BigIntUnsigned)
237 } else {
238 Ok(MysqlType::BigInt)
239 }
240 }
241 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
242 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
243 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
244 let precision = (mysql_col.column_length() - 2) as u8;
245 let scale = mysql_col.decimals();
246 Ok(MysqlType::Decimal(precision, scale))
247 }
248 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
249 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
250 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
251 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
252 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
253 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
254 ColumnType::MYSQL_TYPE_STRING if is_binary => {
255 if is_utf8_bin_character_set {
256 Ok(MysqlType::Char)
257 } else {
258 Ok(MysqlType::Binary)
259 }
260 }
261 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
262 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
263 if is_utf8_bin_character_set {
264 Ok(MysqlType::Varchar)
265 } else {
266 Ok(MysqlType::Varbinary)
267 }
268 }
269 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
270 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
271 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
272 if is_utf8_bin_character_set {
273 Ok(MysqlType::Text(col_length))
274 } else {
275 Ok(MysqlType::Blob(col_length))
276 }
277 }
278 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
279 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
280 _ => Err(DataFusionError::NotImplemented(format!(
281 "Unsupported mysql type: {mysql_col:?}",
282 ))),
283 }
284}
285
286fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
287 let mut remote_fields = vec![];
288 for col in stmt.columns() {
289 remote_fields.push(RemoteField::new(
290 col.name_str().to_string(),
291 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
292 true,
293 ));
294 }
295 Ok(RemoteSchema::new(remote_fields))
296}
297
298macro_rules! handle_primitive_type {
299 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
300 let builder = $builder
301 .as_any_mut()
302 .downcast_mut::<$builder_ty>()
303 .unwrap_or_else(|| {
304 panic!(
305 "Failed to downcast builder to {} for {:?} and {:?}",
306 stringify!($builder_ty),
307 $field,
308 $col
309 )
310 });
311 let v = $row.get_opt::<$value_ty, usize>($index);
312
313 match v {
314 None => builder.append_null(),
315 Some(Ok(v)) => builder.append_value($convert(v)?),
316 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
317 Some(Err(e)) => {
318 return Err(DataFusionError::Execution(format!(
319 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
320 stringify!($value_ty),
321 $field,
322 $col,
323 )));
324 }
325 }
326 }};
327}
328
329fn rows_to_batch(
330 rows: &[Row],
331 table_schema: &SchemaRef,
332 projection: Option<&Vec<usize>>,
333) -> DFResult<RecordBatch> {
334 let projected_schema = project_schema(table_schema, projection)?;
335 let mut array_builders = vec![];
336 for field in table_schema.fields() {
337 let builder = make_builder(field.data_type(), rows.len());
338 array_builders.push(builder);
339 }
340
341 for row in rows {
342 for (idx, field) in table_schema.fields.iter().enumerate() {
343 if !projections_contains(projection, idx) {
344 continue;
345 }
346 let builder = &mut array_builders[idx];
347 let col = row.columns_ref().get(idx);
348 match field.data_type() {
349 DataType::Int8 => {
350 handle_primitive_type!(
351 builder,
352 field,
353 col,
354 Int8Builder,
355 i8,
356 row,
357 idx,
358 just_return
359 );
360 }
361 DataType::Int16 => {
362 handle_primitive_type!(
363 builder,
364 field,
365 col,
366 Int16Builder,
367 i16,
368 row,
369 idx,
370 just_return
371 );
372 }
373 DataType::Int32 => {
374 handle_primitive_type!(
375 builder,
376 field,
377 col,
378 Int32Builder,
379 i32,
380 row,
381 idx,
382 just_return
383 );
384 }
385 DataType::Int64 => {
386 handle_primitive_type!(
387 builder,
388 field,
389 col,
390 Int64Builder,
391 i64,
392 row,
393 idx,
394 just_return
395 );
396 }
397 DataType::UInt8 => {
398 handle_primitive_type!(
399 builder,
400 field,
401 col,
402 UInt8Builder,
403 u8,
404 row,
405 idx,
406 just_return
407 );
408 }
409 DataType::UInt16 => {
410 handle_primitive_type!(
411 builder,
412 field,
413 col,
414 UInt16Builder,
415 u16,
416 row,
417 idx,
418 just_return
419 );
420 }
421 DataType::UInt32 => {
422 handle_primitive_type!(
423 builder,
424 field,
425 col,
426 UInt32Builder,
427 u32,
428 row,
429 idx,
430 just_return
431 );
432 }
433 DataType::UInt64 => {
434 handle_primitive_type!(
435 builder,
436 field,
437 col,
438 UInt64Builder,
439 u64,
440 row,
441 idx,
442 just_return
443 );
444 }
445 DataType::Float32 => {
446 handle_primitive_type!(
447 builder,
448 field,
449 col,
450 Float32Builder,
451 f32,
452 row,
453 idx,
454 just_return
455 );
456 }
457 DataType::Float64 => {
458 handle_primitive_type!(
459 builder,
460 field,
461 col,
462 Float64Builder,
463 f64,
464 row,
465 idx,
466 just_return
467 );
468 }
469 DataType::Decimal128(_precision, scale) => {
470 handle_primitive_type!(
471 builder,
472 field,
473 col,
474 Decimal128Builder,
475 BigDecimal,
476 row,
477 idx,
478 |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
479 );
480 }
481 DataType::Decimal256(_precision, _scale) => {
482 handle_primitive_type!(
483 builder,
484 field,
485 col,
486 Decimal256Builder,
487 BigDecimal,
488 row,
489 idx,
490 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
491 );
492 }
493 DataType::Date32 => {
494 handle_primitive_type!(
495 builder,
496 field,
497 col,
498 Date32Builder,
499 chrono::NaiveDate,
500 row,
501 idx,
502 |v: chrono::NaiveDate| {
503 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
504 }
505 );
506 }
507 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
508 match tz_opt {
509 None => {}
510 Some(tz) => {
511 if !tz.eq_ignore_ascii_case("utc") {
512 return Err(DataFusionError::NotImplemented(format!(
513 "Unsupported data type {:?} for col: {:?}",
514 field.data_type(),
515 col
516 )));
517 }
518 }
519 }
520 handle_primitive_type!(
521 builder,
522 field,
523 col,
524 TimestampMicrosecondBuilder,
525 time::PrimitiveDateTime,
526 row,
527 idx,
528 |v: time::PrimitiveDateTime| {
529 let timestamp_micros =
530 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
531 Ok::<_, DataFusionError>(timestamp_micros)
532 }
533 );
534 }
535 DataType::Time32(TimeUnit::Second) => {
536 handle_primitive_type!(
537 builder,
538 field,
539 col,
540 Time32SecondBuilder,
541 chrono::NaiveTime,
542 row,
543 idx,
544 |v: chrono::NaiveTime| {
545 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
546 }
547 );
548 }
549 DataType::Time64(TimeUnit::Nanosecond) => {
550 handle_primitive_type!(
551 builder,
552 field,
553 col,
554 Time64NanosecondBuilder,
555 chrono::NaiveTime,
556 row,
557 idx,
558 |v: chrono::NaiveTime| {
559 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
560 + i64::from(v.nanosecond());
561 Ok::<_, DataFusionError>(t)
562 }
563 );
564 }
565 DataType::Utf8 => {
566 handle_primitive_type!(
567 builder,
568 field,
569 col,
570 StringBuilder,
571 String,
572 row,
573 idx,
574 just_return
575 );
576 }
577 DataType::LargeUtf8 => {
578 handle_primitive_type!(
579 builder,
580 field,
581 col,
582 LargeStringBuilder,
583 String,
584 row,
585 idx,
586 just_return
587 );
588 }
589 DataType::Binary => {
590 handle_primitive_type!(
591 builder,
592 field,
593 col,
594 BinaryBuilder,
595 Vec<u8>,
596 row,
597 idx,
598 just_return
599 );
600 }
601 DataType::LargeBinary => {
602 handle_primitive_type!(
603 builder,
604 field,
605 col,
606 LargeBinaryBuilder,
607 Vec<u8>,
608 row,
609 idx,
610 just_return
611 );
612 }
613 _ => {
614 return Err(DataFusionError::NotImplemented(format!(
615 "Unsupported data type {:?} for col: {:?}",
616 field.data_type(),
617 col
618 )));
619 }
620 }
621 }
622 }
623 let projected_columns = array_builders
624 .into_iter()
625 .enumerate()
626 .filter(|(idx, _)| projections_contains(projection, *idx))
627 .map(|(_, mut builder)| builder.finish())
628 .collect::<Vec<ArrayRef>>();
629 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
630 Ok(RecordBatch::try_new_with_options(
631 projected_schema,
632 projected_columns,
633 &options,
634 )?)
635}
636
637fn to_decimal_256(decimal: &BigDecimal) -> i256 {
638 let (bigint_value, _) = decimal.as_bigint_and_exponent();
639 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
640
641 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
642 let fill_byte = if is_negative { 0xFF } else { 0x00 };
643
644 if bigint_bytes.len() > 32 {
645 bigint_bytes.truncate(32);
646 } else {
647 bigint_bytes.resize(32, fill_byte);
648 };
649
650 let mut array = [0u8; 32];
651 array.copy_from_slice(&bigint_bytes);
652
653 i256::from_le_bytes(array)
654}