datafusion_remote_table/connection/
postgres.rs

1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{Connection, DFResult, PostgresType, RemoteField, RemoteSchema, RemoteType, Transform};
4use bb8_postgres::tokio_postgres::types::{FromSql, Type};
5use bb8_postgres::tokio_postgres::{NoTls, Row};
6use bb8_postgres::PostgresConnectionManager;
7use chrono::Timelike;
8use datafusion::arrow::array::{
9    make_builder, ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder,
10    Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
11    ListBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder,
12};
13use datafusion::arrow::datatypes::{Date32Type, Schema, SchemaRef};
14use datafusion::common::project_schema;
15use datafusion::error::DataFusionError;
16use datafusion::execution::SendableRecordBatchStream;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use futures::{stream, StreamExt};
19use std::string::ToString;
20use std::sync::Arc;
21use std::time::{SystemTime, UNIX_EPOCH};
22
23#[derive(Debug, Clone)]
24pub struct PostgresConnectionOptions {
25    pub host: String,
26    pub port: u16,
27    pub username: String,
28    pub password: String,
29    pub database: Option<String>,
30}
31
32#[derive(Debug)]
33pub(crate) struct PostgresConnection {
34    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
35}
36
37pub(crate) async fn connect_postgres(
38    options: &PostgresConnectionOptions,
39) -> DFResult<PostgresConnection> {
40    let mut config = bb8_postgres::tokio_postgres::config::Config::new();
41    config
42        .host(&options.host)
43        .port(options.port)
44        .user(&options.username)
45        .password(&options.password);
46    if let Some(database) = &options.database {
47        config.dbname(database);
48    }
49    let manager = PostgresConnectionManager::new(config, NoTls);
50    let pool = bb8::Pool::builder()
51        .max_size(5)
52        .build(manager)
53        .await
54        .map_err(|e| {
55            DataFusionError::Execution(format!(
56                "Failed to create postgres connection pool due to {e}",
57            ))
58        })?;
59
60    Ok(PostgresConnection { pool })
61}
62
63#[async_trait::async_trait]
64impl Connection for PostgresConnection {
65    async fn infer_schema(
66        &self,
67        sql: &str,
68        transform: Option<&dyn Transform>,
69    ) -> DFResult<(RemoteSchema, SchemaRef)> {
70        let conn = self.pool.get().await.map_err(|e| {
71            DataFusionError::Execution(format!(
72                "Failed to get connection from postgres connection pool due to {e}",
73            ))
74        })?;
75        let mut stream = conn
76            .query_raw(sql, Vec::<String>::new())
77            .await
78            .map_err(|e| {
79                DataFusionError::Execution(format!(
80                    "Failed to execute query {sql} on postgres due to {e}",
81                ))
82            })?
83            .chunks(1)
84            .boxed();
85
86        let Some(first_chunk) = stream.next().await else {
87            return Err(DataFusionError::Execution(
88                "No data returned from postgres".to_string(),
89            ));
90        };
91        let first_chunk: Vec<Row> = first_chunk
92            .into_iter()
93            .collect::<Result<Vec<_>, _>>()
94            .map_err(|e| {
95                DataFusionError::Execution(format!(
96                    "Failed to collect rows from postgres due to {e}",
97                ))
98            })?;
99        let Some(first_row) = first_chunk.first() else {
100            return Err(DataFusionError::Execution(
101                "No data returned from postgres".to_string(),
102            ));
103        };
104        let (remote_schema, pg_types) = build_remote_schema(first_row)?;
105        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
106        let batch = rows_to_batch(
107            std::slice::from_ref(first_row),
108            &pg_types,
109            arrow_schema,
110            None,
111        )?;
112        if let Some(transform) = transform {
113            let transformed_batch = transform_batch(batch, transform, &remote_schema)?;
114            Ok((remote_schema, transformed_batch.schema()))
115        } else {
116            Ok((remote_schema, batch.schema()))
117        }
118    }
119
120    async fn query(
121        &self,
122        sql: String,
123        projection: Option<Vec<usize>>,
124    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
125        let conn = self.pool.get().await.map_err(|e| {
126            DataFusionError::Execution(format!(
127                "Failed to get connection from postgres connection pool due to {e}",
128            ))
129        })?;
130        let mut stream = conn
131            .query_raw(&sql, Vec::<String>::new())
132            .await
133            .map_err(|e| {
134                DataFusionError::Execution(format!(
135                    "Failed to execute query {sql} on postgres due to {e}",
136                ))
137            })?
138            .chunks(2048)
139            .boxed();
140
141        let Some(first_chunk) = stream.next().await else {
142            return Ok((
143                Box::pin(RecordBatchStreamAdapter::new(
144                    Arc::new(Schema::empty()),
145                    stream::empty(),
146                )),
147                RemoteSchema::empty(),
148            ));
149        };
150        let first_chunk: Vec<Row> = first_chunk
151            .into_iter()
152            .collect::<Result<Vec<_>, _>>()
153            .map_err(|e| {
154                DataFusionError::Execution(format!(
155                    "Failed to collect rows from postgres due to {e}",
156                ))
157            })?;
158        let Some(first_row) = first_chunk.first() else {
159            return Err(DataFusionError::Execution(
160                "No data returned from postgres".to_string(),
161            ));
162        };
163        let (remote_schema, pg_types) = build_remote_schema(first_row)?;
164        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
165        let first_chunk = rows_to_batch(
166            first_chunk.as_slice(),
167            &pg_types,
168            arrow_schema.clone(),
169            projection.as_ref(),
170        )?;
171        let schema = first_chunk.schema();
172
173        let mut stream = stream.map(move |rows| {
174            let rows: Vec<Row> = rows
175                .into_iter()
176                .collect::<Result<Vec<_>, _>>()
177                .map_err(|e| {
178                    DataFusionError::Execution(format!(
179                        "Failed to collect rows from postgres due to {e}",
180                    ))
181                })?;
182            let batch = rows_to_batch(
183                rows.as_slice(),
184                &pg_types,
185                arrow_schema.clone(),
186                projection.as_ref(),
187            )?;
188            Ok::<RecordBatch, DataFusionError>(batch)
189        });
190
191        let output_stream = async_stream::stream! {
192           yield Ok(first_chunk);
193           while let Some(batch) = stream.next().await {
194                match batch {
195                    Ok(batch) => {
196                        yield Ok(batch); // we can yield the batch as-is because we've already converted to Arrow in the chunk map
197                    }
198                    Err(e) => {
199                        yield Err(DataFusionError::Execution(format!("Failed to fetch batch: {e}")));
200                    }
201                }
202           }
203        };
204
205        Ok((
206            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
207            remote_schema,
208        ))
209    }
210}
211
212fn pg_type_to_remote_type(pg_type: &Type) -> DFResult<RemoteType> {
213    match pg_type {
214        &Type::BOOL => Ok(RemoteType::Postgres(PostgresType::Bool)),
215        &Type::CHAR => Ok(RemoteType::Postgres(PostgresType::Char)),
216        &Type::INT2 => Ok(RemoteType::Postgres(PostgresType::Int2)),
217        &Type::INT4 => Ok(RemoteType::Postgres(PostgresType::Int4)),
218        &Type::INT8 => Ok(RemoteType::Postgres(PostgresType::Int8)),
219        &Type::FLOAT4 => Ok(RemoteType::Postgres(PostgresType::Float4)),
220        &Type::FLOAT8 => Ok(RemoteType::Postgres(PostgresType::Float8)),
221        &Type::TEXT => Ok(RemoteType::Postgres(PostgresType::Text)),
222        &Type::VARCHAR => Ok(RemoteType::Postgres(PostgresType::Varchar)),
223        &Type::BYTEA => Ok(RemoteType::Postgres(PostgresType::Bytea)),
224        &Type::DATE => Ok(RemoteType::Postgres(PostgresType::Date)),
225        &Type::TIMESTAMP => Ok(RemoteType::Postgres(PostgresType::Timestamp)),
226        &Type::TIMESTAMPTZ => Ok(RemoteType::Postgres(PostgresType::TimestampTz)),
227        &Type::TIME => Ok(RemoteType::Postgres(PostgresType::Time)),
228        &Type::INT2_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int2Array)),
229        &Type::INT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int4Array)),
230        &Type::INT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int8Array)),
231        &Type::FLOAT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float4Array)),
232        &Type::FLOAT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float8Array)),
233        &Type::TEXT_ARRAY => Ok(RemoteType::Postgres(PostgresType::TextArray)),
234        &Type::VARCHAR_ARRAY => Ok(RemoteType::Postgres(PostgresType::VarcharArray)),
235        &Type::BYTEA_ARRAY => Ok(RemoteType::Postgres(PostgresType::ByteaArray)),
236        other if other.name().eq_ignore_ascii_case("geometry") => {
237            Ok(RemoteType::Postgres(PostgresType::PostGisGeometry))
238        }
239        _ => Err(DataFusionError::NotImplemented(format!(
240            "Unsupported postgres type {pg_type:?}",
241        ))),
242    }
243}
244
245fn build_remote_schema(row: &Row) -> DFResult<(RemoteSchema, Vec<Type>)> {
246    let mut remote_fields = vec![];
247    let mut pg_types = vec![];
248    for col in row.columns() {
249        let col_type = col.type_();
250        pg_types.push(col_type.clone());
251        remote_fields.push(RemoteField::new(
252            col.name(),
253            pg_type_to_remote_type(col_type)?,
254            true,
255        ));
256    }
257    Ok((RemoteSchema::new(remote_fields), pg_types))
258}
259
260macro_rules! handle_primitive_type {
261    ($builder:expr, $pg_type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
262        let builder = $builder
263            .as_any_mut()
264            .downcast_mut::<$builder_ty>()
265            .expect(concat!(
266                "Failed to downcast builder to ",
267                stringify!($builder_ty),
268                " for ",
269                stringify!($pg_type)
270            ));
271        let v: Option<$value_ty> = $row.try_get($index).expect(concat!(
272            "Failed to get ",
273            stringify!($value_ty),
274            " value for column ",
275            stringify!($pg_type)
276        ));
277
278        match v {
279            Some(v) => builder.append_value(v),
280            None => builder.append_null(),
281        }
282    }};
283}
284
285macro_rules! handle_primitive_array_type {
286    ($builder:expr, $pg_type:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
287        let builder = $builder
288            .as_any_mut()
289            .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
290            .expect(concat!(
291                "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for ",
292                stringify!($pg_type)
293            ));
294        let values_builder = builder
295            .values()
296            .as_any_mut()
297            .downcast_mut::<$values_builder_ty>()
298            .expect(concat!(
299                "Failed to downcast values builder to ",
300                stringify!($values_builder_ty),
301                " for ",
302                stringify!($pg_type)
303            ));
304        let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).expect(concat!(
305            "Failed to get ",
306            stringify!($primitive_value_ty),
307            " array value for column ",
308            stringify!($pg_type)
309        ));
310
311        match v {
312            Some(v) => {
313                let v = v.into_iter().map(Some);
314                values_builder.extend(v);
315                builder.append(true);
316            }
317            None => builder.append_null(),
318        }
319    }};
320}
321
322pub struct GeometryFromSql<'a> {
323    wkb: &'a [u8],
324}
325
326impl<'a> FromSql<'a> for GeometryFromSql<'a> {
327    fn from_sql(
328        _ty: &Type,
329        raw: &'a [u8],
330    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
331        Ok(GeometryFromSql { wkb: raw })
332    }
333
334    fn accepts(ty: &Type) -> bool {
335        matches!(ty.name(), "geometry")
336    }
337}
338
339fn rows_to_batch(
340    rows: &[Row],
341    pg_types: &Vec<Type>,
342    arrow_schema: SchemaRef,
343    projection: Option<&Vec<usize>>,
344) -> DFResult<RecordBatch> {
345    let projected_schema = project_schema(&arrow_schema, projection)?;
346    let mut array_builders = vec![];
347    for field in arrow_schema.fields() {
348        let builder = make_builder(&field.data_type(), rows.len());
349        array_builders.push(builder);
350    }
351    for row in rows {
352        for (idx, pg_type) in pg_types.iter().enumerate() {
353            if !projections_contains(projection, idx) {
354                continue;
355            }
356            let builder = &mut array_builders[idx];
357            match pg_type {
358                &Type::BOOL => {
359                    handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, idx);
360                }
361                &Type::CHAR => {
362                    handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, idx);
363                }
364                &Type::INT2 => {
365                    handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, idx);
366                }
367                &Type::INT4 => {
368                    handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, idx);
369                }
370                &Type::INT8 => {
371                    handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, idx);
372                }
373                &Type::FLOAT4 => {
374                    handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, idx);
375                }
376                &Type::FLOAT8 => {
377                    handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, idx);
378                }
379                &Type::TEXT => {
380                    handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, idx);
381                }
382                &Type::VARCHAR => {
383                    handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, idx);
384                }
385                &Type::BYTEA => {
386                    handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, idx);
387                }
388                &Type::TIMESTAMP => {
389                    let builder = builder
390                        .as_any_mut()
391                        .downcast_mut::<TimestampNanosecondBuilder>()
392                        .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
393                    let v: Option<SystemTime> = row
394                        .try_get(idx)
395                        .expect("Failed to get SystemTime value for column Type::TIMESTAMP");
396
397                    match v {
398                        Some(v) => {
399                            if let Ok(v) = v.duration_since(UNIX_EPOCH) {
400                                let timestamp: i64 = v
401                                    .as_nanos()
402                                    .try_into()
403                                    .expect("Failed to convert SystemTime to i64");
404                                builder.append_value(timestamp);
405                            }
406                        }
407                        None => builder.append_null(),
408                    }
409                }
410                &Type::TIMESTAMPTZ => {
411                    let builder = builder
412                        .as_any_mut()
413                        .downcast_mut::<TimestampNanosecondBuilder>()
414                        .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
415                    let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(idx).expect(
416                        "Failed to get chrono::DateTime<chrono::Utc> value for column Type::TIMESTAMPTZ",
417                    );
418
419                    match v {
420                        Some(v) => {
421                            let timestamp: i64 = v.timestamp_nanos_opt().expect(&format!("Failed to get timestamp in nanoseconds from {v} for Type::TIMESTAMP"));
422                            builder.append_value(timestamp);
423                        }
424                        None => {}
425                    }
426                }
427                &Type::TIME => {
428                    let builder = builder
429                        .as_any_mut()
430                        .downcast_mut::<Time64NanosecondBuilder>()
431                        .expect(
432                            "Failed to downcast builder to Time64NanosecondBuilder for Type::TIME",
433                        );
434                    let v: Option<chrono::NaiveTime> = row
435                        .try_get(idx)
436                        .expect("Failed to get chrono::NaiveTime value for column Type::TIME");
437
438                    match v {
439                        Some(v) => {
440                            let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
441                                * 1_000_000_000
442                                + i64::from(v.nanosecond());
443                            builder.append_value(timestamp);
444                        }
445                        None => builder.append_null(),
446                    }
447                }
448                &Type::DATE => {
449                    let builder = builder
450                        .as_any_mut()
451                        .downcast_mut::<Date32Builder>()
452                        .expect("Failed to downcast builder to Date32Builder for Type::DATE");
453                    let v: Option<chrono::NaiveDate> = row
454                        .try_get(idx)
455                        .expect("Failed to get chrono::NaiveDate value for column Type::DATE");
456
457                    match v {
458                        Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
459                        None => builder.append_null(),
460                    }
461                }
462                &Type::INT2_ARRAY => {
463                    handle_primitive_array_type!(
464                        builder,
465                        Type::INT2_ARRAY,
466                        Int16Builder,
467                        i16,
468                        row,
469                        idx
470                    );
471                }
472                &Type::INT4_ARRAY => {
473                    handle_primitive_array_type!(
474                        builder,
475                        Type::INT4_ARRAY,
476                        Int32Builder,
477                        i32,
478                        row,
479                        idx
480                    );
481                }
482                &Type::INT8_ARRAY => {
483                    handle_primitive_array_type!(
484                        builder,
485                        Type::INT8_ARRAY,
486                        Int64Builder,
487                        i64,
488                        row,
489                        idx
490                    );
491                }
492                &Type::FLOAT4_ARRAY => {
493                    handle_primitive_array_type!(
494                        builder,
495                        Type::FLOAT4_ARRAY,
496                        Float32Builder,
497                        f32,
498                        row,
499                        idx
500                    );
501                }
502                &Type::FLOAT8_ARRAY => {
503                    handle_primitive_array_type!(
504                        builder,
505                        Type::FLOAT8_ARRAY,
506                        Float64Builder,
507                        f64,
508                        row,
509                        idx
510                    );
511                }
512                &Type::TEXT_ARRAY => {
513                    handle_primitive_array_type!(
514                        builder,
515                        Type::TEXT_ARRAY,
516                        StringBuilder,
517                        &str,
518                        row,
519                        idx
520                    );
521                }
522                &Type::VARCHAR_ARRAY => {
523                    handle_primitive_array_type!(
524                        builder,
525                        Type::VARCHAR_ARRAY,
526                        StringBuilder,
527                        &str,
528                        row,
529                        idx
530                    );
531                }
532                &Type::BYTEA_ARRAY => {
533                    handle_primitive_array_type!(
534                        builder,
535                        Type::BYTEA_ARRAY,
536                        BinaryBuilder,
537                        Vec<u8>,
538                        row,
539                        idx
540                    );
541                }
542                other if other.name().eq_ignore_ascii_case("geometry") => {
543                    let builder = builder
544                        .as_any_mut()
545                        .downcast_mut::<BinaryBuilder>()
546                        .expect("Failed to downcast builder to BinaryBuilder for Type::geometry");
547                    let v: Option<GeometryFromSql> = row
548                        .try_get(idx)
549                        .expect("Failed to get GeometryFromSql value for column Type::geometry");
550
551                    match v {
552                        Some(v) => builder.append_value(v.wkb),
553                        None => builder.append_null(),
554                    }
555                }
556                _ => {
557                    return Err(DataFusionError::Execution(format!(
558                        "Unsupported postgres type {pg_type:?}",
559                    )));
560                }
561            }
562        }
563    }
564    let projected_columns = array_builders
565        .into_iter()
566        .enumerate()
567        .filter(|(idx, _)| projections_contains(projection, *idx))
568        .map(|(_, mut builder)| builder.finish())
569        .collect::<Vec<ArrayRef>>();
570    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
571}