datafusion_table_providers/sql/db_connection_pool/dbconnection/
postgresconn.rs

1use std::any::Any;
2use std::error::Error;
3use std::sync::Arc;
4
5use crate::sql::arrow_sql_gen::postgres::rows_to_arrow;
6use crate::sql::arrow_sql_gen::postgres::schema::pg_data_type_to_arrow_type;
7use crate::sql::arrow_sql_gen::postgres::schema::ParseContext;
8use crate::util::handle_unsupported_type_error;
9use crate::util::schema::SchemaValidator;
10use arrow::datatypes::Field;
11use arrow::datatypes::Schema;
12use arrow::datatypes::SchemaRef;
13use arrow_schema::DataType;
14use async_stream::stream;
15use bb8_postgres::tokio_postgres::types::ToSql;
16use bb8_postgres::PostgresConnectionManager;
17use datafusion::error::DataFusionError;
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::sql::TableReference;
21use futures::stream;
22use futures::StreamExt;
23use postgres_native_tls::MakeTlsConnector;
24use snafu::prelude::*;
25
26use crate::UnsupportedTypeAction;
27
28use super::AsyncDbConnection;
29use super::DbConnection;
30use super::Result;
31
32const SCHEMA_QUERY: &str = r"
33WITH custom_type_details AS (
34SELECT
35t.typname,
36t.typtype,
37CASE
38    WHEN t.typtype = 'e' THEN
39        jsonb_build_object(
40            'type', 'enum',
41            'values', (
42                SELECT jsonb_agg(e.enumlabel ORDER BY e.enumsortorder)
43                FROM pg_enum e
44                WHERE e.enumtypid = t.oid
45            )
46        )
47    WHEN t.typtype = 'c' THEN
48        jsonb_build_object(
49            'type', 'composite',
50            'attributes', (
51                SELECT jsonb_agg(
52                    jsonb_build_object(
53                        'name', a2.attname,
54                        'type', pg_catalog.format_type(a2.atttypid, a2.atttypmod)
55                    )
56                    ORDER BY a2.attnum
57                )
58                FROM pg_attribute a2
59                WHERE a2.attrelid = t.typrelid
60                AND a2.attnum > 0
61                AND NOT a2.attisdropped
62            )
63        )
64END as type_details
65FROM pg_type t
66JOIN pg_namespace n ON t.typnamespace = n.oid
67WHERE n.nspname = $1
68)
69SELECT
70    a.attname AS column_name,
71    CASE
72    -- when an array type is encountered, label as 'array'
73    WHEN t.typcategory = 'A' THEN 'array'
74    -- if it’s a user-defined enum or composite type then output that specific string
75    WHEN t.typtype = 'e' THEN 'enum'
76    WHEN t.typtype = 'c' THEN 'composite'
77    ELSE pg_catalog.format_type(a.atttypid, a.atttypmod)
78    END AS data_type,
79    CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
80    CASE
81    WHEN t.typcategory = 'A' THEN
82        jsonb_build_object(
83        'type', 'array',
84        'element_type', (
85            SELECT pg_catalog.format_type(et.oid, a.atttypmod)
86            FROM pg_type t2
87            JOIN pg_type et ON t2.typelem = et.oid
88            WHERE t2.oid = a.atttypid
89        )
90        )
91    ELSE custom.type_details
92    END AS type_details
93FROM pg_class cls
94JOIN pg_namespace ns ON cls.relnamespace = ns.oid
95JOIN pg_attribute a ON a.attrelid = cls.oid
96LEFT JOIN pg_type t ON t.oid = a.atttypid
97LEFT JOIN custom_type_details custom ON custom.typname = t.typname
98WHERE ns.nspname = $1
99    AND cls.relname = $2
100    AND cls.relkind IN ('r','v','m')  -- covers tables, normal views, & materialized views
101    AND a.attnum > 0
102    AND NOT a.attisdropped
103ORDER BY a.attnum;
104";
105
106const SCHEMAS_QUERY: &str = "
107SELECT nspname AS schema_name
108FROM pg_namespace
109WHERE nspname NOT IN ('pg_catalog', 'information_schema')
110  AND nspname !~ '^pg_toast';
111";
112
113const TABLES_QUERY: &str = "
114SELECT tablename
115FROM pg_tables
116WHERE schemaname = $1;
117";
118
119#[derive(Debug, Snafu)]
120pub enum PostgresError {
121    #[snafu(display(
122        "Query execution failed.\n{source}\nFor details, refer to the PostgreSQL manual: https://www.postgresql.org/docs/17/index.html"
123    ))]
124    QueryError {
125        source: bb8_postgres::tokio_postgres::Error,
126    },
127
128    #[snafu(display("Failed to convert query result to Arrow.\n{source}\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
129    ConversionError {
130        source: crate::sql::arrow_sql_gen::postgres::Error,
131    },
132}
133
134pub struct PostgresConnection {
135    pub conn: bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
136    unsupported_type_action: UnsupportedTypeAction,
137}
138
139impl SchemaValidator for PostgresConnection {
140    type Error = super::Error;
141
142    fn is_data_type_supported(data_type: &DataType) -> bool {
143        !matches!(data_type, DataType::Map(_, _))
144    }
145
146    fn unsupported_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
147        super::Error::UnsupportedDataType {
148            data_type: data_type.to_string(),
149            field_name: field_name.to_string(),
150        }
151    }
152}
153
154impl<'a>
155    DbConnection<
156        bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
157        &'a (dyn ToSql + Sync),
158    > for PostgresConnection
159{
160    fn as_any(&self) -> &dyn Any {
161        self
162    }
163
164    fn as_any_mut(&mut self) -> &mut dyn Any {
165        self
166    }
167
168    fn as_async(
169        &self,
170    ) -> Option<
171        &dyn AsyncDbConnection<
172            bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
173            &'a (dyn ToSql + Sync),
174        >,
175    > {
176        Some(self)
177    }
178}
179
180#[async_trait::async_trait]
181impl<'a>
182    AsyncDbConnection<
183        bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
184        &'a (dyn ToSql + Sync),
185    > for PostgresConnection
186{
187    fn new(
188        conn: bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
189    ) -> Self {
190        PostgresConnection {
191            conn,
192            unsupported_type_action: UnsupportedTypeAction::default(),
193        }
194    }
195
196    async fn tables(&self, schema: &str) -> Result<Vec<String>, super::Error> {
197        let rows = self
198            .conn
199            .query(TABLES_QUERY, &[&schema])
200            .await
201            .map_err(|e| super::Error::UnableToGetTables {
202                source: Box::new(e),
203            })?;
204
205        Ok(rows.iter().map(|r| r.get::<usize, String>(0)).collect())
206    }
207
208    async fn schemas(&self) -> Result<Vec<String>, super::Error> {
209        let rows = self.conn.query(SCHEMAS_QUERY, &[]).await.map_err(|e| {
210            super::Error::UnableToGetSchemas {
211                source: Box::new(e),
212            }
213        })?;
214
215        Ok(rows.iter().map(|r| r.get::<usize, String>(0)).collect())
216    }
217
218    async fn get_schema(
219        &self,
220        table_reference: &TableReference,
221    ) -> Result<SchemaRef, super::Error> {
222        let table_name = table_reference.table();
223        let schema_name = table_reference.schema().unwrap_or("public");
224
225        let rows = match self
226            .conn
227            .query(SCHEMA_QUERY, &[&schema_name, &table_name])
228            .await
229        {
230            Ok(rows) => rows,
231            Err(e) => {
232                if let Some(error_source) = e.source() {
233                    if let Some(pg_error) =
234                        error_source.downcast_ref::<tokio_postgres::error::DbError>()
235                    {
236                        if pg_error.code() == &tokio_postgres::error::SqlState::UNDEFINED_TABLE {
237                            return Err(super::Error::UndefinedTable {
238                                source: Box::new(pg_error.clone()),
239                                table_name: table_reference.to_string(),
240                            });
241                        }
242                    }
243                }
244                return Err(super::Error::UnableToGetSchema {
245                    source: Box::new(e),
246                });
247            }
248        };
249
250        let mut fields = Vec::new();
251        for row in rows {
252            let column_name = row.get::<usize, String>(0);
253            let pg_type = row.get::<usize, String>(1);
254            let nullable_str = row.get::<usize, String>(2);
255            let nullable = nullable_str == "YES";
256            let type_details = row.get::<usize, Option<serde_json::Value>>(3);
257            let mut context =
258                ParseContext::new().with_unsupported_type_action(self.unsupported_type_action);
259
260            if let Some(type_details) = type_details {
261                context = context.with_type_details(type_details);
262            };
263
264            let Ok(arrow_type) = pg_data_type_to_arrow_type(&pg_type, &context) else {
265                handle_unsupported_type_error(
266                    self.unsupported_type_action,
267                    super::Error::UnsupportedDataType {
268                        data_type: pg_type.to_string(),
269                        field_name: column_name.to_string(),
270                    },
271                )?;
272
273                continue;
274            };
275
276            fields.push(Field::new(column_name, arrow_type, nullable));
277        }
278
279        let schema = Arc::new(Schema::new(fields));
280        Ok(schema)
281    }
282
283    async fn query_arrow(
284        &self,
285        sql: &str,
286        params: &[&'a (dyn ToSql + Sync)],
287        projected_schema: Option<SchemaRef>,
288    ) -> Result<SendableRecordBatchStream> {
289        // TODO: We should have a way to detect if params have been passed
290        // if they haven't we should use .copy_out instead, because it should be much faster
291        let streamable = self
292            .conn
293            .query_raw(sql, params.iter().copied()) // use .query_raw to get access to the underlying RowStream
294            .await
295            .context(QuerySnafu)?;
296
297        // chunk the stream into groups of rows
298        let mut stream = streamable.chunks(4_000).boxed().map(move |rows| {
299            let rows = rows
300                .into_iter()
301                .collect::<std::result::Result<Vec<_>, _>>()
302                .context(QuerySnafu)?;
303            let rec = rows_to_arrow(rows.as_slice(), &projected_schema).context(ConversionSnafu)?;
304            Ok::<_, PostgresError>(rec)
305        });
306
307        let Some(first_chunk) = stream.next().await else {
308            return Ok(Box::pin(RecordBatchStreamAdapter::new(
309                Arc::new(Schema::empty()),
310                stream::empty(),
311            )));
312        };
313
314        let first_chunk = first_chunk?;
315        let schema = first_chunk.schema(); // pull out the schema from the first chunk to use in the DataFusion Stream Adapter
316
317        let output_stream = stream! {
318           yield Ok(first_chunk);
319           while let Some(batch) = stream.next().await {
320                match batch {
321                    Ok(batch) => {
322                        yield Ok(batch); // we can yield the batch as-is because we've already converted to Arrow in the chunk map
323                    }
324                    Err(e) => {
325                        yield Err(DataFusionError::Execution(format!("Failed to fetch batch: {e}")));
326                    }
327                }
328           }
329        };
330
331        Ok(Box::pin(RecordBatchStreamAdapter::new(
332            schema,
333            output_stream,
334        )))
335    }
336
337    async fn execute(&self, sql: &str, params: &[&'a (dyn ToSql + Sync)]) -> Result<u64> {
338        Ok(self.conn.execute(sql, params).await?)
339    }
340}
341
342impl PostgresConnection {
343    #[must_use]
344    pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self {
345        self.unsupported_type_action = action;
346        self
347    }
348}