datafusion_table_providers/sql/db_connection_pool/dbconnection/
mysqlconn.rs

1use std::{any::Any, sync::Arc};
2
3use crate::sql::arrow_sql_gen::mysql::map_column_to_data_type;
4use crate::sql::arrow_sql_gen::{self, mysql::rows_to_arrow};
5use async_stream::stream;
6use datafusion::arrow::datatypes::{Field, Schema, SchemaRef};
7use datafusion::error::DataFusionError;
8use datafusion::execution::SendableRecordBatchStream;
9use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
10use datafusion::sql::unparser::dialect::{Dialect, MySqlDialect};
11use datafusion::sql::TableReference;
12use futures::lock::Mutex;
13use futures::{stream, StreamExt};
14use mysql_async::consts::ColumnType;
15use mysql_async::prelude::Queryable;
16use mysql_async::{prelude::ToValue, Conn, Params, Row};
17use snafu::prelude::*;
18
19use super::Result;
20use super::{AsyncDbConnection, DbConnection};
21
22#[derive(Debug, Snafu)]
23pub enum Error {
24    #[snafu(display("Query execution failed.\n{source}\nFor details, refer to the MySQL manual: https://dev.mysql.com/doc/mysql-errors/9.1/en/error-reference-introduction.html"))]
25    QueryError { source: mysql_async::Error },
26
27    #[snafu(display("Failed to convert query result to Arrow.\n{source}.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
28    ConversionError { source: arrow_sql_gen::mysql::Error },
29
30    #[snafu(display("An unexpected error occurred. Verify the configuration and try again."))]
31    QueryResultStreamError {},
32
33    #[snafu(display("Unsupported data type '{data_type}' for field '{column_name}'.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
34    UnsupportedDataTypeError {
35        column_name: String,
36        data_type: String,
37    },
38
39    #[snafu(display("Unable to extract precision and scale from type: {data_type}.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
40    UnableToGetDecimalPrecisionAndScale { data_type: String },
41
42    #[snafu(display("Failed to find the field '{field}'.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
43    MissingField { field: String },
44}
45
46pub struct MySQLConnection {
47    pub conn: Arc<Mutex<Conn>>,
48}
49
50impl MySQLConnection {
51    /// Create a [`TableReference`] in a manner that properly handles the unique quote style of MySQL.
52    ///
53    /// [`TableReference::from`] uses `DefaultDialect` and therefore gets quoting incorrect.
54    fn to_mysql_quoted_string(tbl: &TableReference) -> String {
55        let q = MySqlDialect {}
56            .identifier_quote_style("") // parameter unimportant for `MySqlDialect`.
57            .unwrap_or_default();
58
59        [tbl.catalog(), tbl.schema(), Some(tbl.table())]
60            .into_iter()
61            .flatten()
62            .map(|part| {
63                if part.starts_with(q) && part.ends_with(q) {
64                    part.to_string()
65                } else {
66                    format!("{quote}{part}{quote}", quote = q)
67                }
68            })
69            .collect::<Vec<_>>()
70            .join(".")
71    }
72}
73
74impl<'a> DbConnection<Conn, &'a (dyn ToValue + Sync)> for MySQLConnection {
75    fn as_any(&self) -> &dyn Any {
76        self
77    }
78
79    fn as_any_mut(&mut self) -> &mut dyn Any {
80        self
81    }
82
83    fn as_async(&self) -> Option<&dyn super::AsyncDbConnection<Conn, &'a (dyn ToValue + Sync)>> {
84        Some(self)
85    }
86}
87
88#[async_trait::async_trait]
89impl<'a> AsyncDbConnection<Conn, &'a (dyn ToValue + Sync)> for MySQLConnection {
90    fn new(conn: Conn) -> Self {
91        MySQLConnection {
92            conn: Arc::new(Mutex::new(conn)),
93        }
94    }
95
96    async fn tables(&self, schema: &str) -> Result<Vec<String>, super::Error> {
97        let mut conn = self.conn.lock().await;
98        let conn = &mut *conn;
99
100        let query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ?";
101        let tables: Vec<Row> = conn
102            .exec(query, (schema,))
103            .await
104            .boxed()
105            .context(super::UnableToGetTablesSnafu)?;
106
107        let table_names = tables
108            .iter()
109            .filter_map(|row| row.get::<String, _>("TABLE_NAME"))
110            .collect();
111
112        Ok(table_names)
113    }
114
115    async fn schemas(&self) -> Result<Vec<String>, super::Error> {
116        let mut conn = self.conn.lock().await;
117        let conn = &mut *conn;
118
119        let query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA \
120                    WHERE SCHEMA_NAME NOT IN ('information_schema', 'mysql', \
121                    'performance_schema', 'sys')";
122
123        let schemas: Vec<Row> = conn
124            .exec(query, ())
125            .await
126            .boxed()
127            .context(super::UnableToGetSchemasSnafu)?;
128
129        let schema_names = schemas
130            .iter()
131            .filter_map(|row| row.get::<String, _>("SCHEMA_NAME"))
132            .collect();
133
134        Ok(schema_names)
135    }
136
137    async fn get_schema(
138        &self,
139        table_reference: &TableReference,
140    ) -> Result<SchemaRef, super::Error> {
141        let mut conn = self.conn.lock().await;
142        let conn = &mut *conn;
143
144        // we don't use SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{}' AND TABLE_SCHEMA = '{}'
145        // as table_reference don't always have schema specified so we need to extract schema/db name from connection properties
146        // to ensure we are querying information for correct table
147        let columns_meta: Vec<Row> = match conn
148            .exec(
149                format!(
150                    "SHOW COLUMNS FROM {table_name}",
151                    table_name = Self::to_mysql_quoted_string(table_reference)
152                ),
153                Params::Empty,
154            )
155            .await
156        {
157            Ok(columns_meta) => columns_meta,
158            Err(e) => match e {
159                mysql_async::Error::Server(server_error) => {
160                    if server_error.code == 1146 {
161                        return Err(super::Error::UndefinedTable {
162                            source: Box::new(server_error.clone()),
163                            table_name: table_reference.to_string(),
164                        });
165                    }
166                    return Err(super::Error::UnableToGetSchema {
167                        source: Box::new(mysql_async::Error::Server(server_error)),
168                    });
169                }
170                _ => {
171                    return Err(super::Error::UnableToGetSchema {
172                        source: Box::new(e),
173                    })
174                }
175            },
176        };
177
178        Ok(columns_meta_to_schema(columns_meta).context(super::UnableToGetSchemaSnafu)?)
179    }
180
181    async fn query_arrow(
182        &self,
183        sql: &str,
184        params: &[&'a (dyn ToValue + Sync)],
185        projected_schema: Option<SchemaRef>,
186    ) -> Result<SendableRecordBatchStream> {
187        let params_vec: Vec<_> = params.iter().map(|&p| p.to_value()).collect();
188        let sql = sql.replace('"', "");
189
190        let conn = Arc::clone(&self.conn);
191
192        let mut stream = Box::pin(stream! {
193            let mut conn = conn.lock().await;
194            let mut exec_iter = conn
195                .exec_iter(sql, Params::from(params_vec))
196                .await
197                .context(QuerySnafu)?;
198
199            let Some(stream) = exec_iter.stream::<Row>().await.context(QuerySnafu)? else {
200                yield Err(Error::QueryResultStreamError {});
201                return;
202            };
203
204            let mut chunked_stream = stream.chunks(4_000).boxed();
205
206            while let Some(chunk) = chunked_stream.next().await {
207                let rows = chunk
208                    .into_iter()
209                    .collect::<Result<Vec<_>, _>>()
210                    .context(QuerySnafu)?;
211
212                let rec = rows_to_arrow(&rows, &projected_schema).context(ConversionSnafu)?;
213                yield Ok::<_, Error>(rec)
214            }
215        });
216
217        let Some(first_chunk) = stream.next().await else {
218            return Ok(Box::pin(RecordBatchStreamAdapter::new(
219                Arc::new(Schema::empty()),
220                stream::empty(),
221            )));
222        };
223
224        let first_chunk = first_chunk?;
225        let schema = first_chunk.schema();
226
227        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, {
228            stream! {
229                yield Ok(first_chunk);
230                while let Some(batch) = stream.next().await {
231                    yield batch
232                        .map_err(|e| DataFusionError::Execution(format!("Failed to fetch batch: {e}")))
233                }
234            }
235        })))
236    }
237
238    async fn execute(&self, query: &str, params: &[&'a (dyn ToValue + Sync)]) -> Result<u64> {
239        let mut conn = self.conn.lock().await;
240        let conn = &mut *conn;
241        let params_vec: Vec<_> = params.iter().map(|&p| p.to_value()).collect();
242        let _: Vec<Row> = conn
243            .exec(query, Params::from(params_vec))
244            .await
245            .context(QuerySnafu)?;
246        return Ok(conn.affected_rows());
247    }
248}
249
250fn columns_meta_to_schema(columns_meta: Vec<Row>) -> Result<SchemaRef> {
251    let mut fields = Vec::new();
252
253    for row in columns_meta.iter() {
254        let column_name: String = row.get("Field").ok_or(Error::MissingField {
255            field: "Field".to_string(),
256        })?;
257
258        let data_type: String = row.get("Type").ok_or(Error::MissingField {
259            field: "Type".to_string(),
260        })?;
261
262        let column_type = map_str_type_to_column_type(&column_name, &data_type)?;
263        let column_is_binary = map_str_type_to_is_binary(&data_type);
264        let column_is_enum = map_str_type_to_is_enum(&data_type);
265        let column_use_large_str_or_blob = map_str_type_to_use_large_str_or_blob(&data_type);
266
267        let (precision, scale) = match column_type {
268            ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => {
269                let (precision, scale) = extract_decimal_precision_and_scale(&data_type)
270                    .context(super::UnableToGetSchemaSnafu)?;
271                (Some(precision), Some(scale))
272            }
273            _ => (None, None),
274        };
275
276        let arrow_data_type = map_column_to_data_type(
277            column_type,
278            column_is_binary,
279            column_is_enum,
280            column_use_large_str_or_blob,
281            precision,
282            scale,
283        )
284        .context(UnsupportedDataTypeSnafu {
285            column_name: column_name.clone(),
286            data_type,
287        })?;
288
289        fields.push(Field::new(&column_name, arrow_data_type, true));
290    }
291    Ok(Arc::new(Schema::new(fields)))
292}
293
294fn map_str_type_to_column_type(column_name: &str, data_type: &str) -> Result<ColumnType> {
295    let data_type = data_type.to_lowercase();
296    let column_type = match data_type.as_str() {
297        _ if data_type.starts_with("decimal") || data_type.starts_with("numeric") => {
298            ColumnType::MYSQL_TYPE_DECIMAL
299        }
300        // most types can have addtional information: unsigned, size, etc so we use starts_with
301        _ if data_type.starts_with("tinyint") => ColumnType::MYSQL_TYPE_TINY,
302        _ if data_type.starts_with("smallint") => ColumnType::MYSQL_TYPE_SHORT,
303        _ if data_type.starts_with("int") => ColumnType::MYSQL_TYPE_LONG,
304        _ if data_type.starts_with("bigint") => ColumnType::MYSQL_TYPE_LONGLONG,
305        _ if data_type.starts_with("mediumint") => ColumnType::MYSQL_TYPE_INT24,
306        _ if data_type.starts_with("float") => ColumnType::MYSQL_TYPE_FLOAT,
307        _ if data_type.starts_with("double") => ColumnType::MYSQL_TYPE_DOUBLE,
308        _ if data_type.eq("null") => ColumnType::MYSQL_TYPE_NULL,
309        _ if data_type.starts_with("timestamp") => ColumnType::MYSQL_TYPE_TIMESTAMP,
310        _ if data_type.starts_with("time") => ColumnType::MYSQL_TYPE_TIME,
311        _ if data_type.starts_with("datetime") => ColumnType::MYSQL_TYPE_DATETIME,
312        _ if data_type.eq("date") => ColumnType::MYSQL_TYPE_DATE,
313        _ if data_type.eq("year") => ColumnType::MYSQL_TYPE_YEAR,
314        _ if data_type.eq("newdate") => ColumnType::MYSQL_TYPE_NEWDATE,
315        _ if data_type.starts_with("bit") => ColumnType::MYSQL_TYPE_BIT,
316        _ if data_type.starts_with("array") => ColumnType::MYSQL_TYPE_TYPED_ARRAY,
317        _ if data_type.starts_with("json") => ColumnType::MYSQL_TYPE_JSON,
318        _ if data_type.starts_with("newdecimal") => ColumnType::MYSQL_TYPE_NEWDECIMAL,
319        // MySQL ENUM & SET value is exported as MYSQL_TYPE_STRING under c api: https://dev.mysql.com/doc/c-api/9.0/en/c-api-data-structures.html
320        _ if data_type.starts_with("enum") => ColumnType::MYSQL_TYPE_STRING,
321        _ if data_type.starts_with("set") => ColumnType::MYSQL_TYPE_STRING,
322        _ if data_type.starts_with("tinyblob") => ColumnType::MYSQL_TYPE_BLOB,
323        _ if data_type.starts_with("tinytext") => ColumnType::MYSQL_TYPE_BLOB,
324        _ if data_type.starts_with("mediumblob") => ColumnType::MYSQL_TYPE_BLOB,
325        _ if data_type.starts_with("mediumtext") => ColumnType::MYSQL_TYPE_BLOB,
326        _ if data_type.starts_with("longblob") => ColumnType::MYSQL_TYPE_BLOB,
327        _ if data_type.starts_with("longtext") => ColumnType::MYSQL_TYPE_BLOB,
328        _ if data_type.starts_with("blob") => ColumnType::MYSQL_TYPE_BLOB,
329        _ if data_type.starts_with("text") => ColumnType::MYSQL_TYPE_BLOB,
330        _ if data_type.starts_with("varchar") => ColumnType::MYSQL_TYPE_VAR_STRING,
331        _ if data_type.starts_with("varbinary") => ColumnType::MYSQL_TYPE_VAR_STRING,
332        _ if data_type.starts_with("char") => ColumnType::MYSQL_TYPE_STRING,
333        _ if data_type.starts_with("binary") => ColumnType::MYSQL_TYPE_STRING,
334        _ if data_type.starts_with("geometry") => ColumnType::MYSQL_TYPE_GEOMETRY,
335        _ => UnsupportedDataTypeSnafu {
336            column_name,
337            data_type,
338        }
339        .fail()?,
340    };
341
342    Ok(column_type)
343}
344
345fn map_str_type_to_is_binary(data_type: &str) -> bool {
346    if data_type.starts_with("binary")
347        | data_type.starts_with("varbinary")
348        | data_type.starts_with("tinyblob")
349        | data_type.starts_with("mediumblob")
350        | data_type.starts_with("blob")
351        | data_type.starts_with("longblob")
352    {
353        return true;
354    }
355    false
356}
357
358fn map_str_type_to_use_large_str_or_blob(data_type: &str) -> bool {
359    if data_type.starts_with("long") {
360        return true;
361    }
362    false
363}
364
365fn map_str_type_to_is_enum(data_type: &str) -> bool {
366    if data_type.starts_with("enum") {
367        return true;
368    }
369    false
370}
371
372fn extract_decimal_precision_and_scale(data_type: &str) -> Result<(u8, i8)> {
373    let (start, end) = match (data_type.find('('), data_type.find(')')) {
374        (Some(start), Some(end)) => (start, end),
375        _ => UnableToGetDecimalPrecisionAndScaleSnafu { data_type }.fail()?,
376    };
377    let parts: Vec<&str> = data_type[start + 1..end].split(',').collect();
378    if parts.len() != 2 {
379        UnableToGetDecimalPrecisionAndScaleSnafu { data_type }.fail()?;
380    }
381
382    let precision =
383        parts[0]
384            .parse::<u8>()
385            .map_err(|_| Error::UnableToGetDecimalPrecisionAndScale {
386                data_type: data_type.to_string(),
387            })?;
388    let scale = parts[1]
389        .parse::<i8>()
390        .map_err(|_| Error::UnableToGetDecimalPrecisionAndScale {
391            data_type: data_type.to_string(),
392        })?;
393
394    Ok((precision, scale))
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_extract_decimal_precision_and_scale() {
403        let test_cases = vec![
404            ("decimal(10,2)", 10, 2),
405            ("DECIMAL(5,3)", 5, 3),
406            ("numeric(12,4)", 12, 4),
407            ("NUMERIC(8,6)", 8, 6),
408            ("decimal(38,0)", 38, 0),
409        ];
410
411        for (data_type, expected_precision, expected_scale) in test_cases {
412            let (precision, scale) = extract_decimal_precision_and_scale(data_type)
413                .expect("Should extract precision and scale");
414            assert_eq!(
415                precision, expected_precision,
416                "Incorrect precision for: {}",
417                data_type
418            );
419            assert_eq!(scale, expected_scale, "Incorrect scale for: {}", data_type);
420        }
421    }
422}