1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::{big_decimal_to_i128, big_decimal_to_i256};
3use crate::{
4 Connection, ConnectionOptions, DFResult, Literalize, Pool, PoolState,
5 PostgresConnectionOptions, PostgresType, RemoteField, RemoteSchema, RemoteSchemaRef,
6 RemoteSource, RemoteType, literalize_array,
7};
8use bb8_postgres::PostgresConnectionManager;
9use bb8_postgres::tokio_postgres::types::{FromSql, Type};
10use bb8_postgres::tokio_postgres::{NoTls, Row, Statement};
11use bigdecimal::BigDecimal;
12use byteorder::{BigEndian, ReadBytesExt};
13use chrono::Timelike;
14use datafusion::arrow::array::{
15 ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
16 Decimal256Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int16Builder,
17 Int32Builder, Int64Builder, IntervalMonthDayNanoBuilder, LargeStringBuilder, ListBuilder,
18 RecordBatch, RecordBatchOptions, StringBuilder, Time64MicrosecondBuilder,
19 Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder,
20 UInt32Builder, make_builder,
21};
22use datafusion::arrow::datatypes::{
23 DECIMAL256_MAX_PRECISION, DataType, Date32Type, IntervalMonthDayNanoType, IntervalUnit,
24 SchemaRef, TimeUnit, i256,
25};
26
27use datafusion::common::project_schema;
28use datafusion::error::DataFusionError;
29use datafusion::execution::SendableRecordBatchStream;
30use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31use futures::StreamExt;
32use log::debug;
33use num_bigint::{BigInt, Sign};
34use std::any::Any;
35use std::string::ToString;
36use std::sync::Arc;
37use uuid::Uuid;
38
39#[derive(Debug)]
40pub struct PostgresPool {
41 pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
42 options: Arc<PostgresConnectionOptions>,
43}
44
45#[async_trait::async_trait]
46impl Pool for PostgresPool {
47 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
48 let conn = self.pool.get_owned().await.map_err(|e| {
49 DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
50 })?;
51 Ok(Arc::new(PostgresConnection {
52 conn,
53 options: self.options.clone(),
54 }))
55 }
56
57 async fn state(&self) -> DFResult<PoolState> {
58 let bb8_state = self.pool.state();
59 Ok(PoolState {
60 connections: bb8_state.connections as usize,
61 idle_connections: bb8_state.idle_connections as usize,
62 })
63 }
64}
65
66pub(crate) async fn connect_postgres(
67 options: &PostgresConnectionOptions,
68) -> DFResult<PostgresPool> {
69 let mut config = bb8_postgres::tokio_postgres::config::Config::new();
70 config
71 .host(&options.host)
72 .port(options.port)
73 .user(&options.username)
74 .password(&options.password);
75 if let Some(database) = &options.database {
76 config.dbname(database);
77 }
78 let manager = PostgresConnectionManager::new(config, NoTls);
79 let pool = bb8::Pool::builder()
80 .max_size(options.pool_max_size as u32)
81 .min_idle(Some(options.pool_min_idle as u32))
82 .idle_timeout(Some(options.pool_idle_timeout))
83 .reaper_rate(options.pool_ttl_check_interval)
84 .build(manager)
85 .await
86 .map_err(|e| {
87 DataFusionError::Execution(format!(
88 "Failed to create postgres connection pool due to {e}",
89 ))
90 })?;
91
92 Ok(PostgresPool {
93 pool,
94 options: Arc::new(options.clone()),
95 })
96}
97
98#[derive(Debug)]
99pub struct PostgresConnection {
100 pub conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
101 pub options: Arc<PostgresConnectionOptions>,
102}
103
104#[async_trait::async_trait]
105impl Connection for PostgresConnection {
106 fn as_any(&self) -> &dyn Any {
107 self
108 }
109
110 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
111 match source {
112 RemoteSource::Table(table) => {
113 let db_type = RemoteDbType::Postgres;
114 let where_condition = if table.len() == 1 {
115 format!("table_name = {}", db_type.sql_string_literal(&table[0]))
116 } else if table.len() == 2 {
117 format!(
118 "table_schema = {} AND table_name = {}",
119 db_type.sql_string_literal(&table[0]),
120 db_type.sql_string_literal(&table[1])
121 )
122 } else {
123 format!(
124 "table_catalog = {} AND table_schema = {} AND table_name = {}",
125 db_type.sql_string_literal(&table[0]),
126 db_type.sql_string_literal(&table[1]),
127 db_type.sql_string_literal(&table[2])
128 )
129 };
130 let sql = format!(
131 "
132select
133 column_name,
134 case
135 when data_type = 'ARRAY'
136 then data_type || udt_name
137 when data_type = 'USER-DEFINED'
138 then udt_schema || '.' || udt_name
139 else
140 data_type
141 end as column_type,
142 numeric_precision,
143 numeric_scale,
144 is_nullable
145from information_schema.columns
146where {}
147order by ordinal_position",
148 where_condition
149 );
150 let rows = self.conn.query(&sql, &[]).await.map_err(|e| {
151 DataFusionError::Plan(format!(
152 "Failed to execute query {sql} on postgres: {e:?}",
153 ))
154 })?;
155 let remote_schema = Arc::new(build_remote_schema_for_table(
156 rows,
157 self.options.default_numeric_scale,
158 )?);
159 Ok(remote_schema)
160 }
161 RemoteSource::Query(query) => {
162 let stmt = self.conn.prepare(query).await.map_err(|e| {
163 DataFusionError::Plan(format!(
164 "Failed to execute query {query} on postgres: {e:?}",
165 ))
166 })?;
167 let remote_schema = Arc::new(
168 build_remote_schema_for_query(stmt, self.options.default_numeric_scale).await?,
169 );
170 Ok(remote_schema)
171 }
172 }
173 }
174
175 async fn query(
176 &self,
177 conn_options: &ConnectionOptions,
178 source: &RemoteSource,
179 table_schema: SchemaRef,
180 projection: Option<&Vec<usize>>,
181 unparsed_filters: &[String],
182 limit: Option<usize>,
183 ) -> DFResult<SendableRecordBatchStream> {
184 let projected_schema = project_schema(&table_schema, projection)?;
185
186 let sql = RemoteDbType::Postgres.rewrite_query(source, unparsed_filters, limit);
187 debug!("[remote-table] executing postgres query: {sql}");
188
189 let projection = projection.cloned();
190 let chunk_size = conn_options.stream_chunk_size();
191 let stream = self
192 .conn
193 .query_raw(&sql, Vec::<String>::new())
194 .await
195 .map_err(|e| {
196 DataFusionError::Execution(format!(
197 "Failed to execute query {sql} on postgres: {e}",
198 ))
199 })?
200 .chunks(chunk_size)
201 .boxed();
202
203 let stream = stream.map(move |rows| {
204 let rows: Vec<Row> = rows
205 .into_iter()
206 .collect::<Result<Vec<_>, _>>()
207 .map_err(|e| {
208 DataFusionError::Execution(format!(
209 "Failed to collect rows from postgres due to {e}",
210 ))
211 })?;
212 rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
213 });
214
215 Ok(Box::pin(RecordBatchStreamAdapter::new(
216 projected_schema,
217 stream,
218 )))
219 }
220
221 async fn insert(
222 &self,
223 _conn_options: &ConnectionOptions,
224 literalizer: Arc<dyn Literalize>,
225 table: &[String],
226 remote_schema: RemoteSchemaRef,
227 batch: RecordBatch,
228 ) -> DFResult<usize> {
229 let mut columns = Vec::with_capacity(remote_schema.fields.len());
230 for i in 0..batch.num_columns() {
231 let input_field = batch.schema_ref().field(i);
232 let remote_field = &remote_schema.fields[i];
233 if remote_field.auto_increment && input_field.is_nullable() {
234 continue;
235 }
236
237 let remote_type = remote_schema.fields[i].remote_type.clone();
238 let array = batch.column(i);
239 let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
240 columns.push(column);
241 }
242
243 let num_rows = columns[0].len();
244 let num_columns = columns.len();
245
246 let mut values = Vec::with_capacity(num_rows);
247 for i in 0..num_rows {
248 let mut value = Vec::with_capacity(num_columns);
249 for col in columns.iter() {
250 value.push(col[i].as_str());
251 }
252 values.push(format!("({})", value.join(",")));
253 }
254
255 let mut col_names = Vec::with_capacity(remote_schema.fields.len());
256 for (remote_field, input_field) in remote_schema
257 .fields
258 .iter()
259 .zip(batch.schema_ref().fields.iter())
260 {
261 if remote_field.auto_increment && input_field.is_nullable() {
262 continue;
263 }
264 col_names.push(RemoteDbType::Postgres.sql_identifier(&remote_field.name));
265 }
266
267 let sql = format!(
268 "INSERT INTO {} ({}) VALUES {}",
269 RemoteDbType::Postgres.sql_table_name(table),
270 col_names.join(","),
271 values.join(",")
272 );
273
274 let count = self.conn.execute(&sql, &[]).await.map_err(|e| {
275 DataFusionError::Execution(format!(
276 "Failed to execute insert statement on postgres: {e:?}, sql: {sql}"
277 ))
278 })?;
279
280 Ok(count as usize)
281 }
282}
283
284async fn build_remote_schema_for_query(
285 stmt: Statement,
286 default_numeric_scale: i8,
287) -> DFResult<RemoteSchema> {
288 let mut remote_fields = Vec::new();
289 for col in stmt.columns().iter() {
290 let pg_type = col.type_();
291 let remote_type = pg_type_to_remote_type(pg_type, default_numeric_scale)?;
292 remote_fields.push(RemoteField::new(
293 col.name(),
294 RemoteType::Postgres(remote_type),
295 true,
296 ));
297 }
298 Ok(RemoteSchema::new(remote_fields))
299}
300
301fn pg_type_to_remote_type(pg_type: &Type, default_numeric_scale: i8) -> DFResult<PostgresType> {
302 match pg_type {
303 &Type::INT2 => Ok(PostgresType::Int2),
304 &Type::INT4 => Ok(PostgresType::Int4),
305 &Type::INT8 => Ok(PostgresType::Int8),
306 &Type::FLOAT4 => Ok(PostgresType::Float4),
307 &Type::FLOAT8 => Ok(PostgresType::Float8),
308 &Type::NUMERIC => Ok(PostgresType::Numeric(
309 DECIMAL256_MAX_PRECISION,
310 default_numeric_scale,
311 )),
312 &Type::OID => Ok(PostgresType::Oid),
313 &Type::NAME => Ok(PostgresType::Name),
314 &Type::VARCHAR => Ok(PostgresType::Varchar),
315 &Type::BPCHAR => Ok(PostgresType::Bpchar),
316 &Type::TEXT => Ok(PostgresType::Text),
317 &Type::BYTEA => Ok(PostgresType::Bytea),
318 &Type::DATE => Ok(PostgresType::Date),
319 &Type::TIMESTAMP => Ok(PostgresType::Timestamp),
320 &Type::TIMESTAMPTZ => Ok(PostgresType::TimestampTz),
321 &Type::TIME => Ok(PostgresType::Time),
322 &Type::INTERVAL => Ok(PostgresType::Interval),
323 &Type::BOOL => Ok(PostgresType::Bool),
324 &Type::JSON => Ok(PostgresType::Json),
325 &Type::JSONB => Ok(PostgresType::Jsonb),
326 &Type::INT2_ARRAY => Ok(PostgresType::Int2Array),
327 &Type::INT4_ARRAY => Ok(PostgresType::Int4Array),
328 &Type::INT8_ARRAY => Ok(PostgresType::Int8Array),
329 &Type::FLOAT4_ARRAY => Ok(PostgresType::Float4Array),
330 &Type::FLOAT8_ARRAY => Ok(PostgresType::Float8Array),
331 &Type::VARCHAR_ARRAY => Ok(PostgresType::VarcharArray),
332 &Type::BPCHAR_ARRAY => Ok(PostgresType::BpcharArray),
333 &Type::TEXT_ARRAY => Ok(PostgresType::TextArray),
334 &Type::BYTEA_ARRAY => Ok(PostgresType::ByteaArray),
335 &Type::BOOL_ARRAY => Ok(PostgresType::BoolArray),
336 &Type::XML => Ok(PostgresType::Xml),
337 &Type::UUID => Ok(PostgresType::Uuid),
338 other if other.name().eq_ignore_ascii_case("geometry") => Ok(PostgresType::PostGisGeometry),
339 _ => Err(DataFusionError::NotImplemented(format!(
340 "Unsupported postgres type {pg_type:?}",
341 ))),
342 }
343}
344
345fn build_remote_schema_for_table(
346 rows: Vec<Row>,
347 default_numeric_scale: i8,
348) -> DFResult<RemoteSchema> {
349 let mut remote_fields = vec![];
350 for row in rows {
351 let columa_name = row.try_get::<_, String>(0).map_err(|e| {
352 DataFusionError::Plan(format!("Failed to get col name from postgres row: {e:?}"))
353 })?;
354 let column_type = row.try_get::<_, String>(1).map_err(|e| {
355 DataFusionError::Plan(format!("Failed to get col type from postgres row: {e:?}"))
356 })?;
357 let numeric_precision = row.try_get::<_, Option<i32>>(2).map_err(|e| {
358 DataFusionError::Plan(format!(
359 "Failed to get numeric precision from postgres row: {e:?}"
360 ))
361 })?;
362 let numeric_scale = row.try_get::<_, Option<i32>>(3).map_err(|e| {
363 DataFusionError::Plan(format!(
364 "Failed to get numeric scale from postgres row: {e:?}"
365 ))
366 })?;
367 let pg_type = parse_pg_type(
368 &column_type,
369 numeric_precision,
370 numeric_scale.unwrap_or(default_numeric_scale as i32),
371 )?;
372 let is_nullable = row.try_get::<_, String>(4).map_err(|e| {
373 DataFusionError::Plan(format!(
374 "Failed to get is_nullable from postgres row: {e:?}"
375 ))
376 })?;
377 let nullable = match is_nullable.as_str() {
378 "YES" => true,
379 "NO" => false,
380 _ => {
381 return Err(DataFusionError::Plan(format!(
382 "Unsupported postgres is_nullable value {is_nullable}"
383 )));
384 }
385 };
386 remote_fields.push(RemoteField::new(
387 columa_name,
388 RemoteType::Postgres(pg_type),
389 nullable,
390 ));
391 }
392 Ok(RemoteSchema::new(remote_fields))
393}
394
395fn parse_pg_type(
396 pg_type: &str,
397 numeric_precision: Option<i32>,
398 numeric_scale: i32,
399) -> DFResult<PostgresType> {
400 match pg_type {
401 "smallint" => Ok(PostgresType::Int2),
402 "integer" => Ok(PostgresType::Int4),
403 "bigint" => Ok(PostgresType::Int8),
404 "real" => Ok(PostgresType::Float4),
405 "double precision" => Ok(PostgresType::Float8),
406 "numeric" => Ok(PostgresType::Numeric(
407 numeric_precision.unwrap_or(DECIMAL256_MAX_PRECISION as i32) as u8,
408 numeric_scale as i8,
409 )),
410 "character varying" => Ok(PostgresType::Varchar),
411 "character" => Ok(PostgresType::Bpchar),
412 "text" => Ok(PostgresType::Text),
413 "bytea" => Ok(PostgresType::Bytea),
414 "date" => Ok(PostgresType::Date),
415 "time without time zone" => Ok(PostgresType::Time),
416 "timestamp without time zone" => Ok(PostgresType::Timestamp),
417 "timestamp with time zone" => Ok(PostgresType::TimestampTz),
418 "interval" => Ok(PostgresType::Interval),
419 "boolean" => Ok(PostgresType::Bool),
420 "json" => Ok(PostgresType::Json),
421 "jsonb" => Ok(PostgresType::Jsonb),
422 "public.geometry" => Ok(PostgresType::PostGisGeometry),
423 "ARRAY_int2" => Ok(PostgresType::Int2Array),
424 "ARRAY_int4" => Ok(PostgresType::Int4Array),
425 "ARRAY_int8" => Ok(PostgresType::Int8Array),
426 "ARRAY_float4" => Ok(PostgresType::Float4Array),
427 "ARRAY_float8" => Ok(PostgresType::Float8Array),
428 "ARRAY_varchar" => Ok(PostgresType::VarcharArray),
429 "ARRAY_bpchar" => Ok(PostgresType::BpcharArray),
430 "ARRAY_text" => Ok(PostgresType::TextArray),
431 "ARRAY_bytea" => Ok(PostgresType::ByteaArray),
432 "ARRAY_bool" => Ok(PostgresType::BoolArray),
433 "xml" => Ok(PostgresType::Xml),
434 "uuid" => Ok(PostgresType::Uuid),
435 "oid" => Ok(PostgresType::Oid),
436 "name" => Ok(PostgresType::Name),
437 _ => Err(DataFusionError::Execution(format!(
438 "Unsupported postgres type {pg_type}"
439 ))),
440 }
441}
442
443macro_rules! handle_primitive_type {
444 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
445 let builder = $builder
446 .as_any_mut()
447 .downcast_mut::<$builder_ty>()
448 .unwrap_or_else(|| {
449 panic!(
450 "Failed to downcast builder to {} for {:?} and {:?}",
451 stringify!($builder_ty),
452 $field,
453 $col
454 )
455 });
456 let v: Option<$value_ty> = $row.try_get($index).map_err(|e| {
457 DataFusionError::Execution(format!(
458 "Failed to get {} value for {:?} and {:?}: {e:?}",
459 stringify!($value_ty),
460 $field,
461 $col
462 ))
463 })?;
464
465 match v {
466 Some(v) => builder.append_value($convert(v)?),
467 None => builder.append_null(),
468 }
469 }};
470}
471
472macro_rules! handle_primitive_array_type {
473 ($builder:expr, $field:expr, $col:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
474 let builder = $builder
475 .as_any_mut()
476 .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
477 .unwrap_or_else(|| {
478 panic!(
479 "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for {:?} and {:?}",
480 $field, $col
481 )
482 });
483 let values_builder = builder
484 .values()
485 .as_any_mut()
486 .downcast_mut::<$values_builder_ty>()
487 .unwrap_or_else(|| {
488 panic!(
489 "Failed to downcast values builder to {} for {:?} and {:?}",
490 stringify!($builder_ty),
491 $field,
492 $col,
493 )
494 });
495 let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).map_err(|e| {
496 DataFusionError::Execution(format!(
497 "Failed to get {} array value for {:?} and {:?}: {e:?}",
498 stringify!($value_ty),
499 $field,
500 $col,
501 ))
502 })?;
503
504 match v {
505 Some(v) => {
506 let v = v.into_iter().map(Some);
507 values_builder.extend(v);
508 builder.append(true);
509 }
510 None => builder.append_null(),
511 }
512 }};
513}
514
515#[derive(Debug)]
516struct BigDecimalFromSql {
517 inner: BigDecimal,
518}
519
520impl BigDecimalFromSql {
521 fn to_i128_with_scale(&self, scale: i32) -> DFResult<i128> {
522 big_decimal_to_i128(&self.inner, Some(scale))
523 }
524
525 fn to_i256_with_scale(&self, scale: i32) -> DFResult<i256> {
526 big_decimal_to_i256(&self.inner, Some(scale))
527 }
528}
529
530#[allow(clippy::cast_sign_loss)]
531#[allow(clippy::cast_possible_wrap)]
532#[allow(clippy::cast_possible_truncation)]
533impl<'a> FromSql<'a> for BigDecimalFromSql {
534 fn from_sql(
535 _ty: &Type,
536 raw: &'a [u8],
537 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
538 let raw_u16: Vec<u16> = raw
539 .chunks(2)
540 .map(|chunk| {
541 if chunk.len() == 2 {
542 u16::from_be_bytes([chunk[0], chunk[1]])
543 } else {
544 u16::from_be_bytes([chunk[0], 0])
545 }
546 })
547 .collect();
548
549 let base_10_000_digit_count = raw_u16[0];
550 let weight = raw_u16[1] as i16;
551 let sign = raw_u16[2];
552 let scale = raw_u16[3];
553
554 let mut base_10_000_digits = Vec::new();
555 for i in 4..4 + base_10_000_digit_count {
556 base_10_000_digits.push(raw_u16[i as usize]);
557 }
558
559 let mut u8_digits = Vec::new();
560 for &base_10_000_digit in base_10_000_digits.iter().rev() {
561 let mut base_10_000_digit = base_10_000_digit;
562 let mut temp_result = Vec::new();
563 while base_10_000_digit > 0 {
564 temp_result.push((base_10_000_digit % 10) as u8);
565 base_10_000_digit /= 10;
566 }
567 while temp_result.len() < 4 {
568 temp_result.push(0);
569 }
570 u8_digits.extend(temp_result);
571 }
572 u8_digits.reverse();
573
574 let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
575 let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
576 u8_digits.resize(size as usize, 0);
577
578 let sign = match sign {
579 0x4000 => Sign::Minus,
580 0x0000 => Sign::Plus,
581 _ => {
582 return Err(Box::new(DataFusionError::Execution(
583 "Failed to parse big decimal from postgres numeric value".to_string(),
584 )));
585 }
586 };
587
588 let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
589 return Err(Box::new(DataFusionError::Execution(
590 "Failed to parse big decimal from postgres numeric value".to_string(),
591 )));
592 };
593 Ok(BigDecimalFromSql {
594 inner: BigDecimal::new(digits, i64::from(scale)),
595 })
596 }
597
598 fn accepts(ty: &Type) -> bool {
599 matches!(*ty, Type::NUMERIC)
600 }
601}
602
603#[derive(Debug)]
606struct IntervalFromSql {
607 time: i64,
608 day: i32,
609 month: i32,
610}
611
612impl<'a> FromSql<'a> for IntervalFromSql {
613 fn from_sql(
614 _ty: &Type,
615 raw: &'a [u8],
616 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
617 let mut cursor = std::io::Cursor::new(raw);
618
619 let time = cursor.read_i64::<BigEndian>()?;
620 let day = cursor.read_i32::<BigEndian>()?;
621 let month = cursor.read_i32::<BigEndian>()?;
622
623 Ok(IntervalFromSql { time, day, month })
624 }
625
626 fn accepts(ty: &Type) -> bool {
627 matches!(*ty, Type::INTERVAL)
628 }
629}
630
631struct GeometryFromSql<'a> {
632 wkb: &'a [u8],
633}
634
635impl<'a> FromSql<'a> for GeometryFromSql<'a> {
636 fn from_sql(
637 _ty: &Type,
638 raw: &'a [u8],
639 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
640 Ok(GeometryFromSql { wkb: raw })
641 }
642
643 fn accepts(ty: &Type) -> bool {
644 matches!(ty.name(), "geometry")
645 }
646}
647
648struct XmlFromSql<'a> {
649 xml: &'a str,
650}
651
652impl<'a> FromSql<'a> for XmlFromSql<'a> {
653 fn from_sql(
654 _ty: &Type,
655 raw: &'a [u8],
656 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
657 let xml = str::from_utf8(raw)?;
658 Ok(XmlFromSql { xml })
659 }
660
661 fn accepts(ty: &Type) -> bool {
662 matches!(*ty, Type::XML)
663 }
664}
665
666fn rows_to_batch(
667 rows: &[Row],
668 table_schema: &SchemaRef,
669 projection: Option<&Vec<usize>>,
670) -> DFResult<RecordBatch> {
671 let projected_schema = project_schema(table_schema, projection)?;
672 let mut array_builders = vec![];
673 for field in table_schema.fields() {
674 let builder = make_builder(field.data_type(), rows.len());
675 array_builders.push(builder);
676 }
677
678 for row in rows {
679 for (idx, field) in table_schema.fields.iter().enumerate() {
680 if !projections_contains(projection, idx) {
681 continue;
682 }
683 let builder = &mut array_builders[idx];
684 let col = row.columns().get(idx);
685 match field.data_type() {
686 DataType::Int16 => {
687 handle_primitive_type!(
688 builder,
689 field,
690 col,
691 Int16Builder,
692 i16,
693 row,
694 idx,
695 just_return
696 );
697 }
698 DataType::Int32 => {
699 handle_primitive_type!(
700 builder,
701 field,
702 col,
703 Int32Builder,
704 i32,
705 row,
706 idx,
707 just_return
708 );
709 }
710 DataType::UInt32 => {
711 handle_primitive_type!(
712 builder,
713 field,
714 col,
715 UInt32Builder,
716 u32,
717 row,
718 idx,
719 just_return
720 );
721 }
722 DataType::Int64 => {
723 handle_primitive_type!(
724 builder,
725 field,
726 col,
727 Int64Builder,
728 i64,
729 row,
730 idx,
731 just_return
732 );
733 }
734 DataType::Float32 => {
735 handle_primitive_type!(
736 builder,
737 field,
738 col,
739 Float32Builder,
740 f32,
741 row,
742 idx,
743 just_return
744 );
745 }
746 DataType::Float64 => {
747 handle_primitive_type!(
748 builder,
749 field,
750 col,
751 Float64Builder,
752 f64,
753 row,
754 idx,
755 just_return
756 );
757 }
758 DataType::Decimal128(_precision, scale) => {
759 handle_primitive_type!(
760 builder,
761 field,
762 col,
763 Decimal128Builder,
764 BigDecimalFromSql,
765 row,
766 idx,
767 |v: BigDecimalFromSql| { v.to_i128_with_scale(*scale as i32) }
768 );
769 }
770 DataType::Decimal256(_precision, scale) => {
771 handle_primitive_type!(
772 builder,
773 field,
774 col,
775 Decimal256Builder,
776 BigDecimalFromSql,
777 row,
778 idx,
779 |v: BigDecimalFromSql| { v.to_i256_with_scale(*scale as i32) }
780 );
781 }
782 DataType::Utf8 => {
783 if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("xml") {
784 let convert: for<'a> fn(XmlFromSql<'a>) -> DFResult<&'a str> =
785 |v| Ok(v.xml);
786 handle_primitive_type!(
787 builder,
788 field,
789 col,
790 StringBuilder,
791 XmlFromSql,
792 row,
793 idx,
794 convert
795 );
796 } else {
797 handle_primitive_type!(
798 builder,
799 field,
800 col,
801 StringBuilder,
802 &str,
803 row,
804 idx,
805 just_return
806 );
807 }
808 }
809 DataType::LargeUtf8 => {
810 if col.is_some() && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB) {
811 handle_primitive_type!(
812 builder,
813 field,
814 col,
815 LargeStringBuilder,
816 serde_json::value::Value,
817 row,
818 idx,
819 |v: serde_json::value::Value| {
820 Ok::<_, DataFusionError>(v.to_string())
821 }
822 );
823 } else {
824 handle_primitive_type!(
825 builder,
826 field,
827 col,
828 LargeStringBuilder,
829 &str,
830 row,
831 idx,
832 just_return
833 );
834 }
835 }
836 DataType::Binary => {
837 if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("geometry")
838 {
839 let convert: for<'a> fn(GeometryFromSql<'a>) -> DFResult<&'a [u8]> =
840 |v| Ok(v.wkb);
841 handle_primitive_type!(
842 builder,
843 field,
844 col,
845 BinaryBuilder,
846 GeometryFromSql,
847 row,
848 idx,
849 convert
850 );
851 } else if col.is_some()
852 && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
853 {
854 handle_primitive_type!(
855 builder,
856 field,
857 col,
858 BinaryBuilder,
859 serde_json::value::Value,
860 row,
861 idx,
862 |v: serde_json::value::Value| {
863 Ok::<_, DataFusionError>(v.to_string().into_bytes())
864 }
865 );
866 } else {
867 handle_primitive_type!(
868 builder,
869 field,
870 col,
871 BinaryBuilder,
872 Vec<u8>,
873 row,
874 idx,
875 just_return
876 );
877 }
878 }
879 DataType::FixedSizeBinary(_) => {
880 let builder = builder
881 .as_any_mut()
882 .downcast_mut::<FixedSizeBinaryBuilder>()
883 .unwrap_or_else(|| {
884 panic!(
885 "Failed to downcast builder to FixedSizeBinaryBuilder for {field:?}"
886 )
887 });
888 let v = if col.is_some()
889 && col.unwrap().type_().name().eq_ignore_ascii_case("uuid")
890 {
891 let v: Option<Uuid> = row.try_get(idx).map_err(|e| {
892 DataFusionError::Execution(format!(
893 "Failed to get Uuid value for field {:?}: {e:?}",
894 field
895 ))
896 })?;
897 v.map(|v| v.as_bytes().to_vec())
898 } else {
899 let v: Option<Vec<u8>> = row.try_get(idx).map_err(|e| {
900 DataFusionError::Execution(format!(
901 "Failed to get FixedSizeBinary value for field {:?}: {e:?}",
902 field
903 ))
904 })?;
905 v
906 };
907
908 match v {
909 Some(v) => builder.append_value(v)?,
910 None => builder.append_null(),
911 }
912 }
913 DataType::Timestamp(TimeUnit::Microsecond, None) => {
914 handle_primitive_type!(
915 builder,
916 field,
917 col,
918 TimestampMicrosecondBuilder,
919 chrono::NaiveDateTime,
920 row,
921 idx,
922 |v: chrono::NaiveDateTime| {
923 let timestamp: i64 = v.and_utc().timestamp_micros();
924
925 Ok::<i64, DataFusionError>(timestamp)
926 }
927 );
928 }
929 DataType::Timestamp(TimeUnit::Microsecond, Some(_tz)) => {
930 handle_primitive_type!(
931 builder,
932 field,
933 col,
934 TimestampMicrosecondBuilder,
935 chrono::DateTime<chrono::Utc>,
936 row,
937 idx,
938 |v: chrono::DateTime<chrono::Utc>| {
939 let timestamp: i64 = v.timestamp_micros();
940 Ok::<_, DataFusionError>(timestamp)
941 }
942 );
943 }
944 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
945 handle_primitive_type!(
946 builder,
947 field,
948 col,
949 TimestampNanosecondBuilder,
950 chrono::NaiveDateTime,
951 row,
952 idx,
953 |v: chrono::NaiveDateTime| {
954 let timestamp: i64 = v.and_utc().timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
955 Ok::<i64, DataFusionError>(timestamp)
956 }
957 );
958 }
959 DataType::Timestamp(TimeUnit::Nanosecond, Some(_tz)) => {
960 handle_primitive_type!(
961 builder,
962 field,
963 col,
964 TimestampNanosecondBuilder,
965 chrono::DateTime<chrono::Utc>,
966 row,
967 idx,
968 |v: chrono::DateTime<chrono::Utc>| {
969 let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
970 Ok::<_, DataFusionError>(timestamp)
971 }
972 );
973 }
974 DataType::Time64(TimeUnit::Microsecond) => {
975 handle_primitive_type!(
976 builder,
977 field,
978 col,
979 Time64MicrosecondBuilder,
980 chrono::NaiveTime,
981 row,
982 idx,
983 |v: chrono::NaiveTime| {
984 let seconds = i64::from(v.num_seconds_from_midnight());
985 let microseconds = i64::from(v.nanosecond()) / 1000;
986 Ok::<_, DataFusionError>(seconds * 1_000_000 + microseconds)
987 }
988 );
989 }
990 DataType::Time64(TimeUnit::Nanosecond) => {
991 handle_primitive_type!(
992 builder,
993 field,
994 col,
995 Time64NanosecondBuilder,
996 chrono::NaiveTime,
997 row,
998 idx,
999 |v: chrono::NaiveTime| {
1000 let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
1001 * 1_000_000_000
1002 + i64::from(v.nanosecond());
1003 Ok::<_, DataFusionError>(timestamp)
1004 }
1005 );
1006 }
1007 DataType::Date32 => {
1008 handle_primitive_type!(
1009 builder,
1010 field,
1011 col,
1012 Date32Builder,
1013 chrono::NaiveDate,
1014 row,
1015 idx,
1016 |v| { Ok::<_, DataFusionError>(Date32Type::from_naive_date(v)) }
1017 );
1018 }
1019 DataType::Interval(IntervalUnit::MonthDayNano) => {
1020 handle_primitive_type!(
1021 builder,
1022 field,
1023 col,
1024 IntervalMonthDayNanoBuilder,
1025 IntervalFromSql,
1026 row,
1027 idx,
1028 |v: IntervalFromSql| {
1029 let interval_month_day_nano = IntervalMonthDayNanoType::make_value(
1030 v.month,
1031 v.day,
1032 v.time * 1_000,
1033 );
1034 Ok::<_, DataFusionError>(interval_month_day_nano)
1035 }
1036 );
1037 }
1038 DataType::Boolean => {
1039 handle_primitive_type!(
1040 builder,
1041 field,
1042 col,
1043 BooleanBuilder,
1044 bool,
1045 row,
1046 idx,
1047 just_return
1048 );
1049 }
1050 DataType::List(inner) => match inner.data_type() {
1051 DataType::Int16 => {
1052 handle_primitive_array_type!(
1053 builder,
1054 field,
1055 col,
1056 Int16Builder,
1057 i16,
1058 row,
1059 idx
1060 );
1061 }
1062 DataType::Int32 => {
1063 handle_primitive_array_type!(
1064 builder,
1065 field,
1066 col,
1067 Int32Builder,
1068 i32,
1069 row,
1070 idx
1071 );
1072 }
1073 DataType::Int64 => {
1074 handle_primitive_array_type!(
1075 builder,
1076 field,
1077 col,
1078 Int64Builder,
1079 i64,
1080 row,
1081 idx
1082 );
1083 }
1084 DataType::Float32 => {
1085 handle_primitive_array_type!(
1086 builder,
1087 field,
1088 col,
1089 Float32Builder,
1090 f32,
1091 row,
1092 idx
1093 );
1094 }
1095 DataType::Float64 => {
1096 handle_primitive_array_type!(
1097 builder,
1098 field,
1099 col,
1100 Float64Builder,
1101 f64,
1102 row,
1103 idx
1104 );
1105 }
1106 DataType::Utf8 => {
1107 handle_primitive_array_type!(
1108 builder,
1109 field,
1110 col,
1111 StringBuilder,
1112 &str,
1113 row,
1114 idx
1115 );
1116 }
1117 DataType::Binary => {
1118 handle_primitive_array_type!(
1119 builder,
1120 field,
1121 col,
1122 BinaryBuilder,
1123 Vec<u8>,
1124 row,
1125 idx
1126 );
1127 }
1128 DataType::Boolean => {
1129 handle_primitive_array_type!(
1130 builder,
1131 field,
1132 col,
1133 BooleanBuilder,
1134 bool,
1135 row,
1136 idx
1137 );
1138 }
1139 _ => {
1140 return Err(DataFusionError::NotImplemented(format!(
1141 "Unsupported list data type {} for col: {:?}",
1142 field.data_type(),
1143 col
1144 )));
1145 }
1146 },
1147 _ => {
1148 return Err(DataFusionError::NotImplemented(format!(
1149 "Unsupported data type {} for col: {:?}",
1150 field.data_type(),
1151 col
1152 )));
1153 }
1154 }
1155 }
1156 }
1157 let projected_columns = array_builders
1158 .into_iter()
1159 .enumerate()
1160 .filter(|(idx, _)| projections_contains(projection, *idx))
1161 .map(|(_, mut builder)| builder.finish())
1162 .collect::<Vec<ArrayRef>>();
1163 let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
1164 Ok(RecordBatch::try_new_with_options(
1165 projected_schema,
1166 projected_columns,
1167 &options,
1168 )?)
1169}