1use crate::connection::{RemoteDbType, big_decimal_to_i128, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteType,
5};
6use async_stream::stream;
7use bigdecimal::{BigDecimal, num_bigint};
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, Float32Builder,
11 Float64Builder, Int8Builder, Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder,
12 LargeStringBuilder, RecordBatch, StringBuilder, Time32SecondBuilder, Time64NanosecondBuilder,
13 TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt8Builder, UInt16Builder,
14 UInt32Builder, UInt64Builder, make_builder,
15};
16use datafusion::arrow::datatypes::{DataType, Date32Type, SchemaRef, TimeUnit, i256};
17use datafusion::common::{DataFusionError, project_schema};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::prelude::Expr;
21use derive_getters::Getters;
22use derive_with::With;
23use futures::StreamExt;
24use futures::lock::Mutex;
25use mysql_async::consts::{ColumnFlags, ColumnType};
26use mysql_async::prelude::Queryable;
27use mysql_async::{Column, FromValueError, Row, Value};
28use std::sync::Arc;
29
30#[derive(Debug, Clone, With, Getters)]
31pub struct MysqlConnectionOptions {
32 pub(crate) host: String,
33 pub(crate) port: u16,
34 pub(crate) username: String,
35 pub(crate) password: String,
36 pub(crate) database: Option<String>,
37 pub(crate) pool_max_size: usize,
38 pub(crate) stream_chunk_size: usize,
39}
40
41impl MysqlConnectionOptions {
42 pub fn new(
43 host: impl Into<String>,
44 port: u16,
45 username: impl Into<String>,
46 password: impl Into<String>,
47 ) -> Self {
48 Self {
49 host: host.into(),
50 port,
51 username: username.into(),
52 password: password.into(),
53 database: None,
54 pool_max_size: 10,
55 stream_chunk_size: 2048,
56 }
57 }
58}
59
60#[derive(Debug)]
61pub struct MysqlPool {
62 pool: mysql_async::Pool,
63}
64
65pub(crate) fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
66 let pool_opts = mysql_async::PoolOpts::new().with_constraints(
67 mysql_async::PoolConstraints::new(0, options.pool_max_size)
68 .expect("Failed to create pool constraints"),
69 );
70 let opts_builder = mysql_async::OptsBuilder::default()
71 .ip_or_hostname(options.host.clone())
72 .tcp_port(options.port)
73 .user(Some(options.username.clone()))
74 .pass(Some(options.password.clone()))
75 .db_name(options.database.clone())
76 .pool_opts(pool_opts);
77 let pool = mysql_async::Pool::new(opts_builder);
78 Ok(MysqlPool { pool })
79}
80
81#[async_trait::async_trait]
82impl Pool for MysqlPool {
83 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
84 let conn = self.pool.get_conn().await.map_err(|e| {
85 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
86 })?;
87 Ok(Arc::new(MysqlConnection {
88 conn: Arc::new(Mutex::new(conn)),
89 }))
90 }
91}
92
93#[derive(Debug)]
94pub struct MysqlConnection {
95 conn: Arc<Mutex<mysql_async::Conn>>,
96}
97
98#[async_trait::async_trait]
99impl Connection for MysqlConnection {
100 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
101 let sql = RemoteDbType::Mysql.query_limit_1(sql)?;
102 let mut conn = self.conn.lock().await;
103 let conn = &mut *conn;
104 let row: Option<Row> = conn.query_first(&sql).await.map_err(|e| {
105 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}",))
106 })?;
107 let Some(row) = row else {
108 return Err(DataFusionError::Execution(
109 "No rows returned to infer schema".to_string(),
110 ));
111 };
112 let remote_schema = Arc::new(build_remote_schema(&row)?);
113 Ok(remote_schema)
114 }
115
116 async fn query(
117 &self,
118 conn_options: &ConnectionOptions,
119 sql: &str,
120 table_schema: SchemaRef,
121 projection: Option<&Vec<usize>>,
122 filters: &[Expr],
123 limit: Option<usize>,
124 ) -> DFResult<SendableRecordBatchStream> {
125 let projected_schema = project_schema(&table_schema, projection)?;
126 let sql = RemoteDbType::Mysql.try_rewrite_query(sql, filters, limit)?;
127 let projection = projection.cloned();
128 let chunk_size = conn_options.stream_chunk_size();
129 let conn = Arc::clone(&self.conn);
130 let stream = Box::pin(stream! {
131 let mut conn = conn.lock().await;
132 let mut query_iter = conn
133 .query_iter(sql.clone())
134 .await
135 .map_err(|e| {
136 DataFusionError::Execution(format!("Failed to execute query {sql} on mysql: {e:?}"))
137 })?;
138
139 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
140 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
141 })? else {
142 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
143 return;
144 };
145
146 let mut chunked_stream = stream.chunks(chunk_size).boxed();
147
148 while let Some(chunk) = chunked_stream.next().await {
149 let rows = chunk
150 .into_iter()
151 .collect::<Result<Vec<_>, _>>()
152 .map_err(|e| {
153 DataFusionError::Execution(format!(
154 "Failed to collect rows from mysql due to {e}",
155 ))
156 })?;
157
158 yield Ok::<_, DataFusionError>(rows)
159 }
160 });
161
162 let stream = stream.map(move |rows| {
163 let rows = rows?;
164 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
165 });
166
167 Ok(Box::pin(RecordBatchStreamAdapter::new(
168 projected_schema,
169 stream,
170 )))
171 }
172}
173
174fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<MysqlType> {
175 let character_set = mysql_col.character_set();
176 let is_utf8_bin_character_set = character_set == 45;
177 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
178 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
179 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
180 let col_length = mysql_col.column_length();
181 match mysql_col.column_type() {
182 ColumnType::MYSQL_TYPE_TINY => {
183 if is_unsigned {
184 Ok(MysqlType::TinyIntUnsigned)
185 } else {
186 Ok(MysqlType::TinyInt)
187 }
188 }
189 ColumnType::MYSQL_TYPE_SHORT => {
190 if is_unsigned {
191 Ok(MysqlType::SmallIntUnsigned)
192 } else {
193 Ok(MysqlType::SmallInt)
194 }
195 }
196 ColumnType::MYSQL_TYPE_INT24 => {
197 if is_unsigned {
198 Ok(MysqlType::MediumIntUnsigned)
199 } else {
200 Ok(MysqlType::MediumInt)
201 }
202 }
203 ColumnType::MYSQL_TYPE_LONG => {
204 if is_unsigned {
205 Ok(MysqlType::IntegerUnsigned)
206 } else {
207 Ok(MysqlType::Integer)
208 }
209 }
210 ColumnType::MYSQL_TYPE_LONGLONG => {
211 if is_unsigned {
212 Ok(MysqlType::BigIntUnsigned)
213 } else {
214 Ok(MysqlType::BigInt)
215 }
216 }
217 ColumnType::MYSQL_TYPE_FLOAT => Ok(MysqlType::Float),
218 ColumnType::MYSQL_TYPE_DOUBLE => Ok(MysqlType::Double),
219 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
220 let precision = (mysql_col.column_length() - 2) as u8;
221 let scale = mysql_col.decimals();
222 Ok(MysqlType::Decimal(precision, scale))
223 }
224 ColumnType::MYSQL_TYPE_DATE => Ok(MysqlType::Date),
225 ColumnType::MYSQL_TYPE_DATETIME => Ok(MysqlType::Datetime),
226 ColumnType::MYSQL_TYPE_TIME => Ok(MysqlType::Time),
227 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(MysqlType::Timestamp),
228 ColumnType::MYSQL_TYPE_YEAR => Ok(MysqlType::Year),
229 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(MysqlType::Char),
230 ColumnType::MYSQL_TYPE_STRING if is_binary => {
231 if is_utf8_bin_character_set {
232 Ok(MysqlType::Char)
233 } else {
234 Ok(MysqlType::Binary)
235 }
236 }
237 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => Ok(MysqlType::Varchar),
238 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
239 if is_utf8_bin_character_set {
240 Ok(MysqlType::Varchar)
241 } else {
242 Ok(MysqlType::Varbinary)
243 }
244 }
245 ColumnType::MYSQL_TYPE_VARCHAR => Ok(MysqlType::Varchar),
246 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => Ok(MysqlType::Text(col_length)),
247 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
248 if is_utf8_bin_character_set {
249 Ok(MysqlType::Text(col_length))
250 } else {
251 Ok(MysqlType::Blob(col_length))
252 }
253 }
254 ColumnType::MYSQL_TYPE_JSON => Ok(MysqlType::Json),
255 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(MysqlType::Geometry),
256 _ => Err(DataFusionError::NotImplemented(format!(
257 "Unsupported mysql type: {mysql_col:?}",
258 ))),
259 }
260}
261
262fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
263 let mut remote_fields = vec![];
264 for col in row.columns_ref() {
265 remote_fields.push(RemoteField::new(
266 col.name_str().to_string(),
267 RemoteType::Mysql(mysql_type_to_remote_type(col)?),
268 true,
269 ));
270 }
271 Ok(RemoteSchema::new(remote_fields))
272}
273
274macro_rules! handle_primitive_type {
275 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
276 let builder = $builder
277 .as_any_mut()
278 .downcast_mut::<$builder_ty>()
279 .unwrap_or_else(|| {
280 panic!(
281 "Failed to downcast builder to {} for {:?} and {:?}",
282 stringify!($builder_ty),
283 $field,
284 $col
285 )
286 });
287 let v = $row.get_opt::<$value_ty, usize>($index);
288
289 match v {
290 None => builder.append_null(),
291 Some(Ok(v)) => builder.append_value($convert(v)?),
292 Some(Err(FromValueError(Value::NULL))) => builder.append_null(),
293 Some(Err(e)) => {
294 return Err(DataFusionError::Execution(format!(
295 "Failed to get optional {:?} value for {:?} and {:?}: {e:?}",
296 stringify!($value_ty),
297 $field,
298 $col,
299 )));
300 }
301 }
302 }};
303}
304
305fn rows_to_batch(
306 rows: &[Row],
307 table_schema: &SchemaRef,
308 projection: Option<&Vec<usize>>,
309) -> DFResult<RecordBatch> {
310 let projected_schema = project_schema(table_schema, projection)?;
311 let mut array_builders = vec![];
312 for field in table_schema.fields() {
313 let builder = make_builder(field.data_type(), rows.len());
314 array_builders.push(builder);
315 }
316
317 for row in rows {
318 for (idx, field) in table_schema.fields.iter().enumerate() {
319 if !projections_contains(projection, idx) {
320 continue;
321 }
322 let builder = &mut array_builders[idx];
323 let col = row.columns_ref().get(idx);
324 match field.data_type() {
325 DataType::Int8 => {
326 handle_primitive_type!(builder, field, col, Int8Builder, i8, row, idx, |v| {
327 Ok::<_, DataFusionError>(v)
328 });
329 }
330 DataType::Int16 => {
331 handle_primitive_type!(builder, field, col, Int16Builder, i16, row, idx, |v| {
332 Ok::<_, DataFusionError>(v)
333 });
334 }
335 DataType::Int32 => {
336 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx, |v| {
337 Ok::<_, DataFusionError>(v)
338 });
339 }
340 DataType::Int64 => {
341 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx, |v| {
342 Ok::<_, DataFusionError>(v)
343 });
344 }
345 DataType::UInt8 => {
346 handle_primitive_type!(builder, field, col, UInt8Builder, u8, row, idx, |v| {
347 Ok::<_, DataFusionError>(v)
348 });
349 }
350 DataType::UInt16 => {
351 handle_primitive_type!(
352 builder,
353 field,
354 col,
355 UInt16Builder,
356 u16,
357 row,
358 idx,
359 |v| { Ok::<_, DataFusionError>(v) }
360 );
361 }
362 DataType::UInt32 => {
363 handle_primitive_type!(
364 builder,
365 field,
366 col,
367 UInt32Builder,
368 u32,
369 row,
370 idx,
371 |v| { Ok::<_, DataFusionError>(v) }
372 );
373 }
374 DataType::UInt64 => {
375 handle_primitive_type!(
376 builder,
377 field,
378 col,
379 UInt64Builder,
380 u64,
381 row,
382 idx,
383 |v| { Ok::<_, DataFusionError>(v) }
384 );
385 }
386 DataType::Float32 => {
387 handle_primitive_type!(
388 builder,
389 field,
390 col,
391 Float32Builder,
392 f32,
393 row,
394 idx,
395 |v| { Ok::<_, DataFusionError>(v) }
396 );
397 }
398 DataType::Float64 => {
399 handle_primitive_type!(
400 builder,
401 field,
402 col,
403 Float64Builder,
404 f64,
405 row,
406 idx,
407 |v| { Ok::<_, DataFusionError>(v) }
408 );
409 }
410 DataType::Decimal128(_precision, scale) => {
411 handle_primitive_type!(
412 builder,
413 field,
414 col,
415 Decimal128Builder,
416 BigDecimal,
417 row,
418 idx,
419 |v: BigDecimal| {
420 big_decimal_to_i128(&v, Some(*scale as i32)).ok_or_else(|| {
421 DataFusionError::Execution(format!(
422 "Failed to convert BigDecimal {v:?} to i128"
423 ))
424 })
425 }
426 );
427 }
428 DataType::Decimal256(_precision, _scale) => {
429 handle_primitive_type!(
430 builder,
431 field,
432 col,
433 Decimal256Builder,
434 BigDecimal,
435 row,
436 idx,
437 |v: BigDecimal| { Ok::<_, DataFusionError>(to_decimal_256(&v)) }
438 );
439 }
440 DataType::Date32 => {
441 handle_primitive_type!(
442 builder,
443 field,
444 col,
445 Date32Builder,
446 chrono::NaiveDate,
447 row,
448 idx,
449 |v: chrono::NaiveDate| {
450 Ok::<_, DataFusionError>(Date32Type::from_naive_date(v))
451 }
452 );
453 }
454 DataType::Timestamp(TimeUnit::Microsecond, None) => {
455 handle_primitive_type!(
456 builder,
457 field,
458 col,
459 TimestampMicrosecondBuilder,
460 time::PrimitiveDateTime,
461 row,
462 idx,
463 |v: time::PrimitiveDateTime| {
464 let timestamp_micros =
465 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
466 Ok::<_, DataFusionError>(timestamp_micros)
467 }
468 );
469 }
470 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
471 handle_primitive_type!(
472 builder,
473 field,
474 col,
475 TimestampNanosecondBuilder,
476 chrono::NaiveTime,
477 row,
478 idx,
479 |v: chrono::NaiveTime| {
480 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
481 + i64::from(v.nanosecond());
482 Ok::<_, DataFusionError>(t)
483 }
484 );
485 }
486 DataType::Time32(TimeUnit::Second) => {
487 handle_primitive_type!(
488 builder,
489 field,
490 col,
491 Time32SecondBuilder,
492 chrono::NaiveTime,
493 row,
494 idx,
495 |v: chrono::NaiveTime| {
496 Ok::<_, DataFusionError>(v.num_seconds_from_midnight() as i32)
497 }
498 );
499 }
500 DataType::Time64(TimeUnit::Nanosecond) => {
501 handle_primitive_type!(
502 builder,
503 field,
504 col,
505 Time64NanosecondBuilder,
506 chrono::NaiveTime,
507 row,
508 idx,
509 |v: chrono::NaiveTime| {
510 let t = i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
511 + i64::from(v.nanosecond());
512 Ok::<_, DataFusionError>(t)
513 }
514 );
515 }
516 DataType::Utf8 => {
517 handle_primitive_type!(
518 builder,
519 field,
520 col,
521 StringBuilder,
522 String,
523 row,
524 idx,
525 |v| { Ok::<_, DataFusionError>(v) }
526 );
527 }
528 DataType::LargeUtf8 => {
529 handle_primitive_type!(
530 builder,
531 field,
532 col,
533 LargeStringBuilder,
534 String,
535 row,
536 idx,
537 |v| { Ok::<_, DataFusionError>(v) }
538 );
539 }
540 DataType::Binary => {
541 handle_primitive_type!(
542 builder,
543 field,
544 col,
545 BinaryBuilder,
546 Vec<u8>,
547 row,
548 idx,
549 |v| { Ok::<_, DataFusionError>(v) }
550 );
551 }
552 DataType::LargeBinary => {
553 handle_primitive_type!(
554 builder,
555 field,
556 col,
557 LargeBinaryBuilder,
558 Vec<u8>,
559 row,
560 idx,
561 |v| { Ok::<_, DataFusionError>(v) }
562 );
563 }
564 _ => {
565 return Err(DataFusionError::NotImplemented(format!(
566 "Unsupported data type {:?} for col: {:?}",
567 field.data_type(),
568 col
569 )));
570 }
571 }
572 }
573 }
574 let projected_columns = array_builders
575 .into_iter()
576 .enumerate()
577 .filter(|(idx, _)| projections_contains(projection, *idx))
578 .map(|(_, mut builder)| builder.finish())
579 .collect::<Vec<ArrayRef>>();
580 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
581}
582
583fn to_decimal_256(decimal: &BigDecimal) -> i256 {
584 let (bigint_value, _) = decimal.as_bigint_and_exponent();
585 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
586
587 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
588 let fill_byte = if is_negative { 0xFF } else { 0x00 };
589
590 if bigint_bytes.len() > 32 {
591 bigint_bytes.truncate(32);
592 } else {
593 bigint_bytes.resize(32, fill_byte);
594 };
595
596 let mut array = [0u8; 32];
597 array.copy_from_slice(&bigint_bytes);
598
599 i256::from_le_bytes(array)
600}