datafusion_remote_table/connection/
postgres.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::{big_decimal_to_i128, big_decimal_to_i256};
3use crate::{
4    Connection, ConnectionOptions, DFResult, Literalize, Pool, PoolState,
5    PostgresConnectionOptions, PostgresType, RemoteField, RemoteSchema, RemoteSchemaRef,
6    RemoteSource, RemoteType, literalize_array,
7};
8use arrow::array::{
9    ArrayBuilder, ArrayRef, BinaryBuilder, BinaryViewBuilder, BooleanBuilder, Date32Builder,
10    Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder,
11    Int16Builder, Int32Builder, Int64Builder, IntervalMonthDayNanoBuilder, LargeStringBuilder,
12    ListBuilder, RecordBatch, RecordBatchOptions, StringBuilder, StringViewBuilder,
13    Time64MicrosecondBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder,
14    TimestampNanosecondBuilder, UInt32Builder, make_builder,
15};
16use arrow::datatypes::{
17    DECIMAL256_MAX_PRECISION, DataType, Date32Type, IntervalMonthDayNanoType, IntervalUnit,
18    SchemaRef, TimeUnit, i256,
19};
20use bb8_postgres::PostgresConnectionManager;
21use bb8_postgres::tokio_postgres::types::{FromSql, Type};
22use bb8_postgres::tokio_postgres::{NoTls, Row, Statement};
23use bigdecimal::BigDecimal;
24use byteorder::{BigEndian, ReadBytesExt};
25use chrono::Timelike;
26
27use datafusion_common::DataFusionError;
28use datafusion_common::project_schema;
29use datafusion_execution::SendableRecordBatchStream;
30use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
31use futures::StreamExt;
32use log::debug;
33use num_bigint::{BigInt, Sign};
34use std::any::Any;
35use std::string::ToString;
36use std::sync::Arc;
37use uuid::Uuid;
38
39#[derive(Debug)]
40pub struct PostgresPool {
41    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
42    options: Arc<PostgresConnectionOptions>,
43}
44
45#[async_trait::async_trait]
46impl Pool for PostgresPool {
47    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
48        let conn = self.pool.get_owned().await.map_err(|e| {
49            DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
50        })?;
51        Ok(Arc::new(PostgresConnection {
52            conn,
53            options: self.options.clone(),
54        }))
55    }
56
57    async fn state(&self) -> DFResult<PoolState> {
58        let bb8_state = self.pool.state();
59        Ok(PoolState {
60            connections: bb8_state.connections as usize,
61            idle_connections: bb8_state.idle_connections as usize,
62        })
63    }
64}
65
66pub(crate) async fn connect_postgres(
67    options: &PostgresConnectionOptions,
68) -> DFResult<PostgresPool> {
69    let mut config = bb8_postgres::tokio_postgres::config::Config::new();
70    config
71        .host(&options.host)
72        .port(options.port)
73        .user(&options.username)
74        .password(&options.password);
75    if let Some(database) = &options.database {
76        config.dbname(database);
77    }
78    let manager = PostgresConnectionManager::new(config, NoTls);
79    let pool = bb8::Pool::builder()
80        .max_size(options.pool_max_size as u32)
81        .min_idle(Some(options.pool_min_idle as u32))
82        .idle_timeout(Some(options.pool_idle_timeout))
83        .reaper_rate(options.pool_ttl_check_interval)
84        .build(manager)
85        .await
86        .map_err(|e| {
87            DataFusionError::Execution(format!(
88                "Failed to create postgres connection pool due to {e}",
89            ))
90        })?;
91
92    Ok(PostgresPool {
93        pool,
94        options: Arc::new(options.clone()),
95    })
96}
97
98#[derive(Debug)]
99pub struct PostgresConnection {
100    pub conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
101    pub options: Arc<PostgresConnectionOptions>,
102}
103
104#[async_trait::async_trait]
105impl Connection for PostgresConnection {
106    fn as_any(&self) -> &dyn Any {
107        self
108    }
109
110    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
111        match source {
112            RemoteSource::Table(table) => {
113                let db_type = RemoteDbType::Postgres;
114                let where_condition = if table.len() == 1 {
115                    format!("table_name = {}", db_type.sql_string_literal(&table[0]))
116                } else if table.len() == 2 {
117                    format!(
118                        "table_schema = {} AND table_name = {}",
119                        db_type.sql_string_literal(&table[0]),
120                        db_type.sql_string_literal(&table[1])
121                    )
122                } else {
123                    format!(
124                        "table_catalog = {} AND table_schema = {} AND table_name = {}",
125                        db_type.sql_string_literal(&table[0]),
126                        db_type.sql_string_literal(&table[1]),
127                        db_type.sql_string_literal(&table[2])
128                    )
129                };
130                let sql = format!(
131                    "
132select
133	column_name,
134	case
135        when data_type = 'ARRAY'
136        		then data_type || udt_name
137        when data_type = 'USER-DEFINED'
138         		then udt_schema || '.' || udt_name
139		else
140                data_type
141	end as column_type,
142	numeric_precision,
143	numeric_scale,
144	is_nullable
145from information_schema.columns
146where {}
147order by ordinal_position",
148                    where_condition
149                );
150                let rows = self.conn.query(&sql, &[]).await.map_err(|e| {
151                    DataFusionError::Plan(format!(
152                        "Failed to execute query {sql} on postgres: {e:?}",
153                    ))
154                })?;
155                let remote_schema = Arc::new(build_remote_schema_for_table(
156                    rows,
157                    self.options.default_numeric_scale,
158                )?);
159                Ok(remote_schema)
160            }
161            RemoteSource::Query(query) => {
162                let stmt = self.conn.prepare(query).await.map_err(|e| {
163                    DataFusionError::Plan(format!(
164                        "Failed to execute query {query} on postgres: {e:?}",
165                    ))
166                })?;
167                let remote_schema = Arc::new(
168                    build_remote_schema_for_query(stmt, self.options.default_numeric_scale).await?,
169                );
170                Ok(remote_schema)
171            }
172        }
173    }
174
175    async fn query(
176        &self,
177        conn_options: &ConnectionOptions,
178        source: &RemoteSource,
179        table_schema: SchemaRef,
180        projection: Option<&Vec<usize>>,
181        unparsed_filters: &[String],
182        limit: Option<usize>,
183    ) -> DFResult<SendableRecordBatchStream> {
184        let projected_schema = project_schema(&table_schema, projection)?;
185
186        let sql = RemoteDbType::Postgres.rewrite_query(source, unparsed_filters, limit);
187        debug!("[remote-table] executing postgres query: {sql}");
188
189        let projection = projection.cloned();
190        let chunk_size = conn_options.stream_chunk_size();
191        let stream = self
192            .conn
193            .query_raw(&sql, Vec::<String>::new())
194            .await
195            .map_err(|e| {
196                DataFusionError::Execution(format!(
197                    "Failed to execute query {sql} on postgres: {e}",
198                ))
199            })?
200            .chunks(chunk_size)
201            .boxed();
202
203        let stream = stream.map(move |rows| {
204            let rows: Vec<Row> = rows
205                .into_iter()
206                .collect::<Result<Vec<_>, _>>()
207                .map_err(|e| {
208                    DataFusionError::Execution(format!(
209                        "Failed to collect rows from postgres due to {e}",
210                    ))
211                })?;
212            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
213        });
214
215        Ok(Box::pin(RecordBatchStreamAdapter::new(
216            projected_schema,
217            stream,
218        )))
219    }
220
221    async fn insert(
222        &self,
223        _conn_options: &ConnectionOptions,
224        literalizer: Arc<dyn Literalize>,
225        table: &[String],
226        remote_schema: RemoteSchemaRef,
227        batch: RecordBatch,
228    ) -> DFResult<usize> {
229        let mut columns = Vec::with_capacity(remote_schema.fields.len());
230        for i in 0..batch.num_columns() {
231            let input_field = batch.schema_ref().field(i);
232            let remote_field = &remote_schema.fields[i];
233            if remote_field.auto_increment && input_field.is_nullable() {
234                continue;
235            }
236
237            let remote_type = remote_schema.fields[i].remote_type.clone();
238            let array = batch.column(i);
239            let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
240            columns.push(column);
241        }
242
243        let num_rows = columns[0].len();
244        let num_columns = columns.len();
245
246        let mut values = Vec::with_capacity(num_rows);
247        for i in 0..num_rows {
248            let mut value = Vec::with_capacity(num_columns);
249            for col in columns.iter() {
250                value.push(col[i].as_str());
251            }
252            values.push(format!("({})", value.join(",")));
253        }
254
255        let mut col_names = Vec::with_capacity(remote_schema.fields.len());
256        for (remote_field, input_field) in remote_schema
257            .fields
258            .iter()
259            .zip(batch.schema_ref().fields.iter())
260        {
261            if remote_field.auto_increment && input_field.is_nullable() {
262                continue;
263            }
264            col_names.push(RemoteDbType::Postgres.sql_identifier(&remote_field.name));
265        }
266
267        let sql = format!(
268            "INSERT INTO {} ({}) VALUES {}",
269            RemoteDbType::Postgres.sql_table_name(table),
270            col_names.join(","),
271            values.join(",")
272        );
273
274        let count = self.conn.execute(&sql, &[]).await.map_err(|e| {
275            DataFusionError::Execution(format!(
276                "Failed to execute insert statement on postgres: {e:?}, sql: {sql}"
277            ))
278        })?;
279
280        Ok(count as usize)
281    }
282}
283
284async fn build_remote_schema_for_query(
285    stmt: Statement,
286    default_numeric_scale: i8,
287) -> DFResult<RemoteSchema> {
288    let mut remote_fields = Vec::new();
289    for col in stmt.columns().iter() {
290        let pg_type = col.type_();
291        let remote_type = pg_type_to_remote_type(pg_type, default_numeric_scale)?;
292        remote_fields.push(RemoteField::new(
293            col.name(),
294            RemoteType::Postgres(remote_type),
295            true,
296        ));
297    }
298    Ok(RemoteSchema::new(remote_fields))
299}
300
301fn pg_type_to_remote_type(pg_type: &Type, default_numeric_scale: i8) -> DFResult<PostgresType> {
302    match pg_type {
303        &Type::INT2 => Ok(PostgresType::Int2),
304        &Type::INT4 => Ok(PostgresType::Int4),
305        &Type::INT8 => Ok(PostgresType::Int8),
306        &Type::FLOAT4 => Ok(PostgresType::Float4),
307        &Type::FLOAT8 => Ok(PostgresType::Float8),
308        &Type::NUMERIC => Ok(PostgresType::Numeric(
309            DECIMAL256_MAX_PRECISION,
310            default_numeric_scale,
311        )),
312        &Type::OID => Ok(PostgresType::Oid),
313        &Type::NAME => Ok(PostgresType::Name),
314        &Type::VARCHAR => Ok(PostgresType::Varchar),
315        &Type::BPCHAR => Ok(PostgresType::Bpchar),
316        &Type::TEXT => Ok(PostgresType::Text),
317        &Type::BYTEA => Ok(PostgresType::Bytea),
318        &Type::DATE => Ok(PostgresType::Date),
319        &Type::TIMESTAMP => Ok(PostgresType::Timestamp),
320        &Type::TIMESTAMPTZ => Ok(PostgresType::TimestampTz),
321        &Type::TIME => Ok(PostgresType::Time),
322        &Type::INTERVAL => Ok(PostgresType::Interval),
323        &Type::BOOL => Ok(PostgresType::Bool),
324        &Type::JSON => Ok(PostgresType::Json),
325        &Type::JSONB => Ok(PostgresType::Jsonb),
326        &Type::INT2_ARRAY => Ok(PostgresType::Int2Array),
327        &Type::INT4_ARRAY => Ok(PostgresType::Int4Array),
328        &Type::INT8_ARRAY => Ok(PostgresType::Int8Array),
329        &Type::FLOAT4_ARRAY => Ok(PostgresType::Float4Array),
330        &Type::FLOAT8_ARRAY => Ok(PostgresType::Float8Array),
331        &Type::VARCHAR_ARRAY => Ok(PostgresType::VarcharArray),
332        &Type::BPCHAR_ARRAY => Ok(PostgresType::BpcharArray),
333        &Type::TEXT_ARRAY => Ok(PostgresType::TextArray),
334        &Type::BYTEA_ARRAY => Ok(PostgresType::ByteaArray),
335        &Type::BOOL_ARRAY => Ok(PostgresType::BoolArray),
336        &Type::XML => Ok(PostgresType::Xml),
337        &Type::UUID => Ok(PostgresType::Uuid),
338        other if other.name().eq_ignore_ascii_case("geometry") => Ok(PostgresType::PostGisGeometry),
339        _ => Err(DataFusionError::NotImplemented(format!(
340            "Unsupported postgres type {pg_type:?}",
341        ))),
342    }
343}
344
345fn build_remote_schema_for_table(
346    rows: Vec<Row>,
347    default_numeric_scale: i8,
348) -> DFResult<RemoteSchema> {
349    let mut remote_fields = vec![];
350    for row in rows {
351        let columa_name = row.try_get::<_, String>(0).map_err(|e| {
352            DataFusionError::Plan(format!("Failed to get col name from postgres row: {e:?}"))
353        })?;
354        let column_type = row.try_get::<_, String>(1).map_err(|e| {
355            DataFusionError::Plan(format!("Failed to get col type from postgres row: {e:?}"))
356        })?;
357        let numeric_precision = row.try_get::<_, Option<i32>>(2).map_err(|e| {
358            DataFusionError::Plan(format!(
359                "Failed to get numeric precision from postgres row: {e:?}"
360            ))
361        })?;
362        let numeric_scale = row.try_get::<_, Option<i32>>(3).map_err(|e| {
363            DataFusionError::Plan(format!(
364                "Failed to get numeric scale from postgres row: {e:?}"
365            ))
366        })?;
367        let pg_type = parse_pg_type(
368            &column_type,
369            numeric_precision,
370            numeric_scale.unwrap_or(default_numeric_scale as i32),
371        )?;
372        let is_nullable = row.try_get::<_, String>(4).map_err(|e| {
373            DataFusionError::Plan(format!(
374                "Failed to get is_nullable from postgres row: {e:?}"
375            ))
376        })?;
377        let nullable = match is_nullable.as_str() {
378            "YES" => true,
379            "NO" => false,
380            _ => {
381                return Err(DataFusionError::Plan(format!(
382                    "Unsupported postgres is_nullable value {is_nullable}"
383                )));
384            }
385        };
386        remote_fields.push(RemoteField::new(
387            columa_name,
388            RemoteType::Postgres(pg_type),
389            nullable,
390        ));
391    }
392    Ok(RemoteSchema::new(remote_fields))
393}
394
395fn parse_pg_type(
396    pg_type: &str,
397    numeric_precision: Option<i32>,
398    numeric_scale: i32,
399) -> DFResult<PostgresType> {
400    match pg_type {
401        "smallint" => Ok(PostgresType::Int2),
402        "integer" => Ok(PostgresType::Int4),
403        "bigint" => Ok(PostgresType::Int8),
404        "real" => Ok(PostgresType::Float4),
405        "double precision" => Ok(PostgresType::Float8),
406        "numeric" => Ok(PostgresType::Numeric(
407            numeric_precision.unwrap_or(DECIMAL256_MAX_PRECISION as i32) as u8,
408            numeric_scale as i8,
409        )),
410        "character varying" => Ok(PostgresType::Varchar),
411        "character" => Ok(PostgresType::Bpchar),
412        "text" => Ok(PostgresType::Text),
413        "bytea" => Ok(PostgresType::Bytea),
414        "date" => Ok(PostgresType::Date),
415        "time without time zone" => Ok(PostgresType::Time),
416        "timestamp without time zone" => Ok(PostgresType::Timestamp),
417        "timestamp with time zone" => Ok(PostgresType::TimestampTz),
418        "interval" => Ok(PostgresType::Interval),
419        "boolean" => Ok(PostgresType::Bool),
420        "json" => Ok(PostgresType::Json),
421        "jsonb" => Ok(PostgresType::Jsonb),
422        "public.geometry" => Ok(PostgresType::PostGisGeometry),
423        "ARRAY_int2" => Ok(PostgresType::Int2Array),
424        "ARRAY_int4" => Ok(PostgresType::Int4Array),
425        "ARRAY_int8" => Ok(PostgresType::Int8Array),
426        "ARRAY_float4" => Ok(PostgresType::Float4Array),
427        "ARRAY_float8" => Ok(PostgresType::Float8Array),
428        "ARRAY_varchar" => Ok(PostgresType::VarcharArray),
429        "ARRAY_bpchar" => Ok(PostgresType::BpcharArray),
430        "ARRAY_text" => Ok(PostgresType::TextArray),
431        "ARRAY_bytea" => Ok(PostgresType::ByteaArray),
432        "ARRAY_bool" => Ok(PostgresType::BoolArray),
433        "xml" => Ok(PostgresType::Xml),
434        "uuid" => Ok(PostgresType::Uuid),
435        "oid" => Ok(PostgresType::Oid),
436        "name" => Ok(PostgresType::Name),
437        _ => Err(DataFusionError::Execution(format!(
438            "Unsupported postgres type {pg_type}"
439        ))),
440    }
441}
442
443macro_rules! handle_primitive_type {
444    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
445        let builder = $builder
446            .as_any_mut()
447            .downcast_mut::<$builder_ty>()
448            .unwrap_or_else(|| {
449                panic!(
450                    "Failed to downcast builder to {} for {:?} and {:?}",
451                    stringify!($builder_ty),
452                    $field,
453                    $col
454                )
455            });
456        let v: Option<$value_ty> = $row.try_get($index).map_err(|e| {
457            DataFusionError::Execution(format!(
458                "Failed to get {} value for {:?} and {:?}: {e:?}",
459                stringify!($value_ty),
460                $field,
461                $col
462            ))
463        })?;
464
465        match v {
466            Some(v) => builder.append_value($convert(v)?),
467            None => builder.append_null(),
468        }
469    }};
470}
471
472macro_rules! handle_primitive_array_type {
473    ($builder:expr, $field:expr, $col:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
474        let builder = $builder
475            .as_any_mut()
476            .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
477            .unwrap_or_else(|| {
478                panic!(
479                    "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for {:?} and {:?}",
480                    $field, $col
481                )
482            });
483        let values_builder = builder
484            .values()
485            .as_any_mut()
486            .downcast_mut::<$values_builder_ty>()
487            .unwrap_or_else(|| {
488                panic!(
489                    "Failed to downcast values builder to {} for {:?} and {:?}",
490                    stringify!($builder_ty),
491                    $field,
492                    $col,
493                )
494            });
495        let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).map_err(|e| {
496            DataFusionError::Execution(format!(
497                "Failed to get {} array value for {:?} and {:?}: {e:?}",
498                stringify!($value_ty),
499                $field,
500                $col,
501            ))
502        })?;
503
504        match v {
505            Some(v) => {
506                let v = v.into_iter().map(Some);
507                values_builder.extend(v);
508                builder.append(true);
509            }
510            None => builder.append_null(),
511        }
512    }};
513}
514
515#[derive(Debug)]
516struct BigDecimalFromSql {
517    inner: BigDecimal,
518}
519
520impl BigDecimalFromSql {
521    fn to_i128_with_scale(&self, scale: i32) -> DFResult<i128> {
522        big_decimal_to_i128(&self.inner, Some(scale))
523    }
524
525    fn to_i256_with_scale(&self, scale: i32) -> DFResult<i256> {
526        big_decimal_to_i256(&self.inner, Some(scale))
527    }
528}
529
530#[allow(clippy::cast_sign_loss)]
531#[allow(clippy::cast_possible_wrap)]
532#[allow(clippy::cast_possible_truncation)]
533impl<'a> FromSql<'a> for BigDecimalFromSql {
534    fn from_sql(
535        _ty: &Type,
536        raw: &'a [u8],
537    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
538        let raw_u16: Vec<u16> = raw
539            .chunks(2)
540            .map(|chunk| {
541                if chunk.len() == 2 {
542                    u16::from_be_bytes([chunk[0], chunk[1]])
543                } else {
544                    u16::from_be_bytes([chunk[0], 0])
545                }
546            })
547            .collect();
548
549        let base_10_000_digit_count = raw_u16[0];
550        let weight = raw_u16[1] as i16;
551        let sign = raw_u16[2];
552        let scale = raw_u16[3];
553
554        let mut base_10_000_digits = Vec::new();
555        for i in 4..4 + base_10_000_digit_count {
556            base_10_000_digits.push(raw_u16[i as usize]);
557        }
558
559        let mut u8_digits = Vec::new();
560        for &base_10_000_digit in base_10_000_digits.iter().rev() {
561            let mut base_10_000_digit = base_10_000_digit;
562            let mut temp_result = Vec::new();
563            while base_10_000_digit > 0 {
564                temp_result.push((base_10_000_digit % 10) as u8);
565                base_10_000_digit /= 10;
566            }
567            while temp_result.len() < 4 {
568                temp_result.push(0);
569            }
570            u8_digits.extend(temp_result);
571        }
572        u8_digits.reverse();
573
574        let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
575        let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
576        u8_digits.resize(size as usize, 0);
577
578        let sign = match sign {
579            0x4000 => Sign::Minus,
580            0x0000 => Sign::Plus,
581            _ => {
582                return Err(Box::new(DataFusionError::Execution(
583                    "Failed to parse big decimal from postgres numeric value".to_string(),
584                )));
585            }
586        };
587
588        let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
589            return Err(Box::new(DataFusionError::Execution(
590                "Failed to parse big decimal from postgres numeric value".to_string(),
591            )));
592        };
593        Ok(BigDecimalFromSql {
594            inner: BigDecimal::new(digits, i64::from(scale)),
595        })
596    }
597
598    fn accepts(ty: &Type) -> bool {
599        matches!(*ty, Type::NUMERIC)
600    }
601}
602
603// interval_send - Postgres C (https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1032)
604// interval values are internally stored as three integral fields: months, days, and microseconds
605#[derive(Debug)]
606struct IntervalFromSql {
607    time: i64,
608    day: i32,
609    month: i32,
610}
611
612impl<'a> FromSql<'a> for IntervalFromSql {
613    fn from_sql(
614        _ty: &Type,
615        raw: &'a [u8],
616    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
617        let mut cursor = std::io::Cursor::new(raw);
618
619        let time = cursor.read_i64::<BigEndian>()?;
620        let day = cursor.read_i32::<BigEndian>()?;
621        let month = cursor.read_i32::<BigEndian>()?;
622
623        Ok(IntervalFromSql { time, day, month })
624    }
625
626    fn accepts(ty: &Type) -> bool {
627        matches!(*ty, Type::INTERVAL)
628    }
629}
630
631struct GeometryFromSql<'a> {
632    wkb: &'a [u8],
633}
634
635impl<'a> FromSql<'a> for GeometryFromSql<'a> {
636    fn from_sql(
637        _ty: &Type,
638        raw: &'a [u8],
639    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
640        Ok(GeometryFromSql { wkb: raw })
641    }
642
643    fn accepts(ty: &Type) -> bool {
644        matches!(ty.name(), "geometry")
645    }
646}
647
648struct XmlFromSql<'a> {
649    xml: &'a str,
650}
651
652impl<'a> FromSql<'a> for XmlFromSql<'a> {
653    fn from_sql(
654        _ty: &Type,
655        raw: &'a [u8],
656    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
657        let xml = str::from_utf8(raw)?;
658        Ok(XmlFromSql { xml })
659    }
660
661    fn accepts(ty: &Type) -> bool {
662        matches!(*ty, Type::XML)
663    }
664}
665
666fn rows_to_batch(
667    rows: &[Row],
668    table_schema: &SchemaRef,
669    projection: Option<&Vec<usize>>,
670) -> DFResult<RecordBatch> {
671    let projected_schema = project_schema(table_schema, projection)?;
672    let mut array_builders = vec![];
673    for field in table_schema.fields() {
674        let builder = make_builder(field.data_type(), rows.len());
675        array_builders.push(builder);
676    }
677
678    for row in rows {
679        for (idx, field) in table_schema.fields.iter().enumerate() {
680            if !projections_contains(projection, idx) {
681                continue;
682            }
683            let builder = &mut array_builders[idx];
684            let col = row.columns().get(idx);
685            match field.data_type() {
686                DataType::Int16 => {
687                    handle_primitive_type!(
688                        builder,
689                        field,
690                        col,
691                        Int16Builder,
692                        i16,
693                        row,
694                        idx,
695                        just_return
696                    );
697                }
698                DataType::Int32 => {
699                    handle_primitive_type!(
700                        builder,
701                        field,
702                        col,
703                        Int32Builder,
704                        i32,
705                        row,
706                        idx,
707                        just_return
708                    );
709                }
710                DataType::UInt32 => {
711                    handle_primitive_type!(
712                        builder,
713                        field,
714                        col,
715                        UInt32Builder,
716                        u32,
717                        row,
718                        idx,
719                        just_return
720                    );
721                }
722                DataType::Int64 => {
723                    handle_primitive_type!(
724                        builder,
725                        field,
726                        col,
727                        Int64Builder,
728                        i64,
729                        row,
730                        idx,
731                        just_return
732                    );
733                }
734                DataType::Float32 => {
735                    handle_primitive_type!(
736                        builder,
737                        field,
738                        col,
739                        Float32Builder,
740                        f32,
741                        row,
742                        idx,
743                        just_return
744                    );
745                }
746                DataType::Float64 => {
747                    handle_primitive_type!(
748                        builder,
749                        field,
750                        col,
751                        Float64Builder,
752                        f64,
753                        row,
754                        idx,
755                        just_return
756                    );
757                }
758                DataType::Decimal128(_precision, scale) => {
759                    handle_primitive_type!(
760                        builder,
761                        field,
762                        col,
763                        Decimal128Builder,
764                        BigDecimalFromSql,
765                        row,
766                        idx,
767                        |v: BigDecimalFromSql| { v.to_i128_with_scale(*scale as i32) }
768                    );
769                }
770                DataType::Decimal256(_precision, scale) => {
771                    handle_primitive_type!(
772                        builder,
773                        field,
774                        col,
775                        Decimal256Builder,
776                        BigDecimalFromSql,
777                        row,
778                        idx,
779                        |v: BigDecimalFromSql| { v.to_i256_with_scale(*scale as i32) }
780                    );
781                }
782                DataType::Utf8 => {
783                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("xml") {
784                        let convert: for<'a> fn(XmlFromSql<'a>) -> DFResult<&'a str> =
785                            |v| Ok(v.xml);
786                        handle_primitive_type!(
787                            builder,
788                            field,
789                            col,
790                            StringBuilder,
791                            XmlFromSql,
792                            row,
793                            idx,
794                            convert
795                        );
796                    } else {
797                        handle_primitive_type!(
798                            builder,
799                            field,
800                            col,
801                            StringBuilder,
802                            &str,
803                            row,
804                            idx,
805                            just_return
806                        );
807                    }
808                }
809                DataType::LargeUtf8 => {
810                    if col.is_some() && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB) {
811                        handle_primitive_type!(
812                            builder,
813                            field,
814                            col,
815                            LargeStringBuilder,
816                            serde_json::value::Value,
817                            row,
818                            idx,
819                            |v: serde_json::value::Value| {
820                                Ok::<_, DataFusionError>(v.to_string())
821                            }
822                        );
823                    } else {
824                        handle_primitive_type!(
825                            builder,
826                            field,
827                            col,
828                            LargeStringBuilder,
829                            &str,
830                            row,
831                            idx,
832                            just_return
833                        );
834                    }
835                }
836                DataType::Utf8View => {
837                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("xml") {
838                        let convert: for<'a> fn(XmlFromSql<'a>) -> DFResult<&'a str> =
839                            |v| Ok(v.xml);
840                        handle_primitive_type!(
841                            builder,
842                            field,
843                            col,
844                            StringViewBuilder,
845                            XmlFromSql,
846                            row,
847                            idx,
848                            convert
849                        );
850                    } else if col.is_some()
851                        && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
852                    {
853                        handle_primitive_type!(
854                            builder,
855                            field,
856                            col,
857                            StringViewBuilder,
858                            serde_json::value::Value,
859                            row,
860                            idx,
861                            |v: serde_json::value::Value| {
862                                Ok::<_, DataFusionError>(v.to_string())
863                            }
864                        );
865                    } else {
866                        handle_primitive_type!(
867                            builder,
868                            field,
869                            col,
870                            StringViewBuilder,
871                            &str,
872                            row,
873                            idx,
874                            just_return
875                        );
876                    }
877                }
878                DataType::Binary => {
879                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("geometry")
880                    {
881                        let convert: for<'a> fn(GeometryFromSql<'a>) -> DFResult<&'a [u8]> =
882                            |v| Ok(v.wkb);
883                        handle_primitive_type!(
884                            builder,
885                            field,
886                            col,
887                            BinaryBuilder,
888                            GeometryFromSql,
889                            row,
890                            idx,
891                            convert
892                        );
893                    } else if col.is_some()
894                        && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
895                    {
896                        handle_primitive_type!(
897                            builder,
898                            field,
899                            col,
900                            BinaryBuilder,
901                            serde_json::value::Value,
902                            row,
903                            idx,
904                            |v: serde_json::value::Value| {
905                                Ok::<_, DataFusionError>(v.to_string().into_bytes())
906                            }
907                        );
908                    } else {
909                        handle_primitive_type!(
910                            builder,
911                            field,
912                            col,
913                            BinaryBuilder,
914                            Vec<u8>,
915                            row,
916                            idx,
917                            just_return
918                        );
919                    }
920                }
921                DataType::BinaryView => {
922                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("geometry")
923                    {
924                        let convert: for<'a> fn(GeometryFromSql<'a>) -> DFResult<&'a [u8]> =
925                            |v| Ok(v.wkb);
926                        handle_primitive_type!(
927                            builder,
928                            field,
929                            col,
930                            BinaryViewBuilder,
931                            GeometryFromSql,
932                            row,
933                            idx,
934                            convert
935                        );
936                    } else if col.is_some()
937                        && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
938                    {
939                        handle_primitive_type!(
940                            builder,
941                            field,
942                            col,
943                            BinaryViewBuilder,
944                            serde_json::value::Value,
945                            row,
946                            idx,
947                            |v: serde_json::value::Value| {
948                                Ok::<_, DataFusionError>(v.to_string().into_bytes())
949                            }
950                        );
951                    } else {
952                        handle_primitive_type!(
953                            builder,
954                            field,
955                            col,
956                            BinaryViewBuilder,
957                            Vec<u8>,
958                            row,
959                            idx,
960                            just_return
961                        );
962                    }
963                }
964                DataType::FixedSizeBinary(_) => {
965                    let builder = builder
966                        .as_any_mut()
967                        .downcast_mut::<FixedSizeBinaryBuilder>()
968                        .unwrap_or_else(|| {
969                            panic!(
970                                "Failed to downcast builder to FixedSizeBinaryBuilder for {field:?}"
971                            )
972                        });
973                    let v = if col.is_some()
974                        && col.unwrap().type_().name().eq_ignore_ascii_case("uuid")
975                    {
976                        let v: Option<Uuid> = row.try_get(idx).map_err(|e| {
977                            DataFusionError::Execution(format!(
978                                "Failed to get Uuid value for field {:?}: {e:?}",
979                                field
980                            ))
981                        })?;
982                        v.map(|v| v.as_bytes().to_vec())
983                    } else {
984                        let v: Option<Vec<u8>> = row.try_get(idx).map_err(|e| {
985                            DataFusionError::Execution(format!(
986                                "Failed to get FixedSizeBinary value for field {:?}: {e:?}",
987                                field
988                            ))
989                        })?;
990                        v
991                    };
992
993                    match v {
994                        Some(v) => builder.append_value(v)?,
995                        None => builder.append_null(),
996                    }
997                }
998                DataType::Timestamp(TimeUnit::Microsecond, None) => {
999                    handle_primitive_type!(
1000                        builder,
1001                        field,
1002                        col,
1003                        TimestampMicrosecondBuilder,
1004                        chrono::NaiveDateTime,
1005                        row,
1006                        idx,
1007                        |v: chrono::NaiveDateTime| {
1008                            let timestamp: i64 = v.and_utc().timestamp_micros();
1009
1010                            Ok::<i64, DataFusionError>(timestamp)
1011                        }
1012                    );
1013                }
1014                DataType::Timestamp(TimeUnit::Microsecond, Some(_tz)) => {
1015                    handle_primitive_type!(
1016                        builder,
1017                        field,
1018                        col,
1019                        TimestampMicrosecondBuilder,
1020                        chrono::DateTime<chrono::Utc>,
1021                        row,
1022                        idx,
1023                        |v: chrono::DateTime<chrono::Utc>| {
1024                            let timestamp: i64 = v.timestamp_micros();
1025                            Ok::<_, DataFusionError>(timestamp)
1026                        }
1027                    );
1028                }
1029                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
1030                    handle_primitive_type!(
1031                        builder,
1032                        field,
1033                        col,
1034                        TimestampNanosecondBuilder,
1035                        chrono::NaiveDateTime,
1036                        row,
1037                        idx,
1038                        |v: chrono::NaiveDateTime| {
1039                            let timestamp: i64 = v.and_utc().timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
1040                            Ok::<i64, DataFusionError>(timestamp)
1041                        }
1042                    );
1043                }
1044                DataType::Timestamp(TimeUnit::Nanosecond, Some(_tz)) => {
1045                    handle_primitive_type!(
1046                        builder,
1047                        field,
1048                        col,
1049                        TimestampNanosecondBuilder,
1050                        chrono::DateTime<chrono::Utc>,
1051                        row,
1052                        idx,
1053                        |v: chrono::DateTime<chrono::Utc>| {
1054                            let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
1055                            Ok::<_, DataFusionError>(timestamp)
1056                        }
1057                    );
1058                }
1059                DataType::Time64(TimeUnit::Microsecond) => {
1060                    handle_primitive_type!(
1061                        builder,
1062                        field,
1063                        col,
1064                        Time64MicrosecondBuilder,
1065                        chrono::NaiveTime,
1066                        row,
1067                        idx,
1068                        |v: chrono::NaiveTime| {
1069                            let seconds = i64::from(v.num_seconds_from_midnight());
1070                            let microseconds = i64::from(v.nanosecond()) / 1000;
1071                            Ok::<_, DataFusionError>(seconds * 1_000_000 + microseconds)
1072                        }
1073                    );
1074                }
1075                DataType::Time64(TimeUnit::Nanosecond) => {
1076                    handle_primitive_type!(
1077                        builder,
1078                        field,
1079                        col,
1080                        Time64NanosecondBuilder,
1081                        chrono::NaiveTime,
1082                        row,
1083                        idx,
1084                        |v: chrono::NaiveTime| {
1085                            let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
1086                                * 1_000_000_000
1087                                + i64::from(v.nanosecond());
1088                            Ok::<_, DataFusionError>(timestamp)
1089                        }
1090                    );
1091                }
1092                DataType::Date32 => {
1093                    handle_primitive_type!(
1094                        builder,
1095                        field,
1096                        col,
1097                        Date32Builder,
1098                        chrono::NaiveDate,
1099                        row,
1100                        idx,
1101                        |v| { Ok::<_, DataFusionError>(Date32Type::from_naive_date(v)) }
1102                    );
1103                }
1104                DataType::Interval(IntervalUnit::MonthDayNano) => {
1105                    handle_primitive_type!(
1106                        builder,
1107                        field,
1108                        col,
1109                        IntervalMonthDayNanoBuilder,
1110                        IntervalFromSql,
1111                        row,
1112                        idx,
1113                        |v: IntervalFromSql| {
1114                            let interval_month_day_nano = IntervalMonthDayNanoType::make_value(
1115                                v.month,
1116                                v.day,
1117                                v.time * 1_000,
1118                            );
1119                            Ok::<_, DataFusionError>(interval_month_day_nano)
1120                        }
1121                    );
1122                }
1123                DataType::Boolean => {
1124                    handle_primitive_type!(
1125                        builder,
1126                        field,
1127                        col,
1128                        BooleanBuilder,
1129                        bool,
1130                        row,
1131                        idx,
1132                        just_return
1133                    );
1134                }
1135                DataType::List(inner) => match inner.data_type() {
1136                    DataType::Int16 => {
1137                        handle_primitive_array_type!(
1138                            builder,
1139                            field,
1140                            col,
1141                            Int16Builder,
1142                            i16,
1143                            row,
1144                            idx
1145                        );
1146                    }
1147                    DataType::Int32 => {
1148                        handle_primitive_array_type!(
1149                            builder,
1150                            field,
1151                            col,
1152                            Int32Builder,
1153                            i32,
1154                            row,
1155                            idx
1156                        );
1157                    }
1158                    DataType::Int64 => {
1159                        handle_primitive_array_type!(
1160                            builder,
1161                            field,
1162                            col,
1163                            Int64Builder,
1164                            i64,
1165                            row,
1166                            idx
1167                        );
1168                    }
1169                    DataType::Float32 => {
1170                        handle_primitive_array_type!(
1171                            builder,
1172                            field,
1173                            col,
1174                            Float32Builder,
1175                            f32,
1176                            row,
1177                            idx
1178                        );
1179                    }
1180                    DataType::Float64 => {
1181                        handle_primitive_array_type!(
1182                            builder,
1183                            field,
1184                            col,
1185                            Float64Builder,
1186                            f64,
1187                            row,
1188                            idx
1189                        );
1190                    }
1191                    DataType::Utf8 => {
1192                        handle_primitive_array_type!(
1193                            builder,
1194                            field,
1195                            col,
1196                            StringBuilder,
1197                            &str,
1198                            row,
1199                            idx
1200                        );
1201                    }
1202                    DataType::Binary => {
1203                        handle_primitive_array_type!(
1204                            builder,
1205                            field,
1206                            col,
1207                            BinaryBuilder,
1208                            Vec<u8>,
1209                            row,
1210                            idx
1211                        );
1212                    }
1213                    DataType::Boolean => {
1214                        handle_primitive_array_type!(
1215                            builder,
1216                            field,
1217                            col,
1218                            BooleanBuilder,
1219                            bool,
1220                            row,
1221                            idx
1222                        );
1223                    }
1224                    _ => {
1225                        return Err(DataFusionError::NotImplemented(format!(
1226                            "Unsupported list data type {} for col: {:?}",
1227                            field.data_type(),
1228                            col
1229                        )));
1230                    }
1231                },
1232                _ => {
1233                    return Err(DataFusionError::NotImplemented(format!(
1234                        "Unsupported data type {} for col: {:?}",
1235                        field.data_type(),
1236                        col
1237                    )));
1238                }
1239            }
1240        }
1241    }
1242    let projected_columns = array_builders
1243        .into_iter()
1244        .enumerate()
1245        .filter(|(idx, _)| projections_contains(projection, *idx))
1246        .map(|(_, mut builder)| builder.finish())
1247        .collect::<Vec<ArrayRef>>();
1248    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
1249    Ok(RecordBatch::try_new_with_options(
1250        projected_schema,
1251        projected_columns,
1252        &options,
1253    )?)
1254}