1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::{big_decimal_to_i128, big_decimal_to_i256};
3use crate::{
4 Connection, ConnectionOptions, DFResult, Pool, PostgresType, RemoteField, RemoteSchema,
5 RemoteSchemaRef, RemoteSource, RemoteType, Unparse, unparse_array,
6};
7use bb8_postgres::PostgresConnectionManager;
8use bb8_postgres::tokio_postgres::types::{FromSql, Type};
9use bb8_postgres::tokio_postgres::{NoTls, Row, Statement};
10use bigdecimal::BigDecimal;
11use byteorder::{BigEndian, ReadBytesExt};
12use chrono::Timelike;
13use datafusion::arrow::array::{
14 ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
15 Decimal256Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int16Builder,
16 Int32Builder, Int64Builder, IntervalMonthDayNanoBuilder, LargeStringBuilder, ListBuilder,
17 RecordBatch, RecordBatchOptions, StringBuilder, Time64MicrosecondBuilder,
18 Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder,
19 UInt32Builder, make_builder,
20};
21use datafusion::arrow::datatypes::{
22 DECIMAL256_MAX_PRECISION, DataType, Date32Type, IntervalMonthDayNanoType, IntervalUnit,
23 SchemaRef, TimeUnit, i256,
24};
25
26use datafusion::common::project_schema;
27use datafusion::error::DataFusionError;
28use datafusion::execution::SendableRecordBatchStream;
29use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
30use derive_getters::Getters;
31use derive_with::With;
32use futures::StreamExt;
33use log::debug;
34use num_bigint::{BigInt, Sign};
35use std::any::Any;
36use std::string::ToString;
37use std::sync::Arc;
38use uuid::Uuid;
39
40#[derive(Debug, Clone, With, Getters)]
41pub struct PostgresConnectionOptions {
42 pub(crate) host: String,
43 pub(crate) port: u16,
44 pub(crate) username: String,
45 pub(crate) password: String,
46 pub(crate) database: Option<String>,
47 pub(crate) pool_max_size: usize,
48 pub(crate) stream_chunk_size: usize,
49 pub(crate) default_numeric_scale: i8,
50}
51
52impl PostgresConnectionOptions {
53 pub fn new(
54 host: impl Into<String>,
55 port: u16,
56 username: impl Into<String>,
57 password: impl Into<String>,
58 ) -> Self {
59 Self {
60 host: host.into(),
61 port,
62 username: username.into(),
63 password: password.into(),
64 database: None,
65 pool_max_size: 10,
66 stream_chunk_size: 2048,
67 default_numeric_scale: 10,
68 }
69 }
70}
71
72impl From<PostgresConnectionOptions> for ConnectionOptions {
73 fn from(options: PostgresConnectionOptions) -> Self {
74 ConnectionOptions::Postgres(options)
75 }
76}
77
78#[derive(Debug)]
79pub struct PostgresPool {
80 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
81 options: Arc<PostgresConnectionOptions>,
82}
83
84#[async_trait::async_trait]
85impl Pool for PostgresPool {
86 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
87 let conn = self.pool.get_owned().await.map_err(|e| {
88 DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
89 })?;
90 Ok(Arc::new(PostgresConnection {
91 conn,
92 options: self.options.clone(),
93 }))
94 }
95}
96
97pub(crate) async fn connect_postgres(
98 options: &PostgresConnectionOptions,
99) -> DFResult<PostgresPool> {
100 let mut config = bb8_postgres::tokio_postgres::config::Config::new();
101 config
102 .host(&options.host)
103 .port(options.port)
104 .user(&options.username)
105 .password(&options.password);
106 if let Some(database) = &options.database {
107 config.dbname(database);
108 }
109 let manager = PostgresConnectionManager::new(config, NoTls);
110 let pool = bb8::Pool::builder()
111 .max_size(options.pool_max_size as u32)
112 .build(manager)
113 .await
114 .map_err(|e| {
115 DataFusionError::Execution(format!(
116 "Failed to create postgres connection pool due to {e}",
117 ))
118 })?;
119
120 Ok(PostgresPool {
121 pool,
122 options: Arc::new(options.clone()),
123 })
124}
125
126#[derive(Debug)]
127pub(crate) struct PostgresConnection {
128 conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
129 options: Arc<PostgresConnectionOptions>,
130}
131
132#[async_trait::async_trait]
133impl Connection for PostgresConnection {
134 fn as_any(&self) -> &dyn Any {
135 self
136 }
137
138 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
139 match source {
140 RemoteSource::Table(table) => {
141 let db_type = RemoteDbType::Postgres;
142 let where_condition = if table.len() == 1 {
143 format!("table_name = {}", db_type.sql_string_literal(&table[0]))
144 } else if table.len() == 2 {
145 format!(
146 "table_schema = {} AND table_name = {}",
147 db_type.sql_string_literal(&table[0]),
148 db_type.sql_string_literal(&table[1])
149 )
150 } else {
151 format!(
152 "table_catalog = {} AND table_schema = {} AND table_name = {}",
153 db_type.sql_string_literal(&table[0]),
154 db_type.sql_string_literal(&table[1]),
155 db_type.sql_string_literal(&table[2])
156 )
157 };
158 let sql = format!(
159 "
160select
161 column_name,
162 case
163 when data_type = 'ARRAY'
164 then data_type || udt_name
165 when data_type = 'USER-DEFINED'
166 then udt_schema || '.' || udt_name
167 else
168 data_type
169 end as column_type,
170 numeric_precision,
171 numeric_scale,
172 is_nullable
173from information_schema.columns
174where {}
175order by ordinal_position",
176 where_condition
177 );
178 let rows = self.conn.query(&sql, &[]).await.map_err(|e| {
179 DataFusionError::Execution(format!(
180 "Failed to execute query {sql} on postgres: {e:?}",
181 ))
182 })?;
183 let remote_schema = Arc::new(build_remote_schema_for_table(
184 rows,
185 self.options.default_numeric_scale,
186 )?);
187 Ok(remote_schema)
188 }
189 RemoteSource::Query(_query) => {
190 let sql = source.query(RemoteDbType::Postgres);
191 let stmt = self.conn.prepare(&sql).await.map_err(|e| {
192 DataFusionError::Execution(format!(
193 "Failed to execute query {sql} on postgres: {e:?}",
194 ))
195 })?;
196 let remote_schema = Arc::new(
197 build_remote_schema_for_query(stmt, self.options.default_numeric_scale).await?,
198 );
199 Ok(remote_schema)
200 }
201 }
202 }
203
204 async fn query(
205 &self,
206 conn_options: &ConnectionOptions,
207 source: &RemoteSource,
208 table_schema: SchemaRef,
209 projection: Option<&Vec<usize>>,
210 unparsed_filters: &[String],
211 limit: Option<usize>,
212 ) -> DFResult<SendableRecordBatchStream> {
213 let projected_schema = project_schema(&table_schema, projection)?;
214
215 let sql = RemoteDbType::Postgres.rewrite_query(source, unparsed_filters, limit);
216 debug!("[remote-table] executing postgres query: {sql}");
217
218 let projection = projection.cloned();
219 let chunk_size = conn_options.stream_chunk_size();
220 let stream = self
221 .conn
222 .query_raw(&sql, Vec::<String>::new())
223 .await
224 .map_err(|e| {
225 DataFusionError::Execution(format!(
226 "Failed to execute query {sql} on postgres: {e}",
227 ))
228 })?
229 .chunks(chunk_size)
230 .boxed();
231
232 let stream = stream.map(move |rows| {
233 let rows: Vec<Row> = rows
234 .into_iter()
235 .collect::<Result<Vec<_>, _>>()
236 .map_err(|e| {
237 DataFusionError::Execution(format!(
238 "Failed to collect rows from postgres due to {e}",
239 ))
240 })?;
241 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
242 });
243
244 Ok(Box::pin(RecordBatchStreamAdapter::new(
245 projected_schema,
246 stream,
247 )))
248 }
249
250 async fn insert(
251 &self,
252 _conn_options: &ConnectionOptions,
253 unparser: Arc<dyn Unparse>,
254 table: &[String],
255 remote_schema: RemoteSchemaRef,
256 mut input: SendableRecordBatchStream,
257 ) -> DFResult<usize> {
258 let input_schema = input.schema();
259
260 let mut total_count = 0;
261 while let Some(batch) = input.next().await {
262 let batch = batch?;
263
264 let mut columns = Vec::with_capacity(remote_schema.fields.len());
265 for i in 0..batch.num_columns() {
266 let input_field = input_schema.field(i);
267 let remote_field = &remote_schema.fields[i];
268 if remote_field.auto_increment && input_field.is_nullable() {
269 continue;
270 }
271
272 let remote_type = remote_schema.fields[i].remote_type.clone();
273 let array = batch.column(i);
274 let column = unparse_array(unparser.as_ref(), array, remote_type)?;
275 columns.push(column);
276 }
277
278 let num_rows = columns[0].len();
279 let num_columns = columns.len();
280
281 let mut values = Vec::with_capacity(num_rows);
282 for i in 0..num_rows {
283 let mut value = Vec::with_capacity(num_columns);
284 for col in columns.iter() {
285 value.push(col[i].as_str());
286 }
287 values.push(format!("({})", value.join(",")));
288 }
289
290 let mut col_names = Vec::with_capacity(remote_schema.fields.len());
291 for (remote_field, input_field) in
292 remote_schema.fields.iter().zip(input_schema.fields.iter())
293 {
294 if remote_field.auto_increment && input_field.is_nullable() {
295 continue;
296 }
297 col_names.push(RemoteDbType::Postgres.sql_identifier(&remote_field.name));
298 }
299
300 let sql = format!(
301 "INSERT INTO {} ({}) VALUES {}",
302 RemoteDbType::Postgres.sql_table_name(table),
303 col_names.join(","),
304 values.join(",")
305 );
306
307 let count = self.conn.execute(&sql, &[]).await.map_err(|e| {
308 DataFusionError::Execution(format!(
309 "Failed to execute insert statement on postgres: {e:?}, sql: {sql}"
310 ))
311 })?;
312 total_count += count as usize;
313 }
314
315 Ok(total_count)
316 }
317}
318
319async fn build_remote_schema_for_query(
320 stmt: Statement,
321 default_numeric_scale: i8,
322) -> DFResult<RemoteSchema> {
323 let mut remote_fields = Vec::new();
324 for col in stmt.columns().iter() {
325 let pg_type = col.type_();
326 let remote_type = pg_type_to_remote_type(pg_type, default_numeric_scale)?;
327 remote_fields.push(RemoteField::new(
328 col.name(),
329 RemoteType::Postgres(remote_type),
330 true,
331 ));
332 }
333 Ok(RemoteSchema::new(remote_fields))
334}
335
336fn pg_type_to_remote_type(pg_type: &Type, default_numeric_scale: i8) -> DFResult<PostgresType> {
337 match pg_type {
338 &Type::INT2 => Ok(PostgresType::Int2),
339 &Type::INT4 => Ok(PostgresType::Int4),
340 &Type::INT8 => Ok(PostgresType::Int8),
341 &Type::FLOAT4 => Ok(PostgresType::Float4),
342 &Type::FLOAT8 => Ok(PostgresType::Float8),
343 &Type::NUMERIC => Ok(PostgresType::Numeric(
344 DECIMAL256_MAX_PRECISION,
345 default_numeric_scale,
346 )),
347 &Type::OID => Ok(PostgresType::Oid),
348 &Type::NAME => Ok(PostgresType::Name),
349 &Type::VARCHAR => Ok(PostgresType::Varchar),
350 &Type::BPCHAR => Ok(PostgresType::Bpchar),
351 &Type::TEXT => Ok(PostgresType::Text),
352 &Type::BYTEA => Ok(PostgresType::Bytea),
353 &Type::DATE => Ok(PostgresType::Date),
354 &Type::TIMESTAMP => Ok(PostgresType::Timestamp),
355 &Type::TIMESTAMPTZ => Ok(PostgresType::TimestampTz),
356 &Type::TIME => Ok(PostgresType::Time),
357 &Type::INTERVAL => Ok(PostgresType::Interval),
358 &Type::BOOL => Ok(PostgresType::Bool),
359 &Type::JSON => Ok(PostgresType::Json),
360 &Type::JSONB => Ok(PostgresType::Jsonb),
361 &Type::INT2_ARRAY => Ok(PostgresType::Int2Array),
362 &Type::INT4_ARRAY => Ok(PostgresType::Int4Array),
363 &Type::INT8_ARRAY => Ok(PostgresType::Int8Array),
364 &Type::FLOAT4_ARRAY => Ok(PostgresType::Float4Array),
365 &Type::FLOAT8_ARRAY => Ok(PostgresType::Float8Array),
366 &Type::VARCHAR_ARRAY => Ok(PostgresType::VarcharArray),
367 &Type::BPCHAR_ARRAY => Ok(PostgresType::BpcharArray),
368 &Type::TEXT_ARRAY => Ok(PostgresType::TextArray),
369 &Type::BYTEA_ARRAY => Ok(PostgresType::ByteaArray),
370 &Type::BOOL_ARRAY => Ok(PostgresType::BoolArray),
371 &Type::XML => Ok(PostgresType::Xml),
372 &Type::UUID => Ok(PostgresType::Uuid),
373 other if other.name().eq_ignore_ascii_case("geometry") => Ok(PostgresType::PostGisGeometry),
374 _ => Err(DataFusionError::NotImplemented(format!(
375 "Unsupported postgres type {pg_type:?}",
376 ))),
377 }
378}
379
380fn build_remote_schema_for_table(
381 rows: Vec<Row>,
382 default_numeric_scale: i8,
383) -> DFResult<RemoteSchema> {
384 let mut remote_fields = vec![];
385 for row in rows {
386 let columa_name = row.try_get::<_, String>(0).map_err(|e| {
387 DataFusionError::Execution(format!("Failed to get col name from postgres row: {e:?}"))
388 })?;
389 let column_type = row.try_get::<_, String>(1).map_err(|e| {
390 DataFusionError::Execution(format!("Failed to get col type from postgres row: {e:?}"))
391 })?;
392 let numeric_precision = row.try_get::<_, Option<i32>>(2).map_err(|e| {
393 DataFusionError::Execution(format!(
394 "Failed to get numeric precision from postgres row: {e:?}"
395 ))
396 })?;
397 let numeric_scale = row.try_get::<_, Option<i32>>(3).map_err(|e| {
398 DataFusionError::Execution(format!(
399 "Failed to get numeric scale from postgres row: {e:?}"
400 ))
401 })?;
402 let pg_type = parse_pg_type(
403 &column_type,
404 numeric_precision,
405 numeric_scale.unwrap_or(default_numeric_scale as i32),
406 )?;
407 let is_nullable = row.try_get::<_, String>(4).map_err(|e| {
408 DataFusionError::Execution(format!(
409 "Failed to get is_nullable from postgres row: {e:?}"
410 ))
411 })?;
412 let nullable = match is_nullable.as_str() {
413 "YES" => true,
414 "NO" => false,
415 _ => {
416 return Err(DataFusionError::Execution(format!(
417 "Unsupported postgres is_nullable value {is_nullable}"
418 )));
419 }
420 };
421 remote_fields.push(RemoteField::new(
422 columa_name,
423 RemoteType::Postgres(pg_type),
424 nullable,
425 ));
426 }
427 Ok(RemoteSchema::new(remote_fields))
428}
429
430fn parse_pg_type(
431 pg_type: &str,
432 numeric_precision: Option<i32>,
433 numeric_scale: i32,
434) -> DFResult<PostgresType> {
435 match pg_type {
436 "smallint" => Ok(PostgresType::Int2),
437 "integer" => Ok(PostgresType::Int4),
438 "bigint" => Ok(PostgresType::Int8),
439 "real" => Ok(PostgresType::Float4),
440 "double precision" => Ok(PostgresType::Float8),
441 "numeric" => Ok(PostgresType::Numeric(
442 numeric_precision.unwrap_or(DECIMAL256_MAX_PRECISION as i32) as u8,
443 numeric_scale as i8,
444 )),
445 "character varying" => Ok(PostgresType::Varchar),
446 "character" => Ok(PostgresType::Bpchar),
447 "text" => Ok(PostgresType::Text),
448 "bytea" => Ok(PostgresType::Bytea),
449 "date" => Ok(PostgresType::Date),
450 "time without time zone" => Ok(PostgresType::Time),
451 "timestamp without time zone" => Ok(PostgresType::Timestamp),
452 "timestamp with time zone" => Ok(PostgresType::TimestampTz),
453 "interval" => Ok(PostgresType::Interval),
454 "boolean" => Ok(PostgresType::Bool),
455 "json" => Ok(PostgresType::Json),
456 "jsonb" => Ok(PostgresType::Jsonb),
457 "public.geometry" => Ok(PostgresType::PostGisGeometry),
458 "ARRAY_int2" => Ok(PostgresType::Int2Array),
459 "ARRAY_int4" => Ok(PostgresType::Int4Array),
460 "ARRAY_int8" => Ok(PostgresType::Int8Array),
461 "ARRAY_float4" => Ok(PostgresType::Float4Array),
462 "ARRAY_float8" => Ok(PostgresType::Float8Array),
463 "ARRAY_varchar" => Ok(PostgresType::VarcharArray),
464 "ARRAY_bpchar" => Ok(PostgresType::BpcharArray),
465 "ARRAY_text" => Ok(PostgresType::TextArray),
466 "ARRAY_bytea" => Ok(PostgresType::ByteaArray),
467 "ARRAY_bool" => Ok(PostgresType::BoolArray),
468 "xml" => Ok(PostgresType::Xml),
469 "uuid" => Ok(PostgresType::Uuid),
470 "oid" => Ok(PostgresType::Oid),
471 "name" => Ok(PostgresType::Name),
472 _ => Err(DataFusionError::Execution(format!(
473 "Unsupported postgres type {pg_type}"
474 ))),
475 }
476}
477
478macro_rules! handle_primitive_type {
479 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
480 let builder = $builder
481 .as_any_mut()
482 .downcast_mut::<$builder_ty>()
483 .unwrap_or_else(|| {
484 panic!(
485 "Failed to downcast builder to {} for {:?} and {:?}",
486 stringify!($builder_ty),
487 $field,
488 $col
489 )
490 });
491 let v: Option<$value_ty> = $row.try_get($index).map_err(|e| {
492 DataFusionError::Execution(format!(
493 "Failed to get {} value for {:?} and {:?}: {e:?}",
494 stringify!($value_ty),
495 $field,
496 $col
497 ))
498 })?;
499
500 match v {
501 Some(v) => builder.append_value($convert(v)?),
502 None => builder.append_null(),
503 }
504 }};
505}
506
507macro_rules! handle_primitive_array_type {
508 ($builder:expr, $field:expr, $col:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
509 let builder = $builder
510 .as_any_mut()
511 .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
512 .unwrap_or_else(|| {
513 panic!(
514 "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for {:?} and {:?}",
515 $field, $col
516 )
517 });
518 let values_builder = builder
519 .values()
520 .as_any_mut()
521 .downcast_mut::<$values_builder_ty>()
522 .unwrap_or_else(|| {
523 panic!(
524 "Failed to downcast values builder to {} for {:?} and {:?}",
525 stringify!($builder_ty),
526 $field,
527 $col,
528 )
529 });
530 let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).map_err(|e| {
531 DataFusionError::Execution(format!(
532 "Failed to get {} array value for {:?} and {:?}: {e:?}",
533 stringify!($value_ty),
534 $field,
535 $col,
536 ))
537 })?;
538
539 match v {
540 Some(v) => {
541 let v = v.into_iter().map(Some);
542 values_builder.extend(v);
543 builder.append(true);
544 }
545 None => builder.append_null(),
546 }
547 }};
548}
549
550#[derive(Debug)]
551struct BigDecimalFromSql {
552 inner: BigDecimal,
553}
554
555impl BigDecimalFromSql {
556 fn to_i128_with_scale(&self, scale: i32) -> DFResult<i128> {
557 big_decimal_to_i128(&self.inner, Some(scale))
558 }
559
560 fn to_i256_with_scale(&self, scale: i32) -> DFResult<i256> {
561 big_decimal_to_i256(&self.inner, Some(scale))
562 }
563}
564
565#[allow(clippy::cast_sign_loss)]
566#[allow(clippy::cast_possible_wrap)]
567#[allow(clippy::cast_possible_truncation)]
568impl<'a> FromSql<'a> for BigDecimalFromSql {
569 fn from_sql(
570 _ty: &Type,
571 raw: &'a [u8],
572 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
573 let raw_u16: Vec<u16> = raw
574 .chunks(2)
575 .map(|chunk| {
576 if chunk.len() == 2 {
577 u16::from_be_bytes([chunk[0], chunk[1]])
578 } else {
579 u16::from_be_bytes([chunk[0], 0])
580 }
581 })
582 .collect();
583
584 let base_10_000_digit_count = raw_u16[0];
585 let weight = raw_u16[1] as i16;
586 let sign = raw_u16[2];
587 let scale = raw_u16[3];
588
589 let mut base_10_000_digits = Vec::new();
590 for i in 4..4 + base_10_000_digit_count {
591 base_10_000_digits.push(raw_u16[i as usize]);
592 }
593
594 let mut u8_digits = Vec::new();
595 for &base_10_000_digit in base_10_000_digits.iter().rev() {
596 let mut base_10_000_digit = base_10_000_digit;
597 let mut temp_result = Vec::new();
598 while base_10_000_digit > 0 {
599 temp_result.push((base_10_000_digit % 10) as u8);
600 base_10_000_digit /= 10;
601 }
602 while temp_result.len() < 4 {
603 temp_result.push(0);
604 }
605 u8_digits.extend(temp_result);
606 }
607 u8_digits.reverse();
608
609 let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
610 let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
611 u8_digits.resize(size as usize, 0);
612
613 let sign = match sign {
614 0x4000 => Sign::Minus,
615 0x0000 => Sign::Plus,
616 _ => {
617 return Err(Box::new(DataFusionError::Execution(
618 "Failed to parse big decimal from postgres numeric value".to_string(),
619 )));
620 }
621 };
622
623 let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
624 return Err(Box::new(DataFusionError::Execution(
625 "Failed to parse big decimal from postgres numeric value".to_string(),
626 )));
627 };
628 Ok(BigDecimalFromSql {
629 inner: BigDecimal::new(digits, i64::from(scale)),
630 })
631 }
632
633 fn accepts(ty: &Type) -> bool {
634 matches!(*ty, Type::NUMERIC)
635 }
636}
637
638#[derive(Debug)]
641struct IntervalFromSql {
642 time: i64,
643 day: i32,
644 month: i32,
645}
646
647impl<'a> FromSql<'a> for IntervalFromSql {
648 fn from_sql(
649 _ty: &Type,
650 raw: &'a [u8],
651 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
652 let mut cursor = std::io::Cursor::new(raw);
653
654 let time = cursor.read_i64::<BigEndian>()?;
655 let day = cursor.read_i32::<BigEndian>()?;
656 let month = cursor.read_i32::<BigEndian>()?;
657
658 Ok(IntervalFromSql { time, day, month })
659 }
660
661 fn accepts(ty: &Type) -> bool {
662 matches!(*ty, Type::INTERVAL)
663 }
664}
665
666struct GeometryFromSql<'a> {
667 wkb: &'a [u8],
668}
669
670impl<'a> FromSql<'a> for GeometryFromSql<'a> {
671 fn from_sql(
672 _ty: &Type,
673 raw: &'a [u8],
674 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
675 Ok(GeometryFromSql { wkb: raw })
676 }
677
678 fn accepts(ty: &Type) -> bool {
679 matches!(ty.name(), "geometry")
680 }
681}
682
683struct XmlFromSql<'a> {
684 xml: &'a str,
685}
686
687impl<'a> FromSql<'a> for XmlFromSql<'a> {
688 fn from_sql(
689 _ty: &Type,
690 raw: &'a [u8],
691 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
692 let xml = str::from_utf8(raw)?;
693 Ok(XmlFromSql { xml })
694 }
695
696 fn accepts(ty: &Type) -> bool {
697 matches!(*ty, Type::XML)
698 }
699}
700
701fn rows_to_batch(
702 rows: &[Row],
703 table_schema: &SchemaRef,
704 projection: Option<&Vec<usize>>,
705) -> DFResult<RecordBatch> {
706 let projected_schema = project_schema(table_schema, projection)?;
707 let mut array_builders = vec![];
708 for field in table_schema.fields() {
709 let builder = make_builder(field.data_type(), rows.len());
710 array_builders.push(builder);
711 }
712
713 for row in rows {
714 for (idx, field) in table_schema.fields.iter().enumerate() {
715 if !projections_contains(projection, idx) {
716 continue;
717 }
718 let builder = &mut array_builders[idx];
719 let col = row.columns().get(idx);
720 match field.data_type() {
721 DataType::Int16 => {
722 handle_primitive_type!(
723 builder,
724 field,
725 col,
726 Int16Builder,
727 i16,
728 row,
729 idx,
730 just_return
731 );
732 }
733 DataType::Int32 => {
734 handle_primitive_type!(
735 builder,
736 field,
737 col,
738 Int32Builder,
739 i32,
740 row,
741 idx,
742 just_return
743 );
744 }
745 DataType::UInt32 => {
746 handle_primitive_type!(
747 builder,
748 field,
749 col,
750 UInt32Builder,
751 u32,
752 row,
753 idx,
754 just_return
755 );
756 }
757 DataType::Int64 => {
758 handle_primitive_type!(
759 builder,
760 field,
761 col,
762 Int64Builder,
763 i64,
764 row,
765 idx,
766 just_return
767 );
768 }
769 DataType::Float32 => {
770 handle_primitive_type!(
771 builder,
772 field,
773 col,
774 Float32Builder,
775 f32,
776 row,
777 idx,
778 just_return
779 );
780 }
781 DataType::Float64 => {
782 handle_primitive_type!(
783 builder,
784 field,
785 col,
786 Float64Builder,
787 f64,
788 row,
789 idx,
790 just_return
791 );
792 }
793 DataType::Decimal128(_precision, scale) => {
794 handle_primitive_type!(
795 builder,
796 field,
797 col,
798 Decimal128Builder,
799 BigDecimalFromSql,
800 row,
801 idx,
802 |v: BigDecimalFromSql| { v.to_i128_with_scale(*scale as i32) }
803 );
804 }
805 DataType::Decimal256(_precision, scale) => {
806 handle_primitive_type!(
807 builder,
808 field,
809 col,
810 Decimal256Builder,
811 BigDecimalFromSql,
812 row,
813 idx,
814 |v: BigDecimalFromSql| { v.to_i256_with_scale(*scale as i32) }
815 );
816 }
817 DataType::Utf8 => {
818 if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("xml") {
819 let convert: for<'a> fn(XmlFromSql<'a>) -> DFResult<&'a str> =
820 |v| Ok(v.xml);
821 handle_primitive_type!(
822 builder,
823 field,
824 col,
825 StringBuilder,
826 XmlFromSql,
827 row,
828 idx,
829 convert
830 );
831 } else {
832 handle_primitive_type!(
833 builder,
834 field,
835 col,
836 StringBuilder,
837 &str,
838 row,
839 idx,
840 just_return
841 );
842 }
843 }
844 DataType::LargeUtf8 => {
845 if col.is_some() && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB) {
846 handle_primitive_type!(
847 builder,
848 field,
849 col,
850 LargeStringBuilder,
851 serde_json::value::Value,
852 row,
853 idx,
854 |v: serde_json::value::Value| {
855 Ok::<_, DataFusionError>(v.to_string())
856 }
857 );
858 } else {
859 handle_primitive_type!(
860 builder,
861 field,
862 col,
863 LargeStringBuilder,
864 &str,
865 row,
866 idx,
867 just_return
868 );
869 }
870 }
871 DataType::Binary => {
872 if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("geometry")
873 {
874 let convert: for<'a> fn(GeometryFromSql<'a>) -> DFResult<&'a [u8]> =
875 |v| Ok(v.wkb);
876 handle_primitive_type!(
877 builder,
878 field,
879 col,
880 BinaryBuilder,
881 GeometryFromSql,
882 row,
883 idx,
884 convert
885 );
886 } else if col.is_some()
887 && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
888 {
889 handle_primitive_type!(
890 builder,
891 field,
892 col,
893 BinaryBuilder,
894 serde_json::value::Value,
895 row,
896 idx,
897 |v: serde_json::value::Value| {
898 Ok::<_, DataFusionError>(v.to_string().into_bytes())
899 }
900 );
901 } else {
902 handle_primitive_type!(
903 builder,
904 field,
905 col,
906 BinaryBuilder,
907 Vec<u8>,
908 row,
909 idx,
910 just_return
911 );
912 }
913 }
914 DataType::FixedSizeBinary(_) => {
915 let builder = builder
916 .as_any_mut()
917 .downcast_mut::<FixedSizeBinaryBuilder>()
918 .unwrap_or_else(|| {
919 panic!(
920 "Failed to downcast builder to FixedSizeBinaryBuilder for {field:?}"
921 )
922 });
923 let v = if col.is_some()
924 && col.unwrap().type_().name().eq_ignore_ascii_case("uuid")
925 {
926 let v: Option<Uuid> = row.try_get(idx).map_err(|e| {
927 DataFusionError::Execution(format!(
928 "Failed to get Uuid value for field {:?}: {e:?}",
929 field
930 ))
931 })?;
932 v.map(|v| v.as_bytes().to_vec())
933 } else {
934 let v: Option<Vec<u8>> = row.try_get(idx).map_err(|e| {
935 DataFusionError::Execution(format!(
936 "Failed to get FixedSizeBinary value for field {:?}: {e:?}",
937 field
938 ))
939 })?;
940 v
941 };
942
943 match v {
944 Some(v) => builder.append_value(v)?,
945 None => builder.append_null(),
946 }
947 }
948 DataType::Timestamp(TimeUnit::Microsecond, None) => {
949 handle_primitive_type!(
950 builder,
951 field,
952 col,
953 TimestampMicrosecondBuilder,
954 chrono::NaiveDateTime,
955 row,
956 idx,
957 |v: chrono::NaiveDateTime| {
958 let timestamp: i64 = v.and_utc().timestamp_micros();
959
960 Ok::<i64, DataFusionError>(timestamp)
961 }
962 );
963 }
964 DataType::Timestamp(TimeUnit::Microsecond, Some(_tz)) => {
965 handle_primitive_type!(
966 builder,
967 field,
968 col,
969 TimestampMicrosecondBuilder,
970 chrono::DateTime<chrono::Utc>,
971 row,
972 idx,
973 |v: chrono::DateTime<chrono::Utc>| {
974 let timestamp: i64 = v.timestamp_micros();
975 Ok::<_, DataFusionError>(timestamp)
976 }
977 );
978 }
979 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
980 handle_primitive_type!(
981 builder,
982 field,
983 col,
984 TimestampNanosecondBuilder,
985 chrono::NaiveDateTime,
986 row,
987 idx,
988 |v: chrono::NaiveDateTime| {
989 let timestamp: i64 = v.and_utc().timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
990 Ok::<i64, DataFusionError>(timestamp)
991 }
992 );
993 }
994 DataType::Timestamp(TimeUnit::Nanosecond, Some(_tz)) => {
995 handle_primitive_type!(
996 builder,
997 field,
998 col,
999 TimestampNanosecondBuilder,
1000 chrono::DateTime<chrono::Utc>,
1001 row,
1002 idx,
1003 |v: chrono::DateTime<chrono::Utc>| {
1004 let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
1005 Ok::<_, DataFusionError>(timestamp)
1006 }
1007 );
1008 }
1009 DataType::Time64(TimeUnit::Microsecond) => {
1010 handle_primitive_type!(
1011 builder,
1012 field,
1013 col,
1014 Time64MicrosecondBuilder,
1015 chrono::NaiveTime,
1016 row,
1017 idx,
1018 |v: chrono::NaiveTime| {
1019 let seconds = i64::from(v.num_seconds_from_midnight());
1020 let microseconds = i64::from(v.nanosecond()) / 1000;
1021 Ok::<_, DataFusionError>(seconds * 1_000_000 + microseconds)
1022 }
1023 );
1024 }
1025 DataType::Time64(TimeUnit::Nanosecond) => {
1026 handle_primitive_type!(
1027 builder,
1028 field,
1029 col,
1030 Time64NanosecondBuilder,
1031 chrono::NaiveTime,
1032 row,
1033 idx,
1034 |v: chrono::NaiveTime| {
1035 let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
1036 * 1_000_000_000
1037 + i64::from(v.nanosecond());
1038 Ok::<_, DataFusionError>(timestamp)
1039 }
1040 );
1041 }
1042 DataType::Date32 => {
1043 handle_primitive_type!(
1044 builder,
1045 field,
1046 col,
1047 Date32Builder,
1048 chrono::NaiveDate,
1049 row,
1050 idx,
1051 |v| { Ok::<_, DataFusionError>(Date32Type::from_naive_date(v)) }
1052 );
1053 }
1054 DataType::Interval(IntervalUnit::MonthDayNano) => {
1055 handle_primitive_type!(
1056 builder,
1057 field,
1058 col,
1059 IntervalMonthDayNanoBuilder,
1060 IntervalFromSql,
1061 row,
1062 idx,
1063 |v: IntervalFromSql| {
1064 let interval_month_day_nano = IntervalMonthDayNanoType::make_value(
1065 v.month,
1066 v.day,
1067 v.time * 1_000,
1068 );
1069 Ok::<_, DataFusionError>(interval_month_day_nano)
1070 }
1071 );
1072 }
1073 DataType::Boolean => {
1074 handle_primitive_type!(
1075 builder,
1076 field,
1077 col,
1078 BooleanBuilder,
1079 bool,
1080 row,
1081 idx,
1082 just_return
1083 );
1084 }
1085 DataType::List(inner) => match inner.data_type() {
1086 DataType::Int16 => {
1087 handle_primitive_array_type!(
1088 builder,
1089 field,
1090 col,
1091 Int16Builder,
1092 i16,
1093 row,
1094 idx
1095 );
1096 }
1097 DataType::Int32 => {
1098 handle_primitive_array_type!(
1099 builder,
1100 field,
1101 col,
1102 Int32Builder,
1103 i32,
1104 row,
1105 idx
1106 );
1107 }
1108 DataType::Int64 => {
1109 handle_primitive_array_type!(
1110 builder,
1111 field,
1112 col,
1113 Int64Builder,
1114 i64,
1115 row,
1116 idx
1117 );
1118 }
1119 DataType::Float32 => {
1120 handle_primitive_array_type!(
1121 builder,
1122 field,
1123 col,
1124 Float32Builder,
1125 f32,
1126 row,
1127 idx
1128 );
1129 }
1130 DataType::Float64 => {
1131 handle_primitive_array_type!(
1132 builder,
1133 field,
1134 col,
1135 Float64Builder,
1136 f64,
1137 row,
1138 idx
1139 );
1140 }
1141 DataType::Utf8 => {
1142 handle_primitive_array_type!(
1143 builder,
1144 field,
1145 col,
1146 StringBuilder,
1147 &str,
1148 row,
1149 idx
1150 );
1151 }
1152 DataType::Binary => {
1153 handle_primitive_array_type!(
1154 builder,
1155 field,
1156 col,
1157 BinaryBuilder,
1158 Vec<u8>,
1159 row,
1160 idx
1161 );
1162 }
1163 DataType::Boolean => {
1164 handle_primitive_array_type!(
1165 builder,
1166 field,
1167 col,
1168 BooleanBuilder,
1169 bool,
1170 row,
1171 idx
1172 );
1173 }
1174 _ => {
1175 return Err(DataFusionError::NotImplemented(format!(
1176 "Unsupported list data type {} for col: {:?}",
1177 field.data_type(),
1178 col
1179 )));
1180 }
1181 },
1182 _ => {
1183 return Err(DataFusionError::NotImplemented(format!(
1184 "Unsupported data type {} for col: {:?}",
1185 field.data_type(),
1186 col
1187 )));
1188 }
1189 }
1190 }
1191 }
1192 let projected_columns = array_builders
1193 .into_iter()
1194 .enumerate()
1195 .filter(|(idx, _)| projections_contains(projection, *idx))
1196 .map(|(_, mut builder)| builder.finish())
1197 .collect::<Vec<ArrayRef>>();
1198 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
1199 Ok(RecordBatch::try_new_with_options(
1200 projected_schema,
1201 projected_columns,
1202 &options,
1203 )?)
1204}