1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::sync::Arc;
29
30#[derive(Debug, Clone, With, Getters)]
31pub struct MysqlConnectionOptions {
32 pub(crate) host: String,
33 pub(crate) port: u16,
34 pub(crate) username: String,
35 pub(crate) password: String,
36 pub(crate) database: Option<String>,
37 pub(crate) pool_max_size: usize,
38 pub(crate) stream_chunk_size: usize,
39}
40
41impl MysqlConnectionOptions {
42 pub fn new(
43 host: impl Into<String>,
44 port: u16,
45 username: impl Into<String>,
46 password: impl Into<String>,
47 ) -> Self {
48 Self {
49 host: host.into(),
50 port,
51 username: username.into(),
52 password: password.into(),
53 database: None,
54 pool_max_size: 10,
55 stream_chunk_size: 2048,
56 }
57 }
58}
59
60#[derive(Debug)]
61pub struct MysqlPool {
62 pool: mysql_async::Pool,
63}
64
65pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
66 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
67 mysql_async::PoolConstraints::new(0, options.pool_max_size)
68 .expect("Failed to create pool constraints"),
69 );
70 let opts_builder = mysql_async::OptsBuilder::default()
71 .ip_or_hostname(options.host.clone())
72 .tcp_port(options.port)
73 .user(Some(options.username.clone()))
74 .pass(Some(options.password.clone()))
75 .db_name(options.database.clone())
76 .init(vec!["set time_zone='+00:00'".to_string()])
77 .pool_opts(pool_opts);
78 let pool = mysql_async::Pool::new(opts_builder);
79 Ok(MysqlPool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for MysqlPool {
84 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85 let conn = self.pool.get_conn().await.map_err(|e| {
86 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
87 })?;
88 Ok(Arc::new(MysqlConnection {
89 conn: Arc::new(Mutex::new(conn)),
90 }))
91 }
92}
93
94#[derive(Debug)]
95pub struct MysqlConnection {
96 conn: Arc<Mutex<mysql_async::Conn>>,
97}
98
99#[async_trait::async_trait]
100impl Connection for MysqlConnection {
101 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
102 let sql = RemoteDbType::Mysql.query_limit_1(sql);
103 let mut conn = self.conn.lock().await;
104 let conn = &mut *conn;
105 let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
106 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
107 })?;
108 let Some(row) = row else {
109 return Err(DataFusionError::Execution(
110 "No rows returned to infer schema".to_string(),
111 ));
112 };
113 let remote_schema = Arc::new(build_remote_schema(&row)?);
114 Ok(remote_schema)
115 }
116
117 async fn query(
118 &self,
119 conn_options: &ConnectionOptions,
120 sql: &str,
121 table_schema: SchemaRef,
122 projection: Option<&Vec<usize>>,
123 unparsed_filters: &[String],
124 limit: Option<usize>,
125 ) -> DFResult<SendableRecordBatchStream> {
126 let projected_schema = project_schema(&table_schema, projection)?;
127
128 let sql = RemoteDbType::Mysql.rewrite_query(sql, unparsed_filters, limit);
129 debug!("[remote-table] executing mysql query: {sql}");
130
131 let projection = projection.cloned();
132 let chunk_size = conn_options.stream_chunk_size();
133 let conn = Arc::clone(&self.conn);
134 let stream = Box::pin(stream! {
135 let mut conn = conn.lock().await;
136 let mut query_iter = conn
137 .query_iter(sql.clone())
138 .await
139 .map_err(|e| {
140 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
141 })?;
142
143 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
144 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
145 })? else {
146 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
147 return;
148 };
149
150 let mut chunked_stream = stream.chunks(chunk_size).boxed();
151
152 while let Some(chunk) = chunked_stream.next().await {
153 let rows = chunk
154 .into_iter()
155 .collect::<Result<Vec<_>, _>>()
156 .map_err(|e| {
157 DataFusionError::Execution(format!(
158 "Failed to collect rows from mysql due to {e}",
159 ))
160 })?;
161
162 yield Ok::<_, DataFusionError>(rows)
163 }
164 });
165
166 let stream = stream.map(move |rows| {
167 let rows = rows?;
168 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
169 });
170
171 Ok(Box::pin(RecordBatchStreamAdapter::new(
172 projected_schema,
173 stream,
174 )))
175 }
176}
177
178fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
179 let character_set = mysql_col.character_set();
180 let is_utf8_bin_character_set = character_set == 45;
181 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
182 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
183 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
184 let col_length = mysql_col.column_length();
185 match mysql_col.column_type() {
186 ColumnType::MYSQL_TYPE_TINY => {
187 if is_unsigned {
188 Ok(MysqlType::TinyIntUnsigned)
189 } else {
190 Ok(MysqlType::TinyInt)
191 }
192 }
193 ColumnType::MYSQL_TYPE_SHORT => {
194 if is_unsigned {
195 Ok(MysqlType::SmallIntUnsigned)
196 } else {
197 Ok(MysqlType::SmallInt)
198 }
199 }
200 ColumnType::MYSQL_TYPE_INT24 => {
201 if is_unsigned {
202 Ok(MysqlType::MediumIntUnsigned)
203 } else {
204 Ok(MysqlType::MediumInt)
205 }
206 }
207 ColumnType::MYSQL_TYPE_LONG => {
208 if is_unsigned {
209 Ok(MysqlType::IntegerUnsigned)
210 } else {
211 Ok(MysqlType::Integer)
212 }
213 }
214 ColumnType::MYSQL_TYPE_LONGLONG => {
215 if is_unsigned {
216 Ok(MysqlType::BigIntUnsigned)
217 } else {
218 Ok(MysqlType::BigInt)
219 }
220 }
221 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
222 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
223 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
224 let precision = (mysql_col.column_length() - 2) as u8;
225 let scale = mysql_col.decimals();
226 Ok(MysqlType::Decimal(precision, scale))
227 }
228 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
229 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
230 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
231 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
232 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
233 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
234 ColumnType::MYSQL_TYPE_STRING if is_binary => {
235 if is_utf8_bin_character_set {
236 Ok(MysqlType::Char)
237 } else {
238 Ok(MysqlType::Binary)
239 }
240 }
241 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
242 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
243 if is_utf8_bin_character_set {
244 Ok(MysqlType::Varchar)
245 } else {
246 Ok(MysqlType::Varbinary)
247 }
248 }
249 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
250 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
251 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
252 if is_utf8_bin_character_set {
253 Ok(MysqlType::Text(col_length))
254 } else {
255 Ok(MysqlType::Blob(col_length))
256 }
257 }
258 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
259 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
260 _ => Err(DataFusionError::NotImplemented(format!(
261 "Unsupported mysql type: {mysql_col:?}",
262 ))),
263 }
264}
265
266fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
267 let mut remote_fields = vec![];
268 for col in row.columns_ref() {
269 remote_fields.push(RemoteField::new(
270 col.name_str().to_string(),
271 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
272 true,
273 ));
274 }
275 Ok(RemoteSchema::new(remote_fields))
276}
277
278macro_rules! handle_primitive_type {
279 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
280 let builder = $builder
281 .as_any_mut()
282 .downcast_mut::<$builder_ty>()
283 .unwrap_or_else(|| {
284 panic!(
285 "Failed to downcast builder to {} for {:?} and {:?}",
286 stringify!($builder_ty),
287 $field,
288 $col
289 )
290 });
291 let v = $row.get_opt::<$value_ty, usize>($index);
292
293 match v {
294 None => builder.append_null(),
295 Some(Ok(v)) => builder.append_value($convert(v)?),
296 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
297 Some(Err(e)) => {
298 return Err(DataFusionError::Execution(format!(
299 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
300 stringify!($value_ty),
301 $field,
302 $col,
303 )));
304 }
305 }
306 }};
307}
308
309fn rows_to_batch(
310 rows: &[Row],
311 table_schema: &SchemaRef,
312 projection: Option<&Vec<usize>>,
313) -> DFResult<RecordBatch> {
314 let projected_schema = project_schema(table_schema, projection)?;
315 let mut array_builders = vec![];
316 for field in table_schema.fields() {
317 let builder = make_builder(field.data_type(), rows.len());
318 array_builders.push(builder);
319 }
320
321 for row in rows {
322 for (idx, field) in table_schema.fields.iter().enumerate() {
323 if !projections_contains(projection, idx) {
324 continue;
325 }
326 let builder = &mut array_builders[idx];
327 let col = row.columns_ref().get(idx);
328 match field.data_type() {
329 DataType::Int8 => {
330 handle_primitive_type!(
331 builder,
332 field,
333 col,
334 Int8Builder,
335 i8,
336 row,
337 idx,
338 just_return
339 );
340 }
341 DataType::Int16 => {
342 handle_primitive_type!(
343 builder,
344 field,
345 col,
346 Int16Builder,
347 i16,
348 row,
349 idx,
350 just_return
351 );
352 }
353 DataType::Int32 => {
354 handle_primitive_type!(
355 builder,
356 field,
357 col,
358 Int32Builder,
359 i32,
360 row,
361 idx,
362 just_return
363 );
364 }
365 DataType::Int64 => {
366 handle_primitive_type!(
367 builder,
368 field,
369 col,
370 Int64Builder,
371 i64,
372 row,
373 idx,
374 just_return
375 );
376 }
377 DataType::UInt8 => {
378 handle_primitive_type!(
379 builder,
380 field,
381 col,
382 UInt8Builder,
383 u8,
384 row,
385 idx,
386 just_return
387 );
388 }
389 DataType::UInt16 => {
390 handle_primitive_type!(
391 builder,
392 field,
393 col,
394 UInt16Builder,
395 u16,
396 row,
397 idx,
398 just_return
399 );
400 }
401 DataType::UInt32 => {
402 handle_primitive_type!(
403 builder,
404 field,
405 col,
406 UInt32Builder,
407 u32,
408 row,
409 idx,
410 just_return
411 );
412 }
413 DataType::UInt64 => {
414 handle_primitive_type!(
415 builder,
416 field,
417 col,
418 UInt64Builder,
419 u64,
420 row,
421 idx,
422 just_return
423 );
424 }
425 DataType::Float32 => {
426 handle_primitive_type!(
427 builder,
428 field,
429 col,
430 Float32Builder,
431 f32,
432 row,
433 idx,
434 just_return
435 );
436 }
437 DataType::Float64 => {
438 handle_primitive_type!(
439 builder,
440 field,
441 col,
442 Float64Builder,
443 f64,
444 row,
445 idx,
446 just_return
447 );
448 }
449 DataType::Decimal128(_precision, scale) => {
450 handle_primitive_type!(
451 builder,
452 field,
453 col,
454 Decimal128Builder,
455 BigDecimal,
456 row,
457 idx,
458 |v: BigDecimal| {
459 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
460 DataFusionError::Execution(format!(
461 "Failed to convert BigDecimal {v:?} to i128"
462 ))
463 })
464 }
465 );
466 }
467 DataType::Decimal256(_precision, _scale) => {
468 handle_primitive_type!(
469 builder,
470 field,
471 col,
472 Decimal256Builder,
473 BigDecimal,
474 row,
475 idx,
476 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
477 );
478 }
479 DataType::Date32 => {
480 handle_primitive_type!(
481 builder,
482 field,
483 col,
484 Date32Builder,
485 chrono::NaiveDate,
486 row,
487 idx,
488 |v: chrono::NaiveDate| {
489 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
490 }
491 );
492 }
493 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
494 match tz_opt {
495 None => {}
496 Some(tz) => {
497 if !tz.eq_ignore_ascii_case("utc") {
498 return Err(DataFusionError::NotImplemented(format!(
499 "Unsupported data type {:?} for col: {:?}",
500 field.data_type(),
501 col
502 )));
503 }
504 }
505 }
506 handle_primitive_type!(
507 builder,
508 field,
509 col,
510 TimestampMicrosecondBuilder,
511 time::PrimitiveDateTime,
512 row,
513 idx,
514 |v: time::PrimitiveDateTime| {
515 let timestamp_micros =
516 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
517 Ok::<_, DataFusionError>(timestamp_micros)
518 }
519 );
520 }
521 DataType::Time32(TimeUnit::Second) => {
522 handle_primitive_type!(
523 builder,
524 field,
525 col,
526 Time32SecondBuilder,
527 chrono::NaiveTime,
528 row,
529 idx,
530 |v: chrono::NaiveTime| {
531 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
532 }
533 );
534 }
535 DataType::Time64(TimeUnit::Nanosecond) => {
536 handle_primitive_type!(
537 builder,
538 field,
539 col,
540 Time64NanosecondBuilder,
541 chrono::NaiveTime,
542 row,
543 idx,
544 |v: chrono::NaiveTime| {
545 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
546 + i64::from(v.nanosecond());
547 Ok::<_, DataFusionError>(t)
548 }
549 );
550 }
551 DataType::Utf8 => {
552 handle_primitive_type!(
553 builder,
554 field,
555 col,
556 StringBuilder,
557 String,
558 row,
559 idx,
560 just_return
561 );
562 }
563 DataType::LargeUtf8 => {
564 handle_primitive_type!(
565 builder,
566 field,
567 col,
568 LargeStringBuilder,
569 String,
570 row,
571 idx,
572 just_return
573 );
574 }
575 DataType::Binary => {
576 handle_primitive_type!(
577 builder,
578 field,
579 col,
580 BinaryBuilder,
581 Vec<u8>,
582 row,
583 idx,
584 just_return
585 );
586 }
587 DataType::LargeBinary => {
588 handle_primitive_type!(
589 builder,
590 field,
591 col,
592 LargeBinaryBuilder,
593 Vec<u8>,
594 row,
595 idx,
596 just_return
597 );
598 }
599 _ => {
600 return Err(DataFusionError::NotImplemented(format!(
601 "Unsupported data type {:?} for col: {:?}",
602 field.data_type(),
603 col
604 )));
605 }
606 }
607 }
608 }
609 let projected_columns = array_builders
610 .into_iter()
611 .enumerate()
612 .filter(|(idx, _)| projections_contains(projection, *idx))
613 .map(|(_, mut builder)| builder.finish())
614 .collect::<Vec<ArrayRef>>();
615 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
616 Ok(RecordBatch::try_new_with_options(
617 projected_schema,
618 projected_columns,
619 &options,
620 )?)
621}
622
623fn to_decimal_256(decimal: &BigDecimal) -> i256 {
624 let (bigint_value, _) = decimal.as_bigint_and_exponent();
625 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
626
627 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
628 let fill_byte = if is_negative { 0xFF } else { 0x00 };
629
630 if bigint_bytes.len() > 32 {
631 bigint_bytes.truncate(32);
632 } else {
633 bigint_bytes.resize(32, fill_byte);
634 };
635
636 let mut array = [0u8; 32];
637 array.copy_from_slice(&bigint_bytes);
638
639 i256::from_le_bytes(array)
640}