datafusion_table_providers/postgres/
write.rs1use std::{any::Any, fmt, sync::Arc};
2
3use arrow::datatypes::SchemaRef;
4use arrow_schema::{DataType, Field, Schema};
5use async_trait::async_trait;
6use datafusion::{
7 catalog::Session,
8 common::{Constraints, SchemaExt},
9 datasource::{
10 sink::{DataSink, DataSinkExec},
11 TableProvider, TableType,
12 },
13 execution::{SendableRecordBatchStream, TaskContext},
14 logical_expr::{dml::InsertOp, Expr},
15 physical_plan::{metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan},
16};
17use futures::StreamExt;
18use snafu::prelude::*;
19
20use crate::util::{
21 constraints, on_conflict::OnConflict, retriable_error::check_and_mark_retriable_error,
22};
23
24use crate::postgres::Postgres;
25
26use super::to_datafusion_error;
27
28#[derive(Debug, Clone)]
29pub struct PostgresTableWriter {
30 pub read_provider: Arc<dyn TableProvider>,
31 postgres: Arc<Postgres>,
32 on_conflict: Option<OnConflict>,
33}
34
35impl PostgresTableWriter {
36 pub fn create(
37 read_provider: Arc<dyn TableProvider>,
38 postgres: Postgres,
39 on_conflict: Option<OnConflict>,
40 ) -> Arc<Self> {
41 Arc::new(Self {
42 read_provider,
43 postgres: Arc::new(postgres),
44 on_conflict,
45 })
46 }
47
48 pub fn postgres(&self) -> Arc<Postgres> {
49 Arc::clone(&self.postgres)
50 }
51}
52
53#[async_trait]
54impl TableProvider for PostgresTableWriter {
55 fn as_any(&self) -> &dyn Any {
56 self
57 }
58
59 fn schema(&self) -> SchemaRef {
60 self.read_provider.schema()
61 }
62
63 fn table_type(&self) -> TableType {
64 TableType::Base
65 }
66
67 fn constraints(&self) -> Option<&Constraints> {
68 Some(self.postgres.constraints())
69 }
70
71 async fn scan(
72 &self,
73 state: &dyn Session,
74 projection: Option<&Vec<usize>>,
75 filters: &[Expr],
76 limit: Option<usize>,
77 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
78 self.read_provider
79 .scan(state, projection, filters, limit)
80 .await
81 }
82
83 async fn insert_into(
84 &self,
85 _state: &dyn Session,
86 input: Arc<dyn ExecutionPlan>,
87 op: InsertOp,
88 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
89 Ok(Arc::new(DataSinkExec::new(
90 input,
91 Arc::new(PostgresDataSink::new(
92 Arc::clone(&self.postgres),
93 op,
94 self.on_conflict.clone(),
95 self.schema(),
96 )),
97 None,
98 )) as _)
99 }
100}
101
102#[derive(Clone)]
103struct PostgresDataSink {
104 postgres: Arc<Postgres>,
105 overwrite: InsertOp,
106 on_conflict: Option<OnConflict>,
107 schema: SchemaRef,
108}
109
110#[async_trait]
111impl DataSink for PostgresDataSink {
112 fn as_any(&self) -> &dyn Any {
113 self
114 }
115
116 fn metrics(&self) -> Option<MetricsSet> {
117 None
118 }
119
120 fn schema(&self) -> &SchemaRef {
121 &self.schema
122 }
123
124 async fn write_all(
125 &self,
126 mut data: SendableRecordBatchStream,
127 _context: &Arc<TaskContext>,
128 ) -> datafusion::common::Result<u64> {
129 let mut num_rows = 0;
130
131 let mut db_conn = self.postgres.connect().await.map_err(to_datafusion_error)?;
132 let postgres_conn = Postgres::postgres_conn(&mut db_conn).map_err(to_datafusion_error)?;
133
134 let tx = postgres_conn
135 .conn
136 .transaction()
137 .await
138 .context(super::UnableToBeginTransactionSnafu)
139 .map_err(to_datafusion_error)?;
140
141 if matches!(self.overwrite, InsertOp::Overwrite) {
142 self.postgres
143 .delete_all_table_data(&tx)
144 .await
145 .map_err(to_datafusion_error)?;
146 }
147
148 let postgres_fields = self
149 .postgres
150 .schema
151 .fields
152 .iter()
153 .map(|f| {
154 Arc::new(Field::new(
155 f.name(),
156 if f.data_type() == &DataType::LargeUtf8 {
157 DataType::Utf8
158 } else {
159 f.data_type().clone()
160 },
161 f.is_nullable(),
162 ))
163 })
164 .collect::<Vec<_>>();
165
166 let postgres_schema = Arc::new(Schema::new(postgres_fields));
167
168 while let Some(batch) = data.next().await {
169 let batch = batch.map_err(check_and_mark_retriable_error)?;
170
171 let batch_fields = batch
175 .schema_ref()
176 .fields()
177 .iter()
178 .map(|f| {
179 Arc::new(Field::new(
180 f.name(),
181 if f.data_type() == &DataType::LargeUtf8 {
182 DataType::Utf8
183 } else {
184 f.data_type().clone()
185 },
186 f.is_nullable(),
187 ))
188 })
189 .collect::<Vec<_>>();
190 let batch_schema = Arc::new(Schema::new(batch_fields));
191
192 if !Arc::clone(&postgres_schema).equivalent_names_and_types(&batch_schema) {
193 return Err(to_datafusion_error(super::Error::SchemaValidationError {
194 table_name: self.postgres.table.to_string(),
195 }));
196 }
197
198 let batch_num_rows = batch.num_rows();
199
200 if batch_num_rows == 0 {
201 continue;
202 };
203
204 num_rows += batch_num_rows as u64;
205
206 constraints::validate_batch_with_constraints(
207 &[batch.clone()],
208 self.postgres.constraints(),
209 )
210 .await
211 .context(super::ConstraintViolationSnafu)
212 .map_err(to_datafusion_error)?;
213
214 self.postgres
215 .insert_batch(&tx, batch, self.on_conflict.clone())
216 .await
217 .map_err(to_datafusion_error)?;
218 }
219
220 tx.commit()
221 .await
222 .context(super::UnableToCommitPostgresTransactionSnafu)
223 .map_err(to_datafusion_error)?;
224
225 Ok(num_rows)
226 }
227}
228
229impl PostgresDataSink {
230 fn new(
231 postgres: Arc<Postgres>,
232 overwrite: InsertOp,
233 on_conflict: Option<OnConflict>,
234 schema: SchemaRef,
235 ) -> Self {
236 Self {
237 postgres,
238 overwrite,
239 on_conflict,
240 schema,
241 }
242 }
243}
244
245impl std::fmt::Debug for PostgresDataSink {
246 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
247 write!(f, "PostgresDataSink")
248 }
249}
250
251impl DisplayAs for PostgresDataSink {
252 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
253 write!(f, "PostgresDataSink")
254 }
255}