datafusion_remote_table/connection/
postgres.rs

1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4    Connection, DFResult, Pool, PostgresType, RemoteField, RemoteSchema, RemoteType, Transform,
5};
6use bb8_postgres::tokio_postgres::types::{FromSql, Type};
7use bb8_postgres::tokio_postgres::{NoTls, Row};
8use bb8_postgres::PostgresConnectionManager;
9use chrono::Timelike;
10use datafusion::arrow::array::{
11    make_builder, ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder,
12    Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13    ListBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder,
14};
15use datafusion::arrow::datatypes::{Date32Type, Schema, SchemaRef};
16use datafusion::common::project_schema;
17use datafusion::error::DataFusionError;
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use futures::{stream, StreamExt};
21use std::string::ToString;
22use std::sync::Arc;
23use std::time::{SystemTime, UNIX_EPOCH};
24
25#[derive(Debug, Clone)]
26pub struct PostgresConnectionOptions {
27    pub host: String,
28    pub port: u16,
29    pub username: String,
30    pub password: String,
31    pub database: Option<String>,
32}
33
34#[derive(Debug)]
35pub struct PostgresPool {
36    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
37}
38
39#[async_trait::async_trait]
40impl Pool for PostgresPool {
41    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
42        let conn = self.pool.get_owned().await.map_err(|e| {
43            DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
44        })?;
45        Ok(Arc::new(PostgresConnection { conn }))
46    }
47}
48
49pub(crate) async fn connect_postgres(
50    options: &PostgresConnectionOptions,
51) -> DFResult<PostgresPool> {
52    let mut config = bb8_postgres::tokio_postgres::config::Config::new();
53    config
54        .host(&options.host)
55        .port(options.port)
56        .user(&options.username)
57        .password(&options.password);
58    if let Some(database) = &options.database {
59        config.dbname(database);
60    }
61    let manager = PostgresConnectionManager::new(config, NoTls);
62    let pool = bb8::Pool::builder()
63        .max_size(5)
64        .build(manager)
65        .await
66        .map_err(|e| {
67            DataFusionError::Execution(format!(
68                "Failed to create postgres connection pool due to {e}",
69            ))
70        })?;
71
72    Ok(PostgresPool { pool })
73}
74
75#[derive(Debug)]
76pub(crate) struct PostgresConnection {
77    conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
78}
79
80#[async_trait::async_trait]
81impl Connection for PostgresConnection {
82    async fn infer_schema(
83        &self,
84        sql: &str,
85        transform: Option<&dyn Transform>,
86    ) -> DFResult<(RemoteSchema, SchemaRef)> {
87        let mut stream = self
88            .conn
89            .query_raw(sql, Vec::<String>::new())
90            .await
91            .map_err(|e| {
92                DataFusionError::Execution(format!(
93                    "Failed to execute query {sql} on postgres due to {e}",
94                ))
95            })?
96            .chunks(1)
97            .boxed();
98
99        let Some(first_chunk) = stream.next().await else {
100            return Err(DataFusionError::Execution(
101                "No data returned from postgres".to_string(),
102            ));
103        };
104        let first_chunk: Vec<Row> = first_chunk
105            .into_iter()
106            .collect::<Result<Vec<_>, _>>()
107            .map_err(|e| {
108                DataFusionError::Execution(format!(
109                    "Failed to collect rows from postgres due to {e}",
110                ))
111            })?;
112        let Some(first_row) = first_chunk.first() else {
113            return Err(DataFusionError::Execution(
114                "No data returned from postgres".to_string(),
115            ));
116        };
117        let (remote_schema, pg_types) = build_remote_schema(first_row)?;
118        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
119        let batch = rows_to_batch(
120            std::slice::from_ref(first_row),
121            &pg_types,
122            arrow_schema,
123            None,
124        )?;
125        if let Some(transform) = transform {
126            let transformed_batch = transform_batch(batch, transform, &remote_schema)?;
127            Ok((remote_schema, transformed_batch.schema()))
128        } else {
129            Ok((remote_schema, batch.schema()))
130        }
131    }
132
133    async fn query(
134        &self,
135        sql: String,
136        projection: Option<Vec<usize>>,
137    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
138        let mut stream = self
139            .conn
140            .query_raw(&sql, Vec::<String>::new())
141            .await
142            .map_err(|e| {
143                DataFusionError::Execution(format!(
144                    "Failed to execute query {sql} on postgres due to {e}",
145                ))
146            })?
147            .chunks(2048)
148            .boxed();
149
150        let Some(first_chunk) = stream.next().await else {
151            return Ok((
152                Box::pin(RecordBatchStreamAdapter::new(
153                    Arc::new(Schema::empty()),
154                    stream::empty(),
155                )),
156                RemoteSchema::empty(),
157            ));
158        };
159        let first_chunk: Vec<Row> = first_chunk
160            .into_iter()
161            .collect::<Result<Vec<_>, _>>()
162            .map_err(|e| {
163                DataFusionError::Execution(format!(
164                    "Failed to collect rows from postgres due to {e}",
165                ))
166            })?;
167        let Some(first_row) = first_chunk.first() else {
168            return Err(DataFusionError::Execution(
169                "No data returned from postgres".to_string(),
170            ));
171        };
172        let (remote_schema, pg_types) = build_remote_schema(first_row)?;
173        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
174        let first_chunk = rows_to_batch(
175            first_chunk.as_slice(),
176            &pg_types,
177            arrow_schema.clone(),
178            projection.as_ref(),
179        )?;
180        let schema = first_chunk.schema();
181
182        let mut stream = stream.map(move |rows| {
183            let rows: Vec<Row> = rows
184                .into_iter()
185                .collect::<Result<Vec<_>, _>>()
186                .map_err(|e| {
187                    DataFusionError::Execution(format!(
188                        "Failed to collect rows from postgres due to {e}",
189                    ))
190                })?;
191            let batch = rows_to_batch(
192                rows.as_slice(),
193                &pg_types,
194                arrow_schema.clone(),
195                projection.as_ref(),
196            )?;
197            Ok::<RecordBatch, DataFusionError>(batch)
198        });
199
200        let output_stream = async_stream::stream! {
201           yield Ok(first_chunk);
202           while let Some(batch) = stream.next().await {
203                match batch {
204                    Ok(batch) => {
205                        yield Ok(batch); // we can yield the batch as-is because we've already converted to Arrow in the chunk map
206                    }
207                    Err(e) => {
208                        yield Err(DataFusionError::Execution(format!("Failed to fetch batch: {e}")));
209                    }
210                }
211           }
212        };
213
214        Ok((
215            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
216            remote_schema,
217        ))
218    }
219}
220
221fn pg_type_to_remote_type(pg_type: &Type) -> DFResult<RemoteType> {
222    match pg_type {
223        &Type::BOOL => Ok(RemoteType::Postgres(PostgresType::Bool)),
224        &Type::CHAR => Ok(RemoteType::Postgres(PostgresType::Char)),
225        &Type::INT2 => Ok(RemoteType::Postgres(PostgresType::Int2)),
226        &Type::INT4 => Ok(RemoteType::Postgres(PostgresType::Int4)),
227        &Type::INT8 => Ok(RemoteType::Postgres(PostgresType::Int8)),
228        &Type::FLOAT4 => Ok(RemoteType::Postgres(PostgresType::Float4)),
229        &Type::FLOAT8 => Ok(RemoteType::Postgres(PostgresType::Float8)),
230        &Type::TEXT => Ok(RemoteType::Postgres(PostgresType::Text)),
231        &Type::VARCHAR => Ok(RemoteType::Postgres(PostgresType::Varchar)),
232        &Type::BPCHAR => Ok(RemoteType::Postgres(PostgresType::Bpchar)),
233        &Type::BYTEA => Ok(RemoteType::Postgres(PostgresType::Bytea)),
234        &Type::DATE => Ok(RemoteType::Postgres(PostgresType::Date)),
235        &Type::TIMESTAMP => Ok(RemoteType::Postgres(PostgresType::Timestamp)),
236        &Type::TIMESTAMPTZ => Ok(RemoteType::Postgres(PostgresType::TimestampTz)),
237        &Type::TIME => Ok(RemoteType::Postgres(PostgresType::Time)),
238        &Type::INT2_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int2Array)),
239        &Type::INT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int4Array)),
240        &Type::INT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int8Array)),
241        &Type::FLOAT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float4Array)),
242        &Type::FLOAT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float8Array)),
243        &Type::TEXT_ARRAY => Ok(RemoteType::Postgres(PostgresType::TextArray)),
244        &Type::VARCHAR_ARRAY => Ok(RemoteType::Postgres(PostgresType::VarcharArray)),
245        &Type::BYTEA_ARRAY => Ok(RemoteType::Postgres(PostgresType::ByteaArray)),
246        other if other.name().eq_ignore_ascii_case("geometry") => {
247            Ok(RemoteType::Postgres(PostgresType::PostGisGeometry))
248        }
249        _ => Err(DataFusionError::NotImplemented(format!(
250            "Unsupported postgres type {pg_type:?}",
251        ))),
252    }
253}
254
255fn build_remote_schema(row: &Row) -> DFResult<(RemoteSchema, Vec<Type>)> {
256    let mut remote_fields = vec![];
257    let mut pg_types = vec![];
258    for col in row.columns() {
259        let col_type = col.type_();
260        pg_types.push(col_type.clone());
261        remote_fields.push(RemoteField::new(
262            col.name(),
263            pg_type_to_remote_type(col_type)?,
264            true,
265        ));
266    }
267    Ok((RemoteSchema::new(remote_fields), pg_types))
268}
269
270macro_rules! handle_primitive_type {
271    ($builder:expr, $pg_type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
272        let builder = $builder
273            .as_any_mut()
274            .downcast_mut::<$builder_ty>()
275            .expect(concat!(
276                "Failed to downcast builder to ",
277                stringify!($builder_ty),
278                " for ",
279                stringify!($pg_type)
280            ));
281        let v: Option<$value_ty> = $row.try_get($index).expect(concat!(
282            "Failed to get ",
283            stringify!($value_ty),
284            " value for column ",
285            stringify!($pg_type)
286        ));
287
288        match v {
289            Some(v) => builder.append_value(v),
290            None => builder.append_null(),
291        }
292    }};
293}
294
295macro_rules! handle_primitive_array_type {
296    ($builder:expr, $pg_type:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
297        let builder = $builder
298            .as_any_mut()
299            .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
300            .expect(concat!(
301                "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for ",
302                stringify!($pg_type)
303            ));
304        let values_builder = builder
305            .values()
306            .as_any_mut()
307            .downcast_mut::<$values_builder_ty>()
308            .expect(concat!(
309                "Failed to downcast values builder to ",
310                stringify!($values_builder_ty),
311                " for ",
312                stringify!($pg_type)
313            ));
314        let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).expect(concat!(
315            "Failed to get ",
316            stringify!($primitive_value_ty),
317            " array value for column ",
318            stringify!($pg_type)
319        ));
320
321        match v {
322            Some(v) => {
323                let v = v.into_iter().map(Some);
324                values_builder.extend(v);
325                builder.append(true);
326            }
327            None => builder.append_null(),
328        }
329    }};
330}
331
332pub struct GeometryFromSql<'a> {
333    wkb: &'a [u8],
334}
335
336impl<'a> FromSql<'a> for GeometryFromSql<'a> {
337    fn from_sql(
338        _ty: &Type,
339        raw: &'a [u8],
340    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
341        Ok(GeometryFromSql { wkb: raw })
342    }
343
344    fn accepts(ty: &Type) -> bool {
345        matches!(ty.name(), "geometry")
346    }
347}
348
349fn rows_to_batch(
350    rows: &[Row],
351    pg_types: &[Type],
352    arrow_schema: SchemaRef,
353    projection: Option<&Vec<usize>>,
354) -> DFResult<RecordBatch> {
355    let projected_schema = project_schema(&arrow_schema, projection)?;
356    let mut array_builders = vec![];
357    for field in arrow_schema.fields() {
358        let builder = make_builder(field.data_type(), rows.len());
359        array_builders.push(builder);
360    }
361    for row in rows {
362        for (idx, pg_type) in pg_types.iter().enumerate() {
363            if !projections_contains(projection, idx) {
364                continue;
365            }
366            let builder = &mut array_builders[idx];
367            match pg_type {
368                &Type::BOOL => {
369                    handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, idx);
370                }
371                &Type::CHAR => {
372                    handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, idx);
373                }
374                &Type::INT2 => {
375                    handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, idx);
376                }
377                &Type::INT4 => {
378                    handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, idx);
379                }
380                &Type::INT8 => {
381                    handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, idx);
382                }
383                &Type::FLOAT4 => {
384                    handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, idx);
385                }
386                &Type::FLOAT8 => {
387                    handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, idx);
388                }
389                &Type::TEXT => {
390                    handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, idx);
391                }
392                &Type::VARCHAR => {
393                    handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, idx);
394                }
395                &Type::BPCHAR => {
396                    let builder = builder
397                        .as_any_mut()
398                        .downcast_mut::<StringBuilder>()
399                        .expect("Failed to downcast builder to StringBuilder for Type::BPCHAR");
400                    let v: Option<&str> = row
401                        .try_get(idx)
402                        .expect("Failed to get &str value for column Type::BPCHAR");
403
404                    match v {
405                        Some(v) => builder.append_value(v.trim_end()),
406                        None => builder.append_null(),
407                    }
408                }
409                &Type::BYTEA => {
410                    handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, idx);
411                }
412                &Type::TIMESTAMP => {
413                    let builder = builder
414                        .as_any_mut()
415                        .downcast_mut::<TimestampNanosecondBuilder>()
416                        .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
417                    let v: Option<SystemTime> = row
418                        .try_get(idx)
419                        .expect("Failed to get SystemTime value for column Type::TIMESTAMP");
420
421                    match v {
422                        Some(v) => {
423                            if let Ok(v) = v.duration_since(UNIX_EPOCH) {
424                                let timestamp: i64 = v
425                                    .as_nanos()
426                                    .try_into()
427                                    .expect("Failed to convert SystemTime to i64");
428                                builder.append_value(timestamp);
429                            }
430                        }
431                        None => builder.append_null(),
432                    }
433                }
434                &Type::TIMESTAMPTZ => {
435                    let builder = builder
436                        .as_any_mut()
437                        .downcast_mut::<TimestampNanosecondBuilder>()
438                        .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
439                    let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(idx).expect(
440                        "Failed to get chrono::DateTime<chrono::Utc> value for column Type::TIMESTAMPTZ",
441                    );
442
443                    match v {
444                        Some(v) => {
445                            let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for Type::TIMESTAMP"));
446                            builder.append_value(timestamp);
447                        }
448                        None => builder.append_null(),
449                    }
450                }
451                &Type::TIME => {
452                    let builder = builder
453                        .as_any_mut()
454                        .downcast_mut::<Time64NanosecondBuilder>()
455                        .expect(
456                            "Failed to downcast builder to Time64NanosecondBuilder for Type::TIME",
457                        );
458                    let v: Option<chrono::NaiveTime> = row
459                        .try_get(idx)
460                        .expect("Failed to get chrono::NaiveTime value for column Type::TIME");
461
462                    match v {
463                        Some(v) => {
464                            let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
465                                * 1_000_000_000
466                                + i64::from(v.nanosecond());
467                            builder.append_value(timestamp);
468                        }
469                        None => builder.append_null(),
470                    }
471                }
472                &Type::DATE => {
473                    let builder = builder
474                        .as_any_mut()
475                        .downcast_mut::<Date32Builder>()
476                        .expect("Failed to downcast builder to Date32Builder for Type::DATE");
477                    let v: Option<chrono::NaiveDate> = row
478                        .try_get(idx)
479                        .expect("Failed to get chrono::NaiveDate value for column Type::DATE");
480
481                    match v {
482                        Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
483                        None => builder.append_null(),
484                    }
485                }
486                &Type::INT2_ARRAY => {
487                    handle_primitive_array_type!(
488                        builder,
489                        Type::INT2_ARRAY,
490                        Int16Builder,
491                        i16,
492                        row,
493                        idx
494                    );
495                }
496                &Type::INT4_ARRAY => {
497                    handle_primitive_array_type!(
498                        builder,
499                        Type::INT4_ARRAY,
500                        Int32Builder,
501                        i32,
502                        row,
503                        idx
504                    );
505                }
506                &Type::INT8_ARRAY => {
507                    handle_primitive_array_type!(
508                        builder,
509                        Type::INT8_ARRAY,
510                        Int64Builder,
511                        i64,
512                        row,
513                        idx
514                    );
515                }
516                &Type::FLOAT4_ARRAY => {
517                    handle_primitive_array_type!(
518                        builder,
519                        Type::FLOAT4_ARRAY,
520                        Float32Builder,
521                        f32,
522                        row,
523                        idx
524                    );
525                }
526                &Type::FLOAT8_ARRAY => {
527                    handle_primitive_array_type!(
528                        builder,
529                        Type::FLOAT8_ARRAY,
530                        Float64Builder,
531                        f64,
532                        row,
533                        idx
534                    );
535                }
536                &Type::TEXT_ARRAY => {
537                    handle_primitive_array_type!(
538                        builder,
539                        Type::TEXT_ARRAY,
540                        StringBuilder,
541                        &str,
542                        row,
543                        idx
544                    );
545                }
546                &Type::VARCHAR_ARRAY => {
547                    handle_primitive_array_type!(
548                        builder,
549                        Type::VARCHAR_ARRAY,
550                        StringBuilder,
551                        &str,
552                        row,
553                        idx
554                    );
555                }
556                &Type::BYTEA_ARRAY => {
557                    handle_primitive_array_type!(
558                        builder,
559                        Type::BYTEA_ARRAY,
560                        BinaryBuilder,
561                        Vec<u8>,
562                        row,
563                        idx
564                    );
565                }
566                other if other.name().eq_ignore_ascii_case("geometry") => {
567                    let builder = builder
568                        .as_any_mut()
569                        .downcast_mut::<BinaryBuilder>()
570                        .expect("Failed to downcast builder to BinaryBuilder for Type::geometry");
571                    let v: Option<GeometryFromSql> = row
572                        .try_get(idx)
573                        .expect("Failed to get GeometryFromSql value for column Type::geometry");
574
575                    match v {
576                        Some(v) => builder.append_value(v.wkb),
577                        None => builder.append_null(),
578                    }
579                }
580                _ => {
581                    return Err(DataFusionError::Execution(format!(
582                        "Unsupported postgres type {pg_type:?}",
583                    )));
584                }
585            }
586        }
587    }
588    let projected_columns = array_builders
589        .into_iter()
590        .enumerate()
591        .filter(|(idx, _)| projections_contains(projection, *idx))
592        .map(|(_, mut builder)| builder.finish())
593        .collect::<Vec<ArrayRef>>();
594    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
595}