datafusion_remote_table/connection/
postgres.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::{big_decimal_to_i128, big_decimal_to_i256};
3use crate::{
4    Connection, ConnectionOptions, DFResult, Literalize, Pool, PoolState,
5    PostgresConnectionOptions, PostgresType, RemoteField, RemoteSchema, RemoteSchemaRef,
6    RemoteSource, RemoteType, literalize_array,
7};
8use bb8_postgres::PostgresConnectionManager;
9use bb8_postgres::tokio_postgres::types::{FromSql, Type};
10use bb8_postgres::tokio_postgres::{NoTls, Row, Statement};
11use bigdecimal::BigDecimal;
12use byteorder::{BigEndian, ReadBytesExt};
13use chrono::Timelike;
14use datafusion::arrow::array::{
15    ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
16    Decimal256Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int16Builder,
17    Int32Builder, Int64Builder, IntervalMonthDayNanoBuilder, LargeStringBuilder, ListBuilder,
18    RecordBatch, RecordBatchOptions, StringBuilder, Time64MicrosecondBuilder,
19    Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder,
20    UInt32Builder, make_builder,
21};
22use datafusion::arrow::datatypes::{
23    DECIMAL256_MAX_PRECISION, DataType, Date32Type, IntervalMonthDayNanoType, IntervalUnit,
24    SchemaRef, TimeUnit, i256,
25};
26
27use datafusion::common::project_schema;
28use datafusion::error::DataFusionError;
29use datafusion::execution::SendableRecordBatchStream;
30use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31use futures::StreamExt;
32use log::debug;
33use num_bigint::{BigInt, Sign};
34use std::any::Any;
35use std::string::ToString;
36use std::sync::Arc;
37use uuid::Uuid;
38
39#[derive(Debug)]
40pub struct PostgresPool {
41    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
42    options: Arc<PostgresConnectionOptions>,
43}
44
45#[async_trait::async_trait]
46impl Pool for PostgresPool {
47    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
48        let conn = self.pool.get_owned().await.map_err(|e| {
49            DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
50        })?;
51        Ok(Arc::new(PostgresConnection {
52            conn,
53            options: self.options.clone(),
54        }))
55    }
56
57    async fn state(&self) -> DFResult<PoolState> {
58        let bb8_state = self.pool.state();
59        Ok(PoolState {
60            connections: bb8_state.connections as usize,
61            idle_connections: bb8_state.idle_connections as usize,
62        })
63    }
64}
65
66pub(crate) async fn connect_postgres(
67    options: &PostgresConnectionOptions,
68) -> DFResult<PostgresPool> {
69    let mut config = bb8_postgres::tokio_postgres::config::Config::new();
70    config
71        .host(&options.host)
72        .port(options.port)
73        .user(&options.username)
74        .password(&options.password);
75    if let Some(database) = &options.database {
76        config.dbname(database);
77    }
78    let manager = PostgresConnectionManager::new(config, NoTls);
79    let pool = bb8::Pool::builder()
80        .max_size(options.pool_max_size as u32)
81        .min_idle(Some(options.pool_min_idle as u32))
82        .idle_timeout(Some(options.pool_idle_timeout))
83        .reaper_rate(options.pool_ttl_check_interval)
84        .build(manager)
85        .await
86        .map_err(|e| {
87            DataFusionError::Execution(format!(
88                "Failed to create postgres connection pool due to {e}",
89            ))
90        })?;
91
92    Ok(PostgresPool {
93        pool,
94        options: Arc::new(options.clone()),
95    })
96}
97
98#[derive(Debug)]
99pub struct PostgresConnection {
100    pub conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
101    pub options: Arc<PostgresConnectionOptions>,
102}
103
104#[async_trait::async_trait]
105impl Connection for PostgresConnection {
106    fn as_any(&self) -> &dyn Any {
107        self
108    }
109
110    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
111        match source {
112            RemoteSource::Table(table) => {
113                let db_type = RemoteDbType::Postgres;
114                let where_condition = if table.len() == 1 {
115                    format!("table_name = {}", db_type.sql_string_literal(&table[0]))
116                } else if table.len() == 2 {
117                    format!(
118                        "table_schema = {} AND table_name = {}",
119                        db_type.sql_string_literal(&table[0]),
120                        db_type.sql_string_literal(&table[1])
121                    )
122                } else {
123                    format!(
124                        "table_catalog = {} AND table_schema = {} AND table_name = {}",
125                        db_type.sql_string_literal(&table[0]),
126                        db_type.sql_string_literal(&table[1]),
127                        db_type.sql_string_literal(&table[2])
128                    )
129                };
130                let sql = format!(
131                    "
132select
133	column_name,
134	case
135        when data_type = 'ARRAY'
136        		then data_type || udt_name
137        when data_type = 'USER-DEFINED'
138         		then udt_schema || '.' || udt_name
139		else
140                data_type
141	end as column_type,
142	numeric_precision,
143	numeric_scale,
144	is_nullable
145from information_schema.columns
146where {}
147order by ordinal_position",
148                    where_condition
149                );
150                let rows = self.conn.query(&sql, &[]).await.map_err(|e| {
151                    DataFusionError::Plan(format!(
152                        "Failed to execute query {sql} on postgres: {e:?}",
153                    ))
154                })?;
155                let remote_schema = Arc::new(build_remote_schema_for_table(
156                    rows,
157                    self.options.default_numeric_scale,
158                )?);
159                Ok(remote_schema)
160            }
161            RemoteSource::Query(query) => {
162                let stmt = self.conn.prepare(query).await.map_err(|e| {
163                    DataFusionError::Plan(format!(
164                        "Failed to execute query {query} on postgres: {e:?}",
165                    ))
166                })?;
167                let remote_schema = Arc::new(
168                    build_remote_schema_for_query(stmt, self.options.default_numeric_scale).await?,
169                );
170                Ok(remote_schema)
171            }
172        }
173    }
174
175    async fn query(
176        &self,
177        conn_options: &ConnectionOptions,
178        source: &RemoteSource,
179        table_schema: SchemaRef,
180        projection: Option<&Vec<usize>>,
181        unparsed_filters: &[String],
182        limit: Option<usize>,
183    ) -> DFResult<SendableRecordBatchStream> {
184        let projected_schema = project_schema(&table_schema, projection)?;
185
186        let sql = RemoteDbType::Postgres.rewrite_query(source, unparsed_filters, limit);
187        debug!("[remote-table] executing postgres query: {sql}");
188
189        let projection = projection.cloned();
190        let chunk_size = conn_options.stream_chunk_size();
191        let stream = self
192            .conn
193            .query_raw(&sql, Vec::<String>::new())
194            .await
195            .map_err(|e| {
196                DataFusionError::Execution(format!(
197                    "Failed to execute query {sql} on postgres: {e}",
198                ))
199            })?
200            .chunks(chunk_size)
201            .boxed();
202
203        let stream = stream.map(move |rows| {
204            let rows: Vec<Row> = rows
205                .into_iter()
206                .collect::<Result<Vec<_>, _>>()
207                .map_err(|e| {
208                    DataFusionError::Execution(format!(
209                        "Failed to collect rows from postgres due to {e}",
210                    ))
211                })?;
212            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
213        });
214
215        Ok(Box::pin(RecordBatchStreamAdapter::new(
216            projected_schema,
217            stream,
218        )))
219    }
220
221    async fn insert(
222        &self,
223        _conn_options: &ConnectionOptions,
224        literalizer: Arc<dyn Literalize>,
225        table: &[String],
226        remote_schema: RemoteSchemaRef,
227        batch: RecordBatch,
228    ) -> DFResult<usize> {
229        let mut columns = Vec::with_capacity(remote_schema.fields.len());
230        for i in 0..batch.num_columns() {
231            let input_field = batch.schema_ref().field(i);
232            let remote_field = &remote_schema.fields[i];
233            if remote_field.auto_increment && input_field.is_nullable() {
234                continue;
235            }
236
237            let remote_type = remote_schema.fields[i].remote_type.clone();
238            let array = batch.column(i);
239            let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
240            columns.push(column);
241        }
242
243        let num_rows = columns[0].len();
244        let num_columns = columns.len();
245
246        let mut values = Vec::with_capacity(num_rows);
247        for i in 0..num_rows {
248            let mut value = Vec::with_capacity(num_columns);
249            for col in columns.iter() {
250                value.push(col[i].as_str());
251            }
252            values.push(format!("({})", value.join(",")));
253        }
254
255        let mut col_names = Vec::with_capacity(remote_schema.fields.len());
256        for (remote_field, input_field) in remote_schema
257            .fields
258            .iter()
259            .zip(batch.schema_ref().fields.iter())
260        {
261            if remote_field.auto_increment && input_field.is_nullable() {
262                continue;
263            }
264            col_names.push(RemoteDbType::Postgres.sql_identifier(&remote_field.name));
265        }
266
267        let sql = format!(
268            "INSERT INTO {} ({}) VALUES {}",
269            RemoteDbType::Postgres.sql_table_name(table),
270            col_names.join(","),
271            values.join(",")
272        );
273
274        let count = self.conn.execute(&sql, &[]).await.map_err(|e| {
275            DataFusionError::Execution(format!(
276                "Failed to execute insert statement on postgres: {e:?}, sql: {sql}"
277            ))
278        })?;
279
280        Ok(count as usize)
281    }
282}
283
284async fn build_remote_schema_for_query(
285    stmt: Statement,
286    default_numeric_scale: i8,
287) -> DFResult<RemoteSchema> {
288    let mut remote_fields = Vec::new();
289    for col in stmt.columns().iter() {
290        let pg_type = col.type_();
291        let remote_type = pg_type_to_remote_type(pg_type, default_numeric_scale)?;
292        remote_fields.push(RemoteField::new(
293            col.name(),
294            RemoteType::Postgres(remote_type),
295            true,
296        ));
297    }
298    Ok(RemoteSchema::new(remote_fields))
299}
300
301fn pg_type_to_remote_type(pg_type: &Type, default_numeric_scale: i8) -> DFResult<PostgresType> {
302    match pg_type {
303        &Type::INT2 => Ok(PostgresType::Int2),
304        &Type::INT4 => Ok(PostgresType::Int4),
305        &Type::INT8 => Ok(PostgresType::Int8),
306        &Type::FLOAT4 => Ok(PostgresType::Float4),
307        &Type::FLOAT8 => Ok(PostgresType::Float8),
308        &Type::NUMERIC => Ok(PostgresType::Numeric(
309            DECIMAL256_MAX_PRECISION,
310            default_numeric_scale,
311        )),
312        &Type::OID => Ok(PostgresType::Oid),
313        &Type::NAME => Ok(PostgresType::Name),
314        &Type::VARCHAR => Ok(PostgresType::Varchar),
315        &Type::BPCHAR => Ok(PostgresType::Bpchar),
316        &Type::TEXT => Ok(PostgresType::Text),
317        &Type::BYTEA => Ok(PostgresType::Bytea),
318        &Type::DATE => Ok(PostgresType::Date),
319        &Type::TIMESTAMP => Ok(PostgresType::Timestamp),
320        &Type::TIMESTAMPTZ => Ok(PostgresType::TimestampTz),
321        &Type::TIME => Ok(PostgresType::Time),
322        &Type::INTERVAL => Ok(PostgresType::Interval),
323        &Type::BOOL => Ok(PostgresType::Bool),
324        &Type::JSON => Ok(PostgresType::Json),
325        &Type::JSONB => Ok(PostgresType::Jsonb),
326        &Type::INT2_ARRAY => Ok(PostgresType::Int2Array),
327        &Type::INT4_ARRAY => Ok(PostgresType::Int4Array),
328        &Type::INT8_ARRAY => Ok(PostgresType::Int8Array),
329        &Type::FLOAT4_ARRAY => Ok(PostgresType::Float4Array),
330        &Type::FLOAT8_ARRAY => Ok(PostgresType::Float8Array),
331        &Type::VARCHAR_ARRAY => Ok(PostgresType::VarcharArray),
332        &Type::BPCHAR_ARRAY => Ok(PostgresType::BpcharArray),
333        &Type::TEXT_ARRAY => Ok(PostgresType::TextArray),
334        &Type::BYTEA_ARRAY => Ok(PostgresType::ByteaArray),
335        &Type::BOOL_ARRAY => Ok(PostgresType::BoolArray),
336        &Type::XML => Ok(PostgresType::Xml),
337        &Type::UUID => Ok(PostgresType::Uuid),
338        other if other.name().eq_ignore_ascii_case("geometry") => Ok(PostgresType::PostGisGeometry),
339        _ => Err(DataFusionError::NotImplemented(format!(
340            "Unsupported postgres type {pg_type:?}",
341        ))),
342    }
343}
344
345fn build_remote_schema_for_table(
346    rows: Vec<Row>,
347    default_numeric_scale: i8,
348) -> DFResult<RemoteSchema> {
349    let mut remote_fields = vec![];
350    for row in rows {
351        let columa_name = row.try_get::<_, String>(0).map_err(|e| {
352            DataFusionError::Plan(format!("Failed to get col name from postgres row: {e:?}"))
353        })?;
354        let column_type = row.try_get::<_, String>(1).map_err(|e| {
355            DataFusionError::Plan(format!("Failed to get col type from postgres row: {e:?}"))
356        })?;
357        let numeric_precision = row.try_get::<_, Option<i32>>(2).map_err(|e| {
358            DataFusionError::Plan(format!(
359                "Failed to get numeric precision from postgres row: {e:?}"
360            ))
361        })?;
362        let numeric_scale = row.try_get::<_, Option<i32>>(3).map_err(|e| {
363            DataFusionError::Plan(format!(
364                "Failed to get numeric scale from postgres row: {e:?}"
365            ))
366        })?;
367        let pg_type = parse_pg_type(
368            &column_type,
369            numeric_precision,
370            numeric_scale.unwrap_or(default_numeric_scale as i32),
371        )?;
372        let is_nullable = row.try_get::<_, String>(4).map_err(|e| {
373            DataFusionError::Plan(format!(
374                "Failed to get is_nullable from postgres row: {e:?}"
375            ))
376        })?;
377        let nullable = match is_nullable.as_str() {
378            "YES" => true,
379            "NO" => false,
380            _ => {
381                return Err(DataFusionError::Plan(format!(
382                    "Unsupported postgres is_nullable value {is_nullable}"
383                )));
384            }
385        };
386        remote_fields.push(RemoteField::new(
387            columa_name,
388            RemoteType::Postgres(pg_type),
389            nullable,
390        ));
391    }
392    Ok(RemoteSchema::new(remote_fields))
393}
394
395fn parse_pg_type(
396    pg_type: &str,
397    numeric_precision: Option<i32>,
398    numeric_scale: i32,
399) -> DFResult<PostgresType> {
400    match pg_type {
401        "smallint" => Ok(PostgresType::Int2),
402        "integer" => Ok(PostgresType::Int4),
403        "bigint" => Ok(PostgresType::Int8),
404        "real" => Ok(PostgresType::Float4),
405        "double precision" => Ok(PostgresType::Float8),
406        "numeric" => Ok(PostgresType::Numeric(
407            numeric_precision.unwrap_or(DECIMAL256_MAX_PRECISION as i32) as u8,
408            numeric_scale as i8,
409        )),
410        "character varying" => Ok(PostgresType::Varchar),
411        "character" => Ok(PostgresType::Bpchar),
412        "text" => Ok(PostgresType::Text),
413        "bytea" => Ok(PostgresType::Bytea),
414        "date" => Ok(PostgresType::Date),
415        "time without time zone" => Ok(PostgresType::Time),
416        "timestamp without time zone" => Ok(PostgresType::Timestamp),
417        "timestamp with time zone" => Ok(PostgresType::TimestampTz),
418        "interval" => Ok(PostgresType::Interval),
419        "boolean" => Ok(PostgresType::Bool),
420        "json" => Ok(PostgresType::Json),
421        "jsonb" => Ok(PostgresType::Jsonb),
422        "public.geometry" => Ok(PostgresType::PostGisGeometry),
423        "ARRAY_int2" => Ok(PostgresType::Int2Array),
424        "ARRAY_int4" => Ok(PostgresType::Int4Array),
425        "ARRAY_int8" => Ok(PostgresType::Int8Array),
426        "ARRAY_float4" => Ok(PostgresType::Float4Array),
427        "ARRAY_float8" => Ok(PostgresType::Float8Array),
428        "ARRAY_varchar" => Ok(PostgresType::VarcharArray),
429        "ARRAY_bpchar" => Ok(PostgresType::BpcharArray),
430        "ARRAY_text" => Ok(PostgresType::TextArray),
431        "ARRAY_bytea" => Ok(PostgresType::ByteaArray),
432        "ARRAY_bool" => Ok(PostgresType::BoolArray),
433        "xml" => Ok(PostgresType::Xml),
434        "uuid" => Ok(PostgresType::Uuid),
435        "oid" => Ok(PostgresType::Oid),
436        "name" => Ok(PostgresType::Name),
437        _ => Err(DataFusionError::Execution(format!(
438            "Unsupported postgres type {pg_type}"
439        ))),
440    }
441}
442
443macro_rules! handle_primitive_type {
444    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
445        let builder = $builder
446            .as_any_mut()
447            .downcast_mut::<$builder_ty>()
448            .unwrap_or_else(|| {
449                panic!(
450                    "Failed to downcast builder to {} for {:?} and {:?}",
451                    stringify!($builder_ty),
452                    $field,
453                    $col
454                )
455            });
456        let v: Option<$value_ty> = $row.try_get($index).map_err(|e| {
457            DataFusionError::Execution(format!(
458                "Failed to get {} value for {:?} and {:?}: {e:?}",
459                stringify!($value_ty),
460                $field,
461                $col
462            ))
463        })?;
464
465        match v {
466            Some(v) => builder.append_value($convert(v)?),
467            None => builder.append_null(),
468        }
469    }};
470}
471
472macro_rules! handle_primitive_array_type {
473    ($builder:expr, $field:expr, $col:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
474        let builder = $builder
475            .as_any_mut()
476            .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
477            .unwrap_or_else(|| {
478                panic!(
479                    "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for {:?} and {:?}",
480                    $field, $col
481                )
482            });
483        let values_builder = builder
484            .values()
485            .as_any_mut()
486            .downcast_mut::<$values_builder_ty>()
487            .unwrap_or_else(|| {
488                panic!(
489                    "Failed to downcast values builder to {} for {:?} and {:?}",
490                    stringify!($builder_ty),
491                    $field,
492                    $col,
493                )
494            });
495        let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).map_err(|e| {
496            DataFusionError::Execution(format!(
497                "Failed to get {} array value for {:?} and {:?}: {e:?}",
498                stringify!($value_ty),
499                $field,
500                $col,
501            ))
502        })?;
503
504        match v {
505            Some(v) => {
506                let v = v.into_iter().map(Some);
507                values_builder.extend(v);
508                builder.append(true);
509            }
510            None => builder.append_null(),
511        }
512    }};
513}
514
515#[derive(Debug)]
516struct BigDecimalFromSql {
517    inner: BigDecimal,
518}
519
520impl BigDecimalFromSql {
521    fn to_i128_with_scale(&self, scale: i32) -> DFResult<i128> {
522        big_decimal_to_i128(&self.inner, Some(scale))
523    }
524
525    fn to_i256_with_scale(&self, scale: i32) -> DFResult<i256> {
526        big_decimal_to_i256(&self.inner, Some(scale))
527    }
528}
529
530#[allow(clippy::cast_sign_loss)]
531#[allow(clippy::cast_possible_wrap)]
532#[allow(clippy::cast_possible_truncation)]
533impl<'a> FromSql<'a> for BigDecimalFromSql {
534    fn from_sql(
535        _ty: &Type,
536        raw: &'a [u8],
537    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
538        let raw_u16: Vec<u16> = raw
539            .chunks(2)
540            .map(|chunk| {
541                if chunk.len() == 2 {
542                    u16::from_be_bytes([chunk[0], chunk[1]])
543                } else {
544                    u16::from_be_bytes([chunk[0], 0])
545                }
546            })
547            .collect();
548
549        let base_10_000_digit_count = raw_u16[0];
550        let weight = raw_u16[1] as i16;
551        let sign = raw_u16[2];
552        let scale = raw_u16[3];
553
554        let mut base_10_000_digits = Vec::new();
555        for i in 4..4 + base_10_000_digit_count {
556            base_10_000_digits.push(raw_u16[i as usize]);
557        }
558
559        let mut u8_digits = Vec::new();
560        for &base_10_000_digit in base_10_000_digits.iter().rev() {
561            let mut base_10_000_digit = base_10_000_digit;
562            let mut temp_result = Vec::new();
563            while base_10_000_digit > 0 {
564                temp_result.push((base_10_000_digit % 10) as u8);
565                base_10_000_digit /= 10;
566            }
567            while temp_result.len() < 4 {
568                temp_result.push(0);
569            }
570            u8_digits.extend(temp_result);
571        }
572        u8_digits.reverse();
573
574        let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
575        let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
576        u8_digits.resize(size as usize, 0);
577
578        let sign = match sign {
579            0x4000 => Sign::Minus,
580            0x0000 => Sign::Plus,
581            _ => {
582                return Err(Box::new(DataFusionError::Execution(
583                    "Failed to parse big decimal from postgres numeric value".to_string(),
584                )));
585            }
586        };
587
588        let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
589            return Err(Box::new(DataFusionError::Execution(
590                "Failed to parse big decimal from postgres numeric value".to_string(),
591            )));
592        };
593        Ok(BigDecimalFromSql {
594            inner: BigDecimal::new(digits, i64::from(scale)),
595        })
596    }
597
598    fn accepts(ty: &Type) -> bool {
599        matches!(*ty, Type::NUMERIC)
600    }
601}
602
603// interval_send - Postgres C (https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1032)
604// interval values are internally stored as three integral fields: months, days, and microseconds
605#[derive(Debug)]
606struct IntervalFromSql {
607    time: i64,
608    day: i32,
609    month: i32,
610}
611
612impl<'a> FromSql<'a> for IntervalFromSql {
613    fn from_sql(
614        _ty: &Type,
615        raw: &'a [u8],
616    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
617        let mut cursor = std::io::Cursor::new(raw);
618
619        let time = cursor.read_i64::<BigEndian>()?;
620        let day = cursor.read_i32::<BigEndian>()?;
621        let month = cursor.read_i32::<BigEndian>()?;
622
623        Ok(IntervalFromSql { time, day, month })
624    }
625
626    fn accepts(ty: &Type) -> bool {
627        matches!(*ty, Type::INTERVAL)
628    }
629}
630
631struct GeometryFromSql<'a> {
632    wkb: &'a [u8],
633}
634
635impl<'a> FromSql<'a> for GeometryFromSql<'a> {
636    fn from_sql(
637        _ty: &Type,
638        raw: &'a [u8],
639    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
640        Ok(GeometryFromSql { wkb: raw })
641    }
642
643    fn accepts(ty: &Type) -> bool {
644        matches!(ty.name(), "geometry")
645    }
646}
647
648struct XmlFromSql<'a> {
649    xml: &'a str,
650}
651
652impl<'a> FromSql<'a> for XmlFromSql<'a> {
653    fn from_sql(
654        _ty: &Type,
655        raw: &'a [u8],
656    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
657        let xml = str::from_utf8(raw)?;
658        Ok(XmlFromSql { xml })
659    }
660
661    fn accepts(ty: &Type) -> bool {
662        matches!(*ty, Type::XML)
663    }
664}
665
666fn rows_to_batch(
667    rows: &[Row],
668    table_schema: &SchemaRef,
669    projection: Option<&Vec<usize>>,
670) -> DFResult<RecordBatch> {
671    let projected_schema = project_schema(table_schema, projection)?;
672    let mut array_builders = vec![];
673    for field in table_schema.fields() {
674        let builder = make_builder(field.data_type(), rows.len());
675        array_builders.push(builder);
676    }
677
678    for row in rows {
679        for (idx, field) in table_schema.fields.iter().enumerate() {
680            if !projections_contains(projection, idx) {
681                continue;
682            }
683            let builder = &mut array_builders[idx];
684            let col = row.columns().get(idx);
685            match field.data_type() {
686                DataType::Int16 => {
687                    handle_primitive_type!(
688                        builder,
689                        field,
690                        col,
691                        Int16Builder,
692                        i16,
693                        row,
694                        idx,
695                        just_return
696                    );
697                }
698                DataType::Int32 => {
699                    handle_primitive_type!(
700                        builder,
701                        field,
702                        col,
703                        Int32Builder,
704                        i32,
705                        row,
706                        idx,
707                        just_return
708                    );
709                }
710                DataType::UInt32 => {
711                    handle_primitive_type!(
712                        builder,
713                        field,
714                        col,
715                        UInt32Builder,
716                        u32,
717                        row,
718                        idx,
719                        just_return
720                    );
721                }
722                DataType::Int64 => {
723                    handle_primitive_type!(
724                        builder,
725                        field,
726                        col,
727                        Int64Builder,
728                        i64,
729                        row,
730                        idx,
731                        just_return
732                    );
733                }
734                DataType::Float32 => {
735                    handle_primitive_type!(
736                        builder,
737                        field,
738                        col,
739                        Float32Builder,
740                        f32,
741                        row,
742                        idx,
743                        just_return
744                    );
745                }
746                DataType::Float64 => {
747                    handle_primitive_type!(
748                        builder,
749                        field,
750                        col,
751                        Float64Builder,
752                        f64,
753                        row,
754                        idx,
755                        just_return
756                    );
757                }
758                DataType::Decimal128(_precision, scale) => {
759                    handle_primitive_type!(
760                        builder,
761                        field,
762                        col,
763                        Decimal128Builder,
764                        BigDecimalFromSql,
765                        row,
766                        idx,
767                        |v: BigDecimalFromSql| { v.to_i128_with_scale(*scale as i32) }
768                    );
769                }
770                DataType::Decimal256(_precision, scale) => {
771                    handle_primitive_type!(
772                        builder,
773                        field,
774                        col,
775                        Decimal256Builder,
776                        BigDecimalFromSql,
777                        row,
778                        idx,
779                        |v: BigDecimalFromSql| { v.to_i256_with_scale(*scale as i32) }
780                    );
781                }
782                DataType::Utf8 => {
783                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("xml") {
784                        let convert: for<'a> fn(XmlFromSql<'a>) -> DFResult<&'a str> =
785                            |v| Ok(v.xml);
786                        handle_primitive_type!(
787                            builder,
788                            field,
789                            col,
790                            StringBuilder,
791                            XmlFromSql,
792                            row,
793                            idx,
794                            convert
795                        );
796                    } else {
797                        handle_primitive_type!(
798                            builder,
799                            field,
800                            col,
801                            StringBuilder,
802                            &str,
803                            row,
804                            idx,
805                            just_return
806                        );
807                    }
808                }
809                DataType::LargeUtf8 => {
810                    if col.is_some() && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB) {
811                        handle_primitive_type!(
812                            builder,
813                            field,
814                            col,
815                            LargeStringBuilder,
816                            serde_json::value::Value,
817                            row,
818                            idx,
819                            |v: serde_json::value::Value| {
820                                Ok::<_, DataFusionError>(v.to_string())
821                            }
822                        );
823                    } else {
824                        handle_primitive_type!(
825                            builder,
826                            field,
827                            col,
828                            LargeStringBuilder,
829                            &str,
830                            row,
831                            idx,
832                            just_return
833                        );
834                    }
835                }
836                DataType::Binary => {
837                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("geometry")
838                    {
839                        let convert: for<'a> fn(GeometryFromSql<'a>) -> DFResult<&'a [u8]> =
840                            |v| Ok(v.wkb);
841                        handle_primitive_type!(
842                            builder,
843                            field,
844                            col,
845                            BinaryBuilder,
846                            GeometryFromSql,
847                            row,
848                            idx,
849                            convert
850                        );
851                    } else if col.is_some()
852                        && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
853                    {
854                        handle_primitive_type!(
855                            builder,
856                            field,
857                            col,
858                            BinaryBuilder,
859                            serde_json::value::Value,
860                            row,
861                            idx,
862                            |v: serde_json::value::Value| {
863                                Ok::<_, DataFusionError>(v.to_string().into_bytes())
864                            }
865                        );
866                    } else {
867                        handle_primitive_type!(
868                            builder,
869                            field,
870                            col,
871                            BinaryBuilder,
872                            Vec<u8>,
873                            row,
874                            idx,
875                            just_return
876                        );
877                    }
878                }
879                DataType::FixedSizeBinary(_) => {
880                    let builder = builder
881                        .as_any_mut()
882                        .downcast_mut::<FixedSizeBinaryBuilder>()
883                        .unwrap_or_else(|| {
884                            panic!(
885                                "Failed to downcast builder to FixedSizeBinaryBuilder for {field:?}"
886                            )
887                        });
888                    let v = if col.is_some()
889                        && col.unwrap().type_().name().eq_ignore_ascii_case("uuid")
890                    {
891                        let v: Option<Uuid> = row.try_get(idx).map_err(|e| {
892                            DataFusionError::Execution(format!(
893                                "Failed to get Uuid value for field {:?}: {e:?}",
894                                field
895                            ))
896                        })?;
897                        v.map(|v| v.as_bytes().to_vec())
898                    } else {
899                        let v: Option<Vec<u8>> = row.try_get(idx).map_err(|e| {
900                            DataFusionError::Execution(format!(
901                                "Failed to get FixedSizeBinary value for field {:?}: {e:?}",
902                                field
903                            ))
904                        })?;
905                        v
906                    };
907
908                    match v {
909                        Some(v) => builder.append_value(v)?,
910                        None => builder.append_null(),
911                    }
912                }
913                DataType::Timestamp(TimeUnit::Microsecond, None) => {
914                    handle_primitive_type!(
915                        builder,
916                        field,
917                        col,
918                        TimestampMicrosecondBuilder,
919                        chrono::NaiveDateTime,
920                        row,
921                        idx,
922                        |v: chrono::NaiveDateTime| {
923                            let timestamp: i64 = v.and_utc().timestamp_micros();
924
925                            Ok::<i64, DataFusionError>(timestamp)
926                        }
927                    );
928                }
929                DataType::Timestamp(TimeUnit::Microsecond, Some(_tz)) => {
930                    handle_primitive_type!(
931                        builder,
932                        field,
933                        col,
934                        TimestampMicrosecondBuilder,
935                        chrono::DateTime<chrono::Utc>,
936                        row,
937                        idx,
938                        |v: chrono::DateTime<chrono::Utc>| {
939                            let timestamp: i64 = v.timestamp_micros();
940                            Ok::<_, DataFusionError>(timestamp)
941                        }
942                    );
943                }
944                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
945                    handle_primitive_type!(
946                        builder,
947                        field,
948                        col,
949                        TimestampNanosecondBuilder,
950                        chrono::NaiveDateTime,
951                        row,
952                        idx,
953                        |v: chrono::NaiveDateTime| {
954                            let timestamp: i64 = v.and_utc().timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
955                            Ok::<i64, DataFusionError>(timestamp)
956                        }
957                    );
958                }
959                DataType::Timestamp(TimeUnit::Nanosecond, Some(_tz)) => {
960                    handle_primitive_type!(
961                        builder,
962                        field,
963                        col,
964                        TimestampNanosecondBuilder,
965                        chrono::DateTime<chrono::Utc>,
966                        row,
967                        idx,
968                        |v: chrono::DateTime<chrono::Utc>| {
969                            let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
970                            Ok::<_, DataFusionError>(timestamp)
971                        }
972                    );
973                }
974                DataType::Time64(TimeUnit::Microsecond) => {
975                    handle_primitive_type!(
976                        builder,
977                        field,
978                        col,
979                        Time64MicrosecondBuilder,
980                        chrono::NaiveTime,
981                        row,
982                        idx,
983                        |v: chrono::NaiveTime| {
984                            let seconds = i64::from(v.num_seconds_from_midnight());
985                            let microseconds = i64::from(v.nanosecond()) / 1000;
986                            Ok::<_, DataFusionError>(seconds * 1_000_000 + microseconds)
987                        }
988                    );
989                }
990                DataType::Time64(TimeUnit::Nanosecond) => {
991                    handle_primitive_type!(
992                        builder,
993                        field,
994                        col,
995                        Time64NanosecondBuilder,
996                        chrono::NaiveTime,
997                        row,
998                        idx,
999                        |v: chrono::NaiveTime| {
1000                            let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
1001                                * 1_000_000_000
1002                                + i64::from(v.nanosecond());
1003                            Ok::<_, DataFusionError>(timestamp)
1004                        }
1005                    );
1006                }
1007                DataType::Date32 => {
1008                    handle_primitive_type!(
1009                        builder,
1010                        field,
1011                        col,
1012                        Date32Builder,
1013                        chrono::NaiveDate,
1014                        row,
1015                        idx,
1016                        |v| { Ok::<_, DataFusionError>(Date32Type::from_naive_date(v)) }
1017                    );
1018                }
1019                DataType::Interval(IntervalUnit::MonthDayNano) => {
1020                    handle_primitive_type!(
1021                        builder,
1022                        field,
1023                        col,
1024                        IntervalMonthDayNanoBuilder,
1025                        IntervalFromSql,
1026                        row,
1027                        idx,
1028                        |v: IntervalFromSql| {
1029                            let interval_month_day_nano = IntervalMonthDayNanoType::make_value(
1030                                v.month,
1031                                v.day,
1032                                v.time * 1_000,
1033                            );
1034                            Ok::<_, DataFusionError>(interval_month_day_nano)
1035                        }
1036                    );
1037                }
1038                DataType::Boolean => {
1039                    handle_primitive_type!(
1040                        builder,
1041                        field,
1042                        col,
1043                        BooleanBuilder,
1044                        bool,
1045                        row,
1046                        idx,
1047                        just_return
1048                    );
1049                }
1050                DataType::List(inner) => match inner.data_type() {
1051                    DataType::Int16 => {
1052                        handle_primitive_array_type!(
1053                            builder,
1054                            field,
1055                            col,
1056                            Int16Builder,
1057                            i16,
1058                            row,
1059                            idx
1060                        );
1061                    }
1062                    DataType::Int32 => {
1063                        handle_primitive_array_type!(
1064                            builder,
1065                            field,
1066                            col,
1067                            Int32Builder,
1068                            i32,
1069                            row,
1070                            idx
1071                        );
1072                    }
1073                    DataType::Int64 => {
1074                        handle_primitive_array_type!(
1075                            builder,
1076                            field,
1077                            col,
1078                            Int64Builder,
1079                            i64,
1080                            row,
1081                            idx
1082                        );
1083                    }
1084                    DataType::Float32 => {
1085                        handle_primitive_array_type!(
1086                            builder,
1087                            field,
1088                            col,
1089                            Float32Builder,
1090                            f32,
1091                            row,
1092                            idx
1093                        );
1094                    }
1095                    DataType::Float64 => {
1096                        handle_primitive_array_type!(
1097                            builder,
1098                            field,
1099                            col,
1100                            Float64Builder,
1101                            f64,
1102                            row,
1103                            idx
1104                        );
1105                    }
1106                    DataType::Utf8 => {
1107                        handle_primitive_array_type!(
1108                            builder,
1109                            field,
1110                            col,
1111                            StringBuilder,
1112                            &str,
1113                            row,
1114                            idx
1115                        );
1116                    }
1117                    DataType::Binary => {
1118                        handle_primitive_array_type!(
1119                            builder,
1120                            field,
1121                            col,
1122                            BinaryBuilder,
1123                            Vec<u8>,
1124                            row,
1125                            idx
1126                        );
1127                    }
1128                    DataType::Boolean => {
1129                        handle_primitive_array_type!(
1130                            builder,
1131                            field,
1132                            col,
1133                            BooleanBuilder,
1134                            bool,
1135                            row,
1136                            idx
1137                        );
1138                    }
1139                    _ => {
1140                        return Err(DataFusionError::NotImplemented(format!(
1141                            "Unsupported list data type {} for col: {:?}",
1142                            field.data_type(),
1143                            col
1144                        )));
1145                    }
1146                },
1147                _ => {
1148                    return Err(DataFusionError::NotImplemented(format!(
1149                        "Unsupported data type {} for col: {:?}",
1150                        field.data_type(),
1151                        col
1152                    )));
1153                }
1154            }
1155        }
1156    }
1157    let projected_columns = array_builders
1158        .into_iter()
1159        .enumerate()
1160        .filter(|(idx, _)| projections_contains(projection, *idx))
1161        .map(|(_, mut builder)| builder.finish())
1162        .collect::<Vec<ArrayRef>>();
1163    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
1164    Ok(RecordBatch::try_new_with_options(
1165        projected_schema,
1166        projected_columns,
1167        &options,
1168    )?)
1169}