1use crate::connection::{RemoteDbType, big_decimal_to_i128, just_return, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
13 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use log::debug;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33 pub(crate) host: String,
34 pub(crate) port: u16,
35 pub(crate) username: String,
36 pub(crate) password: String,
37 pub(crate) database: Option<String>,
38 pub(crate) pool_max_size: usize,
39 pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43 pub fn new(
44 host: impl Into<String>,
45 port: u16,
46 username: impl Into<String>,
47 password: impl Into<String>,
48 ) -> Self {
49 Self {
50 host: host.into(),
51 port,
52 username: username.into(),
53 password: password.into(),
54 database: None,
55 pool_max_size: 10,
56 stream_chunk_size: 2048,
57 }
58 }
59}
60
61impl From<MysqlConnectionOptions> for ConnectionOptions {
62 fn from(options: MysqlConnectionOptions) -> Self {
63 ConnectionOptions::Mysql(options)
64 }
65}
66
67#[derive(Debug)]
68pub struct MysqlPool {
69 pool: mysql_async::Pool,
70}
71
72pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
73 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
74 mysql_async::PoolConstraints::new(0, options.pool_max_size)
75 .expect("Failed to create pool constraints"),
76 );
77 let opts_builder = mysql_async::OptsBuilder::default()
78 .ip_or_hostname(options.host.clone())
79 .tcp_port(options.port)
80 .user(Some(options.username.clone()))
81 .pass(Some(options.password.clone()))
82 .db_name(options.database.clone())
83 .init(vec!["set time_zone='+00:00'".to_string()])
84 .pool_opts(pool_opts);
85 let pool = mysql_async::Pool::new(opts_builder);
86 Ok(MysqlPool { pool })
87}
88
89#[async_trait::async_trait]
90impl Pool for MysqlPool {
91 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
92 let conn = self.pool.get_conn().await.map_err(|e| {
93 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
94 })?;
95 Ok(Arc::new(MysqlConnection {
96 conn: Arc::new(Mutex::new(conn)),
97 }))
98 }
99}
100
101#[derive(Debug)]
102pub struct MysqlConnection {
103 conn: Arc<Mutex<mysql_async::Conn>>,
104}
105
106#[async_trait::async_trait]
107impl Connection for MysqlConnection {
108 fn as_any(&self) -> &dyn Any {
109 self
110 }
111
112 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
113 let sql = RemoteDbType::Mysql.query_limit_1(sql);
114 let mut conn = self.conn.lock().await;
115 let conn = &mut *conn;
116 let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
117 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
118 })?;
119 let Some(row) = row else {
120 return Err(DataFusionError::Execution(
121 "No rows returned to infer schema".to_string(),
122 ));
123 };
124 let remote_schema = Arc::new(build_remote_schema(&row)?);
125 Ok(remote_schema)
126 }
127
128 async fn query(
129 &self,
130 conn_options: &ConnectionOptions,
131 sql: &str,
132 table_schema: SchemaRef,
133 projection: Option<&Vec<usize>>,
134 unparsed_filters: &[String],
135 limit: Option<usize>,
136 ) -> DFResult<SendableRecordBatchStream> {
137 let projected_schema = project_schema(&table_schema, projection)?;
138
139 let sql = RemoteDbType::Mysql.rewrite_query(sql, unparsed_filters, limit);
140 debug!("[remote-table] executing mysql query: {sql}");
141
142 let projection = projection.cloned();
143 let chunk_size = conn_options.stream_chunk_size();
144 let conn = Arc::clone(&self.conn);
145 let stream = Box::pin(stream! {
146 let mut conn = conn.lock().await;
147 let mut query_iter = conn
148 .query_iter(sql.clone())
149 .await
150 .map_err(|e| {
151 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
152 })?;
153
154 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
155 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
156 })? else {
157 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
158 return;
159 };
160
161 let mut chunked_stream = stream.chunks(chunk_size).boxed();
162
163 while let Some(chunk) = chunked_stream.next().await {
164 let rows = chunk
165 .into_iter()
166 .collect::<Result<Vec<_>, _>>()
167 .map_err(|e| {
168 DataFusionError::Execution(format!(
169 "Failed to collect rows from mysql due to {e}",
170 ))
171 })?;
172
173 yield Ok::<_, DataFusionError>(rows)
174 }
175 });
176
177 let stream = stream.map(move |rows| {
178 let rows = rows?;
179 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
180 });
181
182 Ok(Box::pin(RecordBatchStreamAdapter::new(
183 projected_schema,
184 stream,
185 )))
186 }
187}
188
189fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
190 let character_set = mysql_col.character_set();
191 let is_utf8_bin_character_set = character_set == 45;
192 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
193 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
194 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
195 let col_length = mysql_col.column_length();
196 match mysql_col.column_type() {
197 ColumnType::MYSQL_TYPE_TINY => {
198 if is_unsigned {
199 Ok(MysqlType::TinyIntUnsigned)
200 } else {
201 Ok(MysqlType::TinyInt)
202 }
203 }
204 ColumnType::MYSQL_TYPE_SHORT => {
205 if is_unsigned {
206 Ok(MysqlType::SmallIntUnsigned)
207 } else {
208 Ok(MysqlType::SmallInt)
209 }
210 }
211 ColumnType::MYSQL_TYPE_INT24 => {
212 if is_unsigned {
213 Ok(MysqlType::MediumIntUnsigned)
214 } else {
215 Ok(MysqlType::MediumInt)
216 }
217 }
218 ColumnType::MYSQL_TYPE_LONG => {
219 if is_unsigned {
220 Ok(MysqlType::IntegerUnsigned)
221 } else {
222 Ok(MysqlType::Integer)
223 }
224 }
225 ColumnType::MYSQL_TYPE_LONGLONG => {
226 if is_unsigned {
227 Ok(MysqlType::BigIntUnsigned)
228 } else {
229 Ok(MysqlType::BigInt)
230 }
231 }
232 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
233 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
234 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
235 let precision = (mysql_col.column_length() - 2) as u8;
236 let scale = mysql_col.decimals();
237 Ok(MysqlType::Decimal(precision, scale))
238 }
239 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
240 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
241 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
242 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
243 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
244 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
245 ColumnType::MYSQL_TYPE_STRING if is_binary => {
246 if is_utf8_bin_character_set {
247 Ok(MysqlType::Char)
248 } else {
249 Ok(MysqlType::Binary)
250 }
251 }
252 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
253 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
254 if is_utf8_bin_character_set {
255 Ok(MysqlType::Varchar)
256 } else {
257 Ok(MysqlType::Varbinary)
258 }
259 }
260 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
261 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
262 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
263 if is_utf8_bin_character_set {
264 Ok(MysqlType::Text(col_length))
265 } else {
266 Ok(MysqlType::Blob(col_length))
267 }
268 }
269 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
270 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
271 _ => Err(DataFusionError::NotImplemented(format!(
272 "Unsupported mysql type: {mysql_col:?}",
273 ))),
274 }
275}
276
277fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
278 let mut remote_fields = vec![];
279 for col in row.columns_ref() {
280 remote_fields.push(RemoteField::new(
281 col.name_str().to_string(),
282 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
283 true,
284 ));
285 }
286 Ok(RemoteSchema::new(remote_fields))
287}
288
289macro_rules! handle_primitive_type {
290 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
291 let builder = $builder
292 .as_any_mut()
293 .downcast_mut::<$builder_ty>()
294 .unwrap_or_else(|| {
295 panic!(
296 "Failed to downcast builder to {} for {:?} and {:?}",
297 stringify!($builder_ty),
298 $field,
299 $col
300 )
301 });
302 let v = $row.get_opt::<$value_ty, usize>($index);
303
304 match v {
305 None => builder.append_null(),
306 Some(Ok(v)) => builder.append_value($convert(v)?),
307 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
308 Some(Err(e)) => {
309 return Err(DataFusionError::Execution(format!(
310 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
311 stringify!($value_ty),
312 $field,
313 $col,
314 )));
315 }
316 }
317 }};
318}
319
320fn rows_to_batch(
321 rows: &[Row],
322 table_schema: &SchemaRef,
323 projection: Option<&Vec<usize>>,
324) -> DFResult<RecordBatch> {
325 let projected_schema = project_schema(table_schema, projection)?;
326 let mut array_builders = vec![];
327 for field in table_schema.fields() {
328 let builder = make_builder(field.data_type(), rows.len());
329 array_builders.push(builder);
330 }
331
332 for row in rows {
333 for (idx, field) in table_schema.fields.iter().enumerate() {
334 if !projections_contains(projection, idx) {
335 continue;
336 }
337 let builder = &mut array_builders[idx];
338 let col = row.columns_ref().get(idx);
339 match field.data_type() {
340 DataType::Int8 => {
341 handle_primitive_type!(
342 builder,
343 field,
344 col,
345 Int8Builder,
346 i8,
347 row,
348 idx,
349 just_return
350 );
351 }
352 DataType::Int16 => {
353 handle_primitive_type!(
354 builder,
355 field,
356 col,
357 Int16Builder,
358 i16,
359 row,
360 idx,
361 just_return
362 );
363 }
364 DataType::Int32 => {
365 handle_primitive_type!(
366 builder,
367 field,
368 col,
369 Int32Builder,
370 i32,
371 row,
372 idx,
373 just_return
374 );
375 }
376 DataType::Int64 => {
377 handle_primitive_type!(
378 builder,
379 field,
380 col,
381 Int64Builder,
382 i64,
383 row,
384 idx,
385 just_return
386 );
387 }
388 DataType::UInt8 => {
389 handle_primitive_type!(
390 builder,
391 field,
392 col,
393 UInt8Builder,
394 u8,
395 row,
396 idx,
397 just_return
398 );
399 }
400 DataType::UInt16 => {
401 handle_primitive_type!(
402 builder,
403 field,
404 col,
405 UInt16Builder,
406 u16,
407 row,
408 idx,
409 just_return
410 );
411 }
412 DataType::UInt32 => {
413 handle_primitive_type!(
414 builder,
415 field,
416 col,
417 UInt32Builder,
418 u32,
419 row,
420 idx,
421 just_return
422 );
423 }
424 DataType::UInt64 => {
425 handle_primitive_type!(
426 builder,
427 field,
428 col,
429 UInt64Builder,
430 u64,
431 row,
432 idx,
433 just_return
434 );
435 }
436 DataType::Float32 => {
437 handle_primitive_type!(
438 builder,
439 field,
440 col,
441 Float32Builder,
442 f32,
443 row,
444 idx,
445 just_return
446 );
447 }
448 DataType::Float64 => {
449 handle_primitive_type!(
450 builder,
451 field,
452 col,
453 Float64Builder,
454 f64,
455 row,
456 idx,
457 just_return
458 );
459 }
460 DataType::Decimal128(_precision, scale) => {
461 handle_primitive_type!(
462 builder,
463 field,
464 col,
465 Decimal128Builder,
466 BigDecimal,
467 row,
468 idx,
469 |v: BigDecimal| {
470 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
471 DataFusionError::Execution(format!(
472 "Failed to convert BigDecimal {v:?} to i128"
473 ))
474 })
475 }
476 );
477 }
478 DataType::Decimal256(_precision, _scale) => {
479 handle_primitive_type!(
480 builder,
481 field,
482 col,
483 Decimal256Builder,
484 BigDecimal,
485 row,
486 idx,
487 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
488 );
489 }
490 DataType::Date32 => {
491 handle_primitive_type!(
492 builder,
493 field,
494 col,
495 Date32Builder,
496 chrono::NaiveDate,
497 row,
498 idx,
499 |v: chrono::NaiveDate| {
500 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
501 }
502 );
503 }
504 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
505 match tz_opt {
506 None => {}
507 Some(tz) => {
508 if !tz.eq_ignore_ascii_case("utc") {
509 return Err(DataFusionError::NotImplemented(format!(
510 "Unsupported data type {:?} for col: {:?}",
511 field.data_type(),
512 col
513 )));
514 }
515 }
516 }
517 handle_primitive_type!(
518 builder,
519 field,
520 col,
521 TimestampMicrosecondBuilder,
522 time::PrimitiveDateTime,
523 row,
524 idx,
525 |v: time::PrimitiveDateTime| {
526 let timestamp_micros =
527 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
528 Ok::<_, DataFusionError>(timestamp_micros)
529 }
530 );
531 }
532 DataType::Time32(TimeUnit::Second) => {
533 handle_primitive_type!(
534 builder,
535 field,
536 col,
537 Time32SecondBuilder,
538 chrono::NaiveTime,
539 row,
540 idx,
541 |v: chrono::NaiveTime| {
542 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
543 }
544 );
545 }
546 DataType::Time64(TimeUnit::Nanosecond) => {
547 handle_primitive_type!(
548 builder,
549 field,
550 col,
551 Time64NanosecondBuilder,
552 chrono::NaiveTime,
553 row,
554 idx,
555 |v: chrono::NaiveTime| {
556 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
557 + i64::from(v.nanosecond());
558 Ok::<_, DataFusionError>(t)
559 }
560 );
561 }
562 DataType::Utf8 => {
563 handle_primitive_type!(
564 builder,
565 field,
566 col,
567 StringBuilder,
568 String,
569 row,
570 idx,
571 just_return
572 );
573 }
574 DataType::LargeUtf8 => {
575 handle_primitive_type!(
576 builder,
577 field,
578 col,
579 LargeStringBuilder,
580 String,
581 row,
582 idx,
583 just_return
584 );
585 }
586 DataType::Binary => {
587 handle_primitive_type!(
588 builder,
589 field,
590 col,
591 BinaryBuilder,
592 Vec<u8>,
593 row,
594 idx,
595 just_return
596 );
597 }
598 DataType::LargeBinary => {
599 handle_primitive_type!(
600 builder,
601 field,
602 col,
603 LargeBinaryBuilder,
604 Vec<u8>,
605 row,
606 idx,
607 just_return
608 );
609 }
610 _ => {
611 return Err(DataFusionError::NotImplemented(format!(
612 "Unsupported data type {:?} for col: {:?}",
613 field.data_type(),
614 col
615 )));
616 }
617 }
618 }
619 }
620 let projected_columns = array_builders
621 .into_iter()
622 .enumerate()
623 .filter(|(idx, _)| projections_contains(projection, *idx))
624 .map(|(_, mut builder)| builder.finish())
625 .collect::<Vec<ArrayRef>>();
626 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
627 Ok(RecordBatch::try_new_with_options(
628 projected_schema,
629 projected_columns,
630 &options,
631 )?)
632}
633
634fn to_decimal_256(decimal: &BigDecimal) -> i256 {
635 let (bigint_value, _) = decimal.as_bigint_and_exponent();
636 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
637
638 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
639 let fill_byte = if is_negative { 0xFF } else { 0x00 };
640
641 if bigint_bytes.len() > 32 {
642 bigint_bytes.truncate(32);
643 } else {
644 bigint_bytes.resize(32, fill_byte);
645 };
646
647 let mut array = [0u8; 32];
648 array.copy_from_slice(&bigint_bytes);
649
650 i256::from_le_bytes(array)
651}