1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{
4 Connection, DFResult, Pool, PostgresType, RemoteField, RemoteSchema, RemoteType, Transform,
5};
6use bb8_postgres::tokio_postgres::types::{FromSql, Type};
7use bb8_postgres::tokio_postgres::{NoTls, Row};
8use bb8_postgres::PostgresConnectionManager;
9use chrono::Timelike;
10use datafusion::arrow::array::{
11 make_builder, ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder,
12 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
13 ListBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder,
14};
15use datafusion::arrow::datatypes::{Date32Type, Schema, SchemaRef};
16use datafusion::common::project_schema;
17use datafusion::error::DataFusionError;
18use datafusion::execution::SendableRecordBatchStream;
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use futures::{stream, StreamExt};
21use std::string::ToString;
22use std::sync::Arc;
23use std::time::{SystemTime, UNIX_EPOCH};
24
25#[derive(Debug, Clone)]
26pub struct PostgresConnectionOptions {
27 pub host: String,
28 pub port: u16,
29 pub username: String,
30 pub password: String,
31 pub database: Option<String>,
32}
33
34#[derive(Debug)]
35pub struct PostgresPool {
36 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
37}
38
39#[async_trait::async_trait]
40impl Pool for PostgresPool {
41 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
42 let conn = self.pool.get_owned().await.map_err(|e| {
43 DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
44 })?;
45 Ok(Arc::new(PostgresConnection { conn }))
46 }
47}
48
49pub(crate) async fn connect_postgres(
50 options: &PostgresConnectionOptions,
51) -> DFResult<PostgresPool> {
52 let mut config = bb8_postgres::tokio_postgres::config::Config::new();
53 config
54 .host(&options.host)
55 .port(options.port)
56 .user(&options.username)
57 .password(&options.password);
58 if let Some(database) = &options.database {
59 config.dbname(database);
60 }
61 let manager = PostgresConnectionManager::new(config, NoTls);
62 let pool = bb8::Pool::builder()
63 .max_size(5)
64 .build(manager)
65 .await
66 .map_err(|e| {
67 DataFusionError::Execution(format!(
68 "Failed to create postgres connection pool due to {e}",
69 ))
70 })?;
71
72 Ok(PostgresPool { pool })
73}
74
75#[derive(Debug)]
76pub(crate) struct PostgresConnection {
77 conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
78}
79
80#[async_trait::async_trait]
81impl Connection for PostgresConnection {
82 async fn infer_schema(
83 &self,
84 sql: &str,
85 transform: Option<&dyn Transform>,
86 ) -> DFResult<(RemoteSchema, SchemaRef)> {
87 let mut stream = self
88 .conn
89 .query_raw(sql, Vec::<String>::new())
90 .await
91 .map_err(|e| {
92 DataFusionError::Execution(format!(
93 "Failed to execute query {sql} on postgres due to {e}",
94 ))
95 })?
96 .chunks(1)
97 .boxed();
98
99 let Some(first_chunk) = stream.next().await else {
100 return Err(DataFusionError::Execution(
101 "No data returned from postgres".to_string(),
102 ));
103 };
104 let first_chunk: Vec<Row> = first_chunk
105 .into_iter()
106 .collect::<Result<Vec<_>, _>>()
107 .map_err(|e| {
108 DataFusionError::Execution(format!(
109 "Failed to collect rows from postgres due to {e}",
110 ))
111 })?;
112 let Some(first_row) = first_chunk.first() else {
113 return Err(DataFusionError::Execution(
114 "No data returned from postgres".to_string(),
115 ));
116 };
117 let (remote_schema, pg_types) = build_remote_schema(first_row)?;
118 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
119 let batch = rows_to_batch(
120 std::slice::from_ref(first_row),
121 &pg_types,
122 arrow_schema,
123 None,
124 )?;
125 if let Some(transform) = transform {
126 let transformed_batch = transform_batch(batch, transform, &remote_schema)?;
127 Ok((remote_schema, transformed_batch.schema()))
128 } else {
129 Ok((remote_schema, batch.schema()))
130 }
131 }
132
133 async fn query(
134 &self,
135 sql: String,
136 projection: Option<Vec<usize>>,
137 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
138 let mut stream = self
139 .conn
140 .query_raw(&sql, Vec::<String>::new())
141 .await
142 .map_err(|e| {
143 DataFusionError::Execution(format!(
144 "Failed to execute query {sql} on postgres due to {e}",
145 ))
146 })?
147 .chunks(2048)
148 .boxed();
149
150 let Some(first_chunk) = stream.next().await else {
151 return Ok((
152 Box::pin(RecordBatchStreamAdapter::new(
153 Arc::new(Schema::empty()),
154 stream::empty(),
155 )),
156 RemoteSchema::empty(),
157 ));
158 };
159 let first_chunk: Vec<Row> = first_chunk
160 .into_iter()
161 .collect::<Result<Vec<_>, _>>()
162 .map_err(|e| {
163 DataFusionError::Execution(format!(
164 "Failed to collect rows from postgres due to {e}",
165 ))
166 })?;
167 let Some(first_row) = first_chunk.first() else {
168 return Err(DataFusionError::Execution(
169 "No data returned from postgres".to_string(),
170 ));
171 };
172 let (remote_schema, pg_types) = build_remote_schema(first_row)?;
173 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
174 let first_chunk = rows_to_batch(
175 first_chunk.as_slice(),
176 &pg_types,
177 arrow_schema.clone(),
178 projection.as_ref(),
179 )?;
180 let schema = first_chunk.schema();
181
182 let mut stream = stream.map(move |rows| {
183 let rows: Vec<Row> = rows
184 .into_iter()
185 .collect::<Result<Vec<_>, _>>()
186 .map_err(|e| {
187 DataFusionError::Execution(format!(
188 "Failed to collect rows from postgres due to {e}",
189 ))
190 })?;
191 let batch = rows_to_batch(
192 rows.as_slice(),
193 &pg_types,
194 arrow_schema.clone(),
195 projection.as_ref(),
196 )?;
197 Ok::<RecordBatch, DataFusionError>(batch)
198 });
199
200 let output_stream = async_stream::stream! {
201 yield Ok(first_chunk);
202 while let Some(batch) = stream.next().await {
203 match batch {
204 Ok(batch) => {
205 yield Ok(batch); }
207 Err(e) => {
208 yield Err(DataFusionError::Execution(format!("Failed to fetch batch: {e}")));
209 }
210 }
211 }
212 };
213
214 Ok((
215 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
216 remote_schema,
217 ))
218 }
219}
220
221fn pg_type_to_remote_type(pg_type: &Type) -> DFResult<RemoteType> {
222 match pg_type {
223 &Type::BOOL => Ok(RemoteType::Postgres(PostgresType::Bool)),
224 &Type::CHAR => Ok(RemoteType::Postgres(PostgresType::Char)),
225 &Type::INT2 => Ok(RemoteType::Postgres(PostgresType::Int2)),
226 &Type::INT4 => Ok(RemoteType::Postgres(PostgresType::Int4)),
227 &Type::INT8 => Ok(RemoteType::Postgres(PostgresType::Int8)),
228 &Type::FLOAT4 => Ok(RemoteType::Postgres(PostgresType::Float4)),
229 &Type::FLOAT8 => Ok(RemoteType::Postgres(PostgresType::Float8)),
230 &Type::TEXT => Ok(RemoteType::Postgres(PostgresType::Text)),
231 &Type::VARCHAR => Ok(RemoteType::Postgres(PostgresType::Varchar)),
232 &Type::BPCHAR => Ok(RemoteType::Postgres(PostgresType::Bpchar)),
233 &Type::BYTEA => Ok(RemoteType::Postgres(PostgresType::Bytea)),
234 &Type::DATE => Ok(RemoteType::Postgres(PostgresType::Date)),
235 &Type::TIMESTAMP => Ok(RemoteType::Postgres(PostgresType::Timestamp)),
236 &Type::TIMESTAMPTZ => Ok(RemoteType::Postgres(PostgresType::TimestampTz)),
237 &Type::TIME => Ok(RemoteType::Postgres(PostgresType::Time)),
238 &Type::INT2_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int2Array)),
239 &Type::INT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int4Array)),
240 &Type::INT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int8Array)),
241 &Type::FLOAT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float4Array)),
242 &Type::FLOAT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float8Array)),
243 &Type::TEXT_ARRAY => Ok(RemoteType::Postgres(PostgresType::TextArray)),
244 &Type::VARCHAR_ARRAY => Ok(RemoteType::Postgres(PostgresType::VarcharArray)),
245 &Type::BYTEA_ARRAY => Ok(RemoteType::Postgres(PostgresType::ByteaArray)),
246 other if other.name().eq_ignore_ascii_case("geometry") => {
247 Ok(RemoteType::Postgres(PostgresType::PostGisGeometry))
248 }
249 _ => Err(DataFusionError::NotImplemented(format!(
250 "Unsupported postgres type {pg_type:?}",
251 ))),
252 }
253}
254
255fn build_remote_schema(row: &Row) -> DFResult<(RemoteSchema, Vec<Type>)> {
256 let mut remote_fields = vec![];
257 let mut pg_types = vec![];
258 for col in row.columns() {
259 let col_type = col.type_();
260 pg_types.push(col_type.clone());
261 remote_fields.push(RemoteField::new(
262 col.name(),
263 pg_type_to_remote_type(col_type)?,
264 true,
265 ));
266 }
267 Ok((RemoteSchema::new(remote_fields), pg_types))
268}
269
270macro_rules! handle_primitive_type {
271 ($builder:expr, $pg_type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
272 let builder = $builder
273 .as_any_mut()
274 .downcast_mut::<$builder_ty>()
275 .expect(concat!(
276 "Failed to downcast builder to ",
277 stringify!($builder_ty),
278 " for ",
279 stringify!($pg_type)
280 ));
281 let v: Option<$value_ty> = $row.try_get($index).expect(concat!(
282 "Failed to get ",
283 stringify!($value_ty),
284 " value for column ",
285 stringify!($pg_type)
286 ));
287
288 match v {
289 Some(v) => builder.append_value(v),
290 None => builder.append_null(),
291 }
292 }};
293}
294
295macro_rules! handle_primitive_array_type {
296 ($builder:expr, $pg_type:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
297 let builder = $builder
298 .as_any_mut()
299 .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
300 .expect(concat!(
301 "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for ",
302 stringify!($pg_type)
303 ));
304 let values_builder = builder
305 .values()
306 .as_any_mut()
307 .downcast_mut::<$values_builder_ty>()
308 .expect(concat!(
309 "Failed to downcast values builder to ",
310 stringify!($values_builder_ty),
311 " for ",
312 stringify!($pg_type)
313 ));
314 let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).expect(concat!(
315 "Failed to get ",
316 stringify!($primitive_value_ty),
317 " array value for column ",
318 stringify!($pg_type)
319 ));
320
321 match v {
322 Some(v) => {
323 let v = v.into_iter().map(Some);
324 values_builder.extend(v);
325 builder.append(true);
326 }
327 None => builder.append_null(),
328 }
329 }};
330}
331
332pub struct GeometryFromSql<'a> {
333 wkb: &'a [u8],
334}
335
336impl<'a> FromSql<'a> for GeometryFromSql<'a> {
337 fn from_sql(
338 _ty: &Type,
339 raw: &'a [u8],
340 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
341 Ok(GeometryFromSql { wkb: raw })
342 }
343
344 fn accepts(ty: &Type) -> bool {
345 matches!(ty.name(), "geometry")
346 }
347}
348
349fn rows_to_batch(
350 rows: &[Row],
351 pg_types: &[Type],
352 arrow_schema: SchemaRef,
353 projection: Option<&Vec<usize>>,
354) -> DFResult<RecordBatch> {
355 let projected_schema = project_schema(&arrow_schema, projection)?;
356 let mut array_builders = vec![];
357 for field in arrow_schema.fields() {
358 let builder = make_builder(field.data_type(), rows.len());
359 array_builders.push(builder);
360 }
361 for row in rows {
362 for (idx, pg_type) in pg_types.iter().enumerate() {
363 if !projections_contains(projection, idx) {
364 continue;
365 }
366 let builder = &mut array_builders[idx];
367 match pg_type {
368 &Type::BOOL => {
369 handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, idx);
370 }
371 &Type::CHAR => {
372 handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, idx);
373 }
374 &Type::INT2 => {
375 handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, idx);
376 }
377 &Type::INT4 => {
378 handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, idx);
379 }
380 &Type::INT8 => {
381 handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, idx);
382 }
383 &Type::FLOAT4 => {
384 handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, idx);
385 }
386 &Type::FLOAT8 => {
387 handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, idx);
388 }
389 &Type::TEXT => {
390 handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, idx);
391 }
392 &Type::VARCHAR => {
393 handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, idx);
394 }
395 &Type::BPCHAR => {
396 let builder = builder
397 .as_any_mut()
398 .downcast_mut::<StringBuilder>()
399 .expect("Failed to downcast builder to StringBuilder for Type::BPCHAR");
400 let v: Option<&str> = row
401 .try_get(idx)
402 .expect("Failed to get &str value for column Type::BPCHAR");
403
404 match v {
405 Some(v) => builder.append_value(v.trim_end()),
406 None => builder.append_null(),
407 }
408 }
409 &Type::BYTEA => {
410 handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, idx);
411 }
412 &Type::TIMESTAMP => {
413 let builder = builder
414 .as_any_mut()
415 .downcast_mut::<TimestampNanosecondBuilder>()
416 .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
417 let v: Option<SystemTime> = row
418 .try_get(idx)
419 .expect("Failed to get SystemTime value for column Type::TIMESTAMP");
420
421 match v {
422 Some(v) => {
423 if let Ok(v) = v.duration_since(UNIX_EPOCH) {
424 let timestamp: i64 = v
425 .as_nanos()
426 .try_into()
427 .expect("Failed to convert SystemTime to i64");
428 builder.append_value(timestamp);
429 }
430 }
431 None => builder.append_null(),
432 }
433 }
434 &Type::TIMESTAMPTZ => {
435 let builder = builder
436 .as_any_mut()
437 .downcast_mut::<TimestampNanosecondBuilder>()
438 .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
439 let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(idx).expect(
440 "Failed to get chrono::DateTime<chrono::Utc> value for column Type::TIMESTAMPTZ",
441 );
442
443 match v {
444 Some(v) => {
445 let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for Type::TIMESTAMP"));
446 builder.append_value(timestamp);
447 }
448 None => builder.append_null(),
449 }
450 }
451 &Type::TIME => {
452 let builder = builder
453 .as_any_mut()
454 .downcast_mut::<Time64NanosecondBuilder>()
455 .expect(
456 "Failed to downcast builder to Time64NanosecondBuilder for Type::TIME",
457 );
458 let v: Option<chrono::NaiveTime> = row
459 .try_get(idx)
460 .expect("Failed to get chrono::NaiveTime value for column Type::TIME");
461
462 match v {
463 Some(v) => {
464 let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
465 * 1_000_000_000
466 + i64::from(v.nanosecond());
467 builder.append_value(timestamp);
468 }
469 None => builder.append_null(),
470 }
471 }
472 &Type::DATE => {
473 let builder = builder
474 .as_any_mut()
475 .downcast_mut::<Date32Builder>()
476 .expect("Failed to downcast builder to Date32Builder for Type::DATE");
477 let v: Option<chrono::NaiveDate> = row
478 .try_get(idx)
479 .expect("Failed to get chrono::NaiveDate value for column Type::DATE");
480
481 match v {
482 Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
483 None => builder.append_null(),
484 }
485 }
486 &Type::INT2_ARRAY => {
487 handle_primitive_array_type!(
488 builder,
489 Type::INT2_ARRAY,
490 Int16Builder,
491 i16,
492 row,
493 idx
494 );
495 }
496 &Type::INT4_ARRAY => {
497 handle_primitive_array_type!(
498 builder,
499 Type::INT4_ARRAY,
500 Int32Builder,
501 i32,
502 row,
503 idx
504 );
505 }
506 &Type::INT8_ARRAY => {
507 handle_primitive_array_type!(
508 builder,
509 Type::INT8_ARRAY,
510 Int64Builder,
511 i64,
512 row,
513 idx
514 );
515 }
516 &Type::FLOAT4_ARRAY => {
517 handle_primitive_array_type!(
518 builder,
519 Type::FLOAT4_ARRAY,
520 Float32Builder,
521 f32,
522 row,
523 idx
524 );
525 }
526 &Type::FLOAT8_ARRAY => {
527 handle_primitive_array_type!(
528 builder,
529 Type::FLOAT8_ARRAY,
530 Float64Builder,
531 f64,
532 row,
533 idx
534 );
535 }
536 &Type::TEXT_ARRAY => {
537 handle_primitive_array_type!(
538 builder,
539 Type::TEXT_ARRAY,
540 StringBuilder,
541 &str,
542 row,
543 idx
544 );
545 }
546 &Type::VARCHAR_ARRAY => {
547 handle_primitive_array_type!(
548 builder,
549 Type::VARCHAR_ARRAY,
550 StringBuilder,
551 &str,
552 row,
553 idx
554 );
555 }
556 &Type::BYTEA_ARRAY => {
557 handle_primitive_array_type!(
558 builder,
559 Type::BYTEA_ARRAY,
560 BinaryBuilder,
561 Vec<u8>,
562 row,
563 idx
564 );
565 }
566 other if other.name().eq_ignore_ascii_case("geometry") => {
567 let builder = builder
568 .as_any_mut()
569 .downcast_mut::<BinaryBuilder>()
570 .expect("Failed to downcast builder to BinaryBuilder for Type::geometry");
571 let v: Option<GeometryFromSql> = row
572 .try_get(idx)
573 .expect("Failed to get GeometryFromSql value for column Type::geometry");
574
575 match v {
576 Some(v) => builder.append_value(v.wkb),
577 None => builder.append_null(),
578 }
579 }
580 _ => {
581 return Err(DataFusionError::Execution(format!(
582 "Unsupported postgres type {pg_type:?}",
583 )));
584 }
585 }
586 }
587 }
588 let projected_columns = array_builders
589 .into_iter()
590 .enumerate()
591 .filter(|(idx, _)| projections_contains(projection, *idx))
592 .map(|(_, mut builder)| builder.finish())
593 .collect::<Vec<ArrayRef>>();
594 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
595}