1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4 project_remote_schema, Connection, DFResult, MysqlType, Pool, RemoteField, RemoteSchema,
5 RemoteType, Transform,
6};
7use async_stream::stream;
8use chrono::Timelike;
9use datafusion::arrow::array::{
10 make_builder, ArrayRef, BinaryBuilder, Date32Builder, Float32Builder, Float64Builder,
11 Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeBinaryBuilder, LargeStringBuilder,
12 RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder,
13};
14use datafusion::arrow::datatypes::{Date32Type, SchemaRef};
15use datafusion::common::{project_schema, DataFusionError};
16use datafusion::execution::SendableRecordBatchStream;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use futures::lock::Mutex;
19use futures::StreamExt;
20use mysql_async::consts::{ColumnFlags, ColumnType};
21use mysql_async::prelude::Queryable;
22use mysql_async::{Column, Row};
23use std::sync::Arc;
24
25#[derive(Debug, Clone, derive_with::With)]
26pub struct MysqlConnectionOptions {
27 pub(crate) host: String,
28 pub(crate) port: u16,
29 pub(crate) username: String,
30 pub(crate) password: String,
31 pub(crate) database: Option<String>,
32}
33
34impl MysqlConnectionOptions {
35 pub fn new(
36 host: impl Into<String>,
37 port: u16,
38 username: impl Into<String>,
39 password: impl Into<String>,
40 ) -> Self {
41 Self {
42 host: host.into(),
43 port,
44 username: username.into(),
45 password: password.into(),
46 database: None,
47 }
48 }
49}
50
51#[derive(Debug)]
52pub struct MysqlPool {
53 pool: mysql_async::Pool,
54}
55
56pub fn connect_mysql(options: &MysqlConnectionOptions) -> DFResult<MysqlPool> {
57 let opts_builder = mysql_async::OptsBuilder::default()
58 .ip_or_hostname(options.host.clone())
59 .tcp_port(options.port)
60 .user(Some(options.username.clone()))
61 .pass(Some(options.password.clone()))
62 .db_name(options.database.clone());
63 let pool = mysql_async::Pool::new(opts_builder);
64 Ok(MysqlPool { pool })
65}
66
67#[async_trait::async_trait]
68impl Pool for MysqlPool {
69 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
70 let conn = self.pool.get_conn().await.map_err(|e| {
71 DataFusionError::Execution(format!("Failed to get mysql connection from pool: {:?}", e))
72 })?;
73 Ok(Arc::new(MysqlConnection {
74 conn: Arc::new(Mutex::new(conn)),
75 }))
76 }
77}
78
79#[derive(Debug)]
80pub struct MysqlConnection {
81 conn: Arc<Mutex<mysql_async::Conn>>,
82}
83
84#[async_trait::async_trait]
85impl Connection for MysqlConnection {
86 async fn infer_schema(
87 &self,
88 sql: &str,
89 transform: Option<Arc<dyn Transform>>,
90 ) -> DFResult<(RemoteSchema, SchemaRef)> {
91 let mut conn = self.conn.lock().await;
92 let conn = &mut *conn;
93 let row: Option<Row> = conn.query_first(sql).await.map_err(|e| {
94 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}",))
95 })?;
96 let Some(row) = row else {
97 return Err(DataFusionError::Execution(
98 "No rows returned to infer schema".to_string(),
99 ));
100 };
101 let remote_schema = build_remote_schema(&row)?;
102 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
103 if let Some(transform) = transform {
104 let batch = rows_to_batch(&[row], &remote_schema, arrow_schema.clone(), None)?;
105 let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
106 Ok((remote_schema, transformed_batch.schema()))
107 } else {
108 Ok((remote_schema, arrow_schema))
109 }
110 }
111
112 async fn query(
113 &self,
114 sql: String,
115 projection: Option<Vec<usize>>,
116 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
117 let conn = Arc::clone(&self.conn);
118 let mut stream = Box::pin(stream! {
119 let mut conn = conn.lock().await;
120 let mut query_iter = conn
121 .query_iter(sql)
122 .await
123 .map_err(|e| {
124 DataFusionError::Execution(format!("Failed to execute query on mysql: {e:?}"))
125 })?;
126
127 let Some(stream) = query_iter.stream::<Row>().await.map_err(|e| {
128 DataFusionError::Execution(format!("Failed to get stream from mysql: {e:?}"))
129 })? else {
130 yield Err(DataFusionError::Execution("Get none stream from mysql".to_string()));
131 return;
132 };
133
134 let mut chunked_stream = stream.chunks(4_000).boxed();
135
136 while let Some(chunk) = chunked_stream.next().await {
137 let rows = chunk
138 .into_iter()
139 .collect::<Result<Vec<_>, _>>()
140 .map_err(|e| {
141 DataFusionError::Execution(format!(
142 "Failed to collect rows from mysql due to {e}",
143 ))
144 })?;
145
146 yield Ok::<_, DataFusionError>(rows)
147 }
148 });
149
150 let Some(first_chunk) = stream.next().await else {
151 return Err(DataFusionError::Execution(
152 "No data returned from mysql".to_string(),
153 ));
154 };
155 let first_chunk = first_chunk?;
156
157 let Some(first_row) = first_chunk.first() else {
158 return Err(DataFusionError::Execution(
159 "No data returned from mysql".to_string(),
160 ));
161 };
162
163 let remote_schema = build_remote_schema(first_row)?;
164 let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
165 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
166 let first_chunk = rows_to_batch(
167 first_chunk.as_slice(),
168 &remote_schema,
169 arrow_schema.clone(),
170 projection.as_ref(),
171 )?;
172 let schema = first_chunk.schema();
173
174 let mut stream = stream.map(move |rows| {
175 let rows = rows?;
176 let batch = rows_to_batch(
177 rows.as_slice(),
178 &remote_schema,
179 arrow_schema.clone(),
180 projection.as_ref(),
181 )?;
182 Ok::<RecordBatch, DataFusionError>(batch)
183 });
184
185 let output_stream = async_stream::stream! {
186 yield Ok(first_chunk);
187 while let Some(batch) = stream.next().await {
188 yield batch
189 }
190 };
191
192 Ok((
193 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
194 projected_remote_schema,
195 ))
196 }
197}
198
199fn mysql_type_to_remote_type(mysql_col: &Column) -> DFResult<RemoteType> {
200 let empty_flags = mysql_col.flags().is_empty();
201 let is_binary = mysql_col.flags().contains(ColumnFlags::BINARY_FLAG);
202 let is_blob = mysql_col.flags().contains(ColumnFlags::BLOB_FLAG);
203 let col_length = mysql_col.column_length();
204 match mysql_col.column_type() {
205 ColumnType::MYSQL_TYPE_TINY => Ok(RemoteType::Mysql(MysqlType::TinyInt)),
206 ColumnType::MYSQL_TYPE_SHORT => Ok(RemoteType::Mysql(MysqlType::SmallInt)),
207 ColumnType::MYSQL_TYPE_INT24 => Ok(RemoteType::Mysql(MysqlType::MediumInt)),
208 ColumnType::MYSQL_TYPE_LONG => Ok(RemoteType::Mysql(MysqlType::Integer)),
209 ColumnType::MYSQL_TYPE_LONGLONG => Ok(RemoteType::Mysql(MysqlType::BigInt)),
210 ColumnType::MYSQL_TYPE_FLOAT => Ok(RemoteType::Mysql(MysqlType::Float)),
211 ColumnType::MYSQL_TYPE_DOUBLE => Ok(RemoteType::Mysql(MysqlType::Double)),
212 ColumnType::MYSQL_TYPE_DATE => Ok(RemoteType::Mysql(MysqlType::Date)),
213 ColumnType::MYSQL_TYPE_DATETIME => Ok(RemoteType::Mysql(MysqlType::Datetime)),
214 ColumnType::MYSQL_TYPE_TIME => Ok(RemoteType::Mysql(MysqlType::Time)),
215 ColumnType::MYSQL_TYPE_TIMESTAMP => Ok(RemoteType::Mysql(MysqlType::Timestamp)),
216 ColumnType::MYSQL_TYPE_STRING if empty_flags => Ok(RemoteType::Mysql(MysqlType::Char)),
217 ColumnType::MYSQL_TYPE_STRING if is_binary => Ok(RemoteType::Mysql(MysqlType::Binary)),
218 ColumnType::MYSQL_TYPE_VAR_STRING if empty_flags => {
219 Ok(RemoteType::Mysql(MysqlType::Varchar))
220 }
221 ColumnType::MYSQL_TYPE_VAR_STRING if is_binary => {
222 Ok(RemoteType::Mysql(MysqlType::Varbinary))
223 }
224 ColumnType::MYSQL_TYPE_VARCHAR => Ok(RemoteType::Mysql(MysqlType::Varchar)),
225 ColumnType::MYSQL_TYPE_BLOB if col_length == 1020 && is_blob && !is_binary => {
226 Ok(RemoteType::Mysql(MysqlType::TinyText))
227 }
228 ColumnType::MYSQL_TYPE_BLOB if col_length == 262140 && is_blob && !is_binary => {
229 Ok(RemoteType::Mysql(MysqlType::Text))
230 }
231 ColumnType::MYSQL_TYPE_BLOB if col_length == 67108860 && is_blob && !is_binary => {
232 Ok(RemoteType::Mysql(MysqlType::MediumText))
233 }
234 ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && !is_binary => {
235 Ok(RemoteType::Mysql(MysqlType::LongText))
236 }
237 ColumnType::MYSQL_TYPE_BLOB if col_length == 255 && is_blob && is_binary => {
238 Ok(RemoteType::Mysql(MysqlType::TinyBlob))
239 }
240 ColumnType::MYSQL_TYPE_BLOB if col_length == 65535 && is_blob && is_binary => {
241 Ok(RemoteType::Mysql(MysqlType::Blob))
242 }
243 ColumnType::MYSQL_TYPE_BLOB if col_length == 16777215 && is_blob && is_binary => {
244 Ok(RemoteType::Mysql(MysqlType::MediumBlob))
245 }
246 ColumnType::MYSQL_TYPE_BLOB if col_length == 4294967295 && is_blob && is_binary => {
247 Ok(RemoteType::Mysql(MysqlType::LongBlob))
248 }
249 ColumnType::MYSQL_TYPE_JSON => Ok(RemoteType::Mysql(MysqlType::Json)),
250 ColumnType::MYSQL_TYPE_GEOMETRY => Ok(RemoteType::Mysql(MysqlType::Geometry)),
251 _ => Err(DataFusionError::NotImplemented(format!(
252 "Unsupported mysql type: {mysql_col:?}",
253 ))),
254 }
255}
256
257fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
258 let mut remote_fields = vec![];
259 for col in row.columns_ref() {
260 remote_fields.push(RemoteField::new(
261 col.name_str().to_string(),
262 mysql_type_to_remote_type(col)?,
263 true,
264 ));
265 }
266 Ok(RemoteSchema::new(remote_fields))
267}
268
269macro_rules! handle_primitive_type {
270 ($builder:expr, $mysql_col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
271 let builder = $builder
272 .as_any_mut()
273 .downcast_mut::<$builder_ty>()
274 .unwrap_or_else(|| {
275 panic!(
276 concat!(
277 "Failed to downcast builder to ",
278 stringify!($builder_ty),
279 " for {:?}"
280 ),
281 $mysql_col
282 )
283 });
284 let v = $row.get::<Option<$value_ty>, usize>($index);
285
286 match v {
287 Some(Some(v)) => builder.append_value(v),
288 _ => builder.append_null(),
289 }
290 }};
291}
292
293fn rows_to_batch(
294 rows: &[Row],
295 remote_schema: &RemoteSchema,
296 arrow_schema: SchemaRef,
297 projection: Option<&Vec<usize>>,
298) -> DFResult<RecordBatch> {
299 let projected_schema = project_schema(&arrow_schema, projection)?;
300 let mut array_builders = vec![];
301 for field in arrow_schema.fields() {
302 let builder = make_builder(field.data_type(), rows.len());
303 array_builders.push(builder);
304 }
305
306 for row in rows {
307 for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
308 if !projections_contains(projection, idx) {
309 continue;
310 }
311 let builder = &mut array_builders[idx];
312 match remote_field.remote_type {
313 RemoteType::Mysql(MysqlType::TinyInt) => {
314 handle_primitive_type!(builder, remote_field, Int8Builder, i8, row, idx);
315 }
316 RemoteType::Mysql(MysqlType::SmallInt) => {
317 handle_primitive_type!(builder, remote_field, Int16Builder, i16, row, idx);
318 }
319 RemoteType::Mysql(MysqlType::MediumInt) | RemoteType::Mysql(MysqlType::Integer) => {
320 handle_primitive_type!(builder, remote_field, Int32Builder, i32, row, idx);
321 }
322 RemoteType::Mysql(MysqlType::BigInt) => {
323 handle_primitive_type!(builder, remote_field, Int64Builder, i64, row, idx);
324 }
325 RemoteType::Mysql(MysqlType::Float) => {
326 handle_primitive_type!(builder, remote_field, Float32Builder, f32, row, idx);
327 }
328 RemoteType::Mysql(MysqlType::Double) => {
329 handle_primitive_type!(builder, remote_field, Float64Builder, f64, row, idx);
330 }
331 RemoteType::Mysql(MysqlType::Date) => {
332 let builder = builder
333 .as_any_mut()
334 .downcast_mut::<Date32Builder>()
335 .unwrap_or_else(|| {
336 panic!(
337 "Failed to downcast builder to Date32Builder for {:?}",
338 remote_field
339 )
340 });
341 let v = row.get::<Option<chrono::NaiveDate>, usize>(idx);
342
343 match v {
344 Some(Some(v)) => builder.append_value(Date32Type::from_naive_date(v)),
345 _ => builder.append_null(),
346 }
347 }
348 RemoteType::Mysql(MysqlType::Datetime)
349 | RemoteType::Mysql(MysqlType::Timestamp) => {
350 let builder = builder
351 .as_any_mut()
352 .downcast_mut::<TimestampMicrosecondBuilder>()
353 .unwrap_or_else(|| {
354 panic!(
355 "Failed to downcast builder to TimestampMicrosecondBuilder for {:?}",
356 remote_field
357 )
358 });
359 let v = row.get::<Option<time::PrimitiveDateTime>, usize>(idx);
360
361 match v {
362 Some(Some(v)) => {
363 let timestamp_micros =
364 (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64;
365 builder.append_value(timestamp_micros)
366 }
367 _ => builder.append_null(),
368 }
369 }
370 RemoteType::Mysql(MysqlType::Time) => {
371 let builder = builder
372 .as_any_mut()
373 .downcast_mut::<Time64NanosecondBuilder>()
374 .unwrap_or_else(|| {
375 panic!(
376 "Failed to downcast builder to Time64NanosecondBuilder for {:?}",
377 remote_field
378 )
379 });
380 let v = row.get::<Option<chrono::NaiveTime>, usize>(idx);
381
382 match v {
383 Some(Some(v)) => {
384 builder.append_value(
385 i64::from(v.num_seconds_from_midnight()) * 1_000_000_000
386 + i64::from(v.nanosecond()),
387 );
388 }
389 _ => builder.append_null(),
390 }
391 }
392 RemoteType::Mysql(MysqlType::Char)
393 | RemoteType::Mysql(MysqlType::Varchar)
394 | RemoteType::Mysql(MysqlType::TinyText)
395 | RemoteType::Mysql(MysqlType::Text)
396 | RemoteType::Mysql(MysqlType::MediumText) => {
397 handle_primitive_type!(builder, remote_field, StringBuilder, String, row, idx);
398 }
399 RemoteType::Mysql(MysqlType::LongText) | RemoteType::Mysql(MysqlType::Json) => {
400 handle_primitive_type!(
401 builder,
402 remote_field,
403 LargeStringBuilder,
404 String,
405 row,
406 idx
407 );
408 }
409 RemoteType::Mysql(MysqlType::Binary)
410 | RemoteType::Mysql(MysqlType::Varbinary)
411 | RemoteType::Mysql(MysqlType::TinyBlob)
412 | RemoteType::Mysql(MysqlType::Blob)
413 | RemoteType::Mysql(MysqlType::MediumBlob) => {
414 handle_primitive_type!(builder, remote_field, BinaryBuilder, Vec<u8>, row, idx);
415 }
416 RemoteType::Mysql(MysqlType::LongBlob) | RemoteType::Mysql(MysqlType::Geometry) => {
417 handle_primitive_type!(
418 builder,
419 remote_field,
420 LargeBinaryBuilder,
421 Vec<u8>,
422 row,
423 idx
424 );
425 }
426 _ => panic!("Invalid mysql type: {:?}", remote_field.remote_type),
427 }
428 }
429 }
430 let projected_columns = array_builders
431 .into_iter()
432 .enumerate()
433 .filter(|(idx, _)| projections_contains(projection, *idx))
434 .map(|(_, mut builder)| builder.finish())
435 .collect::<Vec<ArrayRef>>();
436 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
437}