datafusion_table_providers/
postgres.rs

1use crate::sql::arrow_sql_gen::statement::{
2    CreateTableBuilder, Error as SqlGenError, IndexBuilder, InsertBuilder,
3};
4use crate::sql::db_connection_pool::{
5    self,
6    dbconnection::{postgresconn::PostgresConnection, DbConnection},
7    postgrespool::{self, PostgresConnectionPool},
8    DbConnectionPool,
9};
10use crate::sql::sql_provider_datafusion::SqlTable;
11use crate::util::schema::SchemaValidator;
12use crate::UnsupportedTypeAction;
13use arrow::{
14    array::RecordBatch,
15    datatypes::{Schema, SchemaRef},
16};
17use async_trait::async_trait;
18use bb8_postgres::{
19    tokio_postgres::{types::ToSql, Transaction},
20    PostgresConnectionManager,
21};
22use datafusion::catalog::Session;
23use datafusion::sql::unparser::dialect::PostgreSqlDialect;
24use datafusion::{
25    catalog::TableProviderFactory,
26    common::Constraints,
27    datasource::TableProvider,
28    error::{DataFusionError, Result as DataFusionResult},
29    logical_expr::CreateExternalTable,
30    sql::TableReference,
31};
32use postgres_native_tls::MakeTlsConnector;
33use snafu::prelude::*;
34use std::{collections::HashMap, sync::Arc};
35
36use crate::util::{
37    self,
38    column_reference::{self, ColumnReference},
39    constraints::{self, get_primary_keys_from_constraints},
40    indexes::IndexType,
41    on_conflict::{self, OnConflict},
42    secrets::to_secret_map,
43    to_datafusion_error,
44};
45
46use self::write::PostgresTableWriter;
47
48pub mod write;
49
50pub type DynPostgresConnectionPool = dyn DbConnectionPool<
51        bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
52        &'static (dyn ToSql + Sync),
53    > + Send
54    + Sync;
55pub type DynPostgresConnection = dyn DbConnection<
56    bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
57    &'static (dyn ToSql + Sync),
58>;
59
60#[derive(Debug, Snafu)]
61pub enum Error {
62    #[snafu(display("DbConnectionError: {source}"))]
63    DbConnectionError {
64        source: db_connection_pool::dbconnection::GenericError,
65    },
66
67    #[snafu(display("Unable to create Postgres connection pool: {source}"))]
68    UnableToCreatePostgresConnectionPool { source: postgrespool::Error },
69
70    #[snafu(display("Unable to downcast DbConnection to PostgresConnection"))]
71    UnableToDowncastDbConnection {},
72
73    #[snafu(display("Unable to begin Postgres transaction: {source}"))]
74    UnableToBeginTransaction {
75        source: tokio_postgres::error::Error,
76    },
77
78    #[snafu(display("Unable to create the Postgres table: {source}"))]
79    UnableToCreatePostgresTable {
80        source: tokio_postgres::error::Error,
81    },
82
83    #[snafu(display("Unable to create an index for the Postgres table: {source}"))]
84    UnableToCreateIndexForPostgresTable {
85        source: tokio_postgres::error::Error,
86    },
87
88    #[snafu(display("Unable to commit the Postgres transaction: {source}"))]
89    UnableToCommitPostgresTransaction {
90        source: tokio_postgres::error::Error,
91    },
92
93    #[snafu(display("Unable to generate SQL: {source}"))]
94    UnableToGenerateSQL { source: DataFusionError },
95
96    #[snafu(display("Unable to delete all data from the Postgres table: {source}"))]
97    UnableToDeleteAllTableData {
98        source: tokio_postgres::error::Error,
99    },
100
101    #[snafu(display("Unable to delete data from the Postgres table: {source}"))]
102    UnableToDeleteData {
103        source: tokio_postgres::error::Error,
104    },
105
106    #[snafu(display("Unable to insert Arrow batch to Postgres table: {source}"))]
107    UnableToInsertArrowBatch {
108        source: tokio_postgres::error::Error,
109    },
110
111    #[snafu(display("Unable to create insertion statement for Postgres table: {source}"))]
112    UnableToCreateInsertStatement { source: SqlGenError },
113
114    #[snafu(display("The table '{table_name}' doesn't exist in the Postgres server"))]
115    TableDoesntExist { table_name: String },
116
117    #[snafu(display("Constraint Violation: {source}"))]
118    ConstraintViolation { source: constraints::Error },
119
120    #[snafu(display("Error parsing column reference: {source}"))]
121    UnableToParseColumnReference { source: column_reference::Error },
122
123    #[snafu(display("Error parsing on_conflict: {source}"))]
124    UnableToParseOnConflict { source: on_conflict::Error },
125
126    #[snafu(display(
127        "Failed to create '{table_name}': creating a table with a schema is not supported"
128    ))]
129    TableWithSchemaCreationNotSupported { table_name: String },
130
131    #[snafu(display("Schema validation error: the provided data schema does not match the expected table schema: '{table_name}'"))]
132    SchemaValidationError { table_name: String },
133}
134
135type Result<T, E = Error> = std::result::Result<T, E>;
136
137pub struct PostgresTableFactory {
138    pool: Arc<PostgresConnectionPool>,
139}
140
141impl PostgresTableFactory {
142    #[must_use]
143    pub fn new(pool: Arc<PostgresConnectionPool>) -> Self {
144        Self { pool }
145    }
146
147    pub async fn table_provider(
148        &self,
149        table_reference: TableReference,
150    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
151        let pool = Arc::clone(&self.pool);
152        let dyn_pool: Arc<DynPostgresConnectionPool> = pool;
153
154        let table_provider = Arc::new(
155            SqlTable::new("postgres", &dyn_pool, table_reference)
156                .await
157                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?
158                .with_dialect(Arc::new(PostgreSqlDialect {})),
159        );
160
161        #[cfg(feature = "postgres-federation")]
162        let table_provider = Arc::new(
163            table_provider
164                .create_federated_table_provider()
165                .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?,
166        );
167
168        Ok(table_provider)
169    }
170
171    pub async fn read_write_table_provider(
172        &self,
173        table_reference: TableReference,
174    ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> {
175        let read_provider = Self::table_provider(self, table_reference.clone()).await?;
176        let schema = read_provider.schema();
177
178        let postgres = Postgres::new(
179            table_reference,
180            Arc::clone(&self.pool),
181            schema,
182            Constraints::default(),
183        );
184
185        Ok(PostgresTableWriter::create(read_provider, postgres, None))
186    }
187}
188
189#[derive(Debug)]
190pub struct PostgresTableProviderFactory;
191
192impl PostgresTableProviderFactory {
193    #[must_use]
194    pub fn new() -> Self {
195        Self {}
196    }
197}
198
199impl Default for PostgresTableProviderFactory {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205#[async_trait]
206impl TableProviderFactory for PostgresTableProviderFactory {
207    async fn create(
208        &self,
209        _state: &dyn Session,
210        cmd: &CreateExternalTable,
211    ) -> DataFusionResult<Arc<dyn TableProvider>> {
212        if cmd.name.schema().is_some() {
213            TableWithSchemaCreationNotSupportedSnafu {
214                table_name: cmd.name.to_string(),
215            }
216            .fail()
217            .map_err(to_datafusion_error)?;
218        }
219
220        let name = cmd.name.clone();
221        let mut options = cmd.options.clone();
222        let schema: Schema = cmd.schema.as_ref().into();
223
224        let indexes_option_str = options.remove("indexes");
225        let unparsed_indexes: HashMap<String, IndexType> = match indexes_option_str {
226            Some(indexes_str) => util::hashmap_from_option_string(&indexes_str),
227            None => HashMap::new(),
228        };
229
230        let unparsed_indexes = unparsed_indexes
231            .into_iter()
232            .map(|(key, value)| {
233                let columns = ColumnReference::try_from(key.as_str())
234                    .context(UnableToParseColumnReferenceSnafu)
235                    .map_err(to_datafusion_error);
236                (columns, value)
237            })
238            .collect::<Vec<(Result<ColumnReference, DataFusionError>, IndexType)>>();
239
240        let mut indexes: Vec<(ColumnReference, IndexType)> = Vec::new();
241        for (columns, index_type) in unparsed_indexes {
242            let columns = columns?;
243            indexes.push((columns, index_type));
244        }
245
246        let mut on_conflict: Option<OnConflict> = None;
247        if let Some(on_conflict_str) = options.remove("on_conflict") {
248            on_conflict = Some(
249                OnConflict::try_from(on_conflict_str.as_str())
250                    .context(UnableToParseOnConflictSnafu)
251                    .map_err(to_datafusion_error)?,
252            );
253        }
254
255        let params = to_secret_map(options);
256
257        let pool = Arc::new(
258            PostgresConnectionPool::new(params)
259                .await
260                .context(UnableToCreatePostgresConnectionPoolSnafu)
261                .map_err(to_datafusion_error)?,
262        );
263
264        let schema: SchemaRef = Arc::new(schema);
265        PostgresConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::default())
266            .map_err(|e| DataFusionError::External(e.into()))?;
267
268        let postgres = Postgres::new(
269            name.clone(),
270            Arc::clone(&pool),
271            Arc::clone(&schema),
272            cmd.constraints.clone(),
273        );
274
275        let mut db_conn = pool
276            .connect()
277            .await
278            .context(DbConnectionSnafu)
279            .map_err(to_datafusion_error)?;
280        let postgres_conn = Postgres::postgres_conn(&mut db_conn).map_err(to_datafusion_error)?;
281
282        let tx = postgres_conn
283            .conn
284            .transaction()
285            .await
286            .context(UnableToBeginTransactionSnafu)
287            .map_err(to_datafusion_error)?;
288
289        let primary_keys = get_primary_keys_from_constraints(&cmd.constraints, &schema);
290
291        postgres
292            .create_table(Arc::clone(&schema), &tx, primary_keys)
293            .await
294            .map_err(to_datafusion_error)?;
295
296        for index in indexes {
297            postgres
298                .create_index(&tx, index.0.iter().collect(), index.1 == IndexType::Unique)
299                .await
300                .map_err(to_datafusion_error)?;
301        }
302
303        tx.commit()
304            .await
305            .context(UnableToCommitPostgresTransactionSnafu)
306            .map_err(to_datafusion_error)?;
307
308        let dyn_pool: Arc<DynPostgresConnectionPool> = pool;
309
310        let read_provider = Arc::new(
311            SqlTable::new_with_schema("postgres", &dyn_pool, Arc::clone(&schema), name)
312                .with_dialect(Arc::new(PostgreSqlDialect {})),
313        );
314
315        #[cfg(feature = "postgres-federation")]
316        let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
317
318        Ok(PostgresTableWriter::create(
319            read_provider,
320            postgres,
321            on_conflict,
322        ))
323    }
324}
325
326#[derive(Clone)]
327pub struct Postgres {
328    table: TableReference,
329    pool: Arc<PostgresConnectionPool>,
330    schema: SchemaRef,
331    constraints: Constraints,
332}
333
334impl std::fmt::Debug for Postgres {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        f.debug_struct("Postgres")
337            .field("table_name", &self.table)
338            .field("schema", &self.schema)
339            .field("constraints", &self.constraints)
340            .finish()
341    }
342}
343
344impl Postgres {
345    #[must_use]
346    pub fn new(
347        table: TableReference,
348        pool: Arc<PostgresConnectionPool>,
349        schema: SchemaRef,
350        constraints: Constraints,
351    ) -> Self {
352        Self {
353            table,
354            pool,
355            schema,
356            constraints,
357        }
358    }
359
360    #[must_use]
361    pub fn table_name(&self) -> &str {
362        self.table.table()
363    }
364
365    #[must_use]
366    pub fn constraints(&self) -> &Constraints {
367        &self.constraints
368    }
369
370    pub async fn connect(&self) -> Result<Box<DynPostgresConnection>> {
371        let mut conn = self.pool.connect().await.context(DbConnectionSnafu)?;
372
373        let pg_conn = Self::postgres_conn(&mut conn)?;
374
375        if !self.table_exists(pg_conn).await {
376            TableDoesntExistSnafu {
377                table_name: self.table.to_string(),
378            }
379            .fail()?;
380        }
381
382        Ok(conn)
383    }
384
385    pub fn postgres_conn(
386        db_connection: &mut Box<DynPostgresConnection>,
387    ) -> Result<&mut PostgresConnection> {
388        db_connection
389            .as_any_mut()
390            .downcast_mut::<PostgresConnection>()
391            .context(UnableToDowncastDbConnectionSnafu)
392    }
393
394    async fn table_exists(&self, postgres_conn: &PostgresConnection) -> bool {
395        let sql = match self.table.schema() {
396            Some(schema) => format!(
397                r#"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}' AND table_schema = '{schema}')"#,
398                name = self.table.table(),
399                schema = schema
400            ),
401            None => format!(
402                r#"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '{name}')"#,
403                name = self.table.table()
404            ),
405        };
406
407        tracing::trace!("{sql}");
408
409        let Ok(row) = postgres_conn.conn.query_one(&sql, &[]).await else {
410            return false;
411        };
412
413        row.get(0)
414    }
415
416    async fn insert_batch(
417        &self,
418        transaction: &Transaction<'_>,
419        batch: RecordBatch,
420        on_conflict: Option<OnConflict>,
421    ) -> Result<()> {
422        let insert_table_builder = InsertBuilder::new(&self.table, vec![batch]);
423
424        let sea_query_on_conflict =
425            on_conflict.map(|oc| oc.build_sea_query_on_conflict(&self.schema));
426
427        let sql = insert_table_builder
428            .build_postgres(sea_query_on_conflict)
429            .context(UnableToCreateInsertStatementSnafu)?;
430
431        transaction
432            .execute(&sql, &[])
433            .await
434            .context(UnableToInsertArrowBatchSnafu)?;
435
436        Ok(())
437    }
438
439    async fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> Result<()> {
440        transaction
441            .execute(
442                format!(r#"DELETE FROM {}"#, self.table.to_quoted_string()).as_str(),
443                &[],
444            )
445            .await
446            .context(UnableToDeleteAllTableDataSnafu)?;
447
448        Ok(())
449    }
450
451    async fn create_table(
452        &self,
453        schema: SchemaRef,
454        transaction: &Transaction<'_>,
455        primary_keys: Vec<String>,
456    ) -> Result<()> {
457        let create_table_statement =
458            CreateTableBuilder::new(schema, self.table.table()).primary_keys(primary_keys);
459        let create_stmts = create_table_statement.build_postgres();
460
461        for create_stmt in create_stmts {
462            transaction
463                .execute(&create_stmt, &[])
464                .await
465                .context(UnableToCreatePostgresTableSnafu)?;
466        }
467
468        Ok(())
469    }
470
471    async fn create_index(
472        &self,
473        transaction: &Transaction<'_>,
474        columns: Vec<&str>,
475        unique: bool,
476    ) -> Result<()> {
477        let mut index_builder = IndexBuilder::new(self.table.table(), columns);
478        if unique {
479            index_builder = index_builder.unique();
480        }
481        let sql = index_builder.build_postgres();
482
483        transaction
484            .execute(&sql, &[])
485            .await
486            .context(UnableToCreateIndexForPostgresTableSnafu)?;
487
488        Ok(())
489    }
490}