datafusion_table_providers/mysql/
write.rs1use crate::mysql::MySQL;
2use crate::util::on_conflict::OnConflict;
3use crate::util::retriable_error::check_and_mark_retriable_error;
4use crate::util::{constraints, to_datafusion_error};
5use async_trait::async_trait;
6use datafusion::arrow::datatypes::SchemaRef;
7use datafusion::datasource::sink::{DataSink, DataSinkExec};
8use datafusion::{
9 catalog::Session,
10 datasource::{TableProvider, TableType},
11 execution::{SendableRecordBatchStream, TaskContext},
12 logical_expr::{dml::InsertOp, Expr},
13 physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
14};
15use futures::StreamExt;
16use mysql_async::TxOpts;
17use snafu::ResultExt;
18use std::any::Any;
19use std::fmt;
20use std::sync::Arc;
21
22#[derive(Debug, Clone)]
23pub struct MySQLTableWriter {
24 pub read_provider: Arc<dyn TableProvider>,
25 mysql: Arc<MySQL>,
26 on_conflict: Option<OnConflict>,
27}
28
29impl MySQLTableWriter {
30 pub fn create(
31 read_provider: Arc<dyn TableProvider>,
32 mysql: MySQL,
33 on_conflict: Option<OnConflict>,
34 ) -> Arc<Self> {
35 Arc::new(Self {
36 read_provider,
37 mysql: Arc::new(mysql),
38 on_conflict,
39 })
40 }
41
42 pub fn mysql(&self) -> Arc<MySQL> {
43 Arc::clone(&self.mysql)
44 }
45}
46
47#[async_trait]
48impl TableProvider for MySQLTableWriter {
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52
53 fn schema(&self) -> SchemaRef {
54 self.read_provider.schema()
55 }
56
57 fn table_type(&self) -> TableType {
58 TableType::Base
59 }
60
61 async fn scan(
62 &self,
63 state: &dyn Session,
64 projection: Option<&Vec<usize>>,
65 filters: &[Expr],
66 limit: Option<usize>,
67 ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
68 self.read_provider
69 .scan(state, projection, filters, limit)
70 .await
71 }
72
73 async fn insert_into(
74 &self,
75 _state: &dyn Session,
76 input: Arc<dyn ExecutionPlan>,
77 op: InsertOp,
78 ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
79 Ok(Arc::new(DataSinkExec::new(
80 input,
81 Arc::new(MySQLDataSink::new(
82 Arc::clone(&self.mysql),
83 op == InsertOp::Overwrite,
84 self.on_conflict.clone(),
85 self.schema(),
86 )),
87 None,
88 )))
89 }
90}
91
92pub struct MySQLDataSink {
93 pub mysql: Arc<MySQL>,
94 pub overwrite: bool,
95 pub on_conflict: Option<OnConflict>,
96 schema: SchemaRef,
97}
98
99#[async_trait]
100impl DataSink for MySQLDataSink {
101 fn as_any(&self) -> &dyn Any {
102 self
103 }
104
105 fn metrics(&self) -> Option<MetricsSet> {
106 None
107 }
108
109 fn schema(&self) -> &SchemaRef {
110 &self.schema
111 }
112
113 async fn write_all(
114 &self,
115 mut data: SendableRecordBatchStream,
116 _context: &Arc<TaskContext>,
117 ) -> datafusion::common::Result<u64> {
118 let mut num_rows = 0u64;
119
120 let mut db_conn = self.mysql.connect().await.map_err(to_datafusion_error)?;
121 let mysql_conn = MySQL::mysql_conn(&mut db_conn).map_err(to_datafusion_error)?;
122
123 let mut conn_guard = mysql_conn.conn.lock().await;
124 let mut tx = conn_guard
125 .start_transaction(TxOpts::default())
126 .await
127 .context(super::UnableToBeginTransactionSnafu)
128 .map_err(to_datafusion_error)?;
129
130 if self.overwrite {
131 self.mysql
132 .delete_all_table_data(&mut tx)
133 .await
134 .map_err(to_datafusion_error)?;
135 }
136
137 while let Some(batch) = data.next().await {
138 let batch = batch.map_err(check_and_mark_retriable_error)?;
139 let batch_num_rows = batch.num_rows();
140
141 if batch_num_rows == 0 {
142 continue;
143 }
144
145 num_rows += batch_num_rows as u64;
146
147 constraints::validate_batch_with_constraints(
148 &[batch.clone()],
149 self.mysql.constraints(),
150 )
151 .await
152 .context(super::ConstraintViolationSnafu)
153 .map_err(to_datafusion_error)?;
154
155 self.mysql
156 .insert_batch(&mut tx, batch, self.on_conflict.clone())
157 .await
158 .map_err(to_datafusion_error)?;
159 }
160
161 tx.commit()
162 .await
163 .context(super::UnableToCommitMySQLTransactionSnafu)
164 .map_err(to_datafusion_error)?;
165
166 drop(conn_guard);
167
168 Ok(num_rows)
169 }
170}
171
172impl MySQLDataSink {
173 pub fn new(
174 mysql: Arc<MySQL>,
175 overwrite: bool,
176 on_conflict: Option<OnConflict>,
177 schema: SchemaRef,
178 ) -> Self {
179 Self {
180 mysql,
181 overwrite,
182 on_conflict,
183 schema,
184 }
185 }
186}
187
188impl fmt::Debug for MySQLDataSink {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 write!(f, "MySQLDataSink")
191 }
192}
193
194impl DisplayAs for MySQLDataSink {
195 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
196 write!(f, "MySQLDataSink")
197 }
198}