1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13 TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::limit::LimitStream;
20use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
21use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
22use derive_getters::Getters;
23use derive_with::With;
24use futures::StreamExt;
25use futures::lock::Mutex;
26use mysql_async::consts::{ColumnFlags, ColumnType};
27use mysql_async::prelude::Queryable;
28use mysql_async::{Column, FromValueError, Row, Value};
29use std::sync::Arc;
30
31#[derive(Debug, Clone, With, Getters)]
32pub struct MysqlConnectionOptions {
33 pub(crate) host: String,
34 pub(crate) port: u16,
35 pub(crate) username: String,
36 pub(crate) password: String,
37 pub(crate) database: Option<String>,
38 pub(crate) pool_max_size: usize,
39 pub(crate) stream_chunk_size: usize,
40}
41
42impl MysqlConnectionOptions {
43 pub fn new(
44 host: impl Into<String>,
45 port: u16,
46 username: impl Into<String>,
47 password: impl Into<String>,
48 ) -> Self {
49 Self {
50 host: host.into(),
51 port,
52 username: username.into(),
53 password: password.into(),
54 database: None,
55 pool_max_size: 10,
56 stream_chunk_size: 2048,
57 }
58 }
59}
60
61#[derive(Debug)]
62pub struct MysqlPool {
63 pool: mysql_async::Pool,
64}
65
66pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
67 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
68 mysql_async::PoolConstraints::new(0, options.pool_max_size)
69 .expect("Failed to create pool constraints"),
70 );
71 let opts_builder = mysql_async::OptsBuilder::default()
72 .ip_or_hostname(options.host.clone())
73 .tcp_port(options.port)
74 .user(Some(options.username.clone()))
75 .pass(Some(options.password.clone()))
76 .db_name(options.database.clone())
77 .pool_opts(pool_opts);
78 let pool = mysql_async::Pool::new(opts_builder);
79 Ok(MysqlPool { pool })
80}
81
82#[async_trait::async_trait]
83impl Pool for MysqlPool {
84 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
85 let conn = self.pool.get_conn().await.map_err(|e| {
86 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
87 })?;
88 Ok(Arc::new(MysqlConnection {
89 conn: Arc::new(Mutex::new(conn)),
90 }))
91 }
92}
93
94#[derive(Debug)]
95pub struct MysqlConnection {
96 conn: Arc<Mutex<mysql_async::Conn>>,
97}
98
99#[async_trait::async_trait]
100impl Connection for MysqlConnection {
101 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
102 let sql = try_limit_query(sql, 1).unwrap_or_else(|| sql.to_string());
103 let mut conn = self.conn.lock().await;
104 let conn = &mut *conn;
105 let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
106 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
107 })?;
108 let Some(row) = row else {
109 return Err(DataFusionError::Execution(
110 "No rows returned to infer schema".to_string(),
111 ));
112 };
113 let remote_schema = Arc::new(build_remote_schema(&row)?);
114 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
115 Ok((remote_schema, arrow_schema))
116 }
117
118 async fn query(
119 &self,
120 conn_options: &ConnectionOptions,
121 sql: &str,
122 table_schema: SchemaRef,
123 projection: Option<&Vec<usize>>,
124 limit: Option<usize>,
125 ) -> DFResult<SendableRecordBatchStream> {
126 let projected_schema = project_schema(&table_schema, projection)?;
127 let (sql, limit_stream) = match limit {
128 Some(limit) => {
129 if let Some(limited_sql) = try_limit_query(sql, limit) {
130 (limited_sql, false)
131 } else {
132 (sql.to_string(), true)
133 }
134 }
135 None => (sql.to_string(), false),
136 };
137 let projection = projection.cloned();
138 let chunk_size = conn_options.stream_chunk_size();
139 let conn = Arc::clone(&self.conn);
140 let stream = Box::pin(stream! {
141 let mut conn = conn.lock().await;
142 let mut query_iter = conn
143 .query_iter(sql)
144 .await
145 .map_err(|e| {
146 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
147 })?;
148
149 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
150 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
151 })? else {
152 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
153 return;
154 };
155
156 let mut chunked_stream = stream.chunks(chunk_size).boxed();
157
158 while let Some(chunk) = chunked_stream.next().await {
159 let rows = chunk
160 .into_iter()
161 .collect::<Result<Vec<_>, _>>()
162 .map_err(|e| {
163 DataFusionError::Execution(format!(
164 "Failed to collect rows from mysql due to {e}",
165 ))
166 })?;
167
168 yield Ok::<_, DataFusionError>(rows)
169 }
170 });
171
172 let stream = stream.map(move |rows| {
173 let rows = rows?;
174 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
175 });
176
177 let sendable_stream = Box::pin(RecordBatchStreamAdapter::new(projected_schema, stream));
178
179 if limit_stream {
180 let metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
181 Ok(Box::pin(LimitStream::new(
182 sendable_stream,
183 0,
184 limit,
185 metrics,
186 )))
187 } else {
188 Ok(sendable_stream)
189 }
190 }
191}
192
193fn try_limit_query(sql: &str, limit: usize) -> Option<String> {
194 if sql.trim()[0..6].eq_ignore_ascii_case("select") {
195 Some(format!("SELECT * FROM ({sql}) as __subquery LIMIT {limit}"))
196 } else {
197 None
198 }
199}
200
201fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
202 let character_set = mysql_col.character_set();
203 let is_utf8_bin_character_set = character_set == 45;
204 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
205 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
206 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
207 let col_length = mysql_col.column_length();
208 match mysql_col.column_type() {
209 ColumnType::MYSQL_TYPE_TINY => {
210 if is_unsigned {
211 Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
212 } else {
213 Ok(RemoteType::Mysql(MysqlType::TinyInt))
214 }
215 }
216 ColumnType::MYSQL_TYPE_SHORT => {
217 if is_unsigned {
218 Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
219 } else {
220 Ok(RemoteType::Mysql(MysqlType::SmallInt))
221 }
222 }
223 ColumnType::MYSQL_TYPE_INT24 => {
224 if is_unsigned {
225 Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
226 } else {
227 Ok(RemoteType::Mysql(MysqlType::MediumInt))
228 }
229 }
230 ColumnType::MYSQL_TYPE_LONG => {
231 if is_unsigned {
232 Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
233 } else {
234 Ok(RemoteType::Mysql(MysqlType::Integer))
235 }
236 }
237 ColumnType::MYSQL_TYPE_LONGLONG => {
238 if is_unsigned {
239 Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
240 } else {
241 Ok(RemoteType::Mysql(MysqlType::BigInt))
242 }
243 }
244 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
245 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
246 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
247 let precision = (mysql_col.column_length() - 2) as u8;
248 let scale = mysql_col.decimals();
249 Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
250 }
251 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
252 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
253 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
254 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
255 ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
256 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
257 ColumnType::MYSQL_TYPE_STRING if is_binary => {
258 if is_utf8_bin_character_set {
259 Ok(RemoteType::Mysql(MysqlType::Char))
260 } else {
261 Ok(RemoteType::Mysql(MysqlType::Binary))
262 }
263 }
264 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
265 Ok(RemoteType::Mysql(MysqlType::Varchar))
266 }
267 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
268 if is_utf8_bin_character_set {
269 Ok(RemoteType::Mysql(MysqlType::Varchar))
270 } else {
271 Ok(RemoteType::Mysql(MysqlType::Varbinary))
272 }
273 }
274 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
275 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
276 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
277 }
278 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
279 if is_utf8_bin_character_set {
280 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
281 } else {
282 Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
283 }
284 }
285 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
286 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
287 _ => Err(DataFusionError::NotImplemented(format!(
288 "Unsupported mysql type: {mysql_col:?}",
289 ))),
290 }
291}
292
293fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
294 let mut remote_fields = vec![];
295 for col in row.columns_ref() {
296 remote_fields.push(RemoteField::new(
297 col.name_str().to_string(),
298 mysql_type_to_remote_type(col)?,
299 true,
300 ));
301 }
302 Ok(RemoteSchema::new(remote_fields))
303}
304
305macro_rules! handle_primitive_type {
306 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
307 let builder = $builder
308 .as_any_mut()
309 .downcast_mut::<$builder_ty>()
310 .unwrap_or_else(|| {
311 panic!(
312 "Failed to downcast builder to {} for {:?} and {:?}",
313 stringify!($builder_ty),
314 $field,
315 $col
316 )
317 });
318 let v = $row.get_opt::<$value_ty, usize>($index);
319
320 match v {
321 None => builder.append_null(),
322 Some(Ok(v)) => builder.append_value($convert(v)?),
323 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
324 Some(Err(e)) => {
325 return Err(DataFusionError::Execution(format!(
326 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
327 stringify!($value_ty),
328 $field,
329 $col,
330 )));
331 }
332 }
333 }};
334}
335
336fn rows_to_batch(
337 rows: &[Row],
338 table_schema: &SchemaRef,
339 projection: Option<&Vec<usize>>,
340) -> DFResult<RecordBatch> {
341 let projected_schema = project_schema(table_schema, projection)?;
342 let mut array_builders = vec![];
343 for field in table_schema.fields() {
344 let builder = make_builder(field.data_type(), rows.len());
345 array_builders.push(builder);
346 }
347
348 for row in rows {
349 for (idx, field) in table_schema.fields.iter().enumerate() {
350 if !projections_contains(projection, idx) {
351 continue;
352 }
353 let builder = &mut array_builders[idx];
354 let col = row.columns_ref().get(idx);
355 match field.data_type() {
356 DataType::Int8 => {
357 handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
358 Ok::<_, DataFusionError>(v)
359 });
360 }
361 DataType::Int16 => {
362 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
363 Ok::<_, DataFusionError>(v)
364 });
365 }
366 DataType::Int32 => {
367 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
368 Ok::<_, DataFusionError>(v)
369 });
370 }
371 DataType::Int64 => {
372 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
373 Ok::<_, DataFusionError>(v)
374 });
375 }
376 DataType::UInt8 => {
377 handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
378 Ok::<_, DataFusionError>(v)
379 });
380 }
381 DataType::UInt16 => {
382 handle_primitive_type!(
383 builder,
384 field,
385 col,
386 UInt16Builder,
387 u16,
388 row,
389 idx,
390 |v| { Ok::<_, DataFusionError>(v) }
391 );
392 }
393 DataType::UInt32 => {
394 handle_primitive_type!(
395 builder,
396 field,
397 col,
398 UInt32Builder,
399 u32,
400 row,
401 idx,
402 |v| { Ok::<_, DataFusionError>(v) }
403 );
404 }
405 DataType::UInt64 => {
406 handle_primitive_type!(
407 builder,
408 field,
409 col,
410 UInt64Builder,
411 u64,
412 row,
413 idx,
414 |v| { Ok::<_, DataFusionError>(v) }
415 );
416 }
417 DataType::Float32 => {
418 handle_primitive_type!(
419 builder,
420 field,
421 col,
422 Float32Builder,
423 f32,
424 row,
425 idx,
426 |v| { Ok::<_, DataFusionError>(v) }
427 );
428 }
429 DataType::Float64 => {
430 handle_primitive_type!(
431 builder,
432 field,
433 col,
434 Float64Builder,
435 f64,
436 row,
437 idx,
438 |v| { Ok::<_, DataFusionError>(v) }
439 );
440 }
441 DataType::Decimal128(_precision, scale) => {
442 handle_primitive_type!(
443 builder,
444 field,
445 col,
446 Decimal128Builder,
447 BigDecimal,
448 row,
449 idx,
450 |v: BigDecimal| {
451 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
452 DataFusionError::Execution(format!(
453 "Failed to convert BigDecimal {v:?} to i128"
454 ))
455 })
456 }
457 );
458 }
459 DataType::Decimal256(_precision, _scale) => {
460 handle_primitive_type!(
461 builder,
462 field,
463 col,
464 Decimal256Builder,
465 BigDecimal,
466 row,
467 idx,
468 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
469 );
470 }
471 DataType::Date32 => {
472 handle_primitive_type!(
473 builder,
474 field,
475 col,
476 Date32Builder,
477 chrono::NaiveDate,
478 row,
479 idx,
480 |v: chrono::NaiveDate| {
481 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
482 }
483 );
484 }
485 DataType::Timestamp(TimeUnit::Microsecond, None) => {
486 handle_primitive_type!(
487 builder,
488 field,
489 col,
490 TimestampMicrosecondBuilder,
491 time::PrimitiveDateTime,
492 row,
493 idx,
494 |v: time::PrimitiveDateTime| {
495 let timestamp_micros =
496 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
497 Ok::<_, DataFusionError>(timestamp_micros)
498 }
499 );
500 }
501 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
502 handle_primitive_type!(
503 builder,
504 field,
505 col,
506 TimestampNanosecondBuilder,
507 chrono::NaiveTime,
508 row,
509 idx,
510 |v: chrono::NaiveTime| {
511 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
512 + i64::from(v.nanosecond());
513 Ok::<_, DataFusionError>(t)
514 }
515 );
516 }
517 DataType::Time32(TimeUnit::Second) => {
518 handle_primitive_type!(
519 builder,
520 field,
521 col,
522 Time32SecondBuilder,
523 chrono::NaiveTime,
524 row,
525 idx,
526 |v: chrono::NaiveTime| {
527 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
528 }
529 );
530 }
531 DataType::Time64(TimeUnit::Nanosecond) => {
532 handle_primitive_type!(
533 builder,
534 field,
535 col,
536 Time64NanosecondBuilder,
537 chrono::NaiveTime,
538 row,
539 idx,
540 |v: chrono::NaiveTime| {
541 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
542 + i64::from(v.nanosecond());
543 Ok::<_, DataFusionError>(t)
544 }
545 );
546 }
547 DataType::Utf8 => {
548 handle_primitive_type!(
549 builder,
550 field,
551 col,
552 StringBuilder,
553 String,
554 row,
555 idx,
556 |v| { Ok::<_, DataFusionError>(v) }
557 );
558 }
559 DataType::LargeUtf8 => {
560 handle_primitive_type!(
561 builder,
562 field,
563 col,
564 LargeStringBuilder,
565 String,
566 row,
567 idx,
568 |v| { Ok::<_, DataFusionError>(v) }
569 );
570 }
571 DataType::Binary => {
572 handle_primitive_type!(
573 builder,
574 field,
575 col,
576 BinaryBuilder,
577 Vec<u8>,
578 row,
579 idx,
580 |v| { Ok::<_, DataFusionError>(v) }
581 );
582 }
583 DataType::LargeBinary => {
584 handle_primitive_type!(
585 builder,
586 field,
587 col,
588 LargeBinaryBuilder,
589 Vec<u8>,
590 row,
591 idx,
592 |v| { Ok::<_, DataFusionError>(v) }
593 );
594 }
595 _ => {
596 return Err(DataFusionError::NotImplemented(format!(
597 "Unsupported data type {:?} for col: {:?}",
598 field.data_type(),
599 col
600 )));
601 }
602 }
603 }
604 }
605 let projected_columns = array_builders
606 .into_iter()
607 .enumerate()
608 .filter(|(idx, _)| projections_contains(projection, *idx))
609 .map(|(_, mut builder)| builder.finish())
610 .collect::<Vec<ArrayRef>>();
611 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
612}
613
614fn to_decimal_256(decimal: &BigDecimal) -> i256 {
615 let (bigint_value, _) = decimal.as_bigint_and_exponent();
616 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
617
618 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
619 let fill_byte = if is_negative { 0xFF } else { 0x00 };
620
621 if bigint_bytes.len() > 32 {
622 bigint_bytes.truncate(32);
623 } else {
624 bigint_bytes.resize(32, fill_byte);
625 };
626
627 let mut array = [0u8; 32];
628 array.copy_from_slice(&bigint_bytes);
629
630 i256::from_le_bytes(array)
631}