1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13 TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use derive_getters::Getters;
21use derive_with::With;
22use futures::StreamExt;
23use futures::lock::Mutex;
24use mysql_async::consts::{ColumnFlags, ColumnType};
25use mysql_async::prelude::Queryable;
26use mysql_async::{Column, FromValueError, Row, Value};
27use std::sync::Arc;
28
29#[derive(Debug, Clone, With, Getters)]
30pub struct MysqlConnectionOptions {
31 pub(crate) host: String,
32 pub(crate) port: u16,
33 pub(crate) username: String,
34 pub(crate) password: String,
35 pub(crate) database: Option<String>,
36 pub(crate) pool_max_size: usize,
37 pub(crate) stream_chunk_size: usize,
38}
39
40impl MysqlConnectionOptions {
41 pub fn new(
42 host: impl Into<String>,
43 port: u16,
44 username: impl Into<String>,
45 password: impl Into<String>,
46 ) -> Self {
47 Self {
48 host: host.into(),
49 port,
50 username: username.into(),
51 password: password.into(),
52 database: None,
53 pool_max_size: 10,
54 stream_chunk_size: 2048,
55 }
56 }
57}
58
59#[derive(Debug)]
60pub struct MysqlPool {
61 pool: mysql_async::Pool,
62}
63
64pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
65 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
66 mysql_async::PoolConstraints::new(0, options.pool_max_size)
67 .expect("Failed to create pool constraints"),
68 );
69 let opts_builder = mysql_async::OptsBuilder::default()
70 .ip_or_hostname(options.host.clone())
71 .tcp_port(options.port)
72 .user(Some(options.username.clone()))
73 .pass(Some(options.password.clone()))
74 .db_name(options.database.clone())
75 .pool_opts(pool_opts);
76 let pool = mysql_async::Pool::new(opts_builder);
77 Ok(MysqlPool { pool })
78}
79
80#[async_trait::async_trait]
81impl Pool for MysqlPool {
82 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
83 let conn = self.pool.get_conn().await.map_err(|e| {
84 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
85 })?;
86 Ok(Arc::new(MysqlConnection {
87 conn: Arc::new(Mutex::new(conn)),
88 }))
89 }
90}
91
92#[derive(Debug)]
93pub struct MysqlConnection {
94 conn: Arc<Mutex<mysql_async::Conn>>,
95}
96
97#[async_trait::async_trait]
98impl Connection for MysqlConnection {
99 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
100 let sql = try_limit1_query(sql).unwrap_or_else(|| sql.to_string());
101 let mut conn = self.conn.lock().await;
102 let conn = &mut *conn;
103 let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
104 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
105 })?;
106 let Some(row) = row else {
107 return Err(DataFusionError::Execution(
108 "No rows returned to infer schema".to_string(),
109 ));
110 };
111 let remote_schema = Arc::new(build_remote_schema(&row)?);
112 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
113 Ok((remote_schema, arrow_schema))
114 }
115
116 async fn query(
117 &self,
118 conn_options: &ConnectionOptions,
119 sql: &str,
120 table_schema: SchemaRef,
121 projection: Option<&Vec<usize>>,
122 ) -> DFResult<SendableRecordBatchStream> {
123 let projected_schema = project_schema(&table_schema, projection)?;
124 let sql = sql.to_string();
125 let projection = projection.cloned();
126 let chunk_size = conn_options.stream_chunk_size();
127 let conn = Arc::clone(&self.conn);
128 let stream = Box::pin(stream! {
129 let mut conn = conn.lock().await;
130 let mut query_iter = conn
131 .query_iter(sql)
132 .await
133 .map_err(|e| {
134 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
135 })?;
136
137 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
138 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
139 })? else {
140 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
141 return;
142 };
143
144 let mut chunked_stream = stream.chunks(chunk_size).boxed();
145
146 while let Some(chunk) = chunked_stream.next().await {
147 let rows = chunk
148 .into_iter()
149 .collect::<Result<Vec<_>, _>>()
150 .map_err(|e| {
151 DataFusionError::Execution(format!(
152 "Failed to collect rows from mysql due to {e}",
153 ))
154 })?;
155
156 yield Ok::<_, DataFusionError>(rows)
157 }
158 });
159
160 let stream = stream.map(move |rows| {
161 let rows = rows?;
162 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
163 });
164
165 Ok(Box::pin(RecordBatchStreamAdapter::new(
166 projected_schema,
167 stream,
168 )))
169 }
170}
171
172fn try_limit1_query(sql: &str) -> Option<String> {
173 if sql.trim()[0..6].eq_ignore_ascii_case("select") {
174 Some(format!("SELECT * FROM ({sql}) as __subquery LIMIT 1"))
175 } else {
176 None
177 }
178}
179
180fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
181 let character_set = mysql_col.character_set();
182 let is_utf8_bin_character_set = character_set == 45;
183 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
184 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
185 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
186 let col_length = mysql_col.column_length();
187 match mysql_col.column_type() {
188 ColumnType::MYSQL_TYPE_TINY => {
189 if is_unsigned {
190 Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
191 } else {
192 Ok(RemoteType::Mysql(MysqlType::TinyInt))
193 }
194 }
195 ColumnType::MYSQL_TYPE_SHORT => {
196 if is_unsigned {
197 Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
198 } else {
199 Ok(RemoteType::Mysql(MysqlType::SmallInt))
200 }
201 }
202 ColumnType::MYSQL_TYPE_INT24 => {
203 if is_unsigned {
204 Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
205 } else {
206 Ok(RemoteType::Mysql(MysqlType::MediumInt))
207 }
208 }
209 ColumnType::MYSQL_TYPE_LONG => {
210 if is_unsigned {
211 Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
212 } else {
213 Ok(RemoteType::Mysql(MysqlType::Integer))
214 }
215 }
216 ColumnType::MYSQL_TYPE_LONGLONG => {
217 if is_unsigned {
218 Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
219 } else {
220 Ok(RemoteType::Mysql(MysqlType::BigInt))
221 }
222 }
223 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
224 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
225 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
226 let precision = (mysql_col.column_length() - 2) as u8;
227 let scale = mysql_col.decimals();
228 Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
229 }
230 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
231 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
232 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
233 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
234 ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
235 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
236 ColumnType::MYSQL_TYPE_STRING if is_binary => {
237 if is_utf8_bin_character_set {
238 Ok(RemoteType::Mysql(MysqlType::Char))
239 } else {
240 Ok(RemoteType::Mysql(MysqlType::Binary))
241 }
242 }
243 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
244 Ok(RemoteType::Mysql(MysqlType::Varchar))
245 }
246 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
247 if is_utf8_bin_character_set {
248 Ok(RemoteType::Mysql(MysqlType::Varchar))
249 } else {
250 Ok(RemoteType::Mysql(MysqlType::Varbinary))
251 }
252 }
253 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
254 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
255 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
256 }
257 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
258 if is_utf8_bin_character_set {
259 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
260 } else {
261 Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
262 }
263 }
264 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
265 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
266 _ => Err(DataFusionError::NotImplemented(format!(
267 "Unsupported mysql type: {mysql_col:?}",
268 ))),
269 }
270}
271
272fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
273 let mut remote_fields = vec![];
274 for col in row.columns_ref() {
275 remote_fields.push(RemoteField::new(
276 col.name_str().to_string(),
277 mysql_type_to_remote_type(col)?,
278 true,
279 ));
280 }
281 Ok(RemoteSchema::new(remote_fields))
282}
283
284macro_rules! handle_primitive_type {
285 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
286 let builder = $builder
287 .as_any_mut()
288 .downcast_mut::<$builder_ty>()
289 .unwrap_or_else(|| {
290 panic!(
291 "Failed to downcast builder to {} for {:?} and {:?}",
292 stringify!($builder_ty),
293 $field,
294 $col
295 )
296 });
297 let v = $row.get_opt::<$value_ty, usize>($index);
298
299 match v {
300 None => builder.append_null(),
301 Some(Ok(v)) => builder.append_value($convert(v)?),
302 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
303 Some(Err(e)) => {
304 return Err(DataFusionError::Execution(format!(
305 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
306 stringify!($value_ty),
307 $field,
308 $col,
309 )));
310 }
311 }
312 }};
313}
314
315fn rows_to_batch(
316 rows: &[Row],
317 table_schema: &SchemaRef,
318 projection: Option<&Vec<usize>>,
319) -> DFResult<RecordBatch> {
320 let projected_schema = project_schema(table_schema, projection)?;
321 let mut array_builders = vec![];
322 for field in table_schema.fields() {
323 let builder = make_builder(field.data_type(), rows.len());
324 array_builders.push(builder);
325 }
326
327 for row in rows {
328 for (idx, field) in table_schema.fields.iter().enumerate() {
329 if !projections_contains(projection, idx) {
330 continue;
331 }
332 let builder = &mut array_builders[idx];
333 let col = row.columns_ref().get(idx);
334 match field.data_type() {
335 DataType::Int8 => {
336 handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
337 Ok::<_, DataFusionError>(v)
338 });
339 }
340 DataType::Int16 => {
341 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
342 Ok::<_, DataFusionError>(v)
343 });
344 }
345 DataType::Int32 => {
346 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
347 Ok::<_, DataFusionError>(v)
348 });
349 }
350 DataType::Int64 => {
351 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
352 Ok::<_, DataFusionError>(v)
353 });
354 }
355 DataType::UInt8 => {
356 handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
357 Ok::<_, DataFusionError>(v)
358 });
359 }
360 DataType::UInt16 => {
361 handle_primitive_type!(
362 builder,
363 field,
364 col,
365 UInt16Builder,
366 u16,
367 row,
368 idx,
369 |v| { Ok::<_, DataFusionError>(v) }
370 );
371 }
372 DataType::UInt32 => {
373 handle_primitive_type!(
374 builder,
375 field,
376 col,
377 UInt32Builder,
378 u32,
379 row,
380 idx,
381 |v| { Ok::<_, DataFusionError>(v) }
382 );
383 }
384 DataType::UInt64 => {
385 handle_primitive_type!(
386 builder,
387 field,
388 col,
389 UInt64Builder,
390 u64,
391 row,
392 idx,
393 |v| { Ok::<_, DataFusionError>(v) }
394 );
395 }
396 DataType::Float32 => {
397 handle_primitive_type!(
398 builder,
399 field,
400 col,
401 Float32Builder,
402 f32,
403 row,
404 idx,
405 |v| { Ok::<_, DataFusionError>(v) }
406 );
407 }
408 DataType::Float64 => {
409 handle_primitive_type!(
410 builder,
411 field,
412 col,
413 Float64Builder,
414 f64,
415 row,
416 idx,
417 |v| { Ok::<_, DataFusionError>(v) }
418 );
419 }
420 DataType::Decimal128(_precision, scale) => {
421 handle_primitive_type!(
422 builder,
423 field,
424 col,
425 Decimal128Builder,
426 BigDecimal,
427 row,
428 idx,
429 |v: BigDecimal| {
430 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
431 DataFusionError::Execution(format!(
432 "Failed to convert BigDecimal {v:?} to i128"
433 ))
434 })
435 }
436 );
437 }
438 DataType::Decimal256(_precision, _scale) => {
439 handle_primitive_type!(
440 builder,
441 field,
442 col,
443 Decimal256Builder,
444 BigDecimal,
445 row,
446 idx,
447 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
448 );
449 }
450 DataType::Date32 => {
451 handle_primitive_type!(
452 builder,
453 field,
454 col,
455 Date32Builder,
456 chrono::NaiveDate,
457 row,
458 idx,
459 |v: chrono::NaiveDate| {
460 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
461 }
462 );
463 }
464 DataType::Timestamp(TimeUnit::Microsecond, None) => {
465 handle_primitive_type!(
466 builder,
467 field,
468 col,
469 TimestampMicrosecondBuilder,
470 time::PrimitiveDateTime,
471 row,
472 idx,
473 |v: time::PrimitiveDateTime| {
474 let timestamp_micros =
475 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
476 Ok::<_, DataFusionError>(timestamp_micros)
477 }
478 );
479 }
480 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
481 handle_primitive_type!(
482 builder,
483 field,
484 col,
485 TimestampNanosecondBuilder,
486 chrono::NaiveTime,
487 row,
488 idx,
489 |v: chrono::NaiveTime| {
490 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
491 + i64::from(v.nanosecond());
492 Ok::<_, DataFusionError>(t)
493 }
494 );
495 }
496 DataType::Time32(TimeUnit::Second) => {
497 handle_primitive_type!(
498 builder,
499 field,
500 col,
501 Time32SecondBuilder,
502 chrono::NaiveTime,
503 row,
504 idx,
505 |v: chrono::NaiveTime| {
506 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
507 }
508 );
509 }
510 DataType::Time64(TimeUnit::Nanosecond) => {
511 handle_primitive_type!(
512 builder,
513 field,
514 col,
515 Time64NanosecondBuilder,
516 chrono::NaiveTime,
517 row,
518 idx,
519 |v: chrono::NaiveTime| {
520 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
521 + i64::from(v.nanosecond());
522 Ok::<_, DataFusionError>(t)
523 }
524 );
525 }
526 DataType::Utf8 => {
527 handle_primitive_type!(
528 builder,
529 field,
530 col,
531 StringBuilder,
532 String,
533 row,
534 idx,
535 |v| { Ok::<_, DataFusionError>(v) }
536 );
537 }
538 DataType::LargeUtf8 => {
539 handle_primitive_type!(
540 builder,
541 field,
542 col,
543 LargeStringBuilder,
544 String,
545 row,
546 idx,
547 |v| { Ok::<_, DataFusionError>(v) }
548 );
549 }
550 DataType::Binary => {
551 handle_primitive_type!(
552 builder,
553 field,
554 col,
555 BinaryBuilder,
556 Vec<u8>,
557 row,
558 idx,
559 |v| { Ok::<_, DataFusionError>(v) }
560 );
561 }
562 DataType::LargeBinary => {
563 handle_primitive_type!(
564 builder,
565 field,
566 col,
567 LargeBinaryBuilder,
568 Vec<u8>,
569 row,
570 idx,
571 |v| { Ok::<_, DataFusionError>(v) }
572 );
573 }
574 _ => {
575 return Err(DataFusionError::NotImplemented(format!(
576 "Unsupported data type {:?} for col: {:?}",
577 field.data_type(),
578 col
579 )));
580 }
581 }
582 }
583 }
584 let projected_columns = array_builders
585 .into_iter()
586 .enumerate()
587 .filter(|(idx, _)| projections_contains(projection, *idx))
588 .map(|(_, mut builder)| builder.finish())
589 .collect::<Vec<ArrayRef>>();
590 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
591}
592
593fn to_decimal_256(decimal: &BigDecimal) -> i256 {
594 let (bigint_value, _) = decimal.as_bigint_and_exponent();
595 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
596
597 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
598 let fill_byte = if is_negative { 0xFF } else { 0x00 };
599
600 if bigint_bytes.len() > 32 {
601 bigint_bytes.truncate(32);
602 } else {
603 bigint_bytes.resize(32, fill_byte);
604 };
605
606 let mut array = [0u8; 32];
607 array.copy_from_slice(&bigint_bytes);
608
609 i256::from_le_bytes(array)
610}