datafusion_table_providers/sqlite/
write.rs1use std::{any::Any, fmt, sync::Arc};
2
3use async_trait::async_trait;
4use datafusion::arrow::{array::RecordBatch, datatypes::SchemaRef};
5use datafusion::datasource::sink::{DataSink, DataSinkExec};
6use datafusion::{
7 catalog::Session,
8 common::Constraints,
9 datasource::{TableProvider, TableType},
10 error::DataFusionError,
11 execution::{SendableRecordBatchStream, TaskContext},
12 logical_expr::{dml::InsertOp, Expr},
13 physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
14};
15use futures::StreamExt;
16use snafu::prelude::*;
17
18use crate::util::{
19 constraints,
20 on_conflict::OnConflict,
21 retriable_error::{check_and_mark_retriable_error, to_retriable_data_write_error},
22};
23
24use super::{to_datafusion_error, Sqlite};
25
26#[derive(Debug, Clone)]
27pub struct SqliteTableWriter {
28 pub read_provider: Arc<dyn TableProvider>,
29 sqlite: Arc<Sqlite>,
30 on_conflict: Option<OnConflict>,
31}
32
33impl SqliteTableWriter {
34 pub fn create(
35 read_provider: Arc<dyn TableProvider>,
36 sqlite: Sqlite,
37 on_conflict: Option<OnConflict>,
38 ) -> Arc<Self> {
39 Arc::new(Self {
40 read_provider,
41 sqlite: Arc::new(sqlite),
42 on_conflict,
43 })
44 }
45
46 pub fn sqlite(&self) -> Arc<Sqlite> {
47 Arc::clone(&self.sqlite)
48 }
49}
50
51#[async_trait]
52impl TableProvider for SqliteTableWriter {
53 fn as_any(&self) -> &dyn Any {
54 self
55 }
56
57 fn schema(&self) -> SchemaRef {
58 self.read_provider.schema()
59 }
60
61 fn table_type(&self) -> TableType {
62 TableType::Base
63 }
64
65 fn constraints(&self) -> Option<&Constraints> {
66 Some(self.sqlite.constraints())
67 }
68
69 async fn scan(
70 &self,
71 state: &dyn Session,
72 projection: Option<&Vec<usize>>,
73 filters: &[Expr],
74 limit: Option<usize>,
75 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
76 self.read_provider
77 .scan(state, projection, filters, limit)
78 .await
79 }
80
81 async fn insert_into(
82 &self,
83 _state: &dyn Session,
84 input: Arc<dyn ExecutionPlan>,
85 op: InsertOp,
86 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
87 Ok(Arc::new(DataSinkExec::new(
88 input,
89 Arc::new(SqliteDataSink::new(
90 Arc::clone(&self.sqlite),
91 op,
92 self.on_conflict.clone(),
93 self.schema(),
94 )),
95 None,
96 )) as _)
97 }
98}
99
100#[derive(Clone)]
101struct SqliteDataSink {
102 sqlite: Arc<Sqlite>,
103 overwrite: InsertOp,
104 on_conflict: Option<OnConflict>,
105 schema: SchemaRef,
106}
107
108#[async_trait]
109impl DataSink for SqliteDataSink {
110 fn as_any(&self) -> &dyn Any {
111 self
112 }
113
114 fn metrics(&self) -> Option<MetricsSet> {
115 None
116 }
117
118 fn schema(&self) -> &SchemaRef {
119 &self.schema
120 }
121
122 async fn write_all(
123 &self,
124 data: SendableRecordBatchStream,
125 _context: &Arc<TaskContext>,
126 ) -> datafusion::common::Result<u64> {
127 let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(1);
128
129 let (notify_commit_transaction, mut on_commit_transaction) =
131 tokio::sync::oneshot::channel();
132
133 let mut db_conn = self
134 .sqlite
135 .connect()
136 .await
137 .map_err(to_retriable_data_write_error)?;
138 let sqlite_conn =
139 Sqlite::sqlite_conn(&mut db_conn).map_err(to_retriable_data_write_error)?;
140
141 let constraints = self.sqlite.constraints().clone();
142 let mut data = data;
143 let task = tokio::spawn(async move {
144 let mut num_rows: u64 = 0;
145 while let Some(data_batch) = data.next().await {
146 let data_batch = data_batch.map_err(check_and_mark_retriable_error)?;
147 num_rows += u64::try_from(data_batch.num_rows()).map_err(|e| {
148 DataFusionError::Execution(format!("Unable to convert num_rows() to u64: {e}"))
149 })?;
150
151 constraints::validate_batch_with_constraints(&[data_batch.clone()], &constraints)
152 .await
153 .context(super::ConstraintViolationSnafu)
154 .map_err(to_datafusion_error)?;
155
156 batch_tx.send(data_batch).await.map_err(|err| {
157 DataFusionError::Execution(format!("Error sending data batch: {err}"))
158 })?;
159 }
160
161 if notify_commit_transaction.send(()).is_err() {
162 return Err(DataFusionError::Execution(
163 "Unable to send message to commit transaction to SQLite writer.".to_string(),
164 ));
165 };
166
167 drop(batch_tx);
169
170 Ok::<_, DataFusionError>(num_rows)
171 });
172
173 let overwrite = self.overwrite;
174 let sqlite = Arc::clone(&self.sqlite);
175 let on_conflict = self.on_conflict.clone();
176 sqlite_conn
177 .conn
178 .call(move |conn| {
179 let transaction = conn.transaction()?;
180
181 if matches!(overwrite, InsertOp::Overwrite) {
182 sqlite.delete_all_table_data(&transaction)?;
183 }
184
185 while let Some(data_batch) = batch_rx.blocking_recv() {
186 if data_batch.num_rows() > 0 {
187 sqlite.insert_batch(&transaction, data_batch, on_conflict.as_ref())?;
188 }
189 }
190
191 if on_commit_transaction.try_recv().is_err() {
192 return Err(tokio_rusqlite::Error::Other(
193 "No message to commit transaction has been received.".into(),
194 ));
195 }
196
197 transaction.commit()?;
198
199 Ok(())
200 })
201 .await
202 .context(super::UnableToInsertIntoTableAsyncSnafu)
203 .map_err(|e| {
204 if let super::Error::UnableToInsertIntoTableAsync {
205 source:
206 tokio_rusqlite::Error::Rusqlite(rusqlite::Error::SqliteFailure(
207 rusqlite::ffi::Error {
208 code: rusqlite::ffi::ErrorCode::DiskFull,
209 ..
210 },
211 _,
212 )),
213 } = e
214 {
215 DataFusionError::External(super::Error::DiskFull {}.into())
216 } else {
217 to_retriable_data_write_error(e)
218 }
219 })?;
220
221 let num_rows = task.await.map_err(to_retriable_data_write_error)??;
222
223 Ok(num_rows)
224 }
225}
226
227impl SqliteDataSink {
228 fn new(
229 sqlite: Arc<Sqlite>,
230 overwrite: InsertOp,
231 on_conflict: Option<OnConflict>,
232 schema: SchemaRef,
233 ) -> Self {
234 Self {
235 sqlite,
236 overwrite,
237 on_conflict,
238 schema,
239 }
240 }
241}
242
243impl std::fmt::Debug for SqliteDataSink {
244 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
245 write!(f, "SqliteDataSink")
246 }
247}
248
249impl DisplayAs for SqliteDataSink {
250 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
251 write!(f, "SqliteDataSink")
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use std::{collections::HashMap, sync::Arc};
258
259 use datafusion::arrow::{
260 array::{Int64Array, RecordBatch, StringArray},
261 datatypes::{DataType, Schema},
262 };
263 use datafusion::{
264 catalog::TableProviderFactory,
265 common::{Constraints, TableReference, ToDFSchema},
266 execution::context::SessionContext,
267 logical_expr::{dml::InsertOp, CreateExternalTable},
268 physical_plan::collect,
269 };
270
271 use crate::sqlite::SqliteTableProviderFactory;
272 use crate::util::test::MockExec;
273
274 #[tokio::test]
275 #[allow(clippy::unreadable_literal)]
276 async fn test_round_trip_sqlite() {
277 let schema = Arc::new(Schema::new(vec![
278 datafusion::arrow::datatypes::Field::new("time_in_string", DataType::Utf8, false),
279 datafusion::arrow::datatypes::Field::new("time_int", DataType::Int64, false),
280 ]));
281 let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema");
282 let external_table = CreateExternalTable {
283 schema: df_schema,
284 name: TableReference::bare("test_table"),
285 location: String::new(),
286 file_type: String::new(),
287 table_partition_cols: vec![],
288 if_not_exists: true,
289 definition: None,
290 order_exprs: vec![],
291 unbounded: false,
292 options: HashMap::new(),
293 constraints: Constraints::empty(),
294 column_defaults: HashMap::default(),
295 temporary: false,
296 };
297 let ctx = SessionContext::new();
298 let table = SqliteTableProviderFactory::default()
299 .create(&ctx.state(), &external_table)
300 .await
301 .expect("table should be created");
302
303 let arr1 = StringArray::from(vec![
304 "1970-01-01",
305 "2012-12-01T11:11:11Z",
306 "2012-12-01T11:11:12Z",
307 ]);
308 let arr3 = Int64Array::from(vec![0, 1354360271, 1354360272]);
309 let data = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(arr1), Arc::new(arr3)])
310 .expect("data should be created");
311
312 let exec = MockExec::new(vec![Ok(data)], schema);
313
314 let insertion = table
315 .insert_into(&ctx.state(), Arc::new(exec), InsertOp::Append)
316 .await
317 .expect("insertion should be successful");
318
319 collect(insertion, ctx.task_ctx())
320 .await
321 .expect("insert successful");
322 }
323}