datafusion_table_providers/mysql/
write.rs

1use crate::mysql::MySQL;
2use crate::util::on_conflict::OnConflict;
3use crate::util::retriable_error::check_and_mark_retriable_error;
4use crate::util::{constraints, to_datafusion_error};
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::datasource::sink::{DataSink, DataSinkExec};
8use datafusion::{
9    catalog::Session,
10    datasource::{TableProvider, TableType},
11    execution::{SendableRecordBatchStream, TaskContext},
12    logical_expr::{dml::InsertOp, Expr},
13    physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
14};
15use futures::StreamExt;
16use mysql_async::TxOpts;
17use snafu::ResultExt;
18use std::any::Any;
19use std::fmt;
20use std::sync::Arc;
21
22#[derive(Debug, Clone)]
23pub struct MySQLTableWriter {
24    pub read_provider: Arc<dyn TableProvider>,
25    mysql: Arc<MySQL>,
26    on_conflict: Option<OnConflict>,
27}
28
29impl MySQLTableWriter {
30    pub fn create(
31        read_provider: Arc<dyn TableProvider>,
32        mysql: MySQL,
33        on_conflict: Option<OnConflict>,
34    ) -> Arc<Self> {
35        Arc::new(Self {
36            read_provider,
37            mysql: Arc::new(mysql),
38            on_conflict,
39        })
40    }
41
42    pub fn mysql(&self) -> Arc<MySQL> {
43        Arc::clone(&self.mysql)
44    }
45}
46
47#[async_trait]
48impl TableProvider for MySQLTableWriter {
49    fn as_any(&self) -> &dyn Any {
50        self
51    }
52
53    fn schema(&self) -> SchemaRef {
54        self.read_provider.schema()
55    }
56
57    fn table_type(&self) -> TableType {
58        TableType::Base
59    }
60
61    async fn scan(
62        &self,
63        state: &dyn Session,
64        projection: Option<&Vec<usize>>,
65        filters: &[Expr],
66        limit: Option<usize>,
67    ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
68        self.read_provider
69            .scan(state, projection, filters, limit)
70            .await
71    }
72
73    async fn insert_into(
74        &self,
75        _state: &dyn Session,
76        input: Arc<dyn ExecutionPlan>,
77        op: InsertOp,
78    ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
79        Ok(Arc::new(DataSinkExec::new(
80            input,
81            Arc::new(MySQLDataSink::new(
82                Arc::clone(&self.mysql),
83                op == InsertOp::Overwrite,
84                self.on_conflict.clone(),
85                self.schema(),
86            )),
87            None,
88        )))
89    }
90}
91
92pub struct MySQLDataSink {
93    pub mysql: Arc<MySQL>,
94    pub overwrite: bool,
95    pub on_conflict: Option<OnConflict>,
96    schema: SchemaRef,
97}
98
99#[async_trait]
100impl DataSink for MySQLDataSink {
101    fn as_any(&self) -> &dyn Any {
102        self
103    }
104
105    fn metrics(&self) -> Option<MetricsSet> {
106        None
107    }
108
109    fn schema(&self) -> &SchemaRef {
110        &self.schema
111    }
112
113    async fn write_all(
114        &self,
115        mut data: SendableRecordBatchStream,
116        _context: &Arc<TaskContext>,
117    ) -> datafusion::common::Result<u64> {
118        let mut num_rows = 0u64;
119
120        let mut db_conn = self.mysql.connect().await.map_err(to_datafusion_error)?;
121        let mysql_conn = MySQL::mysql_conn(&mut db_conn).map_err(to_datafusion_error)?;
122
123        let mut conn_guard = mysql_conn.conn.lock().await;
124        let mut tx = conn_guard
125            .start_transaction(TxOpts::default())
126            .await
127            .context(super::UnableToBeginTransactionSnafu)
128            .map_err(to_datafusion_error)?;
129
130        if self.overwrite {
131            self.mysql
132                .delete_all_table_data(&mut tx)
133                .await
134                .map_err(to_datafusion_error)?;
135        }
136
137        while let Some(batch) = data.next().await {
138            let batch = batch.map_err(check_and_mark_retriable_error)?;
139            let batch_num_rows = batch.num_rows();
140
141            if batch_num_rows == 0 {
142                continue;
143            }
144
145            num_rows += batch_num_rows as u64;
146
147            constraints::validate_batch_with_constraints(
148                &[batch.clone()],
149                self.mysql.constraints(),
150            )
151            .await
152            .context(super::ConstraintViolationSnafu)
153            .map_err(to_datafusion_error)?;
154
155            self.mysql
156                .insert_batch(&mut tx, batch, self.on_conflict.clone())
157                .await
158                .map_err(to_datafusion_error)?;
159        }
160
161        tx.commit()
162            .await
163            .context(super::UnableToCommitMySQLTransactionSnafu)
164            .map_err(to_datafusion_error)?;
165
166        drop(conn_guard);
167
168        Ok(num_rows)
169    }
170}
171
172impl MySQLDataSink {
173    pub fn new(
174        mysql: Arc<MySQL>,
175        overwrite: bool,
176        on_conflict: Option<OnConflict>,
177        schema: SchemaRef,
178    ) -> Self {
179        Self {
180            mysql,
181            overwrite,
182            on_conflict,
183            schema,
184        }
185    }
186}
187
188impl fmt::Debug for MySQLDataSink {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(f, "MySQLDataSink")
191    }
192}
193
194impl DisplayAs for MySQLDataSink {
195    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
196        write!(f, "MySQLDataSink")
197    }
198}