1use crate::connection::projections_contains;
2use crate::transform::transform_batch;
3use crate::{Connection, DFResult, PostgresType, RemoteField, RemoteSchema, RemoteType, Transform};
4use bb8_postgres::tokio_postgres::types::{FromSql, Type};
5use bb8_postgres::tokio_postgres::{NoTls, Row};
6use bb8_postgres::PostgresConnectionManager;
7use chrono::Timelike;
8use datafusion::arrow::array::{
9 make_builder, ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder,
10 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
11 ListBuilder, RecordBatch, StringBuilder, Time64NanosecondBuilder, TimestampNanosecondBuilder,
12};
13use datafusion::arrow::datatypes::{Date32Type, Schema, SchemaRef};
14use datafusion::common::project_schema;
15use datafusion::error::DataFusionError;
16use datafusion::execution::SendableRecordBatchStream;
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use futures::{stream, StreamExt};
19use std::string::ToString;
20use std::sync::Arc;
21use std::time::{SystemTime, UNIX_EPOCH};
22
23#[derive(Debug, Clone)]
24pub struct PostgresConnectionOptions {
25 pub host: String,
26 pub port: u16,
27 pub username: String,
28 pub password: String,
29 pub database: Option<String>,
30}
31
32#[derive(Debug)]
33pub(crate) struct PostgresConnection {
34 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
35}
36
37pub(crate) async fn connect_postgres(
38 options: &PostgresConnectionOptions,
39) -> DFResult<PostgresConnection> {
40 let mut config = bb8_postgres::tokio_postgres::config::Config::new();
41 config
42 .host(&options.host)
43 .port(options.port)
44 .user(&options.username)
45 .password(&options.password);
46 if let Some(database) = &options.database {
47 config.dbname(database);
48 }
49 let manager = PostgresConnectionManager::new(config, NoTls);
50 let pool = bb8::Pool::builder()
51 .max_size(5)
52 .build(manager)
53 .await
54 .map_err(|e| {
55 DataFusionError::Execution(format!(
56 "Failed to create postgres connection pool due to {e}",
57 ))
58 })?;
59
60 Ok(PostgresConnection { pool })
61}
62
63#[async_trait::async_trait]
64impl Connection for PostgresConnection {
65 async fn infer_schema(
66 &self,
67 sql: &str,
68 transform: Option<&dyn Transform>,
69 ) -> DFResult<(RemoteSchema, SchemaRef)> {
70 let conn = self.pool.get().await.map_err(|e| {
71 DataFusionError::Execution(format!(
72 "Failed to get connection from postgres connection pool due to {e}",
73 ))
74 })?;
75 let mut stream = conn
76 .query_raw(sql, Vec::<String>::new())
77 .await
78 .map_err(|e| {
79 DataFusionError::Execution(format!(
80 "Failed to execute query {sql} on postgres due to {e}",
81 ))
82 })?
83 .chunks(1)
84 .boxed();
85
86 let Some(first_chunk) = stream.next().await else {
87 return Err(DataFusionError::Execution(
88 "No data returned from postgres".to_string(),
89 ));
90 };
91 let first_chunk: Vec<Row> = first_chunk
92 .into_iter()
93 .collect::<Result<Vec<_>, _>>()
94 .map_err(|e| {
95 DataFusionError::Execution(format!(
96 "Failed to collect rows from postgres due to {e}",
97 ))
98 })?;
99 let Some(first_row) = first_chunk.first() else {
100 return Err(DataFusionError::Execution(
101 "No data returned from postgres".to_string(),
102 ));
103 };
104 let (remote_schema, pg_types) = build_remote_schema(first_row)?;
105 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
106 let batch = rows_to_batch(
107 std::slice::from_ref(first_row),
108 &pg_types,
109 arrow_schema,
110 None,
111 )?;
112 if let Some(transform) = transform {
113 let transformed_batch = transform_batch(batch, transform, &remote_schema)?;
114 Ok((remote_schema, transformed_batch.schema()))
115 } else {
116 Ok((remote_schema, batch.schema()))
117 }
118 }
119
120 async fn query(
121 &self,
122 sql: String,
123 projection: Option<Vec<usize>>,
124 ) -> DFResult<(SendableRecordBatchStream, RemoteSchema)> {
125 let conn = self.pool.get().await.map_err(|e| {
126 DataFusionError::Execution(format!(
127 "Failed to get connection from postgres connection pool due to {e}",
128 ))
129 })?;
130 let mut stream = conn
131 .query_raw(&sql, Vec::<String>::new())
132 .await
133 .map_err(|e| {
134 DataFusionError::Execution(format!(
135 "Failed to execute query {sql} on postgres due to {e}",
136 ))
137 })?
138 .chunks(2048)
139 .boxed();
140
141 let Some(first_chunk) = stream.next().await else {
142 return Ok((
143 Box::pin(RecordBatchStreamAdapter::new(
144 Arc::new(Schema::empty()),
145 stream::empty(),
146 )),
147 RemoteSchema::empty(),
148 ));
149 };
150 let first_chunk: Vec<Row> = first_chunk
151 .into_iter()
152 .collect::<Result<Vec<_>, _>>()
153 .map_err(|e| {
154 DataFusionError::Execution(format!(
155 "Failed to collect rows from postgres due to {e}",
156 ))
157 })?;
158 let Some(first_row) = first_chunk.first() else {
159 return Err(DataFusionError::Execution(
160 "No data returned from postgres".to_string(),
161 ));
162 };
163 let (remote_schema, pg_types) = build_remote_schema(first_row)?;
164 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
165 let first_chunk = rows_to_batch(
166 first_chunk.as_slice(),
167 &pg_types,
168 arrow_schema.clone(),
169 projection.as_ref(),
170 )?;
171 let schema = first_chunk.schema();
172
173 let mut stream = stream.map(move |rows| {
174 let rows: Vec<Row> = rows
175 .into_iter()
176 .collect::<Result<Vec<_>, _>>()
177 .map_err(|e| {
178 DataFusionError::Execution(format!(
179 "Failed to collect rows from postgres due to {e}",
180 ))
181 })?;
182 let batch = rows_to_batch(
183 rows.as_slice(),
184 &pg_types,
185 arrow_schema.clone(),
186 projection.as_ref(),
187 )?;
188 Ok::<RecordBatch, DataFusionError>(batch)
189 });
190
191 let output_stream = async_stream::stream! {
192 yield Ok(first_chunk);
193 while let Some(batch) = stream.next().await {
194 match batch {
195 Ok(batch) => {
196 yield Ok(batch); }
198 Err(e) => {
199 yield Err(DataFusionError::Execution(format!("Failed to fetch batch: {e}")));
200 }
201 }
202 }
203 };
204
205 Ok((
206 Box::pin(RecordBatchStreamAdapter::new(schema, output_stream)),
207 remote_schema,
208 ))
209 }
210}
211
212fn pg_type_to_remote_type(pg_type: &Type) -> DFResult<RemoteType> {
213 match pg_type {
214 &Type::BOOL => Ok(RemoteType::Postgres(PostgresType::Bool)),
215 &Type::CHAR => Ok(RemoteType::Postgres(PostgresType::Char)),
216 &Type::INT2 => Ok(RemoteType::Postgres(PostgresType::Int2)),
217 &Type::INT4 => Ok(RemoteType::Postgres(PostgresType::Int4)),
218 &Type::INT8 => Ok(RemoteType::Postgres(PostgresType::Int8)),
219 &Type::FLOAT4 => Ok(RemoteType::Postgres(PostgresType::Float4)),
220 &Type::FLOAT8 => Ok(RemoteType::Postgres(PostgresType::Float8)),
221 &Type::TEXT => Ok(RemoteType::Postgres(PostgresType::Text)),
222 &Type::VARCHAR => Ok(RemoteType::Postgres(PostgresType::Varchar)),
223 &Type::BYTEA => Ok(RemoteType::Postgres(PostgresType::Bytea)),
224 &Type::DATE => Ok(RemoteType::Postgres(PostgresType::Date)),
225 &Type::TIMESTAMP => Ok(RemoteType::Postgres(PostgresType::Timestamp)),
226 &Type::TIMESTAMPTZ => Ok(RemoteType::Postgres(PostgresType::TimestampTz)),
227 &Type::TIME => Ok(RemoteType::Postgres(PostgresType::Time)),
228 &Type::INT2_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int2Array)),
229 &Type::INT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int4Array)),
230 &Type::INT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Int8Array)),
231 &Type::FLOAT4_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float4Array)),
232 &Type::FLOAT8_ARRAY => Ok(RemoteType::Postgres(PostgresType::Float8Array)),
233 &Type::TEXT_ARRAY => Ok(RemoteType::Postgres(PostgresType::TextArray)),
234 &Type::VARCHAR_ARRAY => Ok(RemoteType::Postgres(PostgresType::VarcharArray)),
235 &Type::BYTEA_ARRAY => Ok(RemoteType::Postgres(PostgresType::ByteaArray)),
236 other if other.name().eq_ignore_ascii_case("geometry") => {
237 Ok(RemoteType::Postgres(PostgresType::PostGisGeometry))
238 }
239 _ => Err(DataFusionError::NotImplemented(format!(
240 "Unsupported postgres type {pg_type:?}",
241 ))),
242 }
243}
244
245fn build_remote_schema(row: &Row) -> DFResult<(RemoteSchema, Vec<Type>)> {
246 let mut remote_fields = vec![];
247 let mut pg_types = vec![];
248 for col in row.columns() {
249 let col_type = col.type_();
250 pg_types.push(col_type.clone());
251 remote_fields.push(RemoteField::new(
252 col.name(),
253 pg_type_to_remote_type(col_type)?,
254 true,
255 ));
256 }
257 Ok((RemoteSchema::new(remote_fields), pg_types))
258}
259
260macro_rules! handle_primitive_type {
261 ($builder:expr, $pg_type:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
262 let builder = $builder
263 .as_any_mut()
264 .downcast_mut::<$builder_ty>()
265 .expect(concat!(
266 "Failed to downcast builder to ",
267 stringify!($builder_ty),
268 " for ",
269 stringify!($pg_type)
270 ));
271 let v: Option<$value_ty> = $row.try_get($index).expect(concat!(
272 "Failed to get ",
273 stringify!($value_ty),
274 " value for column ",
275 stringify!($pg_type)
276 ));
277
278 match v {
279 Some(v) => builder.append_value(v),
280 None => builder.append_null(),
281 }
282 }};
283}
284
285macro_rules! handle_primitive_array_type {
286 ($builder:expr, $pg_type:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
287 let builder = $builder
288 .as_any_mut()
289 .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
290 .expect(concat!(
291 "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for ",
292 stringify!($pg_type)
293 ));
294 let values_builder = builder
295 .values()
296 .as_any_mut()
297 .downcast_mut::<$values_builder_ty>()
298 .expect(concat!(
299 "Failed to downcast values builder to ",
300 stringify!($values_builder_ty),
301 " for ",
302 stringify!($pg_type)
303 ));
304 let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).expect(concat!(
305 "Failed to get ",
306 stringify!($primitive_value_ty),
307 " array value for column ",
308 stringify!($pg_type)
309 ));
310
311 match v {
312 Some(v) => {
313 let v = v.into_iter().map(Some);
314 values_builder.extend(v);
315 builder.append(true);
316 }
317 None => builder.append_null(),
318 }
319 }};
320}
321
322pub struct GeometryFromSql<'a> {
323 wkb: &'a [u8],
324}
325
326impl<'a> FromSql<'a> for GeometryFromSql<'a> {
327 fn from_sql(
328 _ty: &Type,
329 raw: &'a [u8],
330 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
331 Ok(GeometryFromSql { wkb: raw })
332 }
333
334 fn accepts(ty: &Type) -> bool {
335 matches!(ty.name(), "geometry")
336 }
337}
338
339fn rows_to_batch(
340 rows: &[Row],
341 pg_types: &Vec<Type>,
342 arrow_schema: SchemaRef,
343 projection: Option<&Vec<usize>>,
344) -> DFResult<RecordBatch> {
345 let projected_schema = project_schema(&arrow_schema, projection)?;
346 let mut array_builders = vec![];
347 for field in arrow_schema.fields() {
348 let builder = make_builder(&field.data_type(), rows.len());
349 array_builders.push(builder);
350 }
351 for row in rows {
352 for (idx, pg_type) in pg_types.iter().enumerate() {
353 if !projections_contains(projection, idx) {
354 continue;
355 }
356 let builder = &mut array_builders[idx];
357 match pg_type {
358 &Type::BOOL => {
359 handle_primitive_type!(builder, Type::BOOL, BooleanBuilder, bool, row, idx);
360 }
361 &Type::CHAR => {
362 handle_primitive_type!(builder, Type::CHAR, Int8Builder, i8, row, idx);
363 }
364 &Type::INT2 => {
365 handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, idx);
366 }
367 &Type::INT4 => {
368 handle_primitive_type!(builder, Type::INT4, Int32Builder, i32, row, idx);
369 }
370 &Type::INT8 => {
371 handle_primitive_type!(builder, Type::INT8, Int64Builder, i64, row, idx);
372 }
373 &Type::FLOAT4 => {
374 handle_primitive_type!(builder, Type::FLOAT4, Float32Builder, f32, row, idx);
375 }
376 &Type::FLOAT8 => {
377 handle_primitive_type!(builder, Type::FLOAT8, Float64Builder, f64, row, idx);
378 }
379 &Type::TEXT => {
380 handle_primitive_type!(builder, Type::TEXT, StringBuilder, &str, row, idx);
381 }
382 &Type::VARCHAR => {
383 handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, idx);
384 }
385 &Type::BYTEA => {
386 handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec<u8>, row, idx);
387 }
388 &Type::TIMESTAMP => {
389 let builder = builder
390 .as_any_mut()
391 .downcast_mut::<TimestampNanosecondBuilder>()
392 .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
393 let v: Option<SystemTime> = row
394 .try_get(idx)
395 .expect("Failed to get SystemTime value for column Type::TIMESTAMP");
396
397 match v {
398 Some(v) => {
399 if let Ok(v) = v.duration_since(UNIX_EPOCH) {
400 let timestamp: i64 = v
401 .as_nanos()
402 .try_into()
403 .expect("Failed to convert SystemTime to i64");
404 builder.append_value(timestamp);
405 }
406 }
407 None => builder.append_null(),
408 }
409 }
410 &Type::TIMESTAMPTZ => {
411 let builder = builder
412 .as_any_mut()
413 .downcast_mut::<TimestampNanosecondBuilder>()
414 .expect("Failed to downcast builder to TimestampNanosecondBuilder for Type::TIMESTAMP");
415 let v: Option<chrono::DateTime<chrono::Utc>> = row.try_get(idx).expect(
416 "Failed to get chrono::DateTime<chrono::Utc> value for column Type::TIMESTAMPTZ",
417 );
418
419 match v {
420 Some(v) => {
421 let timestamp: i64 = v.timestamp_nanos_opt().expect(&format!("Failed to get timestamp in nanoseconds from {v} for Type::TIMESTAMP"));
422 builder.append_value(timestamp);
423 }
424 None => {}
425 }
426 }
427 &Type::TIME => {
428 let builder = builder
429 .as_any_mut()
430 .downcast_mut::<Time64NanosecondBuilder>()
431 .expect(
432 "Failed to downcast builder to Time64NanosecondBuilder for Type::TIME",
433 );
434 let v: Option<chrono::NaiveTime> = row
435 .try_get(idx)
436 .expect("Failed to get chrono::NaiveTime value for column Type::TIME");
437
438 match v {
439 Some(v) => {
440 let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
441 * 1_000_000_000
442 + i64::from(v.nanosecond());
443 builder.append_value(timestamp);
444 }
445 None => builder.append_null(),
446 }
447 }
448 &Type::DATE => {
449 let builder = builder
450 .as_any_mut()
451 .downcast_mut::<Date32Builder>()
452 .expect("Failed to downcast builder to Date32Builder for Type::DATE");
453 let v: Option<chrono::NaiveDate> = row
454 .try_get(idx)
455 .expect("Failed to get chrono::NaiveDate value for column Type::DATE");
456
457 match v {
458 Some(v) => builder.append_value(Date32Type::from_naive_date(v)),
459 None => builder.append_null(),
460 }
461 }
462 &Type::INT2_ARRAY => {
463 handle_primitive_array_type!(
464 builder,
465 Type::INT2_ARRAY,
466 Int16Builder,
467 i16,
468 row,
469 idx
470 );
471 }
472 &Type::INT4_ARRAY => {
473 handle_primitive_array_type!(
474 builder,
475 Type::INT4_ARRAY,
476 Int32Builder,
477 i32,
478 row,
479 idx
480 );
481 }
482 &Type::INT8_ARRAY => {
483 handle_primitive_array_type!(
484 builder,
485 Type::INT8_ARRAY,
486 Int64Builder,
487 i64,
488 row,
489 idx
490 );
491 }
492 &Type::FLOAT4_ARRAY => {
493 handle_primitive_array_type!(
494 builder,
495 Type::FLOAT4_ARRAY,
496 Float32Builder,
497 f32,
498 row,
499 idx
500 );
501 }
502 &Type::FLOAT8_ARRAY => {
503 handle_primitive_array_type!(
504 builder,
505 Type::FLOAT8_ARRAY,
506 Float64Builder,
507 f64,
508 row,
509 idx
510 );
511 }
512 &Type::TEXT_ARRAY => {
513 handle_primitive_array_type!(
514 builder,
515 Type::TEXT_ARRAY,
516 StringBuilder,
517 &str,
518 row,
519 idx
520 );
521 }
522 &Type::VARCHAR_ARRAY => {
523 handle_primitive_array_type!(
524 builder,
525 Type::VARCHAR_ARRAY,
526 StringBuilder,
527 &str,
528 row,
529 idx
530 );
531 }
532 &Type::BYTEA_ARRAY => {
533 handle_primitive_array_type!(
534 builder,
535 Type::BYTEA_ARRAY,
536 BinaryBuilder,
537 Vec<u8>,
538 row,
539 idx
540 );
541 }
542 other if other.name().eq_ignore_ascii_case("geometry") => {
543 let builder = builder
544 .as_any_mut()
545 .downcast_mut::<BinaryBuilder>()
546 .expect("Failed to downcast builder to BinaryBuilder for Type::geometry");
547 let v: Option<GeometryFromSql> = row
548 .try_get(idx)
549 .expect("Failed to get GeometryFromSql value for column Type::geometry");
550
551 match v {
552 Some(v) => builder.append_value(v.wkb),
553 None => builder.append_null(),
554 }
555 }
556 _ => {
557 return Err(DataFusionError::Execution(format!(
558 "Unsupported postgres type {pg_type:?}",
559 )));
560 }
561 }
562 }
563 }
564 let projected_columns = array_builders
565 .into_iter()
566 .enumerate()
567 .filter(|(idx, _)| projections_contains(projection, *idx))
568 .map(|(_, mut builder)| builder.finish())
569 .collect::<Vec<ArrayRef>>();
570 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
571}