1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33 pub(crate) host: String,
34 pub(crate) port: u16,
35 pub(crate) username: String,
36 pub(crate) password: String,
37 pub(crate) database: Option<String>,
38 pub(crate) pool_max_size: usize,
39 pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43 pub fn new(
44 host: impl Into<String>,
45 port: u16,
46 username: impl Into<String>,
47 password: impl Into<String>,
48 ) -> Self {
49 Self {
50 host: host.into(),
51 port,
52 username: username.into(),
53 password: password.into(),
54 database: None,
55 pool_max_size: 10,
56 stream_chunk_size: 2048,
57 }
58 }
59}
60
61impl From<MysqlConnectionOptions> for ConnectionOptions {
62 fn from(options: MysqlConnectionOptions) -> Self {
63 ConnectionOptions::Mysql(options)
64 }
65}
66
67#[derive(Debug)]
68pub struct MysqlPool {
69 pool: mysql_async::Pool,
70}
71
72pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
73 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
74 mysql_async::PoolConstraints::new(0, options.pool_max_size)
75 .expect("Failed to create pool constraints"),
76 );
77 let opts_builder = mysql_async::OptsBuilder::default()
78 .ip_or_hostname(options.host.clone())
79 .tcp_port(options.port)
80 .user(Some(options.username.clone()))
81 .pass(Some(options.password.clone()))
82 .db_name(options.database.clone())
83 .init(vec!["set time_zone='+00:00'".to_string()])
84 .pool_opts(pool_opts);
85 let pool = mysql_async::Pool::new(opts_builder);
86 Ok(MysqlPool { pool })
87}
88
89#[async_trait::async_trait]
90impl Pool for MysqlPool {
91 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
92 let conn = self.pool.get_conn().await.map_err(|e| {
93 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
94 })?;
95 Ok(Arc::new(MysqlConnection {
96 conn: Arc::new(Mutex::new(conn)),
97 }))
98 }
99}
100
101#[derive(Debug)]
102pub struct MysqlConnection {
103 conn: Arc<Mutex<mysql_async::Conn>>,
104}
105
106#[async_trait::async_trait]
107impl Connection for MysqlConnection {
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111
112 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
113 let sql = RemoteDbType::Mysql.query_limit_1(sql);
114 let mut conn = self.conn.lock().await;
115 let conn = &mut *conn;
116 let stmt = conn.prep(&sql).await.map_err(|e| {
117 DataFusionError::Execution(format!("Failed to prepare query {sql} on mysql: {e:?}"))
118 })?;
119 let remote_schema = Arc::new(build_remote_schema(&stmt)?);
120 Ok(remote_schema)
121 }
122
123 async fn query(
124 &self,
125 conn_options: &ConnectionOptions,
126 sql: &str,
127 table_schema: SchemaRef,
128 projection: Option<&Vec<usize>>,
129 unparsed_filters: &[String],
130 limit: Option<usize>,
131 ) -> DFResult<SendableRecordBatchStream> {
132 let projected_schema = project_schema(&table_schema, projection)?;
133
134 let sql = RemoteDbType::Mysql.rewrite_query(sql, unparsed_filters, limit);
135 debug!("[remote-table] executing mysql query: {sql}");
136
137 let projection = projection.cloned();
138 let chunk_size = conn_options.stream_chunk_size();
139 let conn = Arc::clone(&self.conn);
140 let stream = Box::pin(stream! {
141 let mut conn = conn.lock().await;
142 let mut query_iter = conn
143 .query_iter(sql.clone())
144 .await
145 .map_err(|e| {
146 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
147 })?;
148
149 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
150 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
151 })? else {
152 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
153 return;
154 };
155
156 let mut chunked_stream = stream.chunks(chunk_size).boxed();
157
158 while let Some(chunk) = chunked_stream.next().await {
159 let rows = chunk
160 .into_iter()
161 .collect::<Result<Vec<_>, _>>()
162 .map_err(|e| {
163 DataFusionError::Execution(format!(
164 "Failed to collect rows from mysql due to {e}",
165 ))
166 })?;
167
168 yield Ok::<_, DataFusionError>(rows)
169 }
170 });
171
172 let stream = stream.map(move |rows| {
173 let rows = rows?;
174 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
175 });
176
177 Ok(Box::pin(RecordBatchStreamAdapter::new(
178 projected_schema,
179 stream,
180 )))
181 }
182}
183
184fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
185 let character_set = mysql_col.character_set();
186 let is_utf8_bin_character_set = character_set == 45;
187 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
188 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
189 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
190 let col_length = mysql_col.column_length();
191 match mysql_col.column_type() {
192 ColumnType::MYSQL_TYPE_TINY => {
193 if is_unsigned {
194 Ok(MysqlType::TinyIntUnsigned)
195 } else {
196 Ok(MysqlType::TinyInt)
197 }
198 }
199 ColumnType::MYSQL_TYPE_SHORT => {
200 if is_unsigned {
201 Ok(MysqlType::SmallIntUnsigned)
202 } else {
203 Ok(MysqlType::SmallInt)
204 }
205 }
206 ColumnType::MYSQL_TYPE_INT24 => {
207 if is_unsigned {
208 Ok(MysqlType::MediumIntUnsigned)
209 } else {
210 Ok(MysqlType::MediumInt)
211 }
212 }
213 ColumnType::MYSQL_TYPE_LONG => {
214 if is_unsigned {
215 Ok(MysqlType::IntegerUnsigned)
216 } else {
217 Ok(MysqlType::Integer)
218 }
219 }
220 ColumnType::MYSQL_TYPE_LONGLONG => {
221 if is_unsigned {
222 Ok(MysqlType::BigIntUnsigned)
223 } else {
224 Ok(MysqlType::BigInt)
225 }
226 }
227 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
228 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
229 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
230 let precision = (mysql_col.column_length() - 2) as u8;
231 let scale = mysql_col.decimals();
232 Ok(MysqlType::Decimal(precision, scale))
233 }
234 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
235 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
236 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
237 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
238 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
239 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
240 ColumnType::MYSQL_TYPE_STRING if is_binary => {
241 if is_utf8_bin_character_set {
242 Ok(MysqlType::Char)
243 } else {
244 Ok(MysqlType::Binary)
245 }
246 }
247 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
248 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
249 if is_utf8_bin_character_set {
250 Ok(MysqlType::Varchar)
251 } else {
252 Ok(MysqlType::Varbinary)
253 }
254 }
255 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
256 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
257 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
258 if is_utf8_bin_character_set {
259 Ok(MysqlType::Text(col_length))
260 } else {
261 Ok(MysqlType::Blob(col_length))
262 }
263 }
264 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
265 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
266 _ => Err(DataFusionError::NotImplemented(format!(
267 "Unsupported mysql type: {mysql_col:?}",
268 ))),
269 }
270}
271
272fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
273 let mut remote_fields = vec![];
274 for col in stmt.columns() {
275 remote_fields.push(RemoteField::new(
276 col.name_str().to_string(),
277 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
278 true,
279 ));
280 }
281 Ok(RemoteSchema::new(remote_fields))
282}
283
284macro_rules! handle_primitive_type {
285 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
286 let builder = $builder
287 .as_any_mut()
288 .downcast_mut::<$builder_ty>()
289 .unwrap_or_else(|| {
290 panic!(
291 "Failed to downcast builder to {} for {:?} and {:?}",
292 stringify!($builder_ty),
293 $field,
294 $col
295 )
296 });
297 let v = $row.get_opt::<$value_ty, usize>($index);
298
299 match v {
300 None => builder.append_null(),
301 Some(Ok(v)) => builder.append_value($convert(v)?),
302 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
303 Some(Err(e)) => {
304 return Err(DataFusionError::Execution(format!(
305 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
306 stringify!($value_ty),
307 $field,
308 $col,
309 )));
310 }
311 }
312 }};
313}
314
315fn rows_to_batch(
316 rows: &[Row],
317 table_schema: &SchemaRef,
318 projection: Option<&Vec<usize>>,
319) -> DFResult<RecordBatch> {
320 let projected_schema = project_schema(table_schema, projection)?;
321 let mut array_builders = vec![];
322 for field in table_schema.fields() {
323 let builder = make_builder(field.data_type(), rows.len());
324 array_builders.push(builder);
325 }
326
327 for row in rows {
328 for (idx, field) in table_schema.fields.iter().enumerate() {
329 if !projections_contains(projection, idx) {
330 continue;
331 }
332 let builder = &mut array_builders[idx];
333 let col = row.columns_ref().get(idx);
334 match field.data_type() {
335 DataType::Int8 => {
336 handle_primitive_type!(
337 builder,
338 field,
339 col,
340 Int8Builder,
341 i8,
342 row,
343 idx,
344 just_return
345 );
346 }
347 DataType::Int16 => {
348 handle_primitive_type!(
349 builder,
350 field,
351 col,
352 Int16Builder,
353 i16,
354 row,
355 idx,
356 just_return
357 );
358 }
359 DataType::Int32 => {
360 handle_primitive_type!(
361 builder,
362 field,
363 col,
364 Int32Builder,
365 i32,
366 row,
367 idx,
368 just_return
369 );
370 }
371 DataType::Int64 => {
372 handle_primitive_type!(
373 builder,
374 field,
375 col,
376 Int64Builder,
377 i64,
378 row,
379 idx,
380 just_return
381 );
382 }
383 DataType::UInt8 => {
384 handle_primitive_type!(
385 builder,
386 field,
387 col,
388 UInt8Builder,
389 u8,
390 row,
391 idx,
392 just_return
393 );
394 }
395 DataType::UInt16 => {
396 handle_primitive_type!(
397 builder,
398 field,
399 col,
400 UInt16Builder,
401 u16,
402 row,
403 idx,
404 just_return
405 );
406 }
407 DataType::UInt32 => {
408 handle_primitive_type!(
409 builder,
410 field,
411 col,
412 UInt32Builder,
413 u32,
414 row,
415 idx,
416 just_return
417 );
418 }
419 DataType::UInt64 => {
420 handle_primitive_type!(
421 builder,
422 field,
423 col,
424 UInt64Builder,
425 u64,
426 row,
427 idx,
428 just_return
429 );
430 }
431 DataType::Float32 => {
432 handle_primitive_type!(
433 builder,
434 field,
435 col,
436 Float32Builder,
437 f32,
438 row,
439 idx,
440 just_return
441 );
442 }
443 DataType::Float64 => {
444 handle_primitive_type!(
445 builder,
446 field,
447 col,
448 Float64Builder,
449 f64,
450 row,
451 idx,
452 just_return
453 );
454 }
455 DataType::Decimal128(_precision, scale) => {
456 handle_primitive_type!(
457 builder,
458 field,
459 col,
460 Decimal128Builder,
461 BigDecimal,
462 row,
463 idx,
464 |v: BigDecimal| {
465 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
466 DataFusionError::Execution(format!(
467 "Failed to convert BigDecimal {v:?} to i128"
468 ))
469 })
470 }
471 );
472 }
473 DataType::Decimal256(_precision, _scale) => {
474 handle_primitive_type!(
475 builder,
476 field,
477 col,
478 Decimal256Builder,
479 BigDecimal,
480 row,
481 idx,
482 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
483 );
484 }
485 DataType::Date32 => {
486 handle_primitive_type!(
487 builder,
488 field,
489 col,
490 Date32Builder,
491 chrono::NaiveDate,
492 row,
493 idx,
494 |v: chrono::NaiveDate| {
495 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
496 }
497 );
498 }
499 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
500 match tz_opt {
501 None => {}
502 Some(tz) => {
503 if !tz.eq_ignore_ascii_case("utc") {
504 return Err(DataFusionError::NotImplemented(format!(
505 "Unsupported data type {:?} for col: {:?}",
506 field.data_type(),
507 col
508 )));
509 }
510 }
511 }
512 handle_primitive_type!(
513 builder,
514 field,
515 col,
516 TimestampMicrosecondBuilder,
517 time::PrimitiveDateTime,
518 row,
519 idx,
520 |v: time::PrimitiveDateTime| {
521 let timestamp_micros =
522 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
523 Ok::<_, DataFusionError>(timestamp_micros)
524 }
525 );
526 }
527 DataType::Time32(TimeUnit::Second) => {
528 handle_primitive_type!(
529 builder,
530 field,
531 col,
532 Time32SecondBuilder,
533 chrono::NaiveTime,
534 row,
535 idx,
536 |v: chrono::NaiveTime| {
537 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
538 }
539 );
540 }
541 DataType::Time64(TimeUnit::Nanosecond) => {
542 handle_primitive_type!(
543 builder,
544 field,
545 col,
546 Time64NanosecondBuilder,
547 chrono::NaiveTime,
548 row,
549 idx,
550 |v: chrono::NaiveTime| {
551 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
552 + i64::from(v.nanosecond());
553 Ok::<_, DataFusionError>(t)
554 }
555 );
556 }
557 DataType::Utf8 => {
558 handle_primitive_type!(
559 builder,
560 field,
561 col,
562 StringBuilder,
563 String,
564 row,
565 idx,
566 just_return
567 );
568 }
569 DataType::LargeUtf8 => {
570 handle_primitive_type!(
571 builder,
572 field,
573 col,
574 LargeStringBuilder,
575 String,
576 row,
577 idx,
578 just_return
579 );
580 }
581 DataType::Binary => {
582 handle_primitive_type!(
583 builder,
584 field,
585 col,
586 BinaryBuilder,
587 Vec<u8>,
588 row,
589 idx,
590 just_return
591 );
592 }
593 DataType::LargeBinary => {
594 handle_primitive_type!(
595 builder,
596 field,
597 col,
598 LargeBinaryBuilder,
599 Vec<u8>,
600 row,
601 idx,
602 just_return
603 );
604 }
605 _ => {
606 return Err(DataFusionError::NotImplemented(format!(
607 "Unsupported data type {:?} for col: {:?}",
608 field.data_type(),
609 col
610 )));
611 }
612 }
613 }
614 }
615 let projected_columns = array_builders
616 .into_iter()
617 .enumerate()
618 .filter(|(idx, _)| projections_contains(projection, *idx))
619 .map(|(_, mut builder)| builder.finish())
620 .collect::<Vec<ArrayRef>>();
621 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
622 Ok(RecordBatch::try_new_with_options(
623 projected_schema,
624 projected_columns,
625 &options,
626 )?)
627}
628
629fn to_decimal_256(decimal: &BigDecimal) -> i256 {
630 let (bigint_value, _) = decimal.as_bigint_and_exponent();
631 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
632
633 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
634 let fill_byte = if is_negative { 0xFF } else { 0x00 };
635
636 if bigint_bytes.len() > 32 {
637 bigint_bytes.truncate(32);
638 } else {
639 bigint_bytes.resize(32, fill_byte);
640 };
641
642 let mut array = [0u8; 32];
643 array.copy_from_slice(&bigint_bytes);
644
645 i256::from_le_bytes(array)
646}