1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5 RemoteSchemaRef, RemoteType, Transform,
6};
7use async_stream::stream;
8use bigdecimal::{num_bigint, BigDecimal};
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 make_builder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
12 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13 LargeBinaryBuilder, LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder,
14 Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder,
15 UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
16};
17use datafusion::arrow::datatypes::{i256, DataType, Date32Type, SchemaRef, TimeUnit};
18use datafusion::common::{project_schema, DataFusionError};
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::lock::Mutex;
22use futures::StreamExt;
23use mysql_async::consts::{ColumnFlags, ColumnType};
24use mysql_async::prelude::Queryable;
25use mysql_async::{Column, FromValueError, Row, Value};
26use std::sync::Arc;
27
28#[derive(Debug, Clone, derive_with::With)]
29pub struct MysqlConnectionOptions {
30 pub(crate) host: String,
31 pub(crate) port: u16,
32 pub(crate) username: String,
33 pub(crate) password: String,
34 pub(crate) database: Option<String>,
35 pub(crate) chunk_size: Option<usize>,
36}
37
38impl MysqlConnectionOptions {
39 pub fn new(
40 host: impl Into<String>,
41 port: u16,
42 username: impl Into<String>,
43 password: impl Into<String>,
44 ) -> Self {
45 Self {
46 host: host.into(),
47 port,
48 username: username.into(),
49 password: password.into(),
50 database: None,
51 chunk_size: None,
52 }
53 }
54}
55
56#[derive(Debug)]
57pub struct MysqlPool {
58 pool: mysql_async::Pool,
59}
60
61pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
62 let opts_builder = mysql_async::OptsBuilder::default()
63 .ip_or_hostname(options.host.clone())
64 .tcp_port(options.port)
65 .user(Some(options.username.clone()))
66 .pass(Some(options.password.clone()))
67 .db_name(options.database.clone());
68 let pool = mysql_async::Pool::new(opts_builder);
69 Ok(MysqlPool { pool })
70}
71
72#[async_trait::async_trait]
73impl Pool for MysqlPool {
74 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
75 let conn = self.pool.get_conn().await.map_err(|e| {
76 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
77 })?;
78 Ok(Arc::new(MysqlConnection {
79 conn: Arc::new(Mutex::new(conn)),
80 }))
81 }
82}
83
84#[derive(Debug)]
85pub struct MysqlConnection {
86 conn: Arc<Mutex<mysql_async::Conn>>,
87}
88
89#[async_trait::async_trait]
90impl Connection for MysqlConnection {
91 async fn infer_schema(
92 &self,
93 sql: &str,
94 transform: Option<Arc<dyn Transform>>,
95 ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
96 let mut conn = self.conn.lock().await;
97 let conn = &mut *conn;
98 let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
99 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
100 })?;
101 let Some(row) = row else {
102 return Err(DataFusionError::Execution(
103 "No rows returned to infer schema".to_string(),
104 ));
105 };
106 let remote_schema = Arc::new(build_remote_schema(&row)?);
107 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
108 if let Some(transform) = transform {
109 let batch = rows_to_batch(&[row], &arrow_schema, None)?;
110 let transformed_batch = transform_batch(
111 batch,
112 transform.as_ref(),
113 &arrow_schema,
114 None,
115 Some(&remote_schema),
116 )?;
117 Ok((remote_schema, transformed_batch.schema()))
118 } else {
119 Ok((remote_schema, arrow_schema))
120 }
121 }
122
123 async fn query(
124 &self,
125 conn_options: &ConnectionOptions,
126 sql: &str,
127 table_schema: SchemaRef,
128 projection: Option<&Vec<usize>>,
129 ) -> DFResult<SendableRecordBatchStream> {
130 let projected_schema = project_schema(&table_schema, projection)?;
131 let sql = sql.to_string();
132 let projection = projection.cloned();
133 let chunk_size = conn_options.chunk_size();
134 let conn = Arc::clone(&self.conn);
135 let stream = Box::pin(stream! {
136 let mut conn = conn.lock().await;
137 let mut query_iter = conn
138 .query_iter(sql)
139 .await
140 .map_err(|e| {
141 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
142 })?;
143
144 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
145 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
146 })? else {
147 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
148 return;
149 };
150
151 let mut chunked_stream = stream.chunks(chunk_size.unwrap_or(2048)).boxed();
152
153 while let Some(chunk) = chunked_stream.next().await {
154 let rows = chunk
155 .into_iter()
156 .collect::<Result<Vec<_>, _>>()
157 .map_err(|e| {
158 DataFusionError::Execution(format!(
159 "Failed to collect rows from mysql due to {e}",
160 ))
161 })?;
162
163 yield Ok::<_, DataFusionError>(rows)
164 }
165 });
166
167 let stream = stream.map(move |rows| {
168 let rows = rows?;
169 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
170 });
171
172 Ok(Box::pin(RecordBatchStreamAdapter::new(
173 projected_schema,
174 stream,
175 )))
176 }
177}
178
179fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
180 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
181 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
182 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
183 let col_length = mysql_col.column_length();
184 match mysql_col.column_type() {
185 ColumnType::MYSQL_TYPE_TINY => {
186 if is_unsigned {
187 Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
188 } else {
189 Ok(RemoteType::Mysql(MysqlType::TinyInt))
190 }
191 }
192 ColumnType::MYSQL_TYPE_SHORT => {
193 if is_unsigned {
194 Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
195 } else {
196 Ok(RemoteType::Mysql(MysqlType::SmallInt))
197 }
198 }
199 ColumnType::MYSQL_TYPE_INT24 => {
200 if is_unsigned {
201 Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
202 } else {
203 Ok(RemoteType::Mysql(MysqlType::MediumInt))
204 }
205 }
206 ColumnType::MYSQL_TYPE_LONG => {
207 if is_unsigned {
208 Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
209 } else {
210 Ok(RemoteType::Mysql(MysqlType::Integer))
211 }
212 }
213 ColumnType::MYSQL_TYPE_LONGLONG => {
214 if is_unsigned {
215 Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
216 } else {
217 Ok(RemoteType::Mysql(MysqlType::BigInt))
218 }
219 }
220 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
221 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
222 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
223 let precision = (mysql_col.column_length() - 2) as u8;
224 let scale = mysql_col.decimals();
225 Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
226 }
227 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
228 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
229 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
230 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
231 ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
232 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
233 ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
234 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
235 Ok(RemoteType::Mysql(MysqlType::Varchar))
236 }
237 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
238 Ok(RemoteType::Mysql(MysqlType::Varbinary))
239 }
240 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
241 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
242 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
243 }
244 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
245 Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
246 }
247 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
248 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
249 _ => Err(DataFusionError::NotImplemented(format!(
250 "Unsupported mysql type: {mysql_col:?}",
251 ))),
252 }
253}
254
255fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
256 let mut remote_fields = vec![];
257 for col in row.columns_ref() {
258 remote_fields.push(RemoteField::new(
259 col.name_str().to_string(),
260 mysql_type_to_remote_type(col)?,
261 true,
262 ));
263 }
264 Ok(RemoteSchema::new(remote_fields))
265}
266
267macro_rules! handle_primitive_type {
268 ($builder:expr, $field:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
269 let builder = $builder
270 .as_any_mut()
271 .downcast_mut::<$builder_ty>()
272 .unwrap_or_else(|| {
273 panic!(
274 concat!(
275 "Failed to downcast builder to ",
276 stringify!($builder_ty),
277 " for {:?} and {:?}"
278 ),
279 $field, $mysql_col
280 )
281 });
282 let v = $row.get_opt::<$value_ty, usize>($index);
283
284 match v {
285 None => builder.append_null(),
286 Some(Ok(v)) => builder.append_value($convert(v)?),
287 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
288 Some(Err(e)) => {
289 return Err(DataFusionError::Execution(format!(
290 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
291 stringify!($value_ty),
292 $field,
293 $mysql_col,
294 )))
295 }
296 }
297 }};
298}
299
300fn rows_to_batch(
301 rows: &[Row],
302 table_schema: &SchemaRef,
303 projection: Option<&Vec<usize>>,
304) -> DFResult<RecordBatch> {
305 let projected_schema = project_schema(table_schema, projection)?;
306 let mut array_builders = vec![];
307 for field in table_schema.fields() {
308 let builder = make_builder(field.data_type(), rows.len());
309 array_builders.push(builder);
310 }
311
312 for row in rows {
313 for (idx, field) in table_schema.fields.iter().enumerate() {
314 if !projections_contains(projection, idx) {
315 continue;
316 }
317 let builder = &mut array_builders[idx];
318 let col = row.columns_ref().get(idx);
319 match field.data_type() {
320 DataType::Int8 => {
321 handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
322 Ok::<_, DataFusionError>(v)
323 });
324 }
325 DataType::Int16 => {
326 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
327 Ok::<_, DataFusionError>(v)
328 });
329 }
330 DataType::Int32 => {
331 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
332 Ok::<_, DataFusionError>(v)
333 });
334 }
335 DataType::Int64 => {
336 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
337 Ok::<_, DataFusionError>(v)
338 });
339 }
340 DataType::UInt8 => {
341 handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
342 Ok::<_, DataFusionError>(v)
343 });
344 }
345 DataType::UInt16 => {
346 handle_primitive_type!(
347 builder,
348 field,
349 col,
350 UInt16Builder,
351 u16,
352 row,
353 idx,
354 |v| { Ok::<_, DataFusionError>(v) }
355 );
356 }
357 DataType::UInt32 => {
358 handle_primitive_type!(
359 builder,
360 field,
361 col,
362 UInt32Builder,
363 u32,
364 row,
365 idx,
366 |v| { Ok::<_, DataFusionError>(v) }
367 );
368 }
369 DataType::UInt64 => {
370 handle_primitive_type!(
371 builder,
372 field,
373 col,
374 UInt64Builder,
375 u64,
376 row,
377 idx,
378 |v| { Ok::<_, DataFusionError>(v) }
379 );
380 }
381 DataType::Float32 => {
382 handle_primitive_type!(
383 builder,
384 field,
385 col,
386 Float32Builder,
387 f32,
388 row,
389 idx,
390 |v| { Ok::<_, DataFusionError>(v) }
391 );
392 }
393 DataType::Float64 => {
394 handle_primitive_type!(
395 builder,
396 field,
397 col,
398 Float64Builder,
399 f64,
400 row,
401 idx,
402 |v| { Ok::<_, DataFusionError>(v) }
403 );
404 }
405 DataType::Decimal128(_precision, scale) => {
406 handle_primitive_type!(
407 builder,
408 field,
409 col,
410 Decimal128Builder,
411 BigDecimal,
412 row,
413 idx,
414 |v: BigDecimal| {
415 big_decimal_to_i128(&v, Some(*scale as u32)).ok_or_else(|| {
416 DataFusionError::Execution(format!(
417 "Failed to convert BigDecimal {v:?} to i128"
418 ))
419 })
420 }
421 );
422 }
423 DataType::Decimal256(_precision, _scale) => {
424 handle_primitive_type!(
425 builder,
426 field,
427 col,
428 Decimal256Builder,
429 BigDecimal,
430 row,
431 idx,
432 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
433 );
434 }
435 DataType::Date32 => {
436 handle_primitive_type!(
437 builder,
438 field,
439 col,
440 Date32Builder,
441 chrono::NaiveDate,
442 row,
443 idx,
444 |v: chrono::NaiveDate| {
445 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
446 }
447 );
448 }
449 DataType::Timestamp(TimeUnit::Microsecond, None) => {
450 handle_primitive_type!(
451 builder,
452 field,
453 col,
454 TimestampMicrosecondBuilder,
455 time::PrimitiveDateTime,
456 row,
457 idx,
458 |v: time::PrimitiveDateTime| {
459 let timestamp_micros =
460 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
461 Ok::<_, DataFusionError>(timestamp_micros)
462 }
463 );
464 }
465 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
466 handle_primitive_type!(
467 builder,
468 field,
469 col,
470 TimestampNanosecondBuilder,
471 chrono::NaiveTime,
472 row,
473 idx,
474 |v: chrono::NaiveTime| {
475 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
476 + i64::from(v.nanosecond());
477 Ok::<_, DataFusionError>(t)
478 }
479 );
480 }
481 DataType::Time32(TimeUnit::Second) => {
482 handle_primitive_type!(
483 builder,
484 field,
485 col,
486 Time32SecondBuilder,
487 chrono::NaiveTime,
488 row,
489 idx,
490 |v: chrono::NaiveTime| {
491 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
492 }
493 );
494 }
495 DataType::Time64(TimeUnit::Nanosecond) => {
496 handle_primitive_type!(
497 builder,
498 field,
499 col,
500 Time64NanosecondBuilder,
501 chrono::NaiveTime,
502 row,
503 idx,
504 |v: chrono::NaiveTime| {
505 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
506 + i64::from(v.nanosecond());
507 Ok::<_, DataFusionError>(t)
508 }
509 );
510 }
511 DataType::Utf8 => {
512 handle_primitive_type!(
513 builder,
514 field,
515 col,
516 StringBuilder,
517 String,
518 row,
519 idx,
520 |v| { Ok::<_, DataFusionError>(v) }
521 );
522 }
523 DataType::LargeUtf8 => {
524 handle_primitive_type!(
525 builder,
526 field,
527 col,
528 LargeStringBuilder,
529 String,
530 row,
531 idx,
532 |v| { Ok::<_, DataFusionError>(v) }
533 );
534 }
535 DataType::Binary => {
536 handle_primitive_type!(
537 builder,
538 field,
539 col,
540 BinaryBuilder,
541 Vec<u8>,
542 row,
543 idx,
544 |v| { Ok::<_, DataFusionError>(v) }
545 );
546 }
547 DataType::LargeBinary => {
548 handle_primitive_type!(
549 builder,
550 field,
551 col,
552 LargeBinaryBuilder,
553 Vec<u8>,
554 row,
555 idx,
556 |v| { Ok::<_, DataFusionError>(v) }
557 );
558 }
559 _ => {
560 return Err(DataFusionError::NotImplemented(format!(
561 "Unsupported data type {:?} for col: {:?}",
562 field.data_type(),
563 col
564 )));
565 }
566 }
567 }
568 }
569 let projected_columns = array_builders
570 .into_iter()
571 .enumerate()
572 .filter(|(idx, _)| projections_contains(projection, *idx))
573 .map(|(_, mut builder)| builder.finish())
574 .collect::<Vec<ArrayRef>>();
575 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
576}
577
578fn to_decimal_256(decimal: &BigDecimal) -> i256 {
579 let (bigint_value, _) = decimal.as_bigint_and_exponent();
580 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
581
582 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
583 let fill_byte = if is_negative { 0xFF } else { 0x00 };
584
585 if bigint_bytes.len() > 32 {
586 bigint_bytes.truncate(32);
587 } else {
588 bigint_bytes.resize(32, fill_byte);
589 };
590
591 let mut array = [0u8; 32];
592 array.copy_from_slice(&bigint_bytes);
593
594 i256::from_le_bytes(array)
595}