datafusion_table_providers/
mysql.rs

1/*
2Copyright 2024 The Spice.ai OSS Authors
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8     https://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16use crate::mysql::write::MySQLTableWriter;
17use crate::sql::arrow_sql_gen::statement::{CreateTableBuilder, IndexBuilder, InsertBuilder};
18use crate::sql::db_connection_pool::dbconnection::mysqlconn::MySQLConnection;
19use crate::sql::db_connection_pool::dbconnection::DbConnection;
20use crate::sql::db_connection_pool::mysqlpool::MySQLConnectionPool;
21use crate::sql::db_connection_pool::{self, mysqlpool, DbConnectionPool};
22use crate::sql::sql_provider_datafusion::{self, SqlTable};
23use crate::util::{
24    self, column_reference::ColumnReference, constraints::get_primary_keys_from_constraints,
25    indexes::IndexType, on_conflict::OnConflict, secrets::to_secret_map, to_datafusion_error,
26};
27use crate::util::{column_reference, constraints, on_conflict};
28use async_trait::async_trait;
29use datafusion::arrow::array::RecordBatch;
30use datafusion::arrow::datatypes::{Schema, SchemaRef};
31use datafusion::catalog::Session;
32use datafusion::sql::unparser::dialect::MySqlDialect;
33use datafusion::{
34    catalog::TableProviderFactory, common::Constraints, datasource::TableProvider,
35    error::DataFusionError, logical_expr::CreateExternalTable, sql::TableReference,
36};
37use mysql_async::prelude::{Queryable, ToValue};
38use mysql_async::{Metrics, TxOpts};
39use sea_query::{Alias, DeleteStatement, MysqlQueryBuilder};
40use snafu::prelude::*;
41use sql_table::MySQLTable;
42use std::collections::HashMap;
43use std::sync::Arc;
44
45pub type DynMySQLConnectionPool =
46    dyn DbConnectionPool<mysql_async::Conn, &'static (dyn ToValue + Sync)> + Send + Sync;
47
48pub type DynMySQLConnection = dyn DbConnection<mysql_async::Conn, &'static (dyn ToValue + Sync)>;
49
50#[cfg(feature = "mysql-federation")]
51pub mod federation;
52pub(crate) mod mysql_window;
53pub mod sql_table;
54pub mod write;
55
56#[derive(Debug, Snafu)]
57pub enum Error {
58    #[snafu(display("DbConnectionError: {source}"))]
59    DbConnectionError {
60        source: db_connection_pool::dbconnection::GenericError,
61    },
62
63    #[snafu(display("Unable to construct SQL table: {source}"))]
64    UnableToConstructSQLTable {
65        source: sql_provider_datafusion::Error,
66    },
67
68    #[snafu(display("Unable to delete all data from the MySQL table: {source}"))]
69    UnableToDeleteAllTableData { source: mysql_async::Error },
70
71    #[snafu(display("Unable to insert Arrow batch to MySQL table: {source}"))]
72    UnableToInsertArrowBatch { source: mysql_async::Error },
73
74    #[snafu(display("Unable to downcast DbConnection to MySQLConnection"))]
75    UnableToDowncastDbConnection {},
76
77    #[snafu(display("Unable to begin MySQL transaction: {source}"))]
78    UnableToBeginTransaction { source: mysql_async::Error },
79
80    #[snafu(display("Unable to create MySQL connection pool: {source}"))]
81    UnableToCreateMySQLConnectionPool { source: mysqlpool::Error },
82
83    #[snafu(display("Unable to create the MySQL table: {source}"))]
84    UnableToCreateMySQLTable { source: mysql_async::Error },
85
86    #[snafu(display("Unable to create an index for the MySQL table: {source}"))]
87    UnableToCreateIndexForMySQLTable { source: mysql_async::Error },
88
89    #[snafu(display("Unable to commit the MySQL transaction: {source}"))]
90    UnableToCommitMySQLTransaction { source: mysql_async::Error },
91
92    #[snafu(display("Unable to create insertion statement for MySQL table: {source}"))]
93    UnableToCreateInsertStatement {
94        source: crate::sql::arrow_sql_gen::statement::Error,
95    },
96
97    #[snafu(display("The table '{table_name}' doesn't exist in the MySQL server"))]
98    TableDoesntExist { table_name: String },
99
100    #[snafu(display("Constraint Violation: {source}"))]
101    ConstraintViolation { source: constraints::Error },
102
103    #[snafu(display("Error parsing column reference: {source}"))]
104    UnableToParseColumnReference { source: column_reference::Error },
105
106    #[snafu(display("Error parsing on_conflict: {source}"))]
107    UnableToParseOnConflict { source: on_conflict::Error },
108}
109
110type Result<T, E = Error> = std::result::Result<T, E>;
111
112pub struct MySQLTableFactory {
113    pool: Arc<MySQLConnectionPool>,
114}
115
116impl MySQLTableFactory {
117    #[must_use]
118    pub fn new(pool: Arc<MySQLConnectionPool>) -> Self {
119        Self { pool }
120    }
121
122    pub async fn table_provider(
123        &self,
124        table_reference: TableReference,
125    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
126        let pool = Arc::clone(&self.pool);
127        let table_provider = Arc::new(
128            MySQLTable::new(&pool, table_reference)
129                .await
130                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
131        );
132
133        #[cfg(feature = "mysql-federation")]
134        let table_provider = Arc::new(
135            table_provider
136                .create_federated_table_provider()
137                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
138        );
139
140        Ok(table_provider)
141    }
142
143    pub async fn read_write_table_provider(
144        &self,
145        table_reference: TableReference,
146    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
147        let read_provider = Self::table_provider(self, table_reference.clone()).await?;
148        let schema = read_provider.schema();
149
150        let table_name = table_reference.to_string();
151        let mysql = MySQL::new(
152            table_name,
153            Arc::clone(&self.pool),
154            schema,
155            Constraints::empty(),
156        );
157
158        Ok(MySQLTableWriter::create(read_provider, mysql, None))
159    }
160
161    pub fn conn_pool_metrics(&self) -> Arc<Metrics> {
162        self.pool.metrics()
163    }
164}
165
166#[derive(Debug)]
167pub struct MySQLTableProviderFactory {}
168
169impl MySQLTableProviderFactory {
170    #[must_use]
171    pub fn new() -> Self {
172        Self {}
173    }
174}
175
176impl Default for MySQLTableProviderFactory {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182#[async_trait]
183impl TableProviderFactory for MySQLTableProviderFactory {
184    async fn create(
185        &self,
186        _state: &dyn Session,
187        cmd: &CreateExternalTable,
188    ) -> datafusion::common::Result<Arc<dyn TableProvider>> {
189        let name = cmd.name.to_string();
190        let mut options = cmd.options.clone();
191        let schema: Schema = cmd.schema.as_ref().into();
192
193        let indexes_option_str = options.remove("indexes");
194        let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
195            Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
196            None => HashMap::new(),
197        };
198
199        let unparsed_indexes = unparsed_indexes
200            .into_iter()
201            .map(|(key, value)| {
202                let columns = ColumnReference::try_from(key.as_str())
203                    .context(UnableToParseColumnReferenceSnafu)
204                    .map_err(util::to_datafusion_error);
205                (columns, value)
206            })
207            .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
208
209        let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
210        for (columns, index_type) in unparsed_indexes {
211            let columns = columns?;
212            indexes.push((columns, index_type));
213        }
214
215        let mut on_conflict: Option<OnConflict> = None;
216        if let Some(on_conflict_str) = options.remove("on_conflict") {
217            on_conflict = Some(
218                OnConflict::try_from(on_conflict_str.as_str())
219                    .context(UnableToParseOnConflictSnafu)
220                    .map_err(util::to_datafusion_error)?,
221            );
222        }
223
224        let params = to_secret_map(options);
225
226        let pool = Arc::new(
227            MySQLConnectionPool::new(params)
228                .await
229                .context(UnableToCreateMySQLConnectionPoolSnafu)
230                .map_err(to_datafusion_error)?,
231        );
232        let schema = Arc::new(schema);
233        let mysql = MySQL::new(
234            name.clone(),
235            Arc::clone(&pool),
236            Arc::clone(&schema),
237            cmd.constraints.clone(),
238        );
239
240        let mut db_conn = pool
241            .connect()
242            .await
243            .context(DbConnectionSnafu)
244            .map_err(to_datafusion_error)?;
245
246        let mysql_conn = MySQL::mysql_conn(&mut db_conn).map_err(to_datafusion_error)?;
247        let mut conn_guard = mysql_conn.conn.lock().await;
248        let mut transaction = conn_guard
249            .start_transaction(TxOpts::default())
250            .await
251            .context(UnableToBeginTransactionSnafu)
252            .map_err(to_datafusion_error)?;
253
254        let primary_keys = get_primary_keys_from_constraints(&cmd.constraints, &schema);
255
256        mysql
257            .create_table(Arc::clone(&schema), &mut transaction, primary_keys)
258            .await
259            .map_err(to_datafusion_error)?;
260
261        for index in indexes {
262            mysql
263                .create_index(
264                    &mut transaction,
265                    index.0.iter().collect(),
266                    index.1 == IndexType::Unique,
267                )
268                .await
269                .map_err(to_datafusion_error)?;
270        }
271
272        transaction
273            .commit()
274            .await
275            .context(UnableToCommitMySQLTransactionSnafu)
276            .map_err(to_datafusion_error)?;
277
278        drop(conn_guard);
279
280        let dyn_pool: Arc<DynMySQLConnectionPool> = pool;
281
282        let read_provider = Arc::new(
283            SqlTable::new_with_schema(
284                "mysql",
285                &dyn_pool,
286                Arc::clone(&schema),
287                TableReference::bare(name.clone()),
288            )
289            .with_dialect(Arc::new(MySqlDialect {})),
290        );
291
292        #[cfg(feature = "mysql-federation")]
293        let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
294        Ok(MySQLTableWriter::create(read_provider, mysql, on_conflict))
295    }
296}
297
298#[derive(Debug)]
299pub struct MySQL {
300    table_name: String,
301    pool: Arc<MySQLConnectionPool>,
302    schema: SchemaRef,
303    constraints: Constraints,
304}
305
306impl MySQL {
307    #[must_use]
308    pub fn new(
309        table_name: String,
310        pool: Arc<MySQLConnectionPool>,
311        schema: SchemaRef,
312        constraints: Constraints,
313    ) -> Self {
314        Self {
315            table_name,
316            pool,
317            schema,
318            constraints,
319        }
320    }
321
322    #[must_use]
323    pub fn table_name(&self) -> &str {
324        &self.table_name
325    }
326
327    #[must_use]
328    pub fn constraints(&self) -> &Constraints {
329        &self.constraints
330    }
331
332    pub async fn connect(&self) -> Result<Box<DynMySQLConnection>> {
333        let mut conn = self.pool.connect().await.context(DbConnectionSnafu)?;
334
335        let mysql_conn = Self::mysql_conn(&mut conn)?;
336
337        if !self.table_exists(mysql_conn).await {
338            TableDoesntExistSnafu {
339                table_name: self.table_name.clone(),
340            }
341            .fail()?;
342        }
343
344        Ok(conn)
345    }
346
347    pub fn mysql_conn(db_connection: &mut Box<DynMySQLConnection>) -> Result<&mut MySQLConnection> {
348        let conn = db_connection
349            .as_any_mut()
350            .downcast_mut::<MySQLConnection>()
351            .context(UnableToDowncastDbConnectionSnafu)?;
352
353        Ok(conn)
354    }
355
356    async fn table_exists(&self, mysql_connection: &MySQLConnection) -> bool {
357        let sql = format!(
358            r#"SELECT EXISTS (
359          SELECT 1
360          FROM information_schema.tables
361          WHERE table_name = '{name}'
362        )"#,
363            name = self.table_name
364        );
365        tracing::trace!("{sql}");
366        let Ok(Some((exists,))) = mysql_connection
367            .conn
368            .lock()
369            .await
370            .query_first::<(bool,), _>(&sql)
371            .await
372        else {
373            return false;
374        };
375
376        exists
377    }
378
379    async fn insert_batch(
380        &self,
381        transaction: &mut mysql_async::Transaction<'_>,
382        batch: RecordBatch,
383        on_conflict: Option<OnConflict>,
384    ) -> Result<()> {
385        let insert_table_builder =
386            InsertBuilder::new(&TableReference::bare(self.table_name.clone()), vec![batch]);
387
388        let sea_query_on_conflict =
389            on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
390
391        let sql = insert_table_builder
392            .build_mysql(sea_query_on_conflict)
393            .context(UnableToCreateInsertStatementSnafu)?;
394
395        transaction
396            .exec_drop(&sql, ())
397            .await
398            .context(UnableToInsertArrowBatchSnafu)?;
399
400        Ok(())
401    }
402
403    async fn delete_all_table_data(
404        &self,
405        transaction: &mut mysql_async::Transaction<'_>,
406    ) -> Result<()> {
407        let delete = DeleteStatement::new()
408            .from_table(Alias::new(self.table_name.clone()))
409            .to_string(MysqlQueryBuilder);
410        transaction
411            .exec_drop(delete.as_str(), ())
412            .await
413            .context(UnableToDeleteAllTableDataSnafu)?;
414
415        Ok(())
416    }
417
418    async fn create_table(
419        &self,
420        schema: SchemaRef,
421        transaction: &mut mysql_async::Transaction<'_>,
422        primary_keys: Vec<String>,
423    ) -> Result<()> {
424        let create_table_statement =
425            CreateTableBuilder::new(schema, &self.table_name).primary_keys(primary_keys);
426        let create_stmts = create_table_statement.build_mysql();
427
428        transaction
429            .exec_drop(create_stmts, ())
430            .await
431            .context(UnableToCreateMySQLTableSnafu)
432    }
433
434    async fn create_index(
435        &self,
436        transaction: &mut mysql_async::Transaction<'_>,
437        columns: Vec<&str>,
438        unique: bool,
439    ) -> Result<()> {
440        let mut index_builder = IndexBuilder::new(&self.table_name, columns);
441        if unique {
442            index_builder = index_builder.unique();
443        }
444        let sql = index_builder.build_mysql();
445
446        transaction
447            .exec_drop(sql, ())
448            .await
449            .context(UnableToCreateIndexForMySQLTableSnafu)
450    }
451}