1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4 Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema, RemoteSchemaRef, RemoteType,
5 Transform,
6};
7use async_stream::stream;
8use bigdecimal::num_bigint;
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 make_builder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
12 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13 LargeBinaryBuilder, LargeStringBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder,
14 TimestampMicrosecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
15};
16use datafusion::arrow::datatypes::{i256, DataType, Date32Type, SchemaRef, TimeUnit};
17use datafusion::common::{project_schema, DataFusionError};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use futures::lock::Mutex;
21use futures::StreamExt;
22use mysql_async::consts::{ColumnFlags, ColumnType};
23use mysql_async::prelude::Queryable;
24use mysql_async::{Column, Row};
25use std::sync::Arc;
26
27#[derive(Debug, Clone, derive_with::With)]
28pub struct MysqlConnectionOptions {
29 pub(crate) host: String,
30 pub(crate) port: u16,
31 pub(crate) username: String,
32 pub(crate) password: String,
33 pub(crate) database: Option<String>,
34}
35
36impl MysqlConnectionOptions {
37 pub fn new(
38 host: impl Into<String>,
39 port: u16,
40 username: impl Into<String>,
41 password: impl Into<String>,
42 ) -> Self {
43 Self {
44 host: host.into(),
45 port,
46 username: username.into(),
47 password: password.into(),
48 database: None,
49 }
50 }
51}
52
53#[derive(Debug)]
54pub struct MysqlPool {
55 pool: mysql_async::Pool,
56}
57
58pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
59 let opts_builder = mysql_async::OptsBuilder::default()
60 .ip_or_hostname(options.host.clone())
61 .tcp_port(options.port)
62 .user(Some(options.username.clone()))
63 .pass(Some(options.password.clone()))
64 .db_name(options.database.clone());
65 let pool = mysql_async::Pool::new(opts_builder);
66 Ok(MysqlPool { pool })
67}
68
69#[async_trait::async_trait]
70impl Pool for MysqlPool {
71 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
72 let conn = self.pool.get_conn().await.map_err(|e| {
73 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
74 })?;
75 Ok(Arc::new(MysqlConnection {
76 conn: Arc::new(Mutex::new(conn)),
77 }))
78 }
79}
80
81#[derive(Debug)]
82pub struct MysqlConnection {
83 conn: Arc<Mutex<mysql_async::Conn>>,
84}
85
86#[async_trait::async_trait]
87impl Connection for MysqlConnection {
88 async fn infer_schema(
89 &self,
90 sql: &str,
91 transform: Option<Arc<dyn Transform>>,
92 ) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
93 let mut conn = self.conn.lock().await;
94 let conn = &mut *conn;
95 let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
96 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
97 })?;
98 let Some(row) = row else {
99 return Err(DataFusionError::Execution(
100 "No rows returned to infer schema".to_string(),
101 ));
102 };
103 let remote_schema = Arc::new(build_remote_schema(&row)?);
104 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
105 if let Some(transform) = transform {
106 let batch = rows_to_batch(&[row], &arrow_schema, None)?;
107 let transformed_batch = transform_batch(
108 batch,
109 transform.as_ref(),
110 &arrow_schema,
111 None,
112 Some(&remote_schema),
113 )?;
114 Ok((remote_schema, transformed_batch.schema()))
115 } else {
116 Ok((remote_schema, arrow_schema))
117 }
118 }
119
120 async fn query(
121 &self,
122 sql: String,
123 table_schema: SchemaRef,
124 projection: Option<Vec<usize>>,
125 ) -> DFResult<SendableRecordBatchStream> {
126 let projected_schema = project_schema(&table_schema, projection.as_ref())?;
127 let conn = Arc::clone(&self.conn);
128 let stream = Box::pin(stream! {
129 let mut conn = conn.lock().await;
130 let mut query_iter = conn
131 .query_iter(sql)
132 .await
133 .map_err(|e| {
134 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
135 })?;
136
137 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
138 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
139 })? else {
140 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
141 return;
142 };
143
144 let mut chunked_stream = stream.chunks(4_000).boxed();
145
146 while let Some(chunk) = chunked_stream.next().await {
147 let rows = chunk
148 .into_iter()
149 .collect::<Result<Vec<_>, _>>()
150 .map_err(|e| {
151 DataFusionError::Execution(format!(
152 "Failed to collect rows from mysql due to {e}",
153 ))
154 })?;
155
156 yield Ok::<_, DataFusionError>(rows)
157 }
158 });
159
160 let stream = stream.map(move |rows| {
161 let rows = rows?;
162 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
163 });
164
165 Ok(Box::pin(RecordBatchStreamAdapter::new(
166 projected_schema,
167 stream,
168 )))
169 }
170}
171
172fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
173 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
174 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
175 let is_unsigned = mysql_col.flags().contains(ColumnFlags::UNSIGNED_FLAG);
176 let col_length = mysql_col.column_length();
177 match mysql_col.column_type() {
178 ColumnType::MYSQL_TYPE_TINY => {
179 if is_unsigned {
180 Ok(RemoteType::Mysql(MysqlType::TinyIntUnsigned))
181 } else {
182 Ok(RemoteType::Mysql(MysqlType::TinyInt))
183 }
184 }
185 ColumnType::MYSQL_TYPE_SHORT => {
186 if is_unsigned {
187 Ok(RemoteType::Mysql(MysqlType::SmallIntUnsigned))
188 } else {
189 Ok(RemoteType::Mysql(MysqlType::SmallInt))
190 }
191 }
192 ColumnType::MYSQL_TYPE_INT24 => {
193 if is_unsigned {
194 Ok(RemoteType::Mysql(MysqlType::MediumIntUnsigned))
195 } else {
196 Ok(RemoteType::Mysql(MysqlType::MediumInt))
197 }
198 }
199 ColumnType::MYSQL_TYPE_LONG => {
200 if is_unsigned {
201 Ok(RemoteType::Mysql(MysqlType::IntegerUnsigned))
202 } else {
203 Ok(RemoteType::Mysql(MysqlType::Integer))
204 }
205 }
206 ColumnType::MYSQL_TYPE_LONGLONG => {
207 if is_unsigned {
208 Ok(RemoteType::Mysql(MysqlType::BigIntUnsigned))
209 } else {
210 Ok(RemoteType::Mysql(MysqlType::BigInt))
211 }
212 }
213 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
214 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
215 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
216 let precision = (mysql_col.column_length() - 2) as u8;
217 let scale = mysql_col.decimals();
218 Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
219 }
220 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
221 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
222 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
223 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
224 ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
225 ColumnType::MYSQL_TYPE_STRING if !is_binary => Ok(RemoteType::Mysql(MysqlType::Char)),
226 ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
227 ColumnType::MYSQL_TYPE_VAR_STRING if !is_binary => {
228 Ok(RemoteType::Mysql(MysqlType::Varchar))
229 }
230 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
231 Ok(RemoteType::Mysql(MysqlType::Varbinary))
232 }
233 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
234 ColumnType::MYSQL_TYPE_BLOB if is_blob && !is_binary => {
235 Ok(RemoteType::Mysql(MysqlType::Text(col_length)))
236 }
237 ColumnType::MYSQL_TYPE_BLOB if is_blob && is_binary => {
238 Ok(RemoteType::Mysql(MysqlType::Blob(col_length)))
239 }
240 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
241 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
242 _ => Err(DataFusionError::NotImplemented(format!(
243 "Unsupported mysql type: {mysql_col:?}",
244 ))),
245 }
246}
247
248fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
249 let mut remote_fields = vec![];
250 for col in row.columns_ref() {
251 remote_fields.push(RemoteField::new(
252 col.name_str().to_string(),
253 mysql_type_to_remote_type(col)?,
254 true,
255 ));
256 }
257 Ok(RemoteSchema::new(remote_fields))
258}
259
260macro_rules! handle_primitive_type {
261 ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
262 let builder = $builder
263 .as_any_mut()
264 .downcast_mut::<$builder_ty>()
265 .unwrap_or_else(|| {
266 panic!(
267 concat!(
268 "Failed to downcast builder to ",
269 stringify!($builder_ty),
270 " for {:?}"
271 ),
272 $mysql_col
273 )
274 });
275 let v = $row.get::<Option<$value_ty>, usize>($index);
276
277 match v {
278 Some(Some(v)) => builder.append_value(v),
279 _ => builder.append_null(),
280 }
281 }};
282}
283
284fn rows_to_batch(
285 rows: &[Row],
286 table_schema: &SchemaRef,
287 projection: Option<&Vec<usize>>,
288) -> DFResult<RecordBatch> {
289 let projected_schema = project_schema(table_schema, projection)?;
290 let mut array_builders = vec![];
291 for field in table_schema.fields() {
292 let builder = make_builder(field.data_type(), rows.len());
293 array_builders.push(builder);
294 }
295
296 for row in rows {
297 for (idx, field) in table_schema.fields.iter().enumerate() {
298 if !projections_contains(projection, idx) {
299 continue;
300 }
301 let builder = &mut array_builders[idx];
302 let col = row.columns_ref().get(idx);
303 match field.data_type() {
304 DataType::Int8 => {
305 handle_primitive_type!(builder, col, Int8Builder, i8, row, idx);
306 }
307 DataType::Int16 => {
308 handle_primitive_type!(builder, col, Int16Builder, i16, row, idx);
309 }
310 DataType::Int32 => {
311 handle_primitive_type!(builder, col, Int32Builder, i32, row, idx);
312 }
313 DataType::Int64 => {
314 handle_primitive_type!(builder, col, Int64Builder, i64, row, idx);
315 }
316 DataType::UInt8 => {
317 handle_primitive_type!(builder, col, UInt8Builder, u8, row, idx);
318 }
319 DataType::UInt16 => {
320 handle_primitive_type!(builder, col, UInt16Builder, u16, row, idx);
321 }
322 DataType::UInt32 => {
323 handle_primitive_type!(builder, col, UInt32Builder, u32, row, idx);
324 }
325 DataType::UInt64 => {
326 handle_primitive_type!(builder, col, UInt64Builder, u64, row, idx);
327 }
328 DataType::Float32 => {
329 handle_primitive_type!(builder, col, Float32Builder, f32, row, idx);
330 }
331 DataType::Float64 => {
332 handle_primitive_type!(builder, col, Float64Builder, f64, row, idx);
333 }
334 DataType::Decimal128(_precision, scale) => {
335 let builder = builder
336 .as_any_mut()
337 .downcast_mut::<Decimal128Builder>()
338 .unwrap_or_else(|| {
339 panic!("Failed to downcast builder to Decimal128Builder for {field:?}")
340 });
341 let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
342
343 match v {
344 Some(Some(v)) => {
345 let Some(v) = big_decimal_to_i128(&v, Some(*scale as u32)) else {
346 return Err(DataFusionError::Execution(format!(
347 "Failed to convert BigDecimal {v:?} to i128"
348 )));
349 };
350 builder.append_value(v)
351 }
352 _ => builder.append_null(),
353 }
354 }
355 DataType::Decimal256(_precision, _scale) => {
356 let builder = builder
357 .as_any_mut()
358 .downcast_mut::<Decimal256Builder>()
359 .unwrap_or_else(|| {
360 panic!("Failed to downcast builder to Decimal256Builder for {field:?}")
361 });
362 let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
363
364 match v {
365 Some(Some(v)) => builder.append_value(to_decimal_256(&v)),
366 _ => builder.append_null(),
367 }
368 }
369 DataType::Date32 => {
370 let builder = builder
372 .as_any_mut()
373 .downcast_mut::<Date32Builder>()
374 .unwrap_or_else(|| {
375 panic!("Failed to downcast builder to Date32Builder for {field:?}")
376 });
377 let v = row.get::<Option<chrono::NaiveDate>, usize>(idx);
378
379 match v {
380 Some(Some(v)) => builder.append_value(Date32Type::from_naive_date(v)),
381 _ => builder.append_null(),
382 }
383 }
384 DataType::Timestamp(TimeUnit::Microsecond, None) => {
385 let builder = builder
386 .as_any_mut()
387 .downcast_mut::<TimestampMicrosecondBuilder>()
388 .unwrap_or_else(|| {
389 panic!("Failed to downcast builder to TimestampMicrosecondBuilder for {field:?}")
390 });
391 let v = row.get::<Option<time::PrimitiveDateTime>, usize>(idx);
392
393 match v {
394 Some(Some(v)) => {
395 let timestamp_micros =
396 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
397 builder.append_value(timestamp_micros)
398 }
399 _ => builder.append_null(),
400 }
401 }
402 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
403 let builder = builder
404 .as_any_mut()
405 .downcast_mut::<Time64NanosecondBuilder>()
406 .unwrap_or_else(|| {
407 panic!("Failed to downcast builder to Time64NanosecondBuilder for {field:?}")
408 });
409 let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
410
411 match v {
412 Some(Some(v)) => {
413 builder.append_value(
414 i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
415 + i64::from(v.nanosecond()),
416 );
417 }
418 _ => builder.append_null(),
419 }
420 }
421 DataType::Time64(TimeUnit::Nanosecond) => {
422 let builder = builder
423 .as_any_mut()
424 .downcast_mut::<Time64NanosecondBuilder>()
425 .unwrap_or_else(|| {
426 panic!("Failed to downcast builder to Time64NanosecondBuilder for {field:?}")
427 });
428 let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
429
430 match v {
431 Some(Some(v)) => {
432 builder.append_value(
433 i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
434 + i64::from(v.nanosecond()),
435 );
436 }
437 _ => builder.append_null(),
438 }
439 }
440 DataType::Utf8 => {
441 handle_primitive_type!(builder, col, StringBuilder, String, row, idx);
442 }
443 DataType::LargeUtf8 => {
444 handle_primitive_type!(builder, col, LargeStringBuilder, String, row, idx);
445 }
446 DataType::Binary => {
447 handle_primitive_type!(builder, col, BinaryBuilder, Vec<u8>, row, idx);
448 }
449 DataType::LargeBinary => {
450 handle_primitive_type!(builder, col, LargeBinaryBuilder, Vec<u8>, row, idx);
451 }
452 _ => {
453 return Err(DataFusionError::NotImplemented(format!(
454 "Unsupported data type {:?} for col: {:?}",
455 field.data_type(),
456 col
457 )));
458 }
459 }
460 }
461 }
462 let projected_columns = array_builders
463 .into_iter()
464 .enumerate()
465 .filter(|(idx, _)| projections_contains(projection, *idx))
466 .map(|(_, mut builder)| builder.finish())
467 .collect::<Vec<ArrayRef>>();
468 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
469}
470
471fn to_decimal_256(decimal: &bigdecimal::BigDecimal) -> i256 {
472 let (bigint_value, _) = decimal.as_bigint_and_exponent();
473 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
474
475 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
476 let fill_byte = if is_negative { 0xFF } else { 0x00 };
477
478 if bigint_bytes.len() > 32 {
479 bigint_bytes.truncate(32);
480 } else {
481 bigint_bytes.resize(32, fill_byte);
482 };
483
484 let mut array = [0u8; 32];
485 array.copy_from_slice(&bigint_bytes);
486
487 i256::from_le_bytes(array)
488}