datafusion_table_providers/sqlite/
write.rs

1use std::{any::Any, fmt, sync::Arc};
2
3use async_trait::async_trait;
4use datafusion::arrow::{array::RecordBatch, datatypes::SchemaRef};
5use datafusion::datasource::sink::{DataSink, DataSinkExec};
6use datafusion::{
7    catalog::Session,
8    common::Constraints,
9    datasource::{TableProvider, TableType},
10    error::DataFusionError,
11    execution::{SendableRecordBatchStream, TaskContext},
12    logical_expr::{dml::InsertOp, Expr},
13    physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
14};
15use futures::StreamExt;
16use snafu::prelude::*;
17
18use crate::util::{
19    constraints,
20    on_conflict::OnConflict,
21    retriable_error::{check_and_mark_retriable_error, to_retriable_data_write_error},
22};
23
24use super::{to_datafusion_error, Sqlite};
25
26#[derive(Debug, Clone)]
27pub struct SqliteTableWriter {
28    pub read_provider: Arc<dyn TableProvider>,
29    sqlite: Arc<Sqlite>,
30    on_conflict: Option<OnConflict>,
31}
32
33impl SqliteTableWriter {
34    pub fn create(
35        read_provider: Arc<dyn TableProvider>,
36        sqlite: Sqlite,
37        on_conflict: Option<OnConflict>,
38    ) -> Arc<Self> {
39        Arc::new(Self {
40            read_provider,
41            sqlite: Arc::new(sqlite),
42            on_conflict,
43        })
44    }
45
46    pub fn sqlite(&self) -> Arc<Sqlite> {
47        Arc::clone(&self.sqlite)
48    }
49}
50
51#[async_trait]
52impl TableProvider for SqliteTableWriter {
53    fn as_any(&self) -> &dyn Any {
54        self
55    }
56
57    fn schema(&self) -> SchemaRef {
58        self.read_provider.schema()
59    }
60
61    fn table_type(&self) -> TableType {
62        TableType::Base
63    }
64
65    fn constraints(&self) -> Option<&Constraints> {
66        Some(self.sqlite.constraints())
67    }
68
69    async fn scan(
70        &self,
71        state: &dyn Session,
72        projection: Option<&Vec<usize>>,
73        filters: &[Expr],
74        limit: Option<usize>,
75    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
76        self.read_provider
77            .scan(state, projection, filters, limit)
78            .await
79    }
80
81    async fn insert_into(
82        &self,
83        _state: &dyn Session,
84        input: Arc<dyn ExecutionPlan>,
85        op: InsertOp,
86    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
87        Ok(Arc::new(DataSinkExec::new(
88            input,
89            Arc::new(SqliteDataSink::new(
90                Arc::clone(&self.sqlite),
91                op,
92                self.on_conflict.clone(),
93                self.schema(),
94            )),
95            None,
96        )) as _)
97    }
98}
99
100#[derive(Clone)]
101struct SqliteDataSink {
102    sqlite: Arc<Sqlite>,
103    overwrite: InsertOp,
104    on_conflict: Option<OnConflict>,
105    schema: SchemaRef,
106}
107
108#[async_trait]
109impl DataSink for SqliteDataSink {
110    fn as_any(&self) -> &dyn Any {
111        self
112    }
113
114    fn metrics(&self) -> Option<MetricsSet> {
115        None
116    }
117
118    fn schema(&self) -> &SchemaRef {
119        &self.schema
120    }
121
122    async fn write_all(
123        &self,
124        data: SendableRecordBatchStream,
125        _context: &Arc<TaskContext>,
126    ) -> datafusion::common::Result<u64> {
127        let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(1);
128
129        // Since the main task/stream can be dropped or fail, we use a oneshot channel to signal that all data is received and we should commit the transaction
130        let (notify_commit_transaction, mut on_commit_transaction) =
131            tokio::sync::oneshot::channel();
132
133        let mut db_conn = self
134            .sqlite
135            .connect()
136            .await
137            .map_err(to_retriable_data_write_error)?;
138        let sqlite_conn =
139            Sqlite::sqlite_conn(&mut db_conn).map_err(to_retriable_data_write_error)?;
140
141        let constraints = self.sqlite.constraints().clone();
142        let mut data = data;
143        let task = tokio::spawn(async move {
144            let mut num_rows: u64 = 0;
145            while let Some(data_batch) = data.next().await {
146                let data_batch = data_batch.map_err(check_and_mark_retriable_error)?;
147                num_rows += u64::try_from(data_batch.num_rows()).map_err(|e| {
148                    DataFusionError::Execution(format!("Unable to convert num_rows() to u64: {e}"))
149                })?;
150
151                constraints::validate_batch_with_constraints(&[data_batch.clone()], &constraints)
152                    .await
153                    .context(super::ConstraintViolationSnafu)
154                    .map_err(to_datafusion_error)?;
155
156                batch_tx.send(data_batch).await.map_err(|err| {
157                    DataFusionError::Execution(format!("Error sending data batch: {err}"))
158                })?;
159            }
160
161            if notify_commit_transaction.send(()).is_err() {
162                return Err(DataFusionError::Execution(
163                    "Unable to send message to commit transaction to SQLite writer.".to_string(),
164                ));
165            };
166
167            // Drop the sender to signal the receiver that no more data is coming
168            drop(batch_tx);
169
170            Ok::<_, DataFusionError>(num_rows)
171        });
172
173        let overwrite = self.overwrite;
174        let sqlite = Arc::clone(&self.sqlite);
175        let on_conflict = self.on_conflict.clone();
176        sqlite_conn
177            .conn
178            .call(move |conn| {
179                let transaction = conn.transaction()?;
180
181                if matches!(overwrite, InsertOp::Overwrite) {
182                    sqlite.delete_all_table_data(&transaction)?;
183                }
184
185                while let Some(data_batch) = batch_rx.blocking_recv() {
186                    if data_batch.num_rows() > 0 {
187                        sqlite.insert_batch(&transaction, data_batch, on_conflict.as_ref())?;
188                    }
189                }
190
191                if on_commit_transaction.try_recv().is_err() {
192                    return Err(tokio_rusqlite::Error::Other(
193                        "No message to commit transaction has been received.".into(),
194                    ));
195                }
196
197                transaction.commit()?;
198
199                Ok(())
200            })
201            .await
202            .context(super::UnableToInsertIntoTableAsyncSnafu)
203            .map_err(|e| {
204                if let super::Error::UnableToInsertIntoTableAsync {
205                    source:
206                        tokio_rusqlite::Error::Rusqlite(rusqlite::Error::SqliteFailure(
207                            rusqlite::ffi::Error {
208                                code: rusqlite::ffi::ErrorCode::DiskFull,
209                                ..
210                            },
211                            _,
212                        )),
213                } = e
214                {
215                    DataFusionError::External(super::Error::DiskFull {}.into())
216                } else {
217                    to_retriable_data_write_error(e)
218                }
219            })?;
220
221        let num_rows = task.await.map_err(to_retriable_data_write_error)??;
222
223        Ok(num_rows)
224    }
225}
226
227impl SqliteDataSink {
228    fn new(
229        sqlite: Arc<Sqlite>,
230        overwrite: InsertOp,
231        on_conflict: Option<OnConflict>,
232        schema: SchemaRef,
233    ) -> Self {
234        Self {
235            sqlite,
236            overwrite,
237            on_conflict,
238            schema,
239        }
240    }
241}
242
243impl std::fmt::Debug for SqliteDataSink {
244    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
245        write!(f, "SqliteDataSink")
246    }
247}
248
249impl DisplayAs for SqliteDataSink {
250    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
251        write!(f, "SqliteDataSink")
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::{collections::HashMap, sync::Arc};
258
259    use datafusion::arrow::{
260        array::{Int64Array, RecordBatch, StringArray},
261        datatypes::{DataType, Schema},
262    };
263    use datafusion::{
264        catalog::TableProviderFactory,
265        common::{Constraints, TableReference, ToDFSchema},
266        execution::context::SessionContext,
267        logical_expr::{dml::InsertOp, CreateExternalTable},
268        physical_plan::collect,
269    };
270
271    use crate::sqlite::SqliteTableProviderFactory;
272    use crate::util::test::MockExec;
273
274    #[tokio::test]
275    #[allow(clippy::unreadable_literal)]
276    async fn test_round_trip_sqlite() {
277        let schema = Arc::new(Schema::new(vec![
278            datafusion::arrow::datatypes::Field::new("time_in_string", DataType::Utf8, false),
279            datafusion::arrow::datatypes::Field::new("time_int", DataType::Int64, false),
280        ]));
281        let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema");
282        let external_table = CreateExternalTable {
283            schema: df_schema,
284            name: TableReference::bare("test_table"),
285            location: String::new(),
286            file_type: String::new(),
287            table_partition_cols: vec![],
288            if_not_exists: true,
289            definition: None,
290            order_exprs: vec![],
291            unbounded: false,
292            options: HashMap::new(),
293            constraints: Constraints::empty(),
294            column_defaults: HashMap::default(),
295            temporary: false,
296        };
297        let ctx = SessionContext::new();
298        let table = SqliteTableProviderFactory::default()
299            .create(&ctx.state(), &external_table)
300            .await
301            .expect("table should be created");
302
303        let arr1 = StringArray::from(vec![
304            "1970-01-01",
305            "2012-12-01T11:11:11Z",
306            "2012-12-01T11:11:12Z",
307        ]);
308        let arr3 = Int64Array::from(vec![0, 1354360271, 1354360272]);
309        let data = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(arr1), Arc::new(arr3)])
310            .expect("data should be created");
311
312        let exec = MockExec::new(vec![Ok(data)], schema);
313
314        let insertion = table
315            .insert_into(&ctx.state(), Arc::new(exec), InsertOp::Append)
316            .await
317            .expect("insertion should be successful");
318
319        collect(insertion, ctx.task_ctx())
320            .await
321            .expect("insert successful");
322    }
323}