datafusion_table_providers/postgres/
write.rs

1use std::{any::Any, fmt, sync::Arc};
2
3use arrow::datatypes::SchemaRef;
4use arrow_schema::{DataType, Field, Schema};
5use async_trait::async_trait;
6use datafusion::{
7    catalog::Session,
8    common::{Constraints, SchemaExt},
9    datasource::{
10        sink::{DataSink, DataSinkExec},
11        TableProvider, TableType,
12    },
13    execution::{SendableRecordBatchStream, TaskContext},
14    logical_expr::{dml::InsertOp, Expr},
15    physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
16};
17use futures::StreamExt;
18use snafu::prelude::*;
19
20use crate::util::{
21    constraints, on_conflict::OnConflict, retriable_error::check_and_mark_retriable_error,
22};
23
24use crate::postgres::Postgres;
25
26use super::to_datafusion_error;
27
28#[derive(Debug, Clone)]
29pub struct PostgresTableWriter {
30    pub read_provider: Arc<dyn TableProvider>,
31    postgres: Arc<Postgres>,
32    on_conflict: Option<OnConflict>,
33}
34
35impl PostgresTableWriter {
36    pub fn create(
37        read_provider: Arc<dyn TableProvider>,
38        postgres: Postgres,
39        on_conflict: Option<OnConflict>,
40    ) -> Arc<Self> {
41        Arc::new(Self {
42            read_provider,
43            postgres: Arc::new(postgres),
44            on_conflict,
45        })
46    }
47
48    pub fn postgres(&self) -> Arc<Postgres> {
49        Arc::clone(&self.postgres)
50    }
51}
52
53#[async_trait]
54impl TableProvider for PostgresTableWriter {
55    fn as_any(&self) -> &dyn Any {
56        self
57    }
58
59    fn schema(&self) -> SchemaRef {
60        self.read_provider.schema()
61    }
62
63    fn table_type(&self) -> TableType {
64        TableType::Base
65    }
66
67    fn constraints(&self) -> Option<&Constraints> {
68        Some(self.postgres.constraints())
69    }
70
71    async fn scan(
72        &self,
73        state: &dyn Session,
74        projection: Option<&Vec<usize>>,
75        filters: &[Expr],
76        limit: Option<usize>,
77    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
78        self.read_provider
79            .scan(state, projection, filters, limit)
80            .await
81    }
82
83    async fn insert_into(
84        &self,
85        _state: &dyn Session,
86        input: Arc<dyn ExecutionPlan>,
87        op: InsertOp,
88    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
89        Ok(Arc::new(DataSinkExec::new(
90            input,
91            Arc::new(PostgresDataSink::new(
92                Arc::clone(&self.postgres),
93                op,
94                self.on_conflict.clone(),
95                self.schema(),
96            )),
97            None,
98        )) as _)
99    }
100}
101
102#[derive(Clone)]
103struct PostgresDataSink {
104    postgres: Arc<Postgres>,
105    overwrite: InsertOp,
106    on_conflict: Option<OnConflict>,
107    schema: SchemaRef,
108}
109
110#[async_trait]
111impl DataSink for PostgresDataSink {
112    fn as_any(&self) -> &dyn Any {
113        self
114    }
115
116    fn metrics(&self) -> Option<MetricsSet> {
117        None
118    }
119
120    fn schema(&self) -> &SchemaRef {
121        &self.schema
122    }
123
124    async fn write_all(
125        &self,
126        mut data: SendableRecordBatchStream,
127        _context: &Arc<TaskContext>,
128    ) -> datafusion::common::Result<u64> {
129        let mut num_rows = 0;
130
131        let mut db_conn = self.postgres.connect().await.map_err(to_datafusion_error)?;
132        let postgres_conn = Postgres::postgres_conn(&mut db_conn).map_err(to_datafusion_error)?;
133
134        let tx = postgres_conn
135            .conn
136            .transaction()
137            .await
138            .context(super::UnableToBeginTransactionSnafu)
139            .map_err(to_datafusion_error)?;
140
141        if matches!(self.overwrite, InsertOp::Overwrite) {
142            self.postgres
143                .delete_all_table_data(&tx)
144                .await
145                .map_err(to_datafusion_error)?;
146        }
147
148        let postgres_fields = self
149            .postgres
150            .schema
151            .fields
152            .iter()
153            .map(|f| {
154                Arc::new(Field::new(
155                    f.name(),
156                    if f.data_type() == &DataType::LargeUtf8 {
157                        DataType::Utf8
158                    } else {
159                        f.data_type().clone()
160                    },
161                    f.is_nullable(),
162                ))
163            })
164            .collect::<Vec<_>>();
165
166        let postgres_schema = Arc::new(Schema::new(postgres_fields));
167
168        while let Some(batch) = data.next().await {
169            let batch = batch.map_err(check_and_mark_retriable_error)?;
170
171            // for the purposes of PostgreSQL, LargeUtf8 is equivalent to Utf8
172            // because Postgres physically cannot store anything larger than 1Gb in text (VARCHAR)
173            // normalize LargeUtf8 fields to Utf8 for both the incoming batch, and Postgres if it happens to specify any
174            let batch_fields = batch
175                .schema_ref()
176                .fields()
177                .iter()
178                .map(|f| {
179                    Arc::new(Field::new(
180                        f.name(),
181                        if f.data_type() == &DataType::LargeUtf8 {
182                            DataType::Utf8
183                        } else {
184                            f.data_type().clone()
185                        },
186                        f.is_nullable(),
187                    ))
188                })
189                .collect::<Vec<_>>();
190            let batch_schema = Arc::new(Schema::new(batch_fields));
191
192            if !Arc::clone(&postgres_schema).equivalent_names_and_types(&batch_schema) {
193                return Err(to_datafusion_error(super::Error::SchemaValidationError {
194                    table_name: self.postgres.table.to_string(),
195                }));
196            }
197
198            let batch_num_rows = batch.num_rows();
199
200            if batch_num_rows == 0 {
201                continue;
202            };
203
204            num_rows += batch_num_rows as u64;
205
206            constraints::validate_batch_with_constraints(
207                &[batch.clone()],
208                self.postgres.constraints(),
209            )
210            .await
211            .context(super::ConstraintViolationSnafu)
212            .map_err(to_datafusion_error)?;
213
214            self.postgres
215                .insert_batch(&tx, batch, self.on_conflict.clone())
216                .await
217                .map_err(to_datafusion_error)?;
218        }
219
220        tx.commit()
221            .await
222            .context(super::UnableToCommitPostgresTransactionSnafu)
223            .map_err(to_datafusion_error)?;
224
225        Ok(num_rows)
226    }
227}
228
229impl PostgresDataSink {
230    fn new(
231        postgres: Arc<Postgres>,
232        overwrite: InsertOp,
233        on_conflict: Option<OnConflict>,
234        schema: SchemaRef,
235    ) -> Self {
236        Self {
237            postgres,
238            overwrite,
239            on_conflict,
240            schema,
241        }
242    }
243}
244
245impl std::fmt::Debug for PostgresDataSink {
246    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
247        write!(f, "PostgresDataSink")
248    }
249}
250
251impl DisplayAs for PostgresDataSink {
252    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
253        write!(f, "PostgresDataSink")
254    }
255}