datafusion_table_providers/
mysql.rs1use crate::mysql::write::MySQLTableWriter;
17use crate::sql::arrow_sql_gen::statement::{CreateTableBuilder, IndexBuilder, InsertBuilder};
18use crate::sql::db_connection_pool::dbconnection::mysqlconn::MySQLConnection;
19use crate::sql::db_connection_pool::dbconnection::DbConnection;
20use crate::sql::db_connection_pool::mysqlpool::MySQLConnectionPool;
21use crate::sql::db_connection_pool::{self, mysqlpool, DbConnectionPool};
22use crate::sql::sql_provider_datafusion::{self, SqlTable};
23use crate::util::{
24 self, column_reference::ColumnReference, constraints::get_primary_keys_from_constraints,
25 indexes::IndexType, on_conflict::OnConflict, secrets::to_secret_map, to_datafusion_error,
26};
27use crate::util::{column_reference, constraints, on_conflict};
28use async_trait::async_trait;
29use datafusion::arrow::array::RecordBatch;
30use datafusion::arrow::datatypes::{Schema, SchemaRef};
31use datafusion::catalog::Session;
32use datafusion::sql::unparser::dialect::MySqlDialect;
33use datafusion::{
34 catalog::TableProviderFactory, common::Constraints, datasource::TableProvider,
35 error::DataFusionError, logical_expr::CreateExternalTable, sql::TableReference,
36};
37use mysql_async::prelude::{Queryable, ToValue};
38use mysql_async::{Metrics, TxOpts};
39use sea_query::{Alias, DeleteStatement, MysqlQueryBuilder};
40use snafu::prelude::*;
41use sql_table::MySQLTable;
42use std::collections::HashMap;
43use std::sync::Arc;
44
45pub type DynMySQLConnectionPool =
46 dyn DbConnectionPool<mysql_async::Conn, &'static (dyn ToValue + Sync)> + Send + Sync;
47
48pub type DynMySQLConnection = dyn DbConnection<mysql_async::Conn, &'static (dyn ToValue + Sync)>;
49
50#[cfg(feature = "mysql-federation")]
51pub mod federation;
52pub(crate) mod mysql_window;
53pub mod sql_table;
54pub mod write;
55
56#[derive(Debug, Snafu)]
57pub enum Error {
58 #[snafu(display("DbConnectionError: {source}"))]
59 DbConnectionError {
60 source: db_connection_pool::dbconnection::GenericError,
61 },
62
63 #[snafu(display("Unable to construct SQL table: {source}"))]
64 UnableToConstructSQLTable {
65 source: sql_provider_datafusion::Error,
66 },
67
68 #[snafu(display("Unable to delete all data from the MySQL table: {source}"))]
69 UnableToDeleteAllTableData { source: mysql_async::Error },
70
71 #[snafu(display("Unable to insert Arrow batch to MySQL table: {source}"))]
72 UnableToInsertArrowBatch { source: mysql_async::Error },
73
74 #[snafu(display("Unable to downcast DbConnection to MySQLConnection"))]
75 UnableToDowncastDbConnection {},
76
77 #[snafu(display("Unable to begin MySQL transaction: {source}"))]
78 UnableToBeginTransaction { source: mysql_async::Error },
79
80 #[snafu(display("Unable to create MySQL connection pool: {source}"))]
81 UnableToCreateMySQLConnectionPool { source: mysqlpool::Error },
82
83 #[snafu(display("Unable to create the MySQL table: {source}"))]
84 UnableToCreateMySQLTable { source: mysql_async::Error },
85
86 #[snafu(display("Unable to create an index for the MySQL table: {source}"))]
87 UnableToCreateIndexForMySQLTable { source: mysql_async::Error },
88
89 #[snafu(display("Unable to commit the MySQL transaction: {source}"))]
90 UnableToCommitMySQLTransaction { source: mysql_async::Error },
91
92 #[snafu(display("Unable to create insertion statement for MySQL table: {source}"))]
93 UnableToCreateInsertStatement {
94 source: crate::sql::arrow_sql_gen::statement::Error,
95 },
96
97 #[snafu(display("The table '{table_name}' doesn't exist in the MySQL server"))]
98 TableDoesntExist { table_name: String },
99
100 #[snafu(display("Constraint Violation: {source}"))]
101 ConstraintViolation { source: constraints::Error },
102
103 #[snafu(display("Error parsing column reference: {source}"))]
104 UnableToParseColumnReference { source: column_reference::Error },
105
106 #[snafu(display("Error parsing on_conflict: {source}"))]
107 UnableToParseOnConflict { source: on_conflict::Error },
108}
109
110type Result<T, E = Error> = std::result::Result<T, E>;
111
112pub struct MySQLTableFactory {
113 pool: Arc<MySQLConnectionPool>,
114}
115
116impl MySQLTableFactory {
117 #[must_use]
118 pub fn new(pool: Arc<MySQLConnectionPool>) -> Self {
119 Self { pool }
120 }
121
122 pub async fn table_provider(
123 &self,
124 table_reference: TableReference,
125 ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
126 let pool = Arc::clone(&self.pool);
127 let table_provider = Arc::new(
128 MySQLTable::new(&pool, table_reference)
129 .await
130 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
131 );
132
133 #[cfg(feature = "mysql-federation")]
134 let table_provider = Arc::new(
135 table_provider
136 .create_federated_table_provider()
137 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
138 );
139
140 Ok(table_provider)
141 }
142
143 pub async fn read_write_table_provider(
144 &self,
145 table_reference: TableReference,
146 ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
147 let read_provider = Self::table_provider(self, table_reference.clone()).await?;
148 let schema = read_provider.schema();
149
150 let table_name = table_reference.to_string();
151 let mysql = MySQL::new(
152 table_name,
153 Arc::clone(&self.pool),
154 schema,
155 Constraints::empty(),
156 );
157
158 Ok(MySQLTableWriter::create(read_provider, mysql, None))
159 }
160
161 pub fn conn_pool_metrics(&self) -> Arc<Metrics> {
162 self.pool.metrics()
163 }
164}
165
166#[derive(Debug)]
167pub struct MySQLTableProviderFactory {}
168
169impl MySQLTableProviderFactory {
170 #[must_use]
171 pub fn new() -> Self {
172 Self {}
173 }
174}
175
176impl Default for MySQLTableProviderFactory {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182#[async_trait]
183impl TableProviderFactory for MySQLTableProviderFactory {
184 async fn create(
185 &self,
186 _state: &dyn Session,
187 cmd: &CreateExternalTable,
188 ) -> datafusion::common::Result<Arc<dyn TableProvider>> {
189 let name = cmd.name.to_string();
190 let mut options = cmd.options.clone();
191 let schema: Schema = cmd.schema.as_ref().into();
192
193 let indexes_option_str = options.remove("indexes");
194 let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
195 Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
196 None => HashMap::new(),
197 };
198
199 let unparsed_indexes = unparsed_indexes
200 .into_iter()
201 .map(|(key, value)| {
202 let columns = ColumnReference::try_from(key.as_str())
203 .context(UnableToParseColumnReferenceSnafu)
204 .map_err(util::to_datafusion_error);
205 (columns, value)
206 })
207 .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
208
209 let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
210 for (columns, index_type) in unparsed_indexes {
211 let columns = columns?;
212 indexes.push((columns, index_type));
213 }
214
215 let mut on_conflict: Option<OnConflict> = None;
216 if let Some(on_conflict_str) = options.remove("on_conflict") {
217 on_conflict = Some(
218 OnConflict::try_from(on_conflict_str.as_str())
219 .context(UnableToParseOnConflictSnafu)
220 .map_err(util::to_datafusion_error)?,
221 );
222 }
223
224 let params = to_secret_map(options);
225
226 let pool = Arc::new(
227 MySQLConnectionPool::new(params)
228 .await
229 .context(UnableToCreateMySQLConnectionPoolSnafu)
230 .map_err(to_datafusion_error)?,
231 );
232 let schema = Arc::new(schema);
233 let mysql = MySQL::new(
234 name.clone(),
235 Arc::clone(&pool),
236 Arc::clone(&schema),
237 cmd.constraints.clone(),
238 );
239
240 let mut db_conn = pool
241 .connect()
242 .await
243 .context(DbConnectionSnafu)
244 .map_err(to_datafusion_error)?;
245
246 let mysql_conn = MySQL::mysql_conn(&mut db_conn).map_err(to_datafusion_error)?;
247 let mut conn_guard = mysql_conn.conn.lock().await;
248 let mut transaction = conn_guard
249 .start_transaction(TxOpts::default())
250 .await
251 .context(UnableToBeginTransactionSnafu)
252 .map_err(to_datafusion_error)?;
253
254 let primary_keys = get_primary_keys_from_constraints(&cmd.constraints, &schema);
255
256 mysql
257 .create_table(Arc::clone(&schema), &mut transaction, primary_keys)
258 .await
259 .map_err(to_datafusion_error)?;
260
261 for index in indexes {
262 mysql
263 .create_index(
264 &mut transaction,
265 index.0.iter().collect(),
266 index.1 == IndexType::Unique,
267 )
268 .await
269 .map_err(to_datafusion_error)?;
270 }
271
272 transaction
273 .commit()
274 .await
275 .context(UnableToCommitMySQLTransactionSnafu)
276 .map_err(to_datafusion_error)?;
277
278 drop(conn_guard);
279
280 let dyn_pool: Arc<DynMySQLConnectionPool> = pool;
281
282 let read_provider = Arc::new(
283 SqlTable::new_with_schema(
284 "mysql",
285 &dyn_pool,
286 Arc::clone(&schema),
287 TableReference::bare(name.clone()),
288 )
289 .with_dialect(Arc::new(MySqlDialect {})),
290 );
291
292 #[cfg(feature = "mysql-federation")]
293 let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
294 Ok(MySQLTableWriter::create(read_provider, mysql, on_conflict))
295 }
296}
297
298#[derive(Debug)]
299pub struct MySQL {
300 table_name: String,
301 pool: Arc<MySQLConnectionPool>,
302 schema: SchemaRef,
303 constraints: Constraints,
304}
305
306impl MySQL {
307 #[must_use]
308 pub fn new(
309 table_name: String,
310 pool: Arc<MySQLConnectionPool>,
311 schema: SchemaRef,
312 constraints: Constraints,
313 ) -> Self {
314 Self {
315 table_name,
316 pool,
317 schema,
318 constraints,
319 }
320 }
321
322 #[must_use]
323 pub fn table_name(&self) -> &str {
324 &self.table_name
325 }
326
327 #[must_use]
328 pub fn constraints(&self) -> &Constraints {
329 &self.constraints
330 }
331
332 pub async fn connect(&self) -> Result<Box<DynMySQLConnection>> {
333 let mut conn = self.pool.connect().await.context(DbConnectionSnafu)?;
334
335 let mysql_conn = Self::mysql_conn(&mut conn)?;
336
337 if !self.table_exists(mysql_conn).await {
338 TableDoesntExistSnafu {
339 table_name: self.table_name.clone(),
340 }
341 .fail()?;
342 }
343
344 Ok(conn)
345 }
346
347 pub fn mysql_conn(db_connection: &mut Box<DynMySQLConnection>) -> Result<&mut MySQLConnection> {
348 let conn = db_connection
349 .as_any_mut()
350 .downcast_mut::<MySQLConnection>()
351 .context(UnableToDowncastDbConnectionSnafu)?;
352
353 Ok(conn)
354 }
355
356 async fn table_exists(&self, mysql_connection: &MySQLConnection) -> bool {
357 let sql = format!(
358 r#"SELECT EXISTS (
359 SELECT 1
360 FROM information_schema.tables
361 WHERE table_name = '{name}'
362 )"#,
363 name = self.table_name
364 );
365 tracing::trace!("{sql}");
366 let Ok(Some((exists,))) = mysql_connection
367 .conn
368 .lock()
369 .await
370 .query_first::<(bool,), _>(&sql)
371 .await
372 else {
373 return false;
374 };
375
376 exists
377 }
378
379 async fn insert_batch(
380 &self,
381 transaction: &mut mysql_async::Transaction<'_>,
382 batch: RecordBatch,
383 on_conflict: Option<OnConflict>,
384 ) -> Result<()> {
385 let insert_table_builder =
386 InsertBuilder::new(&TableReference::bare(self.table_name.clone()), vec![batch]);
387
388 let sea_query_on_conflict =
389 on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
390
391 let sql = insert_table_builder
392 .build_mysql(sea_query_on_conflict)
393 .context(UnableToCreateInsertStatementSnafu)?;
394
395 transaction
396 .exec_drop(&sql, ())
397 .await
398 .context(UnableToInsertArrowBatchSnafu)?;
399
400 Ok(())
401 }
402
403 async fn delete_all_table_data(
404 &self,
405 transaction: &mut mysql_async::Transaction<'_>,
406 ) -> Result<()> {
407 let delete = DeleteStatement::new()
408 .from_table(Alias::new(self.table_name.clone()))
409 .to_string(MysqlQueryBuilder);
410 transaction
411 .exec_drop(delete.as_str(), ())
412 .await
413 .context(UnableToDeleteAllTableDataSnafu)?;
414
415 Ok(())
416 }
417
418 async fn create_table(
419 &self,
420 schema: SchemaRef,
421 transaction: &mut mysql_async::Transaction<'_>,
422 primary_keys: Vec<String>,
423 ) -> Result<()> {
424 let create_table_statement =
425 CreateTableBuilder::new(schema, &self.table_name).primary_keys(primary_keys);
426 let create_stmts = create_table_statement.build_mysql();
427
428 transaction
429 .exec_drop(create_stmts, ())
430 .await
431 .context(UnableToCreateMySQLTableSnafu)
432 }
433
434 async fn create_index(
435 &self,
436 transaction: &mut mysql_async::Transaction<'_>,
437 columns: Vec<&str>,
438 unique: bool,
439 ) -> Result<()> {
440 let mut index_builder = IndexBuilder::new(&self.table_name, columns);
441 if unique {
442 index_builder = index_builder.unique();
443 }
444 let sql = index_builder.build_mysql();
445
446 transaction
447 .exec_drop(sql, ())
448 .await
449 .context(UnableToCreateIndexForMySQLTableSnafu)
450 }
451}