1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4 project_remote_schema, Connection, DFResult, Pool, PostgresType, RemoteField, RemoteSchema,
5 RemoteType, Transform,
6};
7use bb8_postgres::tokio_postgres::types::{FromSql, Type};
8use bb8_postgres::tokio_postgres::{NoTls, Row};
9use bb8_postgres::PostgresConnectionManager;
10use chrono::Timelike;
11use datafusion::arrow::array::{
12 make_builder, ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder,
13 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
14 ListBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder,
15};
16use datafusion::arrow::datatypes::{Date32Type, SchemaRef};
17use datafusion::common::project_schema;
18use datafusion::error::DataFusionError;
19use datafusion::execution::SendableRecordBatchStream;
20use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
21use futures::StreamExt;
22use std::string::ToString;
23use std::sync::Arc;
24use std::time::{SystemTime, UNIX_EPOCH};
25
26#[derive(Debug, Clone, derive_with::With)]
27pub struct PostgresConnectionOptions {
28 pub(crate) host: String,
29 pub(crate) port: u16,
30 pub(crate) username: String,
31 pub(crate) password: String,
32 pub(crate) database: Option<String>,
33}
34
35impl PostgresConnectionOptions {
36 pub fn new(
37 host: impl Into<String>,
38 port: u16,
39 username: impl Into<String>,
40 password: impl Into<String>,
41 ) -> Self {
42 Self {
43 host: host.into(),
44 port,
45 username: username.into(),
46 password: password.into(),
47 database: None,
48 }
49 }
50}
51
52#[derive(Debug)]
53pub struct PostgresPool {
54 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
55}
56
57#[async_trait::async_trait]
58impl Pool for PostgresPool {
59 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
60 let conn = self.pool.get_owned().await.map_err(|e| {
61 DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
62 })?;
63 Ok(Arc::new(PostgresConnection { conn }))
64 }
65}
66
67pub(crate) async fn connect_postgres(
68 options: &PostgresConnectionOptions,
69) -> DFResult<PostgresPool> {
70 let mut config = bb8_postgres::tokio_postgres::config::Config::new();
71 config
72 .host(&options.host)
73 .port(options.port)
74 .user(&options.username)
75 .password(&options.password);
76 if let Some(database) = &options.database {
77 config.dbname(database);
78 }
79 let manager = PostgresConnectionManager::new(config, NoTls);
80 let pool = bb8::Pool::builder()
81 .max_size(5)
82 .build(manager)
83 .await
84 .map_err(|e| {
85 DataFusionError::Execution(format!(
86 "Failed to create postgres connection pool due to {e}",
87 ))
88 })?;
89
90 Ok(PostgresPool { pool })
91}
92
93#[derive(Debug)]
94pub(crate) struct PostgresConnection {
95 conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
96}
97
98#[async_trait::async_trait]
99impl Connection for PostgresConnection {
100 async fn infer_schema(
101 &self,
102 sql: &str,
103 transform: Option<Arc<dyn Transform>>,
104 ) -> DFResult<(RemoteSchema, SchemaRef)> {
105 let mut stream = self
106 .conn
107 .query_raw(sql, Vec::<String>::new())
108 .await
109 .map_err(|e| {
110 DataFusionError::Execution(format!(
111 "Failed to execute query {sql} on postgres due to {e}",
112 ))
113 })?
114 .chunks(1)
115 .boxed();
116
117 let Some(first_chunk) = stream.next().await else {
118 return Err(DataFusionError::Execution(
119 "No data returned from postgres".to_string(),
120 ));
121 };
122 let first_chunk: Vec<Row> = first_chunk
123 .into_iter()
124 .collect::<Result<Vec<_>, _>>()
125 .map_err(|e| {
126 DataFusionError::Execution(format!(
127 "Failed to collect rows from postgres due to {e}",
128 ))
129 })?;
130 let Some(first_row) = first_chunk.first() else {
131 return Err(DataFusionError::Execution(
132 "No data returned from postgres".to_string(),
133 ));
134 };
135 let remote_schema = build_remote_schema(first_row)?;
136 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
137 if let Some(transform) = transform {
138 let batch = rows_to_batch(std::slice::from_ref(first_row), arrow_schema, None)?;
139 let transformed_batch = transform_batch(batch, transform.as_ref(), &remote_schema)?;
140 Ok((remote_schema, transformed_batch.schema()))
141 } else {
142 Ok((remote_schema, arrow_schema))
143 }
144 }
145
146 async fn query(
147 &self,
148 sql: String,
149 projection: Option<Vec<usize>>,
150 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
151 let mut stream = self
152 .conn
153 .query_raw(&sql, Vec::<String>::new())
154 .await
155 .map_err(|e| {
156 DataFusionError::Execution(format!(
157 "Failed to execute query {sql} on postgres due to {e}",
158 ))
159 })?
160 .chunks(2048)
161 .boxed();
162
163 let Some(first_chunk) = stream.next().await else {
164 return Err(DataFusionError::Execution(
165 "No data returned from postgres".to_string(),
166 ));
167 };
168 let first_chunk: Vec<Row> = first_chunk
169 .into_iter()
170 .collect::<Result<Vec<_>, _>>()
171 .map_err(|e| {
172 DataFusionError::Execution(format!(
173 "Failed to collect rows from postgres due to {e}",
174 ))
175 })?;
176 let Some(first_row) = first_chunk.first() else {
177 return Err(DataFusionError::Execution(
178 "No data returned from postgres".to_string(),
179 ));
180 };
181 let remote_schema = build_remote_schema(first_row)?;
182 let projected_remote_schema = project_remote_schema(&remote_schema, projection.as_ref());
183 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
184 let first_chunk = rows_to_batch(
185 first_chunk.as_slice(),
186 arrow_schema.clone(),
187 projection.as_ref(),
188 )?;
189 let schema = first_chunk.schema();
190
191 let mut stream = stream.map(move |rows| {
192 let rows: Vec<Row> = rows
193 .into_iter()
194 .collect::<Result<Vec<_>, _>>()
195 .map_err(|e| {
196 DataFusionError::Execution(format!(
197 "Failed to collect rows from postgres due to {e}",
198 ))
199 })?;
200 let batch = rows_to_batch(rows.as_slice(), arrow_schema.clone(), projection.as_ref())?;
201 Ok::<RecordBatch, DataFusionError>(batch)
202 });
203
204 let output_stream = async_stream::stream! {
205 yield Ok(first_chunk);
206 while let Some(batch) = stream.next().await {
207 yield batch
208 }
209 };
210
211 Ok((
212 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
213 projected_remote_schema,
214 ))
215 }
216}
217
218fn pg_type_to_remote_type(pg_type: &Type) -> DFResult<RemoteType> {
219 match pg_type {
220 &Type::BOOL => Ok(RemoteType::Postgres(PostgresType::Bool)),
221 &Type::CHAR => Ok(RemoteType::Postgres(PostgresType::Char)),
222 &Type::INT2 => Ok(RemoteType::Postgres(PostgresType::Int2)),
223 &Type::INT4 => Ok(RemoteType::Postgres(PostgresType::Int4)),
224 &Type::INT8 => Ok(RemoteType::Postgres(PostgresType::Int8)),
225 &Type::FLOAT4 => Ok(RemoteType::Postgres(PostgresType::Float4)),
226 &Type::FLOAT8 => Ok(RemoteType::Postgres(PostgresType::Float8)),
227 &Type::TEXT => Ok(RemoteType::Postgres(PostgresType::Text)),
228 &Type::VARCHAR => Ok(RemoteType::Postgres(PostgresType::Varchar)),
229 &Type::BPCHAR => Ok(RemoteType::Postgres(PostgresType::Bpchar)),
230 &Type::BYTEA => Ok(RemoteType::Postgres(PostgresType::Bytea)),
231 &Type::DATE => Ok(RemoteType::Postgres(PostgresType::Date)),
232 &Type::TIMESTAMP => Ok(RemoteType::Postgres(PostgresType::Timestamp)),
233 &Type::TIMESTAMPTZ => Ok(RemoteType::Postgres(PostgresType::TimestampTz)),
234 &Type::TIME => Ok(RemoteType::Postgres(PostgresType::Time)),
235 &Type::INT2_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int2Array)),
236 &Type::INT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int4Array)),
237 &Type::INT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int8Array)),
238 &Type::FLOAT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float4Array)),
239 &Type::FLOAT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float8Array)),
240 &Type::TEXT_ARRAY => Ok(RemoteType::Postgres(PostgresType::TextArray)),
241 &Type::VARCHAR_ARRAY => Ok(RemoteType::Postgres(PostgresType::VarcharArray)),
242 &Type::BYTEA_ARRAY => Ok(RemoteType::Postgres(PostgresType::ByteaArray)),
243 other if other.name().eq_ignore_ascii_case("geometry") => {
244 Ok(RemoteType::Postgres(PostgresType::PostGisGeometry))
245 }
246 _ => Err(DataFusionError::NotImplemented(format!(
247 "Unsupported postgres type {pg_type:?}",
248 ))),
249 }
250}
251
252fn build_remote_schema(row: &Row) -> DFResult<RemoteSchema> {
253 let mut remote_fields = vec![];
254 for col in row.columns() {
255 remote_fields.push(RemoteField::new(
256 col.name(),
257 pg_type_to_remote_type(col.type_())?,
258 true,
259 ));
260 }
261 Ok(RemoteSchema::new(remote_fields))
262}
263
264macro_rules! handle_primitive_type {
265 ($builder:expr, $pg_type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
266 let builder = $builder
267 .as_any_mut()
268 .downcast_mut::<$builder_ty>()
269 .expect(concat!(
270 "Failed to downcast builder to ",
271 stringify!($builder_ty),
272 " for ",
273 stringify!($pg_type)
274 ));
275 let v: Option<$value_ty> = $row.try_get($index).expect(concat!(
276 "Failed to get ",
277 stringify!($value_ty),
278 " value for column ",
279 stringify!($pg_type)
280 ));
281
282 match v {
283 Some(v) => builder.append_value(v),
284 None => builder.append_null(),
285 }
286 }};
287}
288
289macro_rules! handle_primitive_array_type {
290 ($builder:expr, $pg_type:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
291 let builder = $builder
292 .as_any_mut()
293 .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
294 .expect(concat!(
295 "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for ",
296 stringify!($pg_type)
297 ));
298 let values_builder = builder
299 .values()
300 .as_any_mut()
301 .downcast_mut::<$values_builder_ty>()
302 .expect(concat!(
303 "Failed to downcast values builder to ",
304 stringify!($values_builder_ty),
305 " for ",
306 stringify!($pg_type)
307 ));
308 let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).expect(concat!(
309 "Failed to get ",
310 stringify!($primitive_value_ty),
311 " array value for column ",
312 stringify!($pg_type)
313 ));
314
315 match v {
316 Some(v) => {
317 let v = v.into_iter().map(Some);
318 values_builder.extend(v);
319 builder.append(true);
320 }
321 None => builder.append_null(),
322 }
323 }};
324}
325
326pub struct GeometryFromSql<'a> {
327 wkb: &'a [u8],
328}
329
330impl<'a> FromSql<'a> for GeometryFromSql<'a> {
331 fn from_sql(
332 _ty: &Type,
333 raw: &'a [u8],
334 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
335 Ok(GeometryFromSql { wkb: raw })
336 }
337
338 fn accepts(ty: &Type) -> bool {
339 matches!(ty.name(), "geometry")
340 }
341}
342
343fn rows_to_batch(
344 rows: &[Row],
345 arrow_schema: SchemaRef,
346 projection: Option<&Vec<usize>>,
347) -> DFResult<RecordBatch> {
348 let projected_schema = project_schema(&arrow_schema, projection)?;
349 let mut array_builders = vec![];
350 for field in arrow_schema.fields() {
351 let builder = make_builder(field.data_type(), rows.len());
352 array_builders.push(builder);
353 }
354 for row in rows {
355 for (idx, col) in row.columns().iter().enumerate() {
356 if !projections_contains(projection, idx) {
357 continue;
358 }
359 let builder = &mut array_builders[idx];
360 match col.type_() {
361 &Type::BOOL => {
362 handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, idx);
363 }
364 &Type::CHAR => {
365 handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, idx);
366 }
367 &Type::INT2 => {
368 handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, idx);
369 }
370 &Type::INT4 => {
371 handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, idx);
372 }
373 &Type::INT8 => {
374 handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, idx);
375 }
376 &Type::FLOAT4 => {
377 handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, idx);
378 }
379 &Type::FLOAT8 => {
380 handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, idx);
381 }
382 &Type::TEXT => {
383 handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, idx);
384 }
385 &Type::VARCHAR => {
386 handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, idx);
387 }
388 &Type::BPCHAR => {
389 let builder = builder
390 .as_any_mut()
391 .downcast_mut::<StringBuilder>()
392 .expect("Failed to downcast builder to StringBuilder for Type::BPCHAR");
393 let v: Option<&str> = row
394 .try_get(idx)
395 .expect("Failed to get &str value for column Type::BPCHAR");
396
397 match v {
398 Some(v) => builder.append_value(v.trim_end()),
399 None => builder.append_null(),
400 }
401 }
402 &Type::BYTEA => {
403 handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, idx);
404 }
405 &Type::TIMESTAMP => {
406 let builder = builder
407 .as_any_mut()
408 .downcast_mut::<TimestampNanosecondBuilder>()
409 .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
410 let v: Option<SystemTime> = row
411 .try_get(idx)
412 .expect("Failed to get SystemTime value for column Type::TIMESTAMP");
413
414 match v {
415 Some(v) => {
416 if let Ok(v) = v.duration_since(UNIX_EPOCH) {
417 let timestamp: i64 = v
418 .as_nanos()
419 .try_into()
420 .expect("Failed to convert SystemTime to i64");
421 builder.append_value(timestamp);
422 }
423 }
424 None => builder.append_null(),
425 }
426 }
427 &Type::TIMESTAMPTZ => {
428 let builder = builder
429 .as_any_mut()
430 .downcast_mut::<TimestampNanosecondBuilder>()
431 .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
432 let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(idx).expect(
433 "Failed to get chrono::DateTime<chrono::Utc> value for column Type::TIMESTAMPTZ",
434 );
435
436 match v {
437 Some(v) => {
438 let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for Type::TIMESTAMP"));
439 builder.append_value(timestamp);
440 }
441 None => builder.append_null(),
442 }
443 }
444 &Type::TIME => {
445 let builder = builder
446 .as_any_mut()
447 .downcast_mut::<Time64NanosecondBuilder>()
448 .expect(
449 "Failed to downcast builder to Time64NanosecondBuilder for Type::TIME",
450 );
451 let v: Option<chrono::NaiveTime> = row
452 .try_get(idx)
453 .expect("Failed to get chrono::NaiveTime value for column Type::TIME");
454
455 match v {
456 Some(v) => {
457 let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
458 * 1_000_000_000
459 + i64::from(v.nanosecond());
460 builder.append_value(timestamp);
461 }
462 None => builder.append_null(),
463 }
464 }
465 &Type::DATE => {
466 let builder = builder
467 .as_any_mut()
468 .downcast_mut::<Date32Builder>()
469 .expect("Failed to downcast builder to Date32Builder for Type::DATE");
470 let v: Option<chrono::NaiveDate> = row
471 .try_get(idx)
472 .expect("Failed to get chrono::NaiveDate value for column Type::DATE");
473
474 match v {
475 Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
476 None => builder.append_null(),
477 }
478 }
479 &Type::INT2_ARRAY => {
480 handle_primitive_array_type!(
481 builder,
482 Type::INT2_ARRAY,
483 Int16Builder,
484 i16,
485 row,
486 idx
487 );
488 }
489 &Type::INT4_ARRAY => {
490 handle_primitive_array_type!(
491 builder,
492 Type::INT4_ARRAY,
493 Int32Builder,
494 i32,
495 row,
496 idx
497 );
498 }
499 &Type::INT8_ARRAY => {
500 handle_primitive_array_type!(
501 builder,
502 Type::INT8_ARRAY,
503 Int64Builder,
504 i64,
505 row,
506 idx
507 );
508 }
509 &Type::FLOAT4_ARRAY => {
510 handle_primitive_array_type!(
511 builder,
512 Type::FLOAT4_ARRAY,
513 Float32Builder,
514 f32,
515 row,
516 idx
517 );
518 }
519 &Type::FLOAT8_ARRAY => {
520 handle_primitive_array_type!(
521 builder,
522 Type::FLOAT8_ARRAY,
523 Float64Builder,
524 f64,
525 row,
526 idx
527 );
528 }
529 &Type::TEXT_ARRAY => {
530 handle_primitive_array_type!(
531 builder,
532 Type::TEXT_ARRAY,
533 StringBuilder,
534 &str,
535 row,
536 idx
537 );
538 }
539 &Type::VARCHAR_ARRAY => {
540 handle_primitive_array_type!(
541 builder,
542 Type::VARCHAR_ARRAY,
543 StringBuilder,
544 &str,
545 row,
546 idx
547 );
548 }
549 &Type::BYTEA_ARRAY => {
550 handle_primitive_array_type!(
551 builder,
552 Type::BYTEA_ARRAY,
553 BinaryBuilder,
554 Vec<u8>,
555 row,
556 idx
557 );
558 }
559 other if other.name().eq_ignore_ascii_case("geometry") => {
560 let builder = builder
561 .as_any_mut()
562 .downcast_mut::<BinaryBuilder>()
563 .expect("Failed to downcast builder to BinaryBuilder for Type::geometry");
564 let v: Option<GeometryFromSql> = row
565 .try_get(idx)
566 .expect("Failed to get GeometryFromSql value for column Type::geometry");
567
568 match v {
569 Some(v) => builder.append_value(v.wkb),
570 None => builder.append_null(),
571 }
572 }
573 _ => {
574 return Err(DataFusionError::Execution(format!(
575 "Unsupported postgres type {col:?}",
576 )));
577 }
578 }
579 }
580 }
581 let projected_columns = array_builders
582 .into_iter()
583 .enumerate()
584 .filter(|(idx, _)| projections_contains(projection, *idx))
585 .map(|(_, mut builder)| builder.finish())
586 .collect::<Vec<ArrayRef>>();
587 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
588}