datafusion_table_providers/sql/db_connection_pool/dbconnection/
postgresconn.rs1use std::any::Any;
2use std::error::Error;
3use std::sync::Arc;
4
5use crate::sql::arrow_sql_gen::postgres::rows_to_arrow;
6use crate::sql::arrow_sql_gen::postgres::schema::pg_data_type_to_arrow_type;
7use crate::sql::arrow_sql_gen::postgres::schema::ParseContext;
8use crate::util::handle_unsupported_type_error;
9use crate::util::schema::SchemaValidator;
10use arrow::datatypes::Field;
11use arrow::datatypes::Schema;
12use arrow::datatypes::SchemaRef;
13use arrow_schema::DataType;
14use async_stream::stream;
15use bb8_postgres::tokio_postgres::types::ToSql;
16use bb8_postgres::PostgresConnectionManager;
17use datafusion::error::DataFusionError;
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::sql::TableReference;
21use futures::stream;
22use futures::StreamExt;
23use postgres_native_tls::MakeTlsConnector;
24use snafu::prelude::*;
25
26use crate::UnsupportedTypeAction;
27
28use super::AsyncDbConnection;
29use super::DbConnection;
30use super::Result;
31
32const SCHEMA_QUERY: &str = r"
33WITH custom_type_details AS (
34SELECT
35t.typname,
36t.typtype,
37CASE
38 WHEN t.typtype = 'e' THEN
39 jsonb_build_object(
40 'type', 'enum',
41 'values', (
42 SELECT jsonb_agg(e.enumlabel ORDER BY e.enumsortorder)
43 FROM pg_enum e
44 WHERE e.enumtypid = t.oid
45 )
46 )
47 WHEN t.typtype = 'c' THEN
48 jsonb_build_object(
49 'type', 'composite',
50 'attributes', (
51 SELECT jsonb_agg(
52 jsonb_build_object(
53 'name', a2.attname,
54 'type', pg_catalog.format_type(a2.atttypid, a2.atttypmod)
55 )
56 ORDER BY a2.attnum
57 )
58 FROM pg_attribute a2
59 WHERE a2.attrelid = t.typrelid
60 AND a2.attnum > 0
61 AND NOT a2.attisdropped
62 )
63 )
64END as type_details
65FROM pg_type t
66JOIN pg_namespace n ON t.typnamespace = n.oid
67WHERE n.nspname = $1
68)
69SELECT
70 a.attname AS column_name,
71 CASE
72 -- when an array type is encountered, label as 'array'
73 WHEN t.typcategory = 'A' THEN 'array'
74 -- if it’s a user-defined enum or composite type then output that specific string
75 WHEN t.typtype = 'e' THEN 'enum'
76 WHEN t.typtype = 'c' THEN 'composite'
77 ELSE pg_catalog.format_type(a.atttypid, a.atttypmod)
78 END AS data_type,
79 CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
80 CASE
81 WHEN t.typcategory = 'A' THEN
82 jsonb_build_object(
83 'type', 'array',
84 'element_type', (
85 SELECT pg_catalog.format_type(et.oid, a.atttypmod)
86 FROM pg_type t2
87 JOIN pg_type et ON t2.typelem = et.oid
88 WHERE t2.oid = a.atttypid
89 )
90 )
91 ELSE custom.type_details
92 END AS type_details
93FROM pg_class cls
94JOIN pg_namespace ns ON cls.relnamespace = ns.oid
95JOIN pg_attribute a ON a.attrelid = cls.oid
96LEFT JOIN pg_type t ON t.oid = a.atttypid
97LEFT JOIN custom_type_details custom ON custom.typname = t.typname
98WHERE ns.nspname = $1
99 AND cls.relname = $2
100 AND cls.relkind IN ('r','v','m') -- covers tables, normal views, & materialized views
101 AND a.attnum > 0
102 AND NOT a.attisdropped
103ORDER BY a.attnum;
104";
105
106const SCHEMAS_QUERY: &str = "
107SELECT nspname AS schema_name
108FROM pg_namespace
109WHERE nspname NOT IN ('pg_catalog', 'information_schema')
110 AND nspname !~ '^pg_toast';
111";
112
113const TABLES_QUERY: &str = "
114SELECT tablename
115FROM pg_tables
116WHERE schemaname = $1;
117";
118
119#[derive(Debug, Snafu)]
120pub enum PostgresError {
121 #[snafu(display(
122 "Query execution failed.\n{source}\nFor details, refer to the PostgreSQL manual: https://www.postgresql.org/docs/17/index.html"
123 ))]
124 QueryError {
125 source: bb8_postgres::tokio_postgres::Error,
126 },
127
128 #[snafu(display("Failed to convert query result to Arrow.\n{source}\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
129 ConversionError {
130 source: crate::sql::arrow_sql_gen::postgres::Error,
131 },
132}
133
134pub struct PostgresConnection {
135 pub conn: bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
136 unsupported_type_action: UnsupportedTypeAction,
137}
138
139impl SchemaValidator for PostgresConnection {
140 type Error = super::Error;
141
142 fn is_data_type_supported(data_type: &DataType) -> bool {
143 !matches!(data_type, DataType::Map(_, _))
144 }
145
146 fn unsupported_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
147 super::Error::UnsupportedDataType {
148 data_type: data_type.to_string(),
149 field_name: field_name.to_string(),
150 }
151 }
152}
153
154impl<'a>
155 DbConnection<
156 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
157 &'a (dyn ToSql + Sync),
158 > for PostgresConnection
159{
160 fn as_any(&self) -> &dyn Any {
161 self
162 }
163
164 fn as_any_mut(&mut self) -> &mut dyn Any {
165 self
166 }
167
168 fn as_async(
169 &self,
170 ) -> Option<
171 &dyn AsyncDbConnection<
172 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
173 &'a (dyn ToSql + Sync),
174 >,
175 > {
176 Some(self)
177 }
178}
179
180#[async_trait::async_trait]
181impl<'a>
182 AsyncDbConnection<
183 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
184 &'a (dyn ToSql + Sync),
185 > for PostgresConnection
186{
187 fn new(
188 conn: bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
189 ) -> Self {
190 PostgresConnection {
191 conn,
192 unsupported_type_action: UnsupportedTypeAction::default(),
193 }
194 }
195
196 async fn tables(&self, schema: &str) -> Result<Vec<String>, super::Error> {
197 let rows = self
198 .conn
199 .query(TABLES_QUERY, &[&schema])
200 .await
201 .map_err(|e| super::Error::UnableToGetTables {
202 source: Box::new(e),
203 })?;
204
205 Ok(rows.iter().map(|r| r.get::<usize, String>(0)).collect())
206 }
207
208 async fn schemas(&self) -> Result<Vec<String>, super::Error> {
209 let rows = self.conn.query(SCHEMAS_QUERY, &[]).await.map_err(|e| {
210 super::Error::UnableToGetSchemas {
211 source: Box::new(e),
212 }
213 })?;
214
215 Ok(rows.iter().map(|r| r.get::<usize, String>(0)).collect())
216 }
217
218 async fn get_schema(
219 &self,
220 table_reference: &TableReference,
221 ) -> Result<SchemaRef, super::Error> {
222 let table_name = table_reference.table();
223 let schema_name = table_reference.schema().unwrap_or("public");
224
225 let rows = match self
226 .conn
227 .query(SCHEMA_QUERY, &[&schema_name, &table_name])
228 .await
229 {
230 Ok(rows) => rows,
231 Err(e) => {
232 if let Some(error_source) = e.source() {
233 if let Some(pg_error) =
234 error_source.downcast_ref::<tokio_postgres::error::DbError>()
235 {
236 if pg_error.code() == &tokio_postgres::error::SqlState::UNDEFINED_TABLE {
237 return Err(super::Error::UndefinedTable {
238 source: Box::new(pg_error.clone()),
239 table_name: table_reference.to_string(),
240 });
241 }
242 }
243 }
244 return Err(super::Error::UnableToGetSchema {
245 source: Box::new(e),
246 });
247 }
248 };
249
250 let mut fields = Vec::new();
251 for row in rows {
252 let column_name = row.get::<usize, String>(0);
253 let pg_type = row.get::<usize, String>(1);
254 let nullable_str = row.get::<usize, String>(2);
255 let nullable = nullable_str == "YES";
256 let type_details = row.get::<usize, Option<serde_json::Value>>(3);
257 let mut context =
258 ParseContext::new().with_unsupported_type_action(self.unsupported_type_action);
259
260 if let Some(type_details) = type_details {
261 context = context.with_type_details(type_details);
262 };
263
264 let Ok(arrow_type) = pg_data_type_to_arrow_type(&pg_type, &context) else {
265 handle_unsupported_type_error(
266 self.unsupported_type_action,
267 super::Error::UnsupportedDataType {
268 data_type: pg_type.to_string(),
269 field_name: column_name.to_string(),
270 },
271 )?;
272
273 continue;
274 };
275
276 fields.push(Field::new(column_name, arrow_type, nullable));
277 }
278
279 let schema = Arc::new(Schema::new(fields));
280 Ok(schema)
281 }
282
283 async fn query_arrow(
284 &self,
285 sql: &str,
286 params: &[&'a (dyn ToSql + Sync)],
287 projected_schema: Option<SchemaRef>,
288 ) -> Result<SendableRecordBatchStream> {
289 let streamable = self
292 .conn
293 .query_raw(sql, params.iter().copied()) .await
295 .context(QuerySnafu)?;
296
297 let mut stream = streamable.chunks(4_000).boxed().map(move |rows| {
299 let rows = rows
300 .into_iter()
301 .collect::<std::result::Result<Vec<_>, _>>()
302 .context(QuerySnafu)?;
303 let rec = rows_to_arrow(rows.as_slice(), &projected_schema).context(ConversionSnafu)?;
304 Ok::<_, PostgresError>(rec)
305 });
306
307 let Some(first_chunk) = stream.next().await else {
308 return Ok(Box::pin(RecordBatchStreamAdapter::new(
309 Arc::new(Schema::empty()),
310 stream::empty(),
311 )));
312 };
313
314 let first_chunk = first_chunk?;
315 let schema = first_chunk.schema(); let output_stream = stream! {
318 yield Ok(first_chunk);
319 while let Some(batch) = stream.next().await {
320 match batch {
321 Ok(batch) => {
322 yield Ok(batch); }
324 Err(e) => {
325 yield Err(DataFusionError::Execution(format!("Failed to fetch batch: {e}")));
326 }
327 }
328 }
329 };
330
331 Ok(Box::pin(RecordBatchStreamAdapter::new(
332 schema,
333 output_stream,
334 )))
335 }
336
337 async fn execute(&self, sql: &str, params: &[&'a (dyn ToSql + Sync)]) -> Result<u64> {
338 Ok(self.conn.execute(sql, params).await?)
339 }
340}
341
342impl PostgresConnection {
343 #[must_use]
344 pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self {
345 self.unsupported_type_action = action;
346 self
347 }
348}