1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::big_decimal_to_i128;
3use crate::{
4 Connection, ConnectionOptions, DFResult, Literalize, MysqlConnectionOptions, MysqlType, Pool,
5 RemoteField, RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType,
6};
7use async_stream::stream;
8use bigdecimal::{BigDecimal, num_bigint};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
12 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
13 LargeStringBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time32SecondBuilder,
14 Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt8Builder, UInt16Builder,
15 UInt32Builder, UInt64Builder, make_builder,
16};
17use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
18use datafusion::common::{DataFusionError, project_schema};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::StreamExt;
22use futures::lock::Mutex;
23use log::debug;
24use mysql_async::consts::{ColumnFlags, ColumnType};
25use mysql_async::prelude::Queryable;
26use mysql_async::{Column, FromValueError, Row, Value};
27use std::any::Any;
28use std::sync::Arc;
29
30#[derive(Debug)]
31pub struct MysqlPool {
32 pool: mysql_async::Pool,
33}
34
35pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
36 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
37 mysql_async::PoolConstraints::new(0, options.pool_max_size)
38 .expect("Failed to create pool constraints"),
39 );
40 let opts_builder = mysql_async::OptsBuilder::default()
41 .ip_or_hostname(options.host.clone())
42 .tcp_port(options.port)
43 .user(Some(options.username.clone()))
44 .pass(Some(options.password.clone()))
45 .db_name(options.database.clone())
46 .init(vec!["set time_zone='+00:00'".to_string()])
47 .pool_opts(pool_opts);
48 let pool = mysql_async::Pool::new(opts_builder);
49 Ok(MysqlPool { pool })
50}
51
52#[async_trait::async_trait]
53impl Pool for MysqlPool {
54 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
55 let conn = self.pool.get_conn().await.map_err(|e| {
56 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {e:?}"))
57 })?;
58 Ok(Arc::new(MysqlConnection {
59 conn: Arc::new(Mutex::new(conn)),
60 }))
61 }
62}
63
64#[derive(Debug)]
65pub struct MysqlConnection {
66 conn: Arc<Mutex<mysql_async::Conn>>,
67}
68
69#[async_trait::async_trait]
70impl Connection for MysqlConnection {
71 fn as_any(&self) -> &dyn Any {
72 self
73 }
74
75 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
76 let sql = RemoteDbType::Mysql.limit_1_query_if_possible(source);
77 let mut conn = self.conn.lock().await;
78 let conn = &mut *conn;
79 let stmt = conn.prep(&sql).await.map_err(|e| {
80 DataFusionError::Plan(format!("Failed to prepare query {sql} on mysql: {e:?}"))
81 })?;
82 let remote_schema = Arc::new(build_remote_schema(&stmt)?);
83 Ok(remote_schema)
84 }
85
86 async fn query(
87 &self,
88 conn_options: &ConnectionOptions,
89 source: &RemoteSource,
90 table_schema: SchemaRef,
91 projection: Option<&Vec<usize>>,
92 unparsed_filters: &[String],
93 limit: Option<usize>,
94 ) -> DFResult<SendableRecordBatchStream> {
95 let projected_schema = project_schema(&table_schema, projection)?;
96
97 let sql = RemoteDbType::Mysql.rewrite_query(source, unparsed_filters, limit);
98 debug!("[remote-table] executing mysql query: {sql}");
99
100 let projection = projection.cloned();
101 let chunk_size = conn_options.stream_chunk_size();
102 let conn = Arc::clone(&self.conn);
103 let stream = Box::pin(stream! {
104 let mut conn = conn.lock().await;
105 let mut query_iter = conn
106 .query_iter(sql.clone())
107 .await
108 .map_err(|e| {
109 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
110 })?;
111
112 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
113 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
114 })? else {
115 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
116 return;
117 };
118
119 let mut chunked_stream = stream.chunks(chunk_size).boxed();
120
121 while let Some(chunk) = chunked_stream.next().await {
122 let rows = chunk
123 .into_iter()
124 .collect::<Result<Vec<_>, _>>()
125 .map_err(|e| {
126 DataFusionError::Execution(format!(
127 "Failed to collect rows from mysql due to {e}",
128 ))
129 })?;
130
131 yield Ok::<_, DataFusionError>(rows)
132 }
133 });
134
135 let stream = stream.map(move |rows| {
136 let rows = rows?;
137 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
138 });
139
140 Ok(Box::pin(RecordBatchStreamAdapter::new(
141 projected_schema,
142 stream,
143 )))
144 }
145
146 async fn insert(
147 &self,
148 _conn_options: &ConnectionOptions,
149 _literalizer: Arc<dyn Literalize>,
150 _table: &[String],
151 _remote_schema: RemoteSchemaRef,
152 _input: SendableRecordBatchStream,
153 ) -> DFResult<usize> {
154 Err(DataFusionError::Execution(
155 "Insert operation is not supported for mysql".to_string(),
156 ))
157 }
158}
159
160fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
161 let character_set = mysql_col.character_set();
162 let is_utf8_bin_character_set = character_set == 45;
163 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
164 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
165 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
166 let col_length = mysql_col.column_length();
167 match mysql_col.column_type() {
168 ColumnType::MYSQL_TYPE_TINY => {
169 if is_unsigned {
170 Ok(MysqlType::TinyIntUnsigned)
171 } else {
172 Ok(MysqlType::TinyInt)
173 }
174 }
175 ColumnType::MYSQL_TYPE_SHORT => {
176 if is_unsigned {
177 Ok(MysqlType::SmallIntUnsigned)
178 } else {
179 Ok(MysqlType::SmallInt)
180 }
181 }
182 ColumnType::MYSQL_TYPE_INT24 => {
183 if is_unsigned {
184 Ok(MysqlType::MediumIntUnsigned)
185 } else {
186 Ok(MysqlType::MediumInt)
187 }
188 }
189 ColumnType::MYSQL_TYPE_LONG => {
190 if is_unsigned {
191 Ok(MysqlType::IntegerUnsigned)
192 } else {
193 Ok(MysqlType::Integer)
194 }
195 }
196 ColumnType::MYSQL_TYPE_LONGLONG => {
197 if is_unsigned {
198 Ok(MysqlType::BigIntUnsigned)
199 } else {
200 Ok(MysqlType::BigInt)
201 }
202 }
203 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
204 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
205 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
206 let precision = (mysql_col.column_length() - 2) as u8;
207 let scale = mysql_col.decimals();
208 Ok(MysqlType::Decimal(precision, scale))
209 }
210 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
211 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
212 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
213 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
214 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
215 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
216 ColumnType::MYSQL_TYPE_STRING if is_binary => {
217 if is_utf8_bin_character_set {
218 Ok(MysqlType::Char)
219 } else {
220 Ok(MysqlType::Binary)
221 }
222 }
223 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
224 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
225 if is_utf8_bin_character_set {
226 Ok(MysqlType::Varchar)
227 } else {
228 Ok(MysqlType::Varbinary)
229 }
230 }
231 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
232 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
233 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
234 if is_utf8_bin_character_set {
235 Ok(MysqlType::Text(col_length))
236 } else {
237 Ok(MysqlType::Blob(col_length))
238 }
239 }
240 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
241 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
242 _ => Err(DataFusionError::NotImplemented(format!(
243 "Unsupported mysql type: {mysql_col:?}",
244 ))),
245 }
246}
247
248fn build_remote_schema(stmt: &mysql_async::Statement) -> DFResult<RemoteSchema> {
249 let mut remote_fields = vec![];
250 for col in stmt.columns() {
251 remote_fields.push(RemoteField::new(
252 col.name_str().to_string(),
253 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
254 true,
255 ));
256 }
257 Ok(RemoteSchema::new(remote_fields))
258}
259
260macro_rules! handle_primitive_type {
261 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
262 let builder = $builder
263 .as_any_mut()
264 .downcast_mut::<$builder_ty>()
265 .unwrap_or_else(|| {
266 panic!(
267 "Failed to downcast builder to {} for {:?} and {:?}",
268 stringify!($builder_ty),
269 $field,
270 $col
271 )
272 });
273 let v = $row.get_opt::<$value_ty, usize>($index);
274
275 match v {
276 None => builder.append_null(),
277 Some(Ok(v)) => builder.append_value($convert(v)?),
278 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
279 Some(Err(e)) => {
280 return Err(DataFusionError::Execution(format!(
281 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
282 stringify!($value_ty),
283 $field,
284 $col,
285 )));
286 }
287 }
288 }};
289}
290
291fn rows_to_batch(
292 rows: &[Row],
293 table_schema: &SchemaRef,
294 projection: Option<&Vec<usize>>,
295) -> DFResult<RecordBatch> {
296 let projected_schema = project_schema(table_schema, projection)?;
297 let mut array_builders = vec![];
298 for field in table_schema.fields() {
299 let builder = make_builder(field.data_type(), rows.len());
300 array_builders.push(builder);
301 }
302
303 for row in rows {
304 for (idx, field) in table_schema.fields.iter().enumerate() {
305 if !projections_contains(projection, idx) {
306 continue;
307 }
308 let builder = &mut array_builders[idx];
309 let col = row.columns_ref().get(idx);
310 match field.data_type() {
311 DataType::Int8 => {
312 handle_primitive_type!(
313 builder,
314 field,
315 col,
316 Int8Builder,
317 i8,
318 row,
319 idx,
320 just_return
321 );
322 }
323 DataType::Int16 => {
324 handle_primitive_type!(
325 builder,
326 field,
327 col,
328 Int16Builder,
329 i16,
330 row,
331 idx,
332 just_return
333 );
334 }
335 DataType::Int32 => {
336 handle_primitive_type!(
337 builder,
338 field,
339 col,
340 Int32Builder,
341 i32,
342 row,
343 idx,
344 just_return
345 );
346 }
347 DataType::Int64 => {
348 handle_primitive_type!(
349 builder,
350 field,
351 col,
352 Int64Builder,
353 i64,
354 row,
355 idx,
356 just_return
357 );
358 }
359 DataType::UInt8 => {
360 handle_primitive_type!(
361 builder,
362 field,
363 col,
364 UInt8Builder,
365 u8,
366 row,
367 idx,
368 just_return
369 );
370 }
371 DataType::UInt16 => {
372 handle_primitive_type!(
373 builder,
374 field,
375 col,
376 UInt16Builder,
377 u16,
378 row,
379 idx,
380 just_return
381 );
382 }
383 DataType::UInt32 => {
384 handle_primitive_type!(
385 builder,
386 field,
387 col,
388 UInt32Builder,
389 u32,
390 row,
391 idx,
392 just_return
393 );
394 }
395 DataType::UInt64 => {
396 handle_primitive_type!(
397 builder,
398 field,
399 col,
400 UInt64Builder,
401 u64,
402 row,
403 idx,
404 just_return
405 );
406 }
407 DataType::Float32 => {
408 handle_primitive_type!(
409 builder,
410 field,
411 col,
412 Float32Builder,
413 f32,
414 row,
415 idx,
416 just_return
417 );
418 }
419 DataType::Float64 => {
420 handle_primitive_type!(
421 builder,
422 field,
423 col,
424 Float64Builder,
425 f64,
426 row,
427 idx,
428 just_return
429 );
430 }
431 DataType::Decimal128(_precision, scale) => {
432 handle_primitive_type!(
433 builder,
434 field,
435 col,
436 Decimal128Builder,
437 BigDecimal,
438 row,
439 idx,
440 |v: BigDecimal| { big_decimal_to_i128(&v, Some(*scale as i32)) }
441 );
442 }
443 DataType::Decimal256(_precision, _scale) => {
444 handle_primitive_type!(
445 builder,
446 field,
447 col,
448 Decimal256Builder,
449 BigDecimal,
450 row,
451 idx,
452 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
453 );
454 }
455 DataType::Date32 => {
456 handle_primitive_type!(
457 builder,
458 field,
459 col,
460 Date32Builder,
461 chrono::NaiveDate,
462 row,
463 idx,
464 |v: chrono::NaiveDate| {
465 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
466 }
467 );
468 }
469 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
470 match tz_opt {
471 None => {}
472 Some(tz) => {
473 if !tz.eq_ignore_ascii_case("utc") {
474 return Err(DataFusionError::NotImplemented(format!(
475 "Unsupported data type {:?} for col: {:?}",
476 field.data_type(),
477 col
478 )));
479 }
480 }
481 }
482 handle_primitive_type!(
483 builder,
484 field,
485 col,
486 TimestampMicrosecondBuilder,
487 time::PrimitiveDateTime,
488 row,
489 idx,
490 |v: time::PrimitiveDateTime| {
491 let timestamp_micros =
492 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
493 Ok::<_, DataFusionError>(timestamp_micros)
494 }
495 );
496 }
497 DataType::Time32(TimeUnit::Second) => {
498 handle_primitive_type!(
499 builder,
500 field,
501 col,
502 Time32SecondBuilder,
503 chrono::NaiveTime,
504 row,
505 idx,
506 |v: chrono::NaiveTime| {
507 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
508 }
509 );
510 }
511 DataType::Time64(TimeUnit::Nanosecond) => {
512 handle_primitive_type!(
513 builder,
514 field,
515 col,
516 Time64NanosecondBuilder,
517 chrono::NaiveTime,
518 row,
519 idx,
520 |v: chrono::NaiveTime| {
521 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
522 + i64::from(v.nanosecond());
523 Ok::<_, DataFusionError>(t)
524 }
525 );
526 }
527 DataType::Utf8 => {
528 handle_primitive_type!(
529 builder,
530 field,
531 col,
532 StringBuilder,
533 String,
534 row,
535 idx,
536 just_return
537 );
538 }
539 DataType::LargeUtf8 => {
540 handle_primitive_type!(
541 builder,
542 field,
543 col,
544 LargeStringBuilder,
545 String,
546 row,
547 idx,
548 just_return
549 );
550 }
551 DataType::Binary => {
552 handle_primitive_type!(
553 builder,
554 field,
555 col,
556 BinaryBuilder,
557 Vec<u8>,
558 row,
559 idx,
560 just_return
561 );
562 }
563 DataType::LargeBinary => {
564 handle_primitive_type!(
565 builder,
566 field,
567 col,
568 LargeBinaryBuilder,
569 Vec<u8>,
570 row,
571 idx,
572 just_return
573 );
574 }
575 _ => {
576 return Err(DataFusionError::NotImplemented(format!(
577 "Unsupported data type {:?} for col: {:?}",
578 field.data_type(),
579 col
580 )));
581 }
582 }
583 }
584 }
585 let projected_columns = array_builders
586 .into_iter()
587 .enumerate()
588 .filter(|(idx, _)| projections_contains(projection, *idx))
589 .map(|(_, mut builder)| builder.finish())
590 .collect::<Vec<ArrayRef>>();
591 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
592 Ok(RecordBatch::try_new_with_options(
593 projected_schema,
594 projected_columns,
595 &options,
596 )?)
597}
598
599fn to_decimal_256(decimal: &BigDecimal) -> i256 {
600 let (bigint_value, _) = decimal.as_bigint_and_exponent();
601 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
602
603 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
604 let fill_byte = if is_negative { 0xFF } else { 0x00 };
605
606 if bigint_bytes.len() > 32 {
607 bigint_bytes.truncate(32);
608 } else {
609 bigint_bytes.resize(32, fill_byte);
610 };
611
612 let mut array = [0u8; 32];
613 array.copy_from_slice(&bigint_bytes);
614
615 i256::from_le_bytes(array)
616}