1use crate::sql::arrow_sql_gen::statement::{
2 CreateTableBuilder, Error as SqlGenError, IndexBuilder, InsertBuilder,
3};
4use crate::sql::db_connection_pool::{
5 self,
6 dbconnection::{postgresconn::PostgresConnection, DbConnection},
7 postgrespool::{self, PostgresConnectionPool},
8 DbConnectionPool,
9};
10use crate::sql::sql_provider_datafusion::SqlTable;
11use crate::util::schema::SchemaValidator;
12use crate::UnsupportedTypeAction;
13use arrow::{
14 array::RecordBatch,
15 datatypes::{Schema, SchemaRef},
16};
17use async_trait::async_trait;
18use bb8_postgres::{
19 tokio_postgres::{types::ToSql, Transaction},
20 PostgresConnectionManager,
21};
22use datafusion::catalog::Session;
23use datafusion::sql::unparser::dialect::PostgreSqlDialect;
24use datafusion::{
25 catalog::TableProviderFactory,
26 common::Constraints,
27 datasource::TableProvider,
28 error::{DataFusionError, Result as DataFusionResult},
29 logical_expr::CreateExternalTable,
30 sql::TableReference,
31};
32use postgres_native_tls::MakeTlsConnector;
33use snafu::prelude::*;
34use std::{collections::HashMap, sync::Arc};
35
36use crate::util::{
37 self,
38 column_reference::{self, ColumnReference},
39 constraints::{self, get_primary_keys_from_constraints},
40 indexes::IndexType,
41 on_conflict::{self, OnConflict},
42 secrets::to_secret_map,
43 to_datafusion_error,
44};
45
46use self::write::PostgresTableWriter;
47
48pub mod write;
49
50pub type DynPostgresConnectionPool = dyn DbConnectionPool<
51 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
52 &'static (dyn ToSql + Sync),
53 > + Send
54 + Sync;
55pub type DynPostgresConnection = dyn DbConnection<
56 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
57 &'static (dyn ToSql + Sync),
58>;
59
60#[derive(Debug, Snafu)]
61pub enum Error {
62 #[snafu(display("DbConnectionError: {source}"))]
63 DbConnectionError {
64 source: db_connection_pool::dbconnection::GenericError,
65 },
66
67 #[snafu(display("Unable to create Postgres connection pool: {source}"))]
68 UnableToCreatePostgresConnectionPool { source: postgrespool::Error },
69
70 #[snafu(display("Unable to downcast DbConnection to PostgresConnection"))]
71 UnableToDowncastDbConnection {},
72
73 #[snafu(display("Unable to begin Postgres transaction: {source}"))]
74 UnableToBeginTransaction {
75 source: tokio_postgres::error::Error,
76 },
77
78 #[snafu(display("Unable to create the Postgres table: {source}"))]
79 UnableToCreatePostgresTable {
80 source: tokio_postgres::error::Error,
81 },
82
83 #[snafu(display("Unable to create an index for the Postgres table: {source}"))]
84 UnableToCreateIndexForPostgresTable {
85 source: tokio_postgres::error::Error,
86 },
87
88 #[snafu(display("Unable to commit the Postgres transaction: {source}"))]
89 UnableToCommitPostgresTransaction {
90 source: tokio_postgres::error::Error,
91 },
92
93 #[snafu(display("Unable to generate SQL: {source}"))]
94 UnableToGenerateSQL { source: DataFusionError },
95
96 #[snafu(display("Unable to delete all data from the Postgres table: {source}"))]
97 UnableToDeleteAllTableData {
98 source: tokio_postgres::error::Error,
99 },
100
101 #[snafu(display("Unable to delete data from the Postgres table: {source}"))]
102 UnableToDeleteData {
103 source: tokio_postgres::error::Error,
104 },
105
106 #[snafu(display("Unable to insert Arrow batch to Postgres table: {source}"))]
107 UnableToInsertArrowBatch {
108 source: tokio_postgres::error::Error,
109 },
110
111 #[snafu(display("Unable to create insertion statement for Postgres table: {source}"))]
112 UnableToCreateInsertStatement { source: SqlGenError },
113
114 #[snafu(display("The table '{table_name}' doesn't exist in the Postgres server"))]
115 TableDoesntExist { table_name: String },
116
117 #[snafu(display("Constraint Violation: {source}"))]
118 ConstraintViolation { source: constraints::Error },
119
120 #[snafu(display("Error parsing column reference: {source}"))]
121 UnableToParseColumnReference { source: column_reference::Error },
122
123 #[snafu(display("Error parsing on_conflict: {source}"))]
124 UnableToParseOnConflict { source: on_conflict::Error },
125
126 #[snafu(display(
127 "Failed to create '{table_name}': creating a table with a schema is not supported"
128 ))]
129 TableWithSchemaCreationNotSupported { table_name: String },
130
131 #[snafu(display("Schema validation error: the provided data schema does not match the expected table schema: '{table_name}'"))]
132 SchemaValidationError { table_name: String },
133}
134
135type Result<T, E = Error> = std::result::Result<T, E>;
136
137pub struct PostgresTableFactory {
138 pool: Arc<PostgresConnectionPool>,
139}
140
141impl PostgresTableFactory {
142 #[must_use]
143 pub fn new(pool: Arc<PostgresConnectionPool>) -> Self {
144 Self { pool }
145 }
146
147 pub async fn table_provider(
148 &self,
149 table_reference: TableReference,
150 ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
151 let pool = Arc::clone(&self.pool);
152 let dyn_pool: Arc<DynPostgresConnectionPool> = pool;
153
154 let table_provider = Arc::new(
155 SqlTable::new("postgres", &dyn_pool, table_reference)
156 .await
157 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
158 .with_dialect(Arc::new(PostgreSqlDialect {})),
159 );
160
161 #[cfg(feature = "postgres-federation")]
162 let table_provider = Arc::new(
163 table_provider
164 .create_federated_table_provider()
165 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
166 );
167
168 Ok(table_provider)
169 }
170
171 pub async fn read_write_table_provider(
172 &self,
173 table_reference: TableReference,
174 ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
175 let read_provider = Self::table_provider(self, table_reference.clone()).await?;
176 let schema = read_provider.schema();
177
178 let postgres = Postgres::new(
179 table_reference,
180 Arc::clone(&self.pool),
181 schema,
182 Constraints::default(),
183 );
184
185 Ok(PostgresTableWriter::create(read_provider, postgres, None))
186 }
187}
188
189#[derive(Debug)]
190pub struct PostgresTableProviderFactory;
191
192impl PostgresTableProviderFactory {
193 #[must_use]
194 pub fn new() -> Self {
195 Self {}
196 }
197}
198
199impl Default for PostgresTableProviderFactory {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205#[async_trait]
206impl TableProviderFactory for PostgresTableProviderFactory {
207 async fn create(
208 &self,
209 _state: &dyn Session,
210 cmd: &CreateExternalTable,
211 ) -> DataFusionResult<Arc<dyn TableProvider>> {
212 if cmd.name.schema().is_some() {
213 TableWithSchemaCreationNotSupportedSnafu {
214 table_name: cmd.name.to_string(),
215 }
216 .fail()
217 .map_err(to_datafusion_error)?;
218 }
219
220 let name = cmd.name.clone();
221 let mut options = cmd.options.clone();
222 let schema: Schema = cmd.schema.as_ref().into();
223
224 let indexes_option_str = options.remove("indexes");
225 let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
226 Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
227 None => HashMap::new(),
228 };
229
230 let unparsed_indexes = unparsed_indexes
231 .into_iter()
232 .map(|(key, value)| {
233 let columns = ColumnReference::try_from(key.as_str())
234 .context(UnableToParseColumnReferenceSnafu)
235 .map_err(to_datafusion_error);
236 (columns, value)
237 })
238 .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
239
240 let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
241 for (columns, index_type) in unparsed_indexes {
242 let columns = columns?;
243 indexes.push((columns, index_type));
244 }
245
246 let mut on_conflict: Option<OnConflict> = None;
247 if let Some(on_conflict_str) = options.remove("on_conflict") {
248 on_conflict = Some(
249 OnConflict::try_from(on_conflict_str.as_str())
250 .context(UnableToParseOnConflictSnafu)
251 .map_err(to_datafusion_error)?,
252 );
253 }
254
255 let params = to_secret_map(options);
256
257 let pool = Arc::new(
258 PostgresConnectionPool::new(params)
259 .await
260 .context(UnableToCreatePostgresConnectionPoolSnafu)
261 .map_err(to_datafusion_error)?,
262 );
263
264 let schema: SchemaRef = Arc::new(schema);
265 PostgresConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::default())
266 .map_err(|e| DataFusionError::External(e.into()))?;
267
268 let postgres = Postgres::new(
269 name.clone(),
270 Arc::clone(&pool),
271 Arc::clone(&schema),
272 cmd.constraints.clone(),
273 );
274
275 let mut db_conn = pool
276 .connect()
277 .await
278 .context(DbConnectionSnafu)
279 .map_err(to_datafusion_error)?;
280 let postgres_conn = Postgres::postgres_conn(&mut db_conn).map_err(to_datafusion_error)?;
281
282 let tx = postgres_conn
283 .conn
284 .transaction()
285 .await
286 .context(UnableToBeginTransactionSnafu)
287 .map_err(to_datafusion_error)?;
288
289 let primary_keys = get_primary_keys_from_constraints(&cmd.constraints, &schema);
290
291 postgres
292 .create_table(Arc::clone(&schema), &tx, primary_keys)
293 .await
294 .map_err(to_datafusion_error)?;
295
296 for index in indexes {
297 postgres
298 .create_index(&tx, index.0.iter().collect(), index.1 == IndexType::Unique)
299 .await
300 .map_err(to_datafusion_error)?;
301 }
302
303 tx.commit()
304 .await
305 .context(UnableToCommitPostgresTransactionSnafu)
306 .map_err(to_datafusion_error)?;
307
308 let dyn_pool: Arc<DynPostgresConnectionPool> = pool;
309
310 let read_provider = Arc::new(
311 SqlTable::new_with_schema("postgres", &dyn_pool, Arc::clone(&schema), name)
312 .with_dialect(Arc::new(PostgreSqlDialect {})),
313 );
314
315 #[cfg(feature = "postgres-federation")]
316 let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
317
318 Ok(PostgresTableWriter::create(
319 read_provider,
320 postgres,
321 on_conflict,
322 ))
323 }
324}
325
326#[derive(Clone)]
327pub struct Postgres {
328 table: TableReference,
329 pool: Arc<PostgresConnectionPool>,
330 schema: SchemaRef,
331 constraints: Constraints,
332}
333
334impl std::fmt::Debug for Postgres {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 f.debug_struct("Postgres")
337 .field("table_name", &self.table)
338 .field("schema", &self.schema)
339 .field("constraints", &self.constraints)
340 .finish()
341 }
342}
343
344impl Postgres {
345 #[must_use]
346 pub fn new(
347 table: TableReference,
348 pool: Arc<PostgresConnectionPool>,
349 schema: SchemaRef,
350 constraints: Constraints,
351 ) -> Self {
352 Self {
353 table,
354 pool,
355 schema,
356 constraints,
357 }
358 }
359
360 #[must_use]
361 pub fn table_name(&self) -> &str {
362 self.table.table()
363 }
364
365 #[must_use]
366 pub fn constraints(&self) -> &Constraints {
367 &self.constraints
368 }
369
370 pub async fn connect(&self) -> Result<Box<DynPostgresConnection>> {
371 let mut conn = self.pool.connect().await.context(DbConnectionSnafu)?;
372
373 let pg_conn = Self::postgres_conn(&mut conn)?;
374
375 if !self.table_exists(pg_conn).await {
376 TableDoesntExistSnafu {
377 table_name: self.table.to_string(),
378 }
379 .fail()?;
380 }
381
382 Ok(conn)
383 }
384
385 pub fn postgres_conn(
386 db_connection: &mut Box<DynPostgresConnection>,
387 ) -> Result<&mut PostgresConnection> {
388 db_connection
389 .as_any_mut()
390 .downcast_mut::<PostgresConnection>()
391 .context(UnableToDowncastDbConnectionSnafu)
392 }
393
394 async fn table_exists(&self, postgres_conn: &PostgresConnection) -> bool {
395 let sql = match self.table.schema() {
396 Some(schema) => format!(
397 r#"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}' AND table_schema = '{schema}')"#,
398 name = self.table.table(),
399 schema = schema
400 ),
401 None => format!(
402 r#"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}')"#,
403 name = self.table.table()
404 ),
405 };
406
407 tracing::trace!("{sql}");
408
409 let Ok(row) = postgres_conn.conn.query_one(&sql, &[]).await else {
410 return false;
411 };
412
413 row.get(0)
414 }
415
416 async fn insert_batch(
417 &self,
418 transaction: &Transaction<'_>,
419 batch: RecordBatch,
420 on_conflict: Option<OnConflict>,
421 ) -> Result<()> {
422 let insert_table_builder = InsertBuilder::new(&self.table, vec![batch]);
423
424 let sea_query_on_conflict =
425 on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
426
427 let sql = insert_table_builder
428 .build_postgres(sea_query_on_conflict)
429 .context(UnableToCreateInsertStatementSnafu)?;
430
431 transaction
432 .execute(&sql, &[])
433 .await
434 .context(UnableToInsertArrowBatchSnafu)?;
435
436 Ok(())
437 }
438
439 async fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> Result<()> {
440 transaction
441 .execute(
442 format!(r#"DELETE FROM {}"#, self.table.to_quoted_string()).as_str(),
443 &[],
444 )
445 .await
446 .context(UnableToDeleteAllTableDataSnafu)?;
447
448 Ok(())
449 }
450
451 async fn create_table(
452 &self,
453 schema: SchemaRef,
454 transaction: &Transaction<'_>,
455 primary_keys: Vec<String>,
456 ) -> Result<()> {
457 let create_table_statement =
458 CreateTableBuilder::new(schema, self.table.table()).primary_keys(primary_keys);
459 let create_stmts = create_table_statement.build_postgres();
460
461 for create_stmt in create_stmts {
462 transaction
463 .execute(&create_stmt, &[])
464 .await
465 .context(UnableToCreatePostgresTableSnafu)?;
466 }
467
468 Ok(())
469 }
470
471 async fn create_index(
472 &self,
473 transaction: &Transaction<'_>,
474 columns: Vec<&str>,
475 unique: bool,
476 ) -> Result<()> {
477 let mut index_builder = IndexBuilder::new(self.table.table(), columns);
478 if unique {
479 index_builder = index_builder.unique();
480 }
481 let sql = index_builder.build_postgres();
482
483 transaction
484 .execute(&sql, &[])
485 .await
486 .context(UnableToCreateIndexForPostgresTableSnafu)?;
487
488 Ok(())
489 }
490}