1use crate::connection::{big_decimal_to_i128, projections_contains};
2use crate::transform::transform_batch;
3use crate::{
4 project_remote_schema, Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5 RemoteType, Transform,
6};
7use async_stream::stream;
8use bigdecimal::num_bigint;
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 make_builder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder,
12 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13 LargeBinaryBuilder, LargeStringBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder,
14 TimestampMicrosecondBuilder,
15};
16use datafusion::arrow::datatypes::{i256, Date32Type, SchemaRef};
17use datafusion::common::{project_schema, DataFusionError};
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use futures::lock::Mutex;
21use futures::StreamExt;
22use mysql_async::consts::{ColumnFlags, ColumnType};
23use mysql_async::prelude::Queryable;
24use mysql_async::{Column, Row};
25use std::sync::Arc;
26
27#[derive(Debug, Clone, derive_with::With)]
28pub struct MysqlConnectionOptions {
29 pub(crate) host: String,
30 pub(crate) port: u16,
31 pub(crate) username: String,
32 pub(crate) password: String,
33 pub(crate) database: Option<String>,
34}
35
36impl MysqlConnectionOptions {
37 pub fn new(
38 host: impl Into<String>,
39 port: u16,
40 username: impl Into<String>,
41 password: impl Into<String>,
42 ) -> Self {
43 Self {
44 host: host.into(),
45 port,
46 username: username.into(),
47 password: password.into(),
48 database: None,
49 }
50 }
51}
52
53#[derive(Debug)]
54pub struct MysqlPool {
55 pool: mysql_async::Pool,
56}
57
58pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
59 let opts_builder = mysql_async::OptsBuilder::default()
60 .ip_or_hostname(options.host.clone())
61 .tcp_port(options.port)
62 .user(Some(options.username.clone()))
63 .pass(Some(options.password.clone()))
64 .db_name(options.database.clone());
65 let pool = mysql_async::Pool::new(opts_builder);
66 Ok(MysqlPool { pool })
67}
68
69#[async_trait::async_trait]
70impl Pool for MysqlPool {
71 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
72 let conn = self.pool.get_conn().await.map_err(|e| {
73 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
74 })?;
75 Ok(Arc::new(MysqlConnection {
76 conn: Arc::new(Mutex::new(conn)),
77 }))
78 }
79}
80
81#[derive(Debug)]
82pub struct MysqlConnection {
83 conn: Arc<Mutex<mysql_async::Conn>>,
84}
85
86#[async_trait::async_trait]
87impl Connection for MysqlConnection {
88 async fn infer_schema(
89 &self,
90 sql: &str,
91 transform: Option<Arc<dyn Transform>>,
92 ) -> DFResult<(RemoteSchema, SchemaRef)> {
93 let mut conn = self.conn.lock().await;
94 let conn = &mut *conn;
95 let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
96 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
97 })?;
98 let Some(row) = row else {
99 return Err(DataFusionError::Execution(
100 "No rows returned to infer schema".to_string(),
101 ));
102 };
103 let remote_schema = build_remote_schema(&row)?;
104 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
105 if let Some(transform) = transform {
106 let batch = rows_to_batch(&[row], &remote_schema, arrow_schema.clone(), None)?;
107 let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
108 Ok((remote_schema, transformed_batch.schema()))
109 } else {
110 Ok((remote_schema, arrow_schema))
111 }
112 }
113
114 async fn query(
115 &self,
116 sql: String,
117 projection: Option<Vec<usize>>,
118 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
119 let conn = Arc::clone(&self.conn);
120 let mut stream = Box::pin(stream! {
121 let mut conn = conn.lock().await;
122 let mut query_iter = conn
123 .query_iter(sql)
124 .await
125 .map_err(|e| {
126 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
127 })?;
128
129 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
130 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
131 })? else {
132 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
133 return;
134 };
135
136 let mut chunked_stream = stream.chunks(4_000).boxed();
137
138 while let Some(chunk) = chunked_stream.next().await {
139 let rows = chunk
140 .into_iter()
141 .collect::<Result<Vec<_>, _>>()
142 .map_err(|e| {
143 DataFusionError::Execution(format!(
144 "Failed to collect rows from mysql due to {e}",
145 ))
146 })?;
147
148 yield Ok::<_, DataFusionError>(rows)
149 }
150 });
151
152 let Some(first_chunk) = stream.next().await else {
153 return Err(DataFusionError::Execution(
154 "No data returned from mysql".to_string(),
155 ));
156 };
157 let first_chunk = first_chunk?;
158
159 let Some(first_row) = first_chunk.first() else {
160 return Err(DataFusionError::Execution(
161 "No data returned from mysql".to_string(),
162 ));
163 };
164
165 let remote_schema = build_remote_schema(first_row)?;
166 let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
167 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
168 let first_chunk = rows_to_batch(
169 first_chunk.as_slice(),
170 &remote_schema,
171 arrow_schema.clone(),
172 projection.as_ref(),
173 )?;
174 let schema = first_chunk.schema();
175
176 let mut stream = stream.map(move |rows| {
177 let rows = rows?;
178 let batch = rows_to_batch(
179 rows.as_slice(),
180 &remote_schema,
181 arrow_schema.clone(),
182 projection.as_ref(),
183 )?;
184 Ok::<RecordBatch, DataFusionError>(batch)
185 });
186
187 let output_stream = async_stream::stream! {
188 yield Ok(first_chunk);
189 while let Some(batch) = stream.next().await {
190 yield batch
191 }
192 };
193
194 Ok((
195 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
196 projected_remote_schema,
197 ))
198 }
199}
200
201fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
202 let empty_flags = mysql_col.flags().is_empty();
203 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
204 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
205 let col_length = mysql_col.column_length();
206 match mysql_col.column_type() {
207 ColumnType::MYSQL_TYPE_TINY => Ok(RemoteType::Mysql(MysqlType::TinyInt)),
208 ColumnType::MYSQL_TYPE_SHORT => Ok(RemoteType::Mysql(MysqlType::SmallInt)),
209 ColumnType::MYSQL_TYPE_INT24 => Ok(RemoteType::Mysql(MysqlType::MediumInt)),
210 ColumnType::MYSQL_TYPE_LONG => Ok(RemoteType::Mysql(MysqlType::Integer)),
211 ColumnType::MYSQL_TYPE_LONGLONG => Ok(RemoteType::Mysql(MysqlType::BigInt)),
212 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
213 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
214 ColumnType::MYSQL_TYPE_NEWDECIMAL => {
215 let precision = (mysql_col.column_length() - 2) as u8;
216 let scale = mysql_col.decimals();
217 Ok(RemoteType::Mysql(MysqlType::Decimal(precision, scale)))
218 }
219 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
220 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
221 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
222 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
223 ColumnType::MYSQL_TYPE_YEAR => Ok(RemoteType::Mysql(MysqlType::Year)),
224 ColumnType::MYSQL_TYPE_STRING if empty_flags => Ok(RemoteType::Mysql(MysqlType::Char)),
225 ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
226 ColumnType::MYSQL_TYPE_VAR_STRING if empty_flags => {
227 Ok(RemoteType::Mysql(MysqlType::Varchar))
228 }
229 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
230 Ok(RemoteType::Mysql(MysqlType::Varbinary))
231 }
232 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
233 ColumnType::MYSQL_TYPE_BLOB if col_length == 1020 && is_blob && !is_binary => {
234 Ok(RemoteType::Mysql(MysqlType::TinyText))
235 }
236 ColumnType::MYSQL_TYPE_BLOB if col_length == 262140 && is_blob && !is_binary => {
237 Ok(RemoteType::Mysql(MysqlType::Text))
238 }
239 ColumnType::MYSQL_TYPE_BLOB if col_length == 67108860 && is_blob && !is_binary => {
240 Ok(RemoteType::Mysql(MysqlType::MediumText))
241 }
242 ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && !is_binary => {
243 Ok(RemoteType::Mysql(MysqlType::LongText))
244 }
245 ColumnType::MYSQL_TYPE_BLOB if col_length == 255 && is_blob && is_binary => {
246 Ok(RemoteType::Mysql(MysqlType::TinyBlob))
247 }
248 ColumnType::MYSQL_TYPE_BLOB if col_length == 65535 && is_blob && is_binary => {
249 Ok(RemoteType::Mysql(MysqlType::Blob))
250 }
251 ColumnType::MYSQL_TYPE_BLOB if col_length == 16777215 && is_blob && is_binary => {
252 Ok(RemoteType::Mysql(MysqlType::MediumBlob))
253 }
254 ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && is_binary => {
255 Ok(RemoteType::Mysql(MysqlType::LongBlob))
256 }
257 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
258 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
259 _ => Err(DataFusionError::NotImplemented(format!(
260 "Unsupported mysql type: {mysql_col:?}",
261 ))),
262 }
263}
264
265fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
266 let mut remote_fields = vec![];
267 for col in row.columns_ref() {
268 remote_fields.push(RemoteField::new(
269 col.name_str().to_string(),
270 mysql_type_to_remote_type(col)?,
271 true,
272 ));
273 }
274 Ok(RemoteSchema::new(remote_fields))
275}
276
277macro_rules! handle_primitive_type {
278 ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
279 let builder = $builder
280 .as_any_mut()
281 .downcast_mut::<$builder_ty>()
282 .unwrap_or_else(|| {
283 panic!(
284 concat!(
285 "Failed to downcast builder to ",
286 stringify!($builder_ty),
287 " for {:?}"
288 ),
289 $mysql_col
290 )
291 });
292 let v = $row.get::<Option<$value_ty>, usize>($index);
293
294 match v {
295 Some(Some(v)) => builder.append_value(v),
296 _ => builder.append_null(),
297 }
298 }};
299}
300
301fn rows_to_batch(
302 rows: &[Row],
303 remote_schema: &RemoteSchema,
304 arrow_schema: SchemaRef,
305 projection: Option<&Vec<usize>>,
306) -> DFResult<RecordBatch> {
307 let projected_schema = project_schema(&arrow_schema, projection)?;
308 let mut array_builders = vec![];
309 for field in arrow_schema.fields() {
310 let builder = make_builder(field.data_type(), rows.len());
311 array_builders.push(builder);
312 }
313
314 for row in rows {
315 for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
316 if !projections_contains(projection, idx) {
317 continue;
318 }
319 let builder = &mut array_builders[idx];
320 match remote_field.remote_type {
321 RemoteType::Mysql(MysqlType::TinyInt) => {
322 handle_primitive_type!(builder, remote_field, Int8Builder, i8, row, idx);
323 }
324 RemoteType::Mysql(MysqlType::SmallInt) | RemoteType::Mysql(MysqlType::Year) => {
325 handle_primitive_type!(builder, remote_field, Int16Builder, i16, row, idx);
326 }
327 RemoteType::Mysql(MysqlType::MediumInt) | RemoteType::Mysql(MysqlType::Integer) => {
328 handle_primitive_type!(builder, remote_field, Int32Builder, i32, row, idx);
329 }
330 RemoteType::Mysql(MysqlType::BigInt) => {
331 handle_primitive_type!(builder, remote_field, Int64Builder, i64, row, idx);
332 }
333 RemoteType::Mysql(MysqlType::Float) => {
334 handle_primitive_type!(builder, remote_field, Float32Builder, f32, row, idx);
335 }
336 RemoteType::Mysql(MysqlType::Double) => {
337 handle_primitive_type!(builder, remote_field, Float64Builder, f64, row, idx);
338 }
339 RemoteType::Mysql(MysqlType::Decimal(precision, scale)) => {
340 if precision > 38 {
341 let builder = builder
342 .as_any_mut()
343 .downcast_mut::<Decimal256Builder>()
344 .unwrap_or_else(|| {
345 panic!(
346 "Failed to downcast builder to Decimal256Builder for {:?}",
347 remote_field
348 )
349 });
350 let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
351
352 match v {
353 Some(Some(v)) => builder.append_value(to_decimal_256(&v)),
354 _ => builder.append_null(),
355 }
356 } else {
357 let builder = builder
358 .as_any_mut()
359 .downcast_mut::<Decimal128Builder>()
360 .unwrap_or_else(|| {
361 panic!(
362 "Failed to downcast builder to Decimal128Builder for {:?}",
363 remote_field
364 )
365 });
366 let v = row.get::<Option<bigdecimal::BigDecimal>, usize>(idx);
367
368 match v {
369 Some(Some(v)) => {
370 let Some(v) = big_decimal_to_i128(&v, Some(scale as u32)) else {
371 return Err(DataFusionError::Execution(format!(
372 "Failed to convert BigDecimal {v:?} to i128"
373 )));
374 };
375 builder.append_value(v)
376 }
377 _ => builder.append_null(),
378 }
379 }
380 }
381 RemoteType::Mysql(MysqlType::Date) => {
382 let builder = builder
384 .as_any_mut()
385 .downcast_mut::<Date32Builder>()
386 .unwrap_or_else(|| {
387 panic!(
388 "Failed to downcast builder to Date32Builder for {:?}",
389 remote_field
390 )
391 });
392 let v = row.get::<Option<chrono::NaiveDate>, usize>(idx);
393
394 match v {
395 Some(Some(v)) => builder.append_value(Date32Type::from_naive_date(v)),
396 _ => builder.append_null(),
397 }
398 }
399 RemoteType::Mysql(MysqlType::Datetime)
400 | RemoteType::Mysql(MysqlType::Timestamp) => {
401 let builder = builder
402 .as_any_mut()
403 .downcast_mut::<TimestampMicrosecondBuilder>()
404 .unwrap_or_else(|| {
405 panic!(
406 "Failed to downcast builder to TimestampMicrosecondBuilder for {:?}",
407 remote_field
408 )
409 });
410 let v = row.get::<Option<time::PrimitiveDateTime>, usize>(idx);
411
412 match v {
413 Some(Some(v)) => {
414 let timestamp_micros =
415 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
416 builder.append_value(timestamp_micros)
417 }
418 _ => builder.append_null(),
419 }
420 }
421 RemoteType::Mysql(MysqlType::Time) => {
422 let builder = builder
423 .as_any_mut()
424 .downcast_mut::<Time64NanosecondBuilder>()
425 .unwrap_or_else(|| {
426 panic!(
427 "Failed to downcast builder to Time64NanosecondBuilder for {:?}",
428 remote_field
429 )
430 });
431 let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
432
433 match v {
434 Some(Some(v)) => {
435 builder.append_value(
436 i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
437 + i64::from(v.nanosecond()),
438 );
439 }
440 _ => builder.append_null(),
441 }
442 }
443 RemoteType::Mysql(MysqlType::Char)
444 | RemoteType::Mysql(MysqlType::Varchar)
445 | RemoteType::Mysql(MysqlType::TinyText)
446 | RemoteType::Mysql(MysqlType::Text)
447 | RemoteType::Mysql(MysqlType::MediumText) => {
448 handle_primitive_type!(builder, remote_field, StringBuilder, String, row, idx);
449 }
450 RemoteType::Mysql(MysqlType::LongText) | RemoteType::Mysql(MysqlType::Json) => {
451 handle_primitive_type!(
452 builder,
453 remote_field,
454 LargeStringBuilder,
455 String,
456 row,
457 idx
458 );
459 }
460 RemoteType::Mysql(MysqlType::Binary)
461 | RemoteType::Mysql(MysqlType::Varbinary)
462 | RemoteType::Mysql(MysqlType::TinyBlob)
463 | RemoteType::Mysql(MysqlType::Blob)
464 | RemoteType::Mysql(MysqlType::MediumBlob) => {
465 handle_primitive_type!(builder, remote_field, BinaryBuilder, Vec<u8>, row, idx);
466 }
467 RemoteType::Mysql(MysqlType::LongBlob) | RemoteType::Mysql(MysqlType::Geometry) => {
468 handle_primitive_type!(
469 builder,
470 remote_field,
471 LargeBinaryBuilder,
472 Vec<u8>,
473 row,
474 idx
475 );
476 }
477 _ => panic!("Invalid mysql type: {:?}", remote_field.remote_type),
478 }
479 }
480 }
481 let projected_columns = array_builders
482 .into_iter()
483 .enumerate()
484 .filter(|(idx, _)| projections_contains(projection, *idx))
485 .map(|(_, mut builder)| builder.finish())
486 .collect::<Vec<ArrayRef>>();
487 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
488}
489
490fn to_decimal_256(decimal: &bigdecimal::BigDecimal) -> i256 {
491 let (bigint_value, _) = decimal.as_bigint_and_exponent();
492 let mut bigint_bytes = bigint_value.to_signed_bytes_le();
493
494 let is_negative = bigint_value.sign() == num_bigint::Sign::Minus;
495 let fill_byte = if is_negative { 0xFF } else { 0x00 };
496
497 if bigint_bytes.len() > 32 {
498 bigint_bytes.truncate(32);
499 } else {
500 bigint_bytes.resize(32, fill_byte);
501 };
502
503 let mut array = [0u8; 32];
504 array.copy_from_slice(&bigint_bytes);
505
506 i256::from_le_bytes(array)
507}