1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13 TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::prelude::Expr;
21use derive_getters::Getters;
22use derive_with::With;
23use futures::StreamExt;
24use futures::lock::Mutex;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::sync::Arc;
29
30#[derive(Debug, Clone, With, Getters)]
31pub struct MysqlConnectionOptions {
32 pub(crate) host: String,
33 pub(crate) port: u16,
34 pub(crate) username: String,
35 pub(crate) password: String,
36 pub(crate) database: Option<String>,
37 pub(crate) pool_max_size: usize,
38 pub(crate) stream_chunk_size: usize,
39}
40
41impl MysqlConnectionOptions {
42 pub fn new(
43 host: impl Into<String>,
44 port: u16,
45 username: impl Into<String>,
46 password: impl Into<String>,
47 ) -> Self {
48 Self {
49 host: host.into(),
50 port,
51 username: username.into(),
52 password: password.into(),
53 database: None,
54 pool_max_size: 10,
55 stream_chunk_size: 2048,
56 }
57 }
58}
59
60#[derive(Debug)]
61pub struct MysqlPool {
62 pool: mysql_async::Pool,
63}
64
65pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
66 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
67 mysql_async::PoolConstraints::new(0, options.pool_max_size)
68 .expect("Failed to create pool constraints"),
69 );
70 let opts_builder = mysql_async::OptsBuilder::default()
71 .ip_or_hostname(options.host.clone())
72 .tcp_port(options.port)
73 .user(Some(options.username.clone()))
74 .pass(Some(options.password.clone()))
75 .db_name(options.database.clone())
76 .pool_opts(pool_opts);
77 let pool = mysql_async::Pool::new(opts_builder);
78 Ok(MysqlPool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for MysqlPool {
83 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84 let conn = self.pool.get_conn().await.map_err(|e| {
85 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
86 })?;
87 Ok(Arc::new(MysqlConnection {
88 conn: Arc::new(Mutex::new(conn)),
89 }))
90 }
91}
92
93#[derive(Debug)]
94pub struct MysqlConnection {
95 conn: Arc<Mutex<mysql_async::Conn>>,
96}
97
98#[async_trait::async_trait]
99impl Connection for MysqlConnection {
100 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
101 let sql = RemoteDbType::Mysql
102 .try_rewrite_query(sql, &[], Some(1))
103 .unwrap_or_else(|| sql.to_string());
104 let mut conn = self.conn.lock().await;
105 let conn = &mut *conn;
106 let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
107 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
108 })?;
109 let Some(row) = row else {
110 return Err(DataFusionError::Execution(
111 "No rows returned to infer schema".to_string(),
112 ));
113 };
114 let remote_schema = Arc::new(build_remote_schema(&row)?);
115 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
116 Ok((remote_schema, arrow_schema))
117 }
118
119 async fn query(
120 &self,
121 conn_options: &ConnectionOptions,
122 sql: &str,
123 table_schema: SchemaRef,
124 projection: Option<&Vec<usize>>,
125 filters: &[Expr],
126 limit: Option<usize>,
127 ) -> DFResult<SendableRecordBatchStream> {
128 let projected_schema = project_schema(&table_schema, projection)?;
129 let sql = RemoteDbType::Mysql
130 .try_rewrite_query(sql, filters, limit)
131 .unwrap_or_else(|| sql.to_string());
132 let projection = projection.cloned();
133 let chunk_size = conn_options.stream_chunk_size();
134 let conn = Arc::clone(&self.conn);
135 let stream = Box::pin(stream! {
136 let mut conn = conn.lock().await;
137 let mut query_iter = conn
138 .query_iter(sql.clone())
139 .await
140 .map_err(|e| {
141 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
142 })?;
143
144 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
145 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
146 })? else {
147 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
148 return;
149 };
150
151 let mut chunked_stream = stream.chunks(chunk_size).boxed();
152
153 while let Some(chunk) = chunked_stream.next().await {
154 let rows = chunk
155 .into_iter()
156 .collect::<Result<Vec<_>, _>>()
157 .map_err(|e| {
158 DataFusionError::Execution(format!(
159 "Failed to collect rows from mysql due to {e}",
160 ))
161 })?;
162
163 yield Ok::<_, DataFusionError>(rows)
164 }
165 });
166
167 let stream = stream.map(move |rows| {
168 let rows = rows?;
169 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
170 });
171
172 Ok(Box::pin(RecordBatchStreamAdapter::new(
173 projected_schema,
174 stream,
175 )))
176 }
177}
178
179fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
180 let character_set = mysql_col.character_set();
181 let is_utf8_bin_character_set = character_set == 45;
182 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
183 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
184 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
185 let col_length = mysql_col.column_length();
186 match mysql_col.column_type() {
187 ColumnType::MYSQL_TYPE_TINY => {
188 if is_unsigned {
189 Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
190 } else {
191 Ok(RemoteType::Mysql(MysqlType::TinyInt))
192 }
193 }
194 ColumnType::MYSQL_TYPE_SHORT => {
195 if is_unsigned {
196 Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
197 } else {
198 Ok(RemoteType::Mysql(MysqlType::SmallInt))
199 }
200 }
201 ColumnType::MYSQL_TYPE_INT24 => {
202 if is_unsigned {
203 Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
204 } else {
205 Ok(RemoteType::Mysql(MysqlType::MediumInt))
206 }
207 }
208 ColumnType::MYSQL_TYPE_LONG => {
209 if is_unsigned {
210 Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
211 } else {
212 Ok(RemoteType::Mysql(MysqlType::Integer))
213 }
214 }
215 ColumnType::MYSQL_TYPE_LONGLONG => {
216 if is_unsigned {
217 Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
218 } else {
219 Ok(RemoteType::Mysql(MysqlType::BigInt))
220 }
221 }
222 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
223 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
224 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
225 let precision = (mysql_col.column_length() - 2) as u8;
226 let scale = mysql_col.decimals();
227 Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
228 }
229 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
230 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
231 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
232 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
233 ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
234 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
235 ColumnType::MYSQL_TYPE_STRING if is_binary => {
236 if is_utf8_bin_character_set {
237 Ok(RemoteType::Mysql(MysqlType::Char))
238 } else {
239 Ok(RemoteType::Mysql(MysqlType::Binary))
240 }
241 }
242 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
243 Ok(RemoteType::Mysql(MysqlType::Varchar))
244 }
245 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
246 if is_utf8_bin_character_set {
247 Ok(RemoteType::Mysql(MysqlType::Varchar))
248 } else {
249 Ok(RemoteType::Mysql(MysqlType::Varbinary))
250 }
251 }
252 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
253 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
254 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
255 }
256 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
257 if is_utf8_bin_character_set {
258 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
259 } else {
260 Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
261 }
262 }
263 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
264 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
265 _ => Err(DataFusionError::NotImplemented(format!(
266 "Unsupported mysql type: {mysql_col:?}",
267 ))),
268 }
269}
270
271fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
272 let mut remote_fields = vec![];
273 for col in row.columns_ref() {
274 remote_fields.push(RemoteField::new(
275 col.name_str().to_string(),
276 mysql_type_to_remote_type(col)?,
277 true,
278 ));
279 }
280 Ok(RemoteSchema::new(remote_fields))
281}
282
283macro_rules! handle_primitive_type {
284 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
285 let builder = $builder
286 .as_any_mut()
287 .downcast_mut::<$builder_ty>()
288 .unwrap_or_else(|| {
289 panic!(
290 "Failed to downcast builder to {} for {:?} and {:?}",
291 stringify!($builder_ty),
292 $field,
293 $col
294 )
295 });
296 let v = $row.get_opt::<$value_ty, usize>($index);
297
298 match v {
299 None => builder.append_null(),
300 Some(Ok(v)) => builder.append_value($convert(v)?),
301 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
302 Some(Err(e)) => {
303 return Err(DataFusionError::Execution(format!(
304 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
305 stringify!($value_ty),
306 $field,
307 $col,
308 )));
309 }
310 }
311 }};
312}
313
314fn rows_to_batch(
315 rows: &[Row],
316 table_schema: &SchemaRef,
317 projection: Option<&Vec<usize>>,
318) -> DFResult<RecordBatch> {
319 let projected_schema = project_schema(table_schema, projection)?;
320 let mut array_builders = vec![];
321 for field in table_schema.fields() {
322 let builder = make_builder(field.data_type(), rows.len());
323 array_builders.push(builder);
324 }
325
326 for row in rows {
327 for (idx, field) in table_schema.fields.iter().enumerate() {
328 if !projections_contains(projection, idx) {
329 continue;
330 }
331 let builder = &mut array_builders[idx];
332 let col = row.columns_ref().get(idx);
333 match field.data_type() {
334 DataType::Int8 => {
335 handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
336 Ok::<_, DataFusionError>(v)
337 });
338 }
339 DataType::Int16 => {
340 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
341 Ok::<_, DataFusionError>(v)
342 });
343 }
344 DataType::Int32 => {
345 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
346 Ok::<_, DataFusionError>(v)
347 });
348 }
349 DataType::Int64 => {
350 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
351 Ok::<_, DataFusionError>(v)
352 });
353 }
354 DataType::UInt8 => {
355 handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
356 Ok::<_, DataFusionError>(v)
357 });
358 }
359 DataType::UInt16 => {
360 handle_primitive_type!(
361 builder,
362 field,
363 col,
364 UInt16Builder,
365 u16,
366 row,
367 idx,
368 |v| { Ok::<_, DataFusionError>(v) }
369 );
370 }
371 DataType::UInt32 => {
372 handle_primitive_type!(
373 builder,
374 field,
375 col,
376 UInt32Builder,
377 u32,
378 row,
379 idx,
380 |v| { Ok::<_, DataFusionError>(v) }
381 );
382 }
383 DataType::UInt64 => {
384 handle_primitive_type!(
385 builder,
386 field,
387 col,
388 UInt64Builder,
389 u64,
390 row,
391 idx,
392 |v| { Ok::<_, DataFusionError>(v) }
393 );
394 }
395 DataType::Float32 => {
396 handle_primitive_type!(
397 builder,
398 field,
399 col,
400 Float32Builder,
401 f32,
402 row,
403 idx,
404 |v| { Ok::<_, DataFusionError>(v) }
405 );
406 }
407 DataType::Float64 => {
408 handle_primitive_type!(
409 builder,
410 field,
411 col,
412 Float64Builder,
413 f64,
414 row,
415 idx,
416 |v| { Ok::<_, DataFusionError>(v) }
417 );
418 }
419 DataType::Decimal128(_precision, scale) => {
420 handle_primitive_type!(
421 builder,
422 field,
423 col,
424 Decimal128Builder,
425 BigDecimal,
426 row,
427 idx,
428 |v: BigDecimal| {
429 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
430 DataFusionError::Execution(format!(
431 "Failed to convert BigDecimal {v:?} to i128"
432 ))
433 })
434 }
435 );
436 }
437 DataType::Decimal256(_precision, _scale) => {
438 handle_primitive_type!(
439 builder,
440 field,
441 col,
442 Decimal256Builder,
443 BigDecimal,
444 row,
445 idx,
446 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
447 );
448 }
449 DataType::Date32 => {
450 handle_primitive_type!(
451 builder,
452 field,
453 col,
454 Date32Builder,
455 chrono::NaiveDate,
456 row,
457 idx,
458 |v: chrono::NaiveDate| {
459 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
460 }
461 );
462 }
463 DataType::Timestamp(TimeUnit::Microsecond, None) => {
464 handle_primitive_type!(
465 builder,
466 field,
467 col,
468 TimestampMicrosecondBuilder,
469 time::PrimitiveDateTime,
470 row,
471 idx,
472 |v: time::PrimitiveDateTime| {
473 let timestamp_micros =
474 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
475 Ok::<_, DataFusionError>(timestamp_micros)
476 }
477 );
478 }
479 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
480 handle_primitive_type!(
481 builder,
482 field,
483 col,
484 TimestampNanosecondBuilder,
485 chrono::NaiveTime,
486 row,
487 idx,
488 |v: chrono::NaiveTime| {
489 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
490 + i64::from(v.nanosecond());
491 Ok::<_, DataFusionError>(t)
492 }
493 );
494 }
495 DataType::Time32(TimeUnit::Second) => {
496 handle_primitive_type!(
497 builder,
498 field,
499 col,
500 Time32SecondBuilder,
501 chrono::NaiveTime,
502 row,
503 idx,
504 |v: chrono::NaiveTime| {
505 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
506 }
507 );
508 }
509 DataType::Time64(TimeUnit::Nanosecond) => {
510 handle_primitive_type!(
511 builder,
512 field,
513 col,
514 Time64NanosecondBuilder,
515 chrono::NaiveTime,
516 row,
517 idx,
518 |v: chrono::NaiveTime| {
519 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
520 + i64::from(v.nanosecond());
521 Ok::<_, DataFusionError>(t)
522 }
523 );
524 }
525 DataType::Utf8 => {
526 handle_primitive_type!(
527 builder,
528 field,
529 col,
530 StringBuilder,
531 String,
532 row,
533 idx,
534 |v| { Ok::<_, DataFusionError>(v) }
535 );
536 }
537 DataType::LargeUtf8 => {
538 handle_primitive_type!(
539 builder,
540 field,
541 col,
542 LargeStringBuilder,
543 String,
544 row,
545 idx,
546 |v| { Ok::<_, DataFusionError>(v) }
547 );
548 }
549 DataType::Binary => {
550 handle_primitive_type!(
551 builder,
552 field,
553 col,
554 BinaryBuilder,
555 Vec<u8>,
556 row,
557 idx,
558 |v| { Ok::<_, DataFusionError>(v) }
559 );
560 }
561 DataType::LargeBinary => {
562 handle_primitive_type!(
563 builder,
564 field,
565 col,
566 LargeBinaryBuilder,
567 Vec<u8>,
568 row,
569 idx,
570 |v| { Ok::<_, DataFusionError>(v) }
571 );
572 }
573 _ => {
574 return Err(DataFusionError::NotImplemented(format!(
575 "Unsupported data type {:?} for col: {:?}",
576 field.data_type(),
577 col
578 )));
579 }
580 }
581 }
582 }
583 let projected_columns = array_builders
584 .into_iter()
585 .enumerate()
586 .filter(|(idx, _)| projections_contains(projection, *idx))
587 .map(|(_, mut builder)| builder.finish())
588 .collect::<Vec<ArrayRef>>();
589 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
590}
591
592fn to_decimal_256(decimal: &BigDecimal) -> i256 {
593 let (bigint_value, _) = decimal.as_bigint_and_exponent();
594 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
595
596 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
597 let fill_byte = if is_negative { 0xFF } else { 0x00 };
598
599 if bigint_bytes.len() > 32 {
600 bigint_bytes.truncate(32);
601 } else {
602 bigint_bytes.resize(32, fill_byte);
603 };
604
605 let mut array = [0u8; 32];
606 array.copy_from_slice(&bigint_bytes);
607
608 i256::from_le_bytes(array)
609}