1use std::{any::Any, sync::Arc};
2
3use crate::sql::arrow_sql_gen::mysql::map_column_to_data_type;
4use crate::sql::arrow_sql_gen::{self, mysql::rows_to_arrow};
5use async_stream::stream;
6use datafusion::arrow::datatypes::{Field, Schema, SchemaRef};
7use datafusion::error::DataFusionError;
8use datafusion::execution::SendableRecordBatchStream;
9use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
10use datafusion::sql::unparser::dialect::{Dialect, MySqlDialect};
11use datafusion::sql::TableReference;
12use futures::lock::Mutex;
13use futures::{stream, StreamExt};
14use mysql_async::consts::ColumnType;
15use mysql_async::prelude::Queryable;
16use mysql_async::{prelude::ToValue, Conn, Params, Row};
17use snafu::prelude::*;
18
19use super::Result;
20use super::{AsyncDbConnection, DbConnection};
21
22#[derive(Debug, Snafu)]
23pub enum Error {
24 #[snafu(display("Query execution failed.\n{source}\nFor details, refer to the MySQL manual: https://dev.mysql.com/doc/mysql-errors/9.1/en/error-reference-introduction.html"))]
25 QueryError { source: mysql_async::Error },
26
27 #[snafu(display("Failed to convert query result to Arrow.\n{source}.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
28 ConversionError { source: arrow_sql_gen::mysql::Error },
29
30 #[snafu(display("An unexpected error occurred. Verify the configuration and try again."))]
31 QueryResultStreamError {},
32
33 #[snafu(display("Unsupported data type '{data_type}' for field '{column_name}'.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
34 UnsupportedDataTypeError {
35 column_name: String,
36 data_type: String,
37 },
38
39 #[snafu(display("Unable to extract precision and scale from type: {data_type}.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
40 UnableToGetDecimalPrecisionAndScale { data_type: String },
41
42 #[snafu(display("Failed to find the field '{field}'.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
43 MissingField { field: String },
44}
45
46pub struct MySQLConnection {
47 pub conn: Arc<Mutex<Conn>>,
48}
49
50impl MySQLConnection {
51 fn to_mysql_quoted_string(tbl: &TableReference) -> String {
55 let q = MySqlDialect {}
56 .identifier_quote_style("") .unwrap_or_default();
58
59 [tbl.catalog(), tbl.schema(), Some(tbl.table())]
60 .into_iter()
61 .flatten()
62 .map(|part| {
63 if part.starts_with(q) && part.ends_with(q) {
64 part.to_string()
65 } else {
66 format!("{quote}{part}{quote}", quote = q)
67 }
68 })
69 .collect::<Vec<_>>()
70 .join(".")
71 }
72}
73
74impl<'a> DbConnection<Conn, &'a (dyn ToValue + Sync)> for MySQLConnection {
75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 fn as_any_mut(&mut self) -> &mut dyn Any {
80 self
81 }
82
83 fn as_async(&self) -> Option<&dyn super::AsyncDbConnection<Conn, &'a (dyn ToValue + Sync)>> {
84 Some(self)
85 }
86}
87
88#[async_trait::async_trait]
89impl<'a> AsyncDbConnection<Conn, &'a (dyn ToValue + Sync)> for MySQLConnection {
90 fn new(conn: Conn) -> Self {
91 MySQLConnection {
92 conn: Arc::new(Mutex::new(conn)),
93 }
94 }
95
96 async fn tables(&self, schema: &str) -> Result<Vec<String>, super::Error> {
97 let mut conn = self.conn.lock().await;
98 let conn = &mut *conn;
99
100 let query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ?";
101 let tables: Vec<Row> = conn
102 .exec(query, (schema,))
103 .await
104 .boxed()
105 .context(super::UnableToGetTablesSnafu)?;
106
107 let table_names = tables
108 .iter()
109 .filter_map(|row| row.get::<String, _>("TABLE_NAME"))
110 .collect();
111
112 Ok(table_names)
113 }
114
115 async fn schemas(&self) -> Result<Vec<String>, super::Error> {
116 let mut conn = self.conn.lock().await;
117 let conn = &mut *conn;
118
119 let query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA \
120 WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', \
121 'performance_schema', 'sys')";
122
123 let schemas: Vec<Row> = conn
124 .exec(query, ())
125 .await
126 .boxed()
127 .context(super::UnableToGetSchemasSnafu)?;
128
129 let schema_names = schemas
130 .iter()
131 .filter_map(|row| row.get::<String, _>("SCHEMA_NAME"))
132 .collect();
133
134 Ok(schema_names)
135 }
136
137 async fn get_schema(
138 &self,
139 table_reference: &TableReference,
140 ) -> Result<SchemaRef, super::Error> {
141 let mut conn = self.conn.lock().await;
142 let conn = &mut *conn;
143
144 let columns_meta: Vec<Row> = match conn
148 .exec(
149 format!(
150 "SHOW COLUMNS FROM {table_name}",
151 table_name = Self::to_mysql_quoted_string(table_reference)
152 ),
153 Params::Empty,
154 )
155 .await
156 {
157 Ok(columns_meta) => columns_meta,
158 Err(e) => match e {
159 mysql_async::Error::Server(server_error) => {
160 if server_error.code == 1146 {
161 return Err(super::Error::UndefinedTable {
162 source: Box::new(server_error.clone()),
163 table_name: table_reference.to_string(),
164 });
165 }
166 return Err(super::Error::UnableToGetSchema {
167 source: Box::new(mysql_async::Error::Server(server_error)),
168 });
169 }
170 _ => {
171 return Err(super::Error::UnableToGetSchema {
172 source: Box::new(e),
173 })
174 }
175 },
176 };
177
178 Ok(columns_meta_to_schema(columns_meta).context(super::UnableToGetSchemaSnafu)?)
179 }
180
181 async fn query_arrow(
182 &self,
183 sql: &str,
184 params: &[&'a (dyn ToValue + Sync)],
185 projected_schema: Option<SchemaRef>,
186 ) -> Result<SendableRecordBatchStream> {
187 let params_vec: Vec<_> = params.iter().map(|&p| p.to_value()).collect();
188 let sql = sql.replace('"', "");
189
190 let conn = Arc::clone(&self.conn);
191
192 let mut stream = Box::pin(stream! {
193 let mut conn = conn.lock().await;
194 let mut exec_iter = conn
195 .exec_iter(sql, Params::from(params_vec))
196 .await
197 .context(QuerySnafu)?;
198
199 let Some(stream) = exec_iter.stream::<Row>().await.context(QuerySnafu)? else {
200 yield Err(Error::QueryResultStreamError {});
201 return;
202 };
203
204 let mut chunked_stream = stream.chunks(4_000).boxed();
205
206 while let Some(chunk) = chunked_stream.next().await {
207 let rows = chunk
208 .into_iter()
209 .collect::<Result<Vec<_>, _>>()
210 .context(QuerySnafu)?;
211
212 let rec = rows_to_arrow(&rows, &projected_schema).context(ConversionSnafu)?;
213 yield Ok::<_, Error>(rec)
214 }
215 });
216
217 let Some(first_chunk) = stream.next().await else {
218 return Ok(Box::pin(RecordBatchStreamAdapter::new(
219 Arc::new(Schema::empty()),
220 stream::empty(),
221 )));
222 };
223
224 let first_chunk = first_chunk?;
225 let schema = first_chunk.schema();
226
227 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, {
228 stream! {
229 yield Ok(first_chunk);
230 while let Some(batch) = stream.next().await {
231 yield batch
232 .map_err(|e| DataFusionError::Execution(format!("Failed to fetch batch: {e}")))
233 }
234 }
235 })))
236 }
237
238 async fn execute(&self, query: &str, params: &[&'a (dyn ToValue + Sync)]) -> Result<u64> {
239 let mut conn = self.conn.lock().await;
240 let conn = &mut *conn;
241 let params_vec: Vec<_> = params.iter().map(|&p| p.to_value()).collect();
242 let _: Vec<Row> = conn
243 .exec(query, Params::from(params_vec))
244 .await
245 .context(QuerySnafu)?;
246 return Ok(conn.affected_rows());
247 }
248}
249
250fn columns_meta_to_schema(columns_meta: Vec<Row>) -> Result<SchemaRef> {
251 let mut fields = Vec::new();
252
253 for row in columns_meta.iter() {
254 let column_name: String = row.get("Field").ok_or(Error::MissingField {
255 field: "Field".to_string(),
256 })?;
257
258 let data_type: String = row.get("Type").ok_or(Error::MissingField {
259 field: "Type".to_string(),
260 })?;
261
262 let column_type = map_str_type_to_column_type(&column_name, &data_type)?;
263 let column_is_binary = map_str_type_to_is_binary(&data_type);
264 let column_is_enum = map_str_type_to_is_enum(&data_type);
265 let column_use_large_str_or_blob = map_str_type_to_use_large_str_or_blob(&data_type);
266
267 let (precision, scale) = match column_type {
268 ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => {
269 let (precision, scale) = extract_decimal_precision_and_scale(&data_type)
270 .context(super::UnableToGetSchemaSnafu)?;
271 (Some(precision), Some(scale))
272 }
273 _ => (None, None),
274 };
275
276 let arrow_data_type = map_column_to_data_type(
277 column_type,
278 column_is_binary,
279 column_is_enum,
280 column_use_large_str_or_blob,
281 precision,
282 scale,
283 )
284 .context(UnsupportedDataTypeSnafu {
285 column_name: column_name.clone(),
286 data_type,
287 })?;
288
289 fields.push(Field::new(&column_name, arrow_data_type, true));
290 }
291 Ok(Arc::new(Schema::new(fields)))
292}
293
294fn map_str_type_to_column_type(column_name: &str, data_type: &str) -> Result<ColumnType> {
295 let data_type = data_type.to_lowercase();
296 let column_type = match data_type.as_str() {
297 _ if data_type.starts_with("decimal") || data_type.starts_with("numeric") => {
298 ColumnType::MYSQL_TYPE_DECIMAL
299 }
300 _ if data_type.starts_with("tinyint") => ColumnType::MYSQL_TYPE_TINY,
302 _ if data_type.starts_with("smallint") => ColumnType::MYSQL_TYPE_SHORT,
303 _ if data_type.starts_with("int") => ColumnType::MYSQL_TYPE_LONG,
304 _ if data_type.starts_with("bigint") => ColumnType::MYSQL_TYPE_LONGLONG,
305 _ if data_type.starts_with("mediumint") => ColumnType::MYSQL_TYPE_INT24,
306 _ if data_type.starts_with("float") => ColumnType::MYSQL_TYPE_FLOAT,
307 _ if data_type.starts_with("double") => ColumnType::MYSQL_TYPE_DOUBLE,
308 _ if data_type.eq("null") => ColumnType::MYSQL_TYPE_NULL,
309 _ if data_type.starts_with("timestamp") => ColumnType::MYSQL_TYPE_TIMESTAMP,
310 _ if data_type.starts_with("time") => ColumnType::MYSQL_TYPE_TIME,
311 _ if data_type.starts_with("datetime") => ColumnType::MYSQL_TYPE_DATETIME,
312 _ if data_type.eq("date") => ColumnType::MYSQL_TYPE_DATE,
313 _ if data_type.eq("year") => ColumnType::MYSQL_TYPE_YEAR,
314 _ if data_type.eq("newdate") => ColumnType::MYSQL_TYPE_NEWDATE,
315 _ if data_type.starts_with("bit") => ColumnType::MYSQL_TYPE_BIT,
316 _ if data_type.starts_with("array") => ColumnType::MYSQL_TYPE_TYPED_ARRAY,
317 _ if data_type.starts_with("json") => ColumnType::MYSQL_TYPE_JSON,
318 _ if data_type.starts_with("newdecimal") => ColumnType::MYSQL_TYPE_NEWDECIMAL,
319 _ if data_type.starts_with("enum") => ColumnType::MYSQL_TYPE_STRING,
321 _ if data_type.starts_with("set") => ColumnType::MYSQL_TYPE_STRING,
322 _ if data_type.starts_with("tinyblob") => ColumnType::MYSQL_TYPE_BLOB,
323 _ if data_type.starts_with("tinytext") => ColumnType::MYSQL_TYPE_BLOB,
324 _ if data_type.starts_with("mediumblob") => ColumnType::MYSQL_TYPE_BLOB,
325 _ if data_type.starts_with("mediumtext") => ColumnType::MYSQL_TYPE_BLOB,
326 _ if data_type.starts_with("longblob") => ColumnType::MYSQL_TYPE_BLOB,
327 _ if data_type.starts_with("longtext") => ColumnType::MYSQL_TYPE_BLOB,
328 _ if data_type.starts_with("blob") => ColumnType::MYSQL_TYPE_BLOB,
329 _ if data_type.starts_with("text") => ColumnType::MYSQL_TYPE_BLOB,
330 _ if data_type.starts_with("varchar") => ColumnType::MYSQL_TYPE_VAR_STRING,
331 _ if data_type.starts_with("varbinary") => ColumnType::MYSQL_TYPE_VAR_STRING,
332 _ if data_type.starts_with("char") => ColumnType::MYSQL_TYPE_STRING,
333 _ if data_type.starts_with("binary") => ColumnType::MYSQL_TYPE_STRING,
334 _ if data_type.starts_with("geometry") => ColumnType::MYSQL_TYPE_GEOMETRY,
335 _ => UnsupportedDataTypeSnafu {
336 column_name,
337 data_type,
338 }
339 .fail()?,
340 };
341
342 Ok(column_type)
343}
344
345fn map_str_type_to_is_binary(data_type: &str) -> bool {
346 if data_type.starts_with("binary")
347 | data_type.starts_with("varbinary")
348 | data_type.starts_with("tinyblob")
349 | data_type.starts_with("mediumblob")
350 | data_type.starts_with("blob")
351 | data_type.starts_with("longblob")
352 {
353 return true;
354 }
355 false
356}
357
358fn map_str_type_to_use_large_str_or_blob(data_type: &str) -> bool {
359 if data_type.starts_with("long") {
360 return true;
361 }
362 false
363}
364
365fn map_str_type_to_is_enum(data_type: &str) -> bool {
366 if data_type.starts_with("enum") {
367 return true;
368 }
369 false
370}
371
372fn extract_decimal_precision_and_scale(data_type: &str) -> Result<(u8, i8)> {
373 let (start, end) = match (data_type.find('('), data_type.find(')')) {
374 (Some(start), Some(end)) => (start, end),
375 _ => UnableToGetDecimalPrecisionAndScaleSnafu { data_type }.fail()?,
376 };
377 let parts: Vec<&str> = data_type[start + 1..end].split(',').collect();
378 if parts.len() != 2 {
379 UnableToGetDecimalPrecisionAndScaleSnafu { data_type }.fail()?;
380 }
381
382 let precision =
383 parts[0]
384 .parse::<u8>()
385 .map_err(|_| Error::UnableToGetDecimalPrecisionAndScale {
386 data_type: data_type.to_string(),
387 })?;
388 let scale = parts[1]
389 .parse::<i8>()
390 .map_err(|_| Error::UnableToGetDecimalPrecisionAndScale {
391 data_type: data_type.to_string(),
392 })?;
393
394 Ok((precision, scale))
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_extract_decimal_precision_and_scale() {
403 let test_cases = vec![
404 ("decimal(10,2)", 10, 2),
405 ("DECIMAL(5,3)", 5, 3),
406 ("numeric(12,4)", 12, 4),
407 ("NUMERIC(8,6)", 8, 6),
408 ("decimal(38,0)", 38, 0),
409 ];
410
411 for (data_type, expected_precision, expected_scale) in test_cases {
412 let (precision, scale) = extract_decimal_precision_and_scale(data_type)
413 .expect("Should extract precision and scale");
414 assert_eq!(
415 precision, expected_precision,
416 "Incorrect precision for: {}",
417 data_type
418 );
419 assert_eq!(scale, expected_scale, "Incorrect scale for: {}", data_type);
420 }
421 }
422}