datafusion_remote_table/connection/
postgres.rs

1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4    project_remote_schema, Connection, DFResult, Pool, PostgresType, RemoteField, RemoteSchema,
5    RemoteType, Transform,
6};
7use bb8_postgres::tokio_postgres::types::{FromSql, Type};
8use bb8_postgres::tokio_postgres::{NoTls, Row};
9use bb8_postgres::PostgresConnectionManager;
10use chrono::Timelike;
11use datafusion::arrow::array::{
12    make_builder, ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder,
13    Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
14    ListBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder,
15};
16use datafusion::arrow::datatypes::{Date32Type, SchemaRef};
17use datafusion::common::project_schema;
18use datafusion::error::DataFusionError;
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::StreamExt;
22use std::string::ToString;
23use std::sync::Arc;
24use std::time::{SystemTime, UNIX_EPOCH};
25
26#[derive(Debug, Clone, derive_with::With)]
27pub struct PostgresConnectionOptions {
28    pub(crate) host: String,
29    pub(crate) port: u16,
30    pub(crate) username: String,
31    pub(crate) password: String,
32    pub(crate) database: Option<String>,
33}
34
35impl PostgresConnectionOptions {
36    pub fn new(
37        host: impl Into<String>,
38        port: u16,
39        username: impl Into<String>,
40        password: impl Into<String>,
41    ) -> Self {
42        Self {
43            host: host.into(),
44            port,
45            username: username.into(),
46            password: password.into(),
47            database: None,
48        }
49    }
50}
51
52#[derive(Debug)]
53pub struct PostgresPool {
54    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
55}
56
57#[async_trait::async_trait]
58impl Pool for PostgresPool {
59    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
60        let conn = self.pool.get_owned().await.map_err(|e| {
61            DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
62        })?;
63        Ok(Arc::new(PostgresConnection { conn }))
64    }
65}
66
67pub(crate) async fn connect_postgres(
68    options: &PostgresConnectionOptions,
69) -> DFResult<PostgresPool> {
70    let mut config = bb8_postgres::tokio_postgres::config::Config::new();
71    config
72        .host(&options.host)
73        .port(options.port)
74        .user(&options.username)
75        .password(&options.password);
76    if let Some(database) = &options.database {
77        config.dbname(database);
78    }
79    let manager = PostgresConnectionManager::new(config, NoTls);
80    let pool = bb8::Pool::builder()
81        .max_size(5)
82        .build(manager)
83        .await
84        .map_err(|e| {
85            DataFusionError::Execution(format!(
86                "Failed to create postgres connection pool due to {e}",
87            ))
88        })?;
89
90    Ok(PostgresPool { pool })
91}
92
93#[derive(Debug)]
94pub(crate) struct PostgresConnection {
95    conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
96}
97
98#[async_trait::async_trait]
99impl Connection for PostgresConnection {
100    async fn infer_schema(
101        &self,
102        sql: &str,
103        transform: Option<Arc<dyn Transform>>,
104    ) -> DFResult<(RemoteSchema, SchemaRef)> {
105        let mut stream = self
106            .conn
107            .query_raw(sql, Vec::<String>::new())
108            .await
109            .map_err(|e| {
110                DataFusionError::Execution(format!(
111                    "Failed to execute query {sql} on postgres due to {e}",
112                ))
113            })?
114            .chunks(1)
115            .boxed();
116
117        let Some(first_chunk) = stream.next().await else {
118            return Err(DataFusionError::Execution(
119                "No data returned from postgres".to_string(),
120            ));
121        };
122        let first_chunk: Vec<Row> = first_chunk
123            .into_iter()
124            .collect::<Result<Vec<_>, _>>()
125            .map_err(|e| {
126                DataFusionError::Execution(format!(
127                    "Failed to collect rows from postgres due to {e}",
128                ))
129            })?;
130        let Some(first_row) = first_chunk.first() else {
131            return Err(DataFusionError::Execution(
132                "No data returned from postgres".to_string(),
133            ));
134        };
135        let remote_schema = build_remote_schema(first_row)?;
136        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
137        if let Some(transform) = transform {
138            let batch = rows_to_batch(std::slice::from_ref(first_row), arrow_schema, None)?;
139            let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
140            Ok((remote_schema, transformed_batch.schema()))
141        } else {
142            Ok((remote_schema, arrow_schema))
143        }
144    }
145
146    async fn query(
147        &self,
148        sql: String,
149        projection: Option<Vec<usize>>,
150    ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
151        let mut stream = self
152            .conn
153            .query_raw(&sql, Vec::<String>::new())
154            .await
155            .map_err(|e| {
156                DataFusionError::Execution(format!(
157                    "Failed to execute query {sql} on postgres due to {e}",
158                ))
159            })?
160            .chunks(2048)
161            .boxed();
162
163        let Some(first_chunk) = stream.next().await else {
164            return Err(DataFusionError::Execution(
165                "No data returned from postgres".to_string(),
166            ));
167        };
168        let first_chunk: Vec<Row> = first_chunk
169            .into_iter()
170            .collect::<Result<Vec<_>, _>>()
171            .map_err(|e| {
172                DataFusionError::Execution(format!(
173                    "Failed to collect rows from postgres due to {e}",
174                ))
175            })?;
176        let Some(first_row) = first_chunk.first() else {
177            return Err(DataFusionError::Execution(
178                "No data returned from postgres".to_string(),
179            ));
180        };
181        let remote_schema = build_remote_schema(first_row)?;
182        let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
183        let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
184        let first_chunk = rows_to_batch(
185            first_chunk.as_slice(),
186            arrow_schema.clone(),
187            projection.as_ref(),
188        )?;
189        let schema = first_chunk.schema();
190
191        let mut stream = stream.map(move |rows| {
192            let rows: Vec<Row> = rows
193                .into_iter()
194                .collect::<Result<Vec<_>, _>>()
195                .map_err(|e| {
196                    DataFusionError::Execution(format!(
197                        "Failed to collect rows from postgres due to {e}",
198                    ))
199                })?;
200            let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
201            Ok::<RecordBatch, DataFusionError>(batch)
202        });
203
204        let output_stream = async_stream::stream! {
205           yield Ok(first_chunk);
206           while let Some(batch) = stream.next().await {
207                yield batch
208           }
209        };
210
211        Ok((
212            Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
213            projected_remote_schema,
214        ))
215    }
216}
217
218fn pg_type_to_remote_type(pg_type: &Type) -> DFResult<RemoteType> {
219    match pg_type {
220        &Type::BOOL => Ok(RemoteType::Postgres(PostgresType::Bool)),
221        &Type::CHAR => Ok(RemoteType::Postgres(PostgresType::Char)),
222        &Type::INT2 => Ok(RemoteType::Postgres(PostgresType::Int2)),
223        &Type::INT4 => Ok(RemoteType::Postgres(PostgresType::Int4)),
224        &Type::INT8 => Ok(RemoteType::Postgres(PostgresType::Int8)),
225        &Type::FLOAT4 => Ok(RemoteType::Postgres(PostgresType::Float4)),
226        &Type::FLOAT8 => Ok(RemoteType::Postgres(PostgresType::Float8)),
227        &Type::TEXT => Ok(RemoteType::Postgres(PostgresType::Text)),
228        &Type::VARCHAR => Ok(RemoteType::Postgres(PostgresType::Varchar)),
229        &Type::BPCHAR => Ok(RemoteType::Postgres(PostgresType::Bpchar)),
230        &Type::BYTEA => Ok(RemoteType::Postgres(PostgresType::Bytea)),
231        &Type::DATE => Ok(RemoteType::Postgres(PostgresType::Date)),
232        &Type::TIMESTAMP => Ok(RemoteType::Postgres(PostgresType::Timestamp)),
233        &Type::TIMESTAMPTZ => Ok(RemoteType::Postgres(PostgresType::TimestampTz)),
234        &Type::TIME => Ok(RemoteType::Postgres(PostgresType::Time)),
235        &Type::INT2_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int2Array)),
236        &Type::INT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int4Array)),
237        &Type::INT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int8Array)),
238        &Type::FLOAT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float4Array)),
239        &Type::FLOAT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float8Array)),
240        &Type::TEXT_ARRAY => Ok(RemoteType::Postgres(PostgresType::TextArray)),
241        &Type::VARCHAR_ARRAY => Ok(RemoteType::Postgres(PostgresType::VarcharArray)),
242        &Type::BYTEA_ARRAY => Ok(RemoteType::Postgres(PostgresType::ByteaArray)),
243        other if other.name().eq_ignore_ascii_case("geometry") => {
244            Ok(RemoteType::Postgres(PostgresType::PostGisGeometry))
245        }
246        _ => Err(DataFusionError::NotImplemented(format!(
247            "Unsupported postgres type {pg_type:?}",
248        ))),
249    }
250}
251
252fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
253    let mut remote_fields = vec![];
254    for col in row.columns() {
255        remote_fields.push(RemoteField::new(
256            col.name(),
257            pg_type_to_remote_type(col.type_())?,
258            true,
259        ));
260    }
261    Ok(RemoteSchema::new(remote_fields))
262}
263
264macro_rules! handle_primitive_type {
265    ($builder:expr, $pg_type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
266        let builder = $builder
267            .as_any_mut()
268            .downcast_mut::<$builder_ty>()
269            .expect(concat!(
270                "Failed to downcast builder to ",
271                stringify!($builder_ty),
272                " for ",
273                stringify!($pg_type)
274            ));
275        let v: Option<$value_ty> = $row.try_get($index).expect(concat!(
276            "Failed to get ",
277            stringify!($value_ty),
278            " value for column ",
279            stringify!($pg_type)
280        ));
281
282        match v {
283            Some(v) => builder.append_value(v),
284            None => builder.append_null(),
285        }
286    }};
287}
288
289macro_rules! handle_primitive_array_type {
290    ($builder:expr, $pg_type:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
291        let builder = $builder
292            .as_any_mut()
293            .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
294            .expect(concat!(
295                "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for ",
296                stringify!($pg_type)
297            ));
298        let values_builder = builder
299            .values()
300            .as_any_mut()
301            .downcast_mut::<$values_builder_ty>()
302            .expect(concat!(
303                "Failed to downcast values builder to ",
304                stringify!($values_builder_ty),
305                " for ",
306                stringify!($pg_type)
307            ));
308        let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).expect(concat!(
309            "Failed to get ",
310            stringify!($primitive_value_ty),
311            " array value for column ",
312            stringify!($pg_type)
313        ));
314
315        match v {
316            Some(v) => {
317                let v = v.into_iter().map(Some);
318                values_builder.extend(v);
319                builder.append(true);
320            }
321            None => builder.append_null(),
322        }
323    }};
324}
325
326pub struct GeometryFromSql<'a> {
327    wkb: &'a [u8],
328}
329
330impl<'a> FromSql<'a> for GeometryFromSql<'a> {
331    fn from_sql(
332        _ty: &Type,
333        raw: &'a [u8],
334    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
335        Ok(GeometryFromSql { wkb: raw })
336    }
337
338    fn accepts(ty: &Type) -> bool {
339        matches!(ty.name(), "geometry")
340    }
341}
342
343fn rows_to_batch(
344    rows: &[Row],
345    arrow_schema: SchemaRef,
346    projection: Option<&Vec<usize>>,
347) -> DFResult<RecordBatch> {
348    let projected_schema = project_schema(&arrow_schema, projection)?;
349    let mut array_builders = vec![];
350    for field in arrow_schema.fields() {
351        let builder = make_builder(field.data_type(), rows.len());
352        array_builders.push(builder);
353    }
354    for row in rows {
355        for (idx, col) in row.columns().iter().enumerate() {
356            if !projections_contains(projection, idx) {
357                continue;
358            }
359            let builder = &mut array_builders[idx];
360            match col.type_() {
361                &Type::BOOL => {
362                    handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, idx);
363                }
364                &Type::CHAR => {
365                    handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, idx);
366                }
367                &Type::INT2 => {
368                    handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, idx);
369                }
370                &Type::INT4 => {
371                    handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, idx);
372                }
373                &Type::INT8 => {
374                    handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, idx);
375                }
376                &Type::FLOAT4 => {
377                    handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, idx);
378                }
379                &Type::FLOAT8 => {
380                    handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, idx);
381                }
382                &Type::TEXT => {
383                    handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, idx);
384                }
385                &Type::VARCHAR => {
386                    handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, idx);
387                }
388                &Type::BPCHAR => {
389                    let builder = builder
390                        .as_any_mut()
391                        .downcast_mut::<StringBuilder>()
392                        .expect("Failed to downcast builder to StringBuilder for Type::BPCHAR");
393                    let v: Option<&str> = row
394                        .try_get(idx)
395                        .expect("Failed to get &str value for column Type::BPCHAR");
396
397                    match v {
398                        Some(v) => builder.append_value(v.trim_end()),
399                        None => builder.append_null(),
400                    }
401                }
402                &Type::BYTEA => {
403                    handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, idx);
404                }
405                &Type::TIMESTAMP => {
406                    let builder = builder
407                        .as_any_mut()
408                        .downcast_mut::<TimestampNanosecondBuilder>()
409                        .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
410                    let v: Option<SystemTime> = row
411                        .try_get(idx)
412                        .expect("Failed to get SystemTime value for column Type::TIMESTAMP");
413
414                    match v {
415                        Some(v) => {
416                            if let Ok(v) = v.duration_since(UNIX_EPOCH) {
417                                let timestamp: i64 = v
418                                    .as_nanos()
419                                    .try_into()
420                                    .expect("Failed to convert SystemTime to i64");
421                                builder.append_value(timestamp);
422                            }
423                        }
424                        None => builder.append_null(),
425                    }
426                }
427                &Type::TIMESTAMPTZ => {
428                    let builder = builder
429                        .as_any_mut()
430                        .downcast_mut::<TimestampNanosecondBuilder>()
431                        .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
432                    let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(idx).expect(
433                        "Failed to get chrono::DateTime<chrono::Utc> value for column Type::TIMESTAMPTZ",
434                    );
435
436                    match v {
437                        Some(v) => {
438                            let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for Type::TIMESTAMP"));
439                            builder.append_value(timestamp);
440                        }
441                        None => builder.append_null(),
442                    }
443                }
444                &Type::TIME => {
445                    let builder = builder
446                        .as_any_mut()
447                        .downcast_mut::<Time64NanosecondBuilder>()
448                        .expect(
449                            "Failed to downcast builder to Time64NanosecondBuilder for Type::TIME",
450                        );
451                    let v: Option<chrono::NaiveTime> = row
452                        .try_get(idx)
453                        .expect("Failed to get chrono::NaiveTime value for column Type::TIME");
454
455                    match v {
456                        Some(v) => {
457                            let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
458                                * 1_000_000_000
459                                + i64::from(v.nanosecond());
460                            builder.append_value(timestamp);
461                        }
462                        None => builder.append_null(),
463                    }
464                }
465                &Type::DATE => {
466                    let builder = builder
467                        .as_any_mut()
468                        .downcast_mut::<Date32Builder>()
469                        .expect("Failed to downcast builder to Date32Builder for Type::DATE");
470                    let v: Option<chrono::NaiveDate> = row
471                        .try_get(idx)
472                        .expect("Failed to get chrono::NaiveDate value for column Type::DATE");
473
474                    match v {
475                        Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
476                        None => builder.append_null(),
477                    }
478                }
479                &Type::INT2_ARRAY => {
480                    handle_primitive_array_type!(
481                        builder,
482                        Type::INT2_ARRAY,
483                        Int16Builder,
484                        i16,
485                        row,
486                        idx
487                    );
488                }
489                &Type::INT4_ARRAY => {
490                    handle_primitive_array_type!(
491                        builder,
492                        Type::INT4_ARRAY,
493                        Int32Builder,
494                        i32,
495                        row,
496                        idx
497                    );
498                }
499                &Type::INT8_ARRAY => {
500                    handle_primitive_array_type!(
501                        builder,
502                        Type::INT8_ARRAY,
503                        Int64Builder,
504                        i64,
505                        row,
506                        idx
507                    );
508                }
509                &Type::FLOAT4_ARRAY => {
510                    handle_primitive_array_type!(
511                        builder,
512                        Type::FLOAT4_ARRAY,
513                        Float32Builder,
514                        f32,
515                        row,
516                        idx
517                    );
518                }
519                &Type::FLOAT8_ARRAY => {
520                    handle_primitive_array_type!(
521                        builder,
522                        Type::FLOAT8_ARRAY,
523                        Float64Builder,
524                        f64,
525                        row,
526                        idx
527                    );
528                }
529                &Type::TEXT_ARRAY => {
530                    handle_primitive_array_type!(
531                        builder,
532                        Type::TEXT_ARRAY,
533                        StringBuilder,
534                        &str,
535                        row,
536                        idx
537                    );
538                }
539                &Type::VARCHAR_ARRAY => {
540                    handle_primitive_array_type!(
541                        builder,
542                        Type::VARCHAR_ARRAY,
543                        StringBuilder,
544                        &str,
545                        row,
546                        idx
547                    );
548                }
549                &Type::BYTEA_ARRAY => {
550                    handle_primitive_array_type!(
551                        builder,
552                        Type::BYTEA_ARRAY,
553                        BinaryBuilder,
554                        Vec<u8>,
555                        row,
556                        idx
557                    );
558                }
559                other if other.name().eq_ignore_ascii_case("geometry") => {
560                    let builder = builder
561                        .as_any_mut()
562                        .downcast_mut::<BinaryBuilder>()
563                        .expect("Failed to downcast builder to BinaryBuilder for Type::geometry");
564                    let v: Option<GeometryFromSql> = row
565                        .try_get(idx)
566                        .expect("Failed to get GeometryFromSql value for column Type::geometry");
567
568                    match v {
569                        Some(v) => builder.append_value(v.wkb),
570                        None => builder.append_null(),
571                    }
572                }
573                _ => {
574                    return Err(DataFusionError::Execution(format!(
575                        "Unsupported postgres type {col:?}",
576                    )));
577                }
578            }
579        }
580    }
581    let projected_columns = array_builders
582        .into_iter()
583        .enumerate()
584        .filter(|(idx, _)| projections_contains(projection, *idx))
585        .map(|(_, mut builder)| builder.finish())
586        .collect::<Vec<ArrayRef>>();
587    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
588}