datafusion_remote_table/connection/
postgres.rs

1use crate::connection::{RemoteDbType, just_return, projections_contains};
2use crate::utils::{big_decimal_to_i128, big_decimal_to_i256};
3use crate::{
4    Connection, ConnectionOptions, DFResult, Pool, PostgresType, RemoteField, RemoteSchema,
5    RemoteSchemaRef, RemoteSource, RemoteType, Unparse, unparse_array,
6};
7use bb8_postgres::PostgresConnectionManager;
8use bb8_postgres::tokio_postgres::types::{FromSql, Type};
9use bb8_postgres::tokio_postgres::{NoTls, Row, Statement};
10use bigdecimal::BigDecimal;
11use byteorder::{BigEndian, ReadBytesExt};
12use chrono::Timelike;
13use datafusion::arrow::array::{
14    ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
15    Decimal256Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int16Builder,
16    Int32Builder, Int64Builder, IntervalMonthDayNanoBuilder, LargeStringBuilder, ListBuilder,
17    RecordBatch, RecordBatchOptions, StringBuilder, Time64MicrosecondBuilder,
18    Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder,
19    UInt32Builder, make_builder,
20};
21use datafusion::arrow::datatypes::{
22    DECIMAL256_MAX_PRECISION, DataType, Date32Type, IntervalMonthDayNanoType, IntervalUnit,
23    SchemaRef, TimeUnit, i256,
24};
25
26use datafusion::common::project_schema;
27use datafusion::error::DataFusionError;
28use datafusion::execution::SendableRecordBatchStream;
29use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
30use derive_getters::Getters;
31use derive_with::With;
32use futures::StreamExt;
33use log::debug;
34use num_bigint::{BigInt, Sign};
35use std::any::Any;
36use std::string::ToString;
37use std::sync::Arc;
38use uuid::Uuid;
39
40#[derive(Debug, Clone, With, Getters)]
41pub struct PostgresConnectionOptions {
42    pub(crate) host: String,
43    pub(crate) port: u16,
44    pub(crate) username: String,
45    pub(crate) password: String,
46    pub(crate) database: Option<String>,
47    pub(crate) pool_max_size: usize,
48    pub(crate) stream_chunk_size: usize,
49    pub(crate) default_numeric_scale: i8,
50}
51
52impl PostgresConnectionOptions {
53    pub fn new(
54        host: impl Into<String>,
55        port: u16,
56        username: impl Into<String>,
57        password: impl Into<String>,
58    ) -> Self {
59        Self {
60            host: host.into(),
61            port,
62            username: username.into(),
63            password: password.into(),
64            database: None,
65            pool_max_size: 10,
66            stream_chunk_size: 2048,
67            default_numeric_scale: 10,
68        }
69    }
70}
71
72impl From<PostgresConnectionOptions> for ConnectionOptions {
73    fn from(options: PostgresConnectionOptions) -> Self {
74        ConnectionOptions::Postgres(options)
75    }
76}
77
78#[derive(Debug)]
79pub struct PostgresPool {
80    pool: bb8::Pool<PostgresConnectionManager<NoTls>>,
81    options: Arc<PostgresConnectionOptions>,
82}
83
84#[async_trait::async_trait]
85impl Pool for PostgresPool {
86    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
87        let conn = self.pool.get_owned().await.map_err(|e| {
88            DataFusionError::Execution(format!("Failed to get postgres connection due to {e:?}"))
89        })?;
90        Ok(Arc::new(PostgresConnection {
91            conn,
92            options: self.options.clone(),
93        }))
94    }
95}
96
97pub(crate) async fn connect_postgres(
98    options: &PostgresConnectionOptions,
99) -> DFResult<PostgresPool> {
100    let mut config = bb8_postgres::tokio_postgres::config::Config::new();
101    config
102        .host(&options.host)
103        .port(options.port)
104        .user(&options.username)
105        .password(&options.password);
106    if let Some(database) = &options.database {
107        config.dbname(database);
108    }
109    let manager = PostgresConnectionManager::new(config, NoTls);
110    let pool = bb8::Pool::builder()
111        .max_size(options.pool_max_size as u32)
112        .build(manager)
113        .await
114        .map_err(|e| {
115            DataFusionError::Execution(format!(
116                "Failed to create postgres connection pool due to {e}",
117            ))
118        })?;
119
120    Ok(PostgresPool {
121        pool,
122        options: Arc::new(options.clone()),
123    })
124}
125
126#[derive(Debug)]
127pub(crate) struct PostgresConnection {
128    conn: bb8::PooledConnection<'static, PostgresConnectionManager<NoTls>>,
129    options: Arc<PostgresConnectionOptions>,
130}
131
132#[async_trait::async_trait]
133impl Connection for PostgresConnection {
134    fn as_any(&self) -> &dyn Any {
135        self
136    }
137
138    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
139        match source {
140            RemoteSource::Table(table) => {
141                let db_type = RemoteDbType::Postgres;
142                let where_condition = if table.len() == 1 {
143                    format!("table_name = {}", db_type.sql_string_literal(&table[0]))
144                } else if table.len() == 2 {
145                    format!(
146                        "table_schema = {} AND table_name = {}",
147                        db_type.sql_string_literal(&table[0]),
148                        db_type.sql_string_literal(&table[1])
149                    )
150                } else {
151                    format!(
152                        "table_catalog = {} AND table_schema = {} AND table_name = {}",
153                        db_type.sql_string_literal(&table[0]),
154                        db_type.sql_string_literal(&table[1]),
155                        db_type.sql_string_literal(&table[2])
156                    )
157                };
158                let sql = format!(
159                    "
160select
161	column_name,
162	case
163        when data_type = 'ARRAY'
164        		then data_type || udt_name
165        when data_type = 'USER-DEFINED'
166         		then udt_schema || '.' || udt_name
167		else
168                data_type
169	end as column_type,
170	numeric_precision,
171	numeric_scale,
172	is_nullable
173from information_schema.columns
174where {}
175order by ordinal_position",
176                    where_condition
177                );
178                let rows = self.conn.query(&sql, &[]).await.map_err(|e| {
179                    DataFusionError::Execution(format!(
180                        "Failed to execute query {sql} on postgres: {e:?}",
181                    ))
182                })?;
183                let remote_schema = Arc::new(build_remote_schema_for_table(
184                    rows,
185                    self.options.default_numeric_scale,
186                )?);
187                Ok(remote_schema)
188            }
189            RemoteSource::Query(_query) => {
190                let sql = source.query(RemoteDbType::Postgres);
191                let stmt = self.conn.prepare(&sql).await.map_err(|e| {
192                    DataFusionError::Execution(format!(
193                        "Failed to execute query {sql} on postgres: {e:?}",
194                    ))
195                })?;
196                let remote_schema = Arc::new(
197                    build_remote_schema_for_query(stmt, self.options.default_numeric_scale).await?,
198                );
199                Ok(remote_schema)
200            }
201        }
202    }
203
204    async fn query(
205        &self,
206        conn_options: &ConnectionOptions,
207        source: &RemoteSource,
208        table_schema: SchemaRef,
209        projection: Option<&Vec<usize>>,
210        unparsed_filters: &[String],
211        limit: Option<usize>,
212    ) -> DFResult<SendableRecordBatchStream> {
213        let projected_schema = project_schema(&table_schema, projection)?;
214
215        let sql = RemoteDbType::Postgres.rewrite_query(source, unparsed_filters, limit);
216        debug!("[remote-table] executing postgres query: {sql}");
217
218        let projection = projection.cloned();
219        let chunk_size = conn_options.stream_chunk_size();
220        let stream = self
221            .conn
222            .query_raw(&sql, Vec::<String>::new())
223            .await
224            .map_err(|e| {
225                DataFusionError::Execution(format!(
226                    "Failed to execute query {sql} on postgres: {e}",
227                ))
228            })?
229            .chunks(chunk_size)
230            .boxed();
231
232        let stream = stream.map(move |rows| {
233            let rows: Vec<Row> = rows
234                .into_iter()
235                .collect::<Result<Vec<_>, _>>()
236                .map_err(|e| {
237                    DataFusionError::Execution(format!(
238                        "Failed to collect rows from postgres due to {e}",
239                    ))
240                })?;
241            rows_to_batch(rows.as_slice(), &table_schema, projection.as_ref())
242        });
243
244        Ok(Box::pin(RecordBatchStreamAdapter::new(
245            projected_schema,
246            stream,
247        )))
248    }
249
250    async fn insert(
251        &self,
252        _conn_options: &ConnectionOptions,
253        unparser: Arc<dyn Unparse>,
254        table: &[String],
255        remote_schema: RemoteSchemaRef,
256        mut input: SendableRecordBatchStream,
257    ) -> DFResult<usize> {
258        let input_schema = input.schema();
259
260        let mut total_count = 0;
261        while let Some(batch) = input.next().await {
262            let batch = batch?;
263
264            let mut columns = Vec::with_capacity(remote_schema.fields.len());
265            for i in 0..batch.num_columns() {
266                let input_field = input_schema.field(i);
267                let remote_field = &remote_schema.fields[i];
268                if remote_field.auto_increment && input_field.is_nullable() {
269                    continue;
270                }
271
272                let remote_type = remote_schema.fields[i].remote_type.clone();
273                let array = batch.column(i);
274                let column = unparse_array(unparser.as_ref(), array, remote_type)?;
275                columns.push(column);
276            }
277
278            let num_rows = columns[0].len();
279            let num_columns = columns.len();
280
281            let mut values = Vec::with_capacity(num_rows);
282            for i in 0..num_rows {
283                let mut value = Vec::with_capacity(num_columns);
284                for col in columns.iter() {
285                    value.push(col[i].as_str());
286                }
287                values.push(format!("({})", value.join(",")));
288            }
289
290            let mut col_names = Vec::with_capacity(remote_schema.fields.len());
291            for (remote_field, input_field) in
292                remote_schema.fields.iter().zip(input_schema.fields.iter())
293            {
294                if remote_field.auto_increment && input_field.is_nullable() {
295                    continue;
296                }
297                col_names.push(RemoteDbType::Postgres.sql_identifier(&remote_field.name));
298            }
299
300            let sql = format!(
301                "INSERT INTO {} ({}) VALUES {}",
302                RemoteDbType::Postgres.sql_table_name(table),
303                col_names.join(","),
304                values.join(",")
305            );
306
307            let count = self.conn.execute(&sql, &[]).await.map_err(|e| {
308                DataFusionError::Execution(format!(
309                    "Failed to execute insert statement on postgres: {e:?}, sql: {sql}"
310                ))
311            })?;
312            total_count += count as usize;
313        }
314
315        Ok(total_count)
316    }
317}
318
319async fn build_remote_schema_for_query(
320    stmt: Statement,
321    default_numeric_scale: i8,
322) -> DFResult<RemoteSchema> {
323    let mut remote_fields = Vec::new();
324    for col in stmt.columns().iter() {
325        let pg_type = col.type_();
326        let remote_type = pg_type_to_remote_type(pg_type, default_numeric_scale)?;
327        remote_fields.push(RemoteField::new(
328            col.name(),
329            RemoteType::Postgres(remote_type),
330            true,
331        ));
332    }
333    Ok(RemoteSchema::new(remote_fields))
334}
335
336fn pg_type_to_remote_type(pg_type: &Type, default_numeric_scale: i8) -> DFResult<PostgresType> {
337    match pg_type {
338        &Type::INT2 => Ok(PostgresType::Int2),
339        &Type::INT4 => Ok(PostgresType::Int4),
340        &Type::INT8 => Ok(PostgresType::Int8),
341        &Type::FLOAT4 => Ok(PostgresType::Float4),
342        &Type::FLOAT8 => Ok(PostgresType::Float8),
343        &Type::NUMERIC => Ok(PostgresType::Numeric(
344            DECIMAL256_MAX_PRECISION,
345            default_numeric_scale,
346        )),
347        &Type::OID => Ok(PostgresType::Oid),
348        &Type::NAME => Ok(PostgresType::Name),
349        &Type::VARCHAR => Ok(PostgresType::Varchar),
350        &Type::BPCHAR => Ok(PostgresType::Bpchar),
351        &Type::TEXT => Ok(PostgresType::Text),
352        &Type::BYTEA => Ok(PostgresType::Bytea),
353        &Type::DATE => Ok(PostgresType::Date),
354        &Type::TIMESTAMP => Ok(PostgresType::Timestamp),
355        &Type::TIMESTAMPTZ => Ok(PostgresType::TimestampTz),
356        &Type::TIME => Ok(PostgresType::Time),
357        &Type::INTERVAL => Ok(PostgresType::Interval),
358        &Type::BOOL => Ok(PostgresType::Bool),
359        &Type::JSON => Ok(PostgresType::Json),
360        &Type::JSONB => Ok(PostgresType::Jsonb),
361        &Type::INT2_ARRAY => Ok(PostgresType::Int2Array),
362        &Type::INT4_ARRAY => Ok(PostgresType::Int4Array),
363        &Type::INT8_ARRAY => Ok(PostgresType::Int8Array),
364        &Type::FLOAT4_ARRAY => Ok(PostgresType::Float4Array),
365        &Type::FLOAT8_ARRAY => Ok(PostgresType::Float8Array),
366        &Type::VARCHAR_ARRAY => Ok(PostgresType::VarcharArray),
367        &Type::BPCHAR_ARRAY => Ok(PostgresType::BpcharArray),
368        &Type::TEXT_ARRAY => Ok(PostgresType::TextArray),
369        &Type::BYTEA_ARRAY => Ok(PostgresType::ByteaArray),
370        &Type::BOOL_ARRAY => Ok(PostgresType::BoolArray),
371        &Type::XML => Ok(PostgresType::Xml),
372        &Type::UUID => Ok(PostgresType::Uuid),
373        other if other.name().eq_ignore_ascii_case("geometry") => Ok(PostgresType::PostGisGeometry),
374        _ => Err(DataFusionError::NotImplemented(format!(
375            "Unsupported postgres type {pg_type:?}",
376        ))),
377    }
378}
379
380fn build_remote_schema_for_table(
381    rows: Vec<Row>,
382    default_numeric_scale: i8,
383) -> DFResult<RemoteSchema> {
384    let mut remote_fields = vec![];
385    for row in rows {
386        let columa_name = row.try_get::<_, String>(0).map_err(|e| {
387            DataFusionError::Execution(format!("Failed to get col name from postgres row: {e:?}"))
388        })?;
389        let column_type = row.try_get::<_, String>(1).map_err(|e| {
390            DataFusionError::Execution(format!("Failed to get col type from postgres row: {e:?}"))
391        })?;
392        let numeric_precision = row.try_get::<_, Option<i32>>(2).map_err(|e| {
393            DataFusionError::Execution(format!(
394                "Failed to get numeric precision from postgres row: {e:?}"
395            ))
396        })?;
397        let numeric_scale = row.try_get::<_, Option<i32>>(3).map_err(|e| {
398            DataFusionError::Execution(format!(
399                "Failed to get numeric scale from postgres row: {e:?}"
400            ))
401        })?;
402        let pg_type = parse_pg_type(
403            &column_type,
404            numeric_precision,
405            numeric_scale.unwrap_or(default_numeric_scale as i32),
406        )?;
407        let is_nullable = row.try_get::<_, String>(4).map_err(|e| {
408            DataFusionError::Execution(format!(
409                "Failed to get is_nullable from postgres row: {e:?}"
410            ))
411        })?;
412        let nullable = match is_nullable.as_str() {
413            "YES" => true,
414            "NO" => false,
415            _ => {
416                return Err(DataFusionError::Execution(format!(
417                    "Unsupported postgres is_nullable value {is_nullable}"
418                )));
419            }
420        };
421        remote_fields.push(RemoteField::new(
422            columa_name,
423            RemoteType::Postgres(pg_type),
424            nullable,
425        ));
426    }
427    Ok(RemoteSchema::new(remote_fields))
428}
429
430fn parse_pg_type(
431    pg_type: &str,
432    numeric_precision: Option<i32>,
433    numeric_scale: i32,
434) -> DFResult<PostgresType> {
435    match pg_type {
436        "smallint" => Ok(PostgresType::Int2),
437        "integer" => Ok(PostgresType::Int4),
438        "bigint" => Ok(PostgresType::Int8),
439        "real" => Ok(PostgresType::Float4),
440        "double precision" => Ok(PostgresType::Float8),
441        "numeric" => Ok(PostgresType::Numeric(
442            numeric_precision.unwrap_or(DECIMAL256_MAX_PRECISION as i32) as u8,
443            numeric_scale as i8,
444        )),
445        "character varying" => Ok(PostgresType::Varchar),
446        "character" => Ok(PostgresType::Bpchar),
447        "text" => Ok(PostgresType::Text),
448        "bytea" => Ok(PostgresType::Bytea),
449        "date" => Ok(PostgresType::Date),
450        "time without time zone" => Ok(PostgresType::Time),
451        "timestamp without time zone" => Ok(PostgresType::Timestamp),
452        "timestamp with time zone" => Ok(PostgresType::TimestampTz),
453        "interval" => Ok(PostgresType::Interval),
454        "boolean" => Ok(PostgresType::Bool),
455        "json" => Ok(PostgresType::Json),
456        "jsonb" => Ok(PostgresType::Jsonb),
457        "public.geometry" => Ok(PostgresType::PostGisGeometry),
458        "ARRAY_int2" => Ok(PostgresType::Int2Array),
459        "ARRAY_int4" => Ok(PostgresType::Int4Array),
460        "ARRAY_int8" => Ok(PostgresType::Int8Array),
461        "ARRAY_float4" => Ok(PostgresType::Float4Array),
462        "ARRAY_float8" => Ok(PostgresType::Float8Array),
463        "ARRAY_varchar" => Ok(PostgresType::VarcharArray),
464        "ARRAY_bpchar" => Ok(PostgresType::BpcharArray),
465        "ARRAY_text" => Ok(PostgresType::TextArray),
466        "ARRAY_bytea" => Ok(PostgresType::ByteaArray),
467        "ARRAY_bool" => Ok(PostgresType::BoolArray),
468        "xml" => Ok(PostgresType::Xml),
469        "uuid" => Ok(PostgresType::Uuid),
470        "oid" => Ok(PostgresType::Oid),
471        "name" => Ok(PostgresType::Name),
472        _ => Err(DataFusionError::Execution(format!(
473            "Unsupported postgres type {pg_type}"
474        ))),
475    }
476}
477
478macro_rules! handle_primitive_type {
479    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr, $convert:expr) => {{
480        let builder = $builder
481            .as_any_mut()
482            .downcast_mut::<$builder_ty>()
483            .unwrap_or_else(|| {
484                panic!(
485                    "Failed to downcast builder to {} for {:?} and {:?}",
486                    stringify!($builder_ty),
487                    $field,
488                    $col
489                )
490            });
491        let v: Option<$value_ty> = $row.try_get($index).map_err(|e| {
492            DataFusionError::Execution(format!(
493                "Failed to get {} value for {:?} and {:?}: {e:?}",
494                stringify!($value_ty),
495                $field,
496                $col
497            ))
498        })?;
499
500        match v {
501            Some(v) => builder.append_value($convert(v)?),
502            None => builder.append_null(),
503        }
504    }};
505}
506
507macro_rules! handle_primitive_array_type {
508    ($builder:expr, $field:expr, $col:expr, $values_builder_ty:ty, $primitive_value_ty:ty, $row:expr, $index:expr) => {{
509        let builder = $builder
510            .as_any_mut()
511            .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
512            .unwrap_or_else(|| {
513                panic!(
514                    "Failed to downcast builder to ListBuilder<Box<dyn ArrayBuilder>> for {:?} and {:?}",
515                    $field, $col
516                )
517            });
518        let values_builder = builder
519            .values()
520            .as_any_mut()
521            .downcast_mut::<$values_builder_ty>()
522            .unwrap_or_else(|| {
523                panic!(
524                    "Failed to downcast values builder to {} for {:?} and {:?}",
525                    stringify!($builder_ty),
526                    $field,
527                    $col,
528                )
529            });
530        let v: Option<Vec<$primitive_value_ty>> = $row.try_get($index).map_err(|e| {
531            DataFusionError::Execution(format!(
532                "Failed to get {} array value for {:?} and {:?}: {e:?}",
533                stringify!($value_ty),
534                $field,
535                $col,
536            ))
537        })?;
538
539        match v {
540            Some(v) => {
541                let v = v.into_iter().map(Some);
542                values_builder.extend(v);
543                builder.append(true);
544            }
545            None => builder.append_null(),
546        }
547    }};
548}
549
550#[derive(Debug)]
551struct BigDecimalFromSql {
552    inner: BigDecimal,
553}
554
555impl BigDecimalFromSql {
556    fn to_i128_with_scale(&self, scale: i32) -> DFResult<i128> {
557        big_decimal_to_i128(&self.inner, Some(scale))
558    }
559
560    fn to_i256_with_scale(&self, scale: i32) -> DFResult<i256> {
561        big_decimal_to_i256(&self.inner, Some(scale))
562    }
563}
564
565#[allow(clippy::cast_sign_loss)]
566#[allow(clippy::cast_possible_wrap)]
567#[allow(clippy::cast_possible_truncation)]
568impl<'a> FromSql<'a> for BigDecimalFromSql {
569    fn from_sql(
570        _ty: &Type,
571        raw: &'a [u8],
572    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
573        let raw_u16: Vec<u16> = raw
574            .chunks(2)
575            .map(|chunk| {
576                if chunk.len() == 2 {
577                    u16::from_be_bytes([chunk[0], chunk[1]])
578                } else {
579                    u16::from_be_bytes([chunk[0], 0])
580                }
581            })
582            .collect();
583
584        let base_10_000_digit_count = raw_u16[0];
585        let weight = raw_u16[1] as i16;
586        let sign = raw_u16[2];
587        let scale = raw_u16[3];
588
589        let mut base_10_000_digits = Vec::new();
590        for i in 4..4 + base_10_000_digit_count {
591            base_10_000_digits.push(raw_u16[i as usize]);
592        }
593
594        let mut u8_digits = Vec::new();
595        for &base_10_000_digit in base_10_000_digits.iter().rev() {
596            let mut base_10_000_digit = base_10_000_digit;
597            let mut temp_result = Vec::new();
598            while base_10_000_digit > 0 {
599                temp_result.push((base_10_000_digit % 10) as u8);
600                base_10_000_digit /= 10;
601            }
602            while temp_result.len() < 4 {
603                temp_result.push(0);
604            }
605            u8_digits.extend(temp_result);
606        }
607        u8_digits.reverse();
608
609        let value_scale = 4 * (i64::from(base_10_000_digit_count) - i64::from(weight) - 1);
610        let size = i64::try_from(u8_digits.len())? + i64::from(scale) - value_scale;
611        u8_digits.resize(size as usize, 0);
612
613        let sign = match sign {
614            0x4000 => Sign::Minus,
615            0x0000 => Sign::Plus,
616            _ => {
617                return Err(Box::new(DataFusionError::Execution(
618                    "Failed to parse big decimal from postgres numeric value".to_string(),
619                )));
620            }
621        };
622
623        let Some(digits) = BigInt::from_radix_be(sign, u8_digits.as_slice(), 10) else {
624            return Err(Box::new(DataFusionError::Execution(
625                "Failed to parse big decimal from postgres numeric value".to_string(),
626            )));
627        };
628        Ok(BigDecimalFromSql {
629            inner: BigDecimal::new(digits, i64::from(scale)),
630        })
631    }
632
633    fn accepts(ty: &Type) -> bool {
634        matches!(*ty, Type::NUMERIC)
635    }
636}
637
638// interval_send - Postgres C (https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1032)
639// interval values are internally stored as three integral fields: months, days, and microseconds
640#[derive(Debug)]
641struct IntervalFromSql {
642    time: i64,
643    day: i32,
644    month: i32,
645}
646
647impl<'a> FromSql<'a> for IntervalFromSql {
648    fn from_sql(
649        _ty: &Type,
650        raw: &'a [u8],
651    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
652        let mut cursor = std::io::Cursor::new(raw);
653
654        let time = cursor.read_i64::<BigEndian>()?;
655        let day = cursor.read_i32::<BigEndian>()?;
656        let month = cursor.read_i32::<BigEndian>()?;
657
658        Ok(IntervalFromSql { time, day, month })
659    }
660
661    fn accepts(ty: &Type) -> bool {
662        matches!(*ty, Type::INTERVAL)
663    }
664}
665
666struct GeometryFromSql<'a> {
667    wkb: &'a [u8],
668}
669
670impl<'a> FromSql<'a> for GeometryFromSql<'a> {
671    fn from_sql(
672        _ty: &Type,
673        raw: &'a [u8],
674    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
675        Ok(GeometryFromSql { wkb: raw })
676    }
677
678    fn accepts(ty: &Type) -> bool {
679        matches!(ty.name(), "geometry")
680    }
681}
682
683struct XmlFromSql<'a> {
684    xml: &'a str,
685}
686
687impl<'a> FromSql<'a> for XmlFromSql<'a> {
688    fn from_sql(
689        _ty: &Type,
690        raw: &'a [u8],
691    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
692        let xml = str::from_utf8(raw)?;
693        Ok(XmlFromSql { xml })
694    }
695
696    fn accepts(ty: &Type) -> bool {
697        matches!(*ty, Type::XML)
698    }
699}
700
701fn rows_to_batch(
702    rows: &[Row],
703    table_schema: &SchemaRef,
704    projection: Option<&Vec<usize>>,
705) -> DFResult<RecordBatch> {
706    let projected_schema = project_schema(table_schema, projection)?;
707    let mut array_builders = vec![];
708    for field in table_schema.fields() {
709        let builder = make_builder(field.data_type(), rows.len());
710        array_builders.push(builder);
711    }
712
713    for row in rows {
714        for (idx, field) in table_schema.fields.iter().enumerate() {
715            if !projections_contains(projection, idx) {
716                continue;
717            }
718            let builder = &mut array_builders[idx];
719            let col = row.columns().get(idx);
720            match field.data_type() {
721                DataType::Int16 => {
722                    handle_primitive_type!(
723                        builder,
724                        field,
725                        col,
726                        Int16Builder,
727                        i16,
728                        row,
729                        idx,
730                        just_return
731                    );
732                }
733                DataType::Int32 => {
734                    handle_primitive_type!(
735                        builder,
736                        field,
737                        col,
738                        Int32Builder,
739                        i32,
740                        row,
741                        idx,
742                        just_return
743                    );
744                }
745                DataType::UInt32 => {
746                    handle_primitive_type!(
747                        builder,
748                        field,
749                        col,
750                        UInt32Builder,
751                        u32,
752                        row,
753                        idx,
754                        just_return
755                    );
756                }
757                DataType::Int64 => {
758                    handle_primitive_type!(
759                        builder,
760                        field,
761                        col,
762                        Int64Builder,
763                        i64,
764                        row,
765                        idx,
766                        just_return
767                    );
768                }
769                DataType::Float32 => {
770                    handle_primitive_type!(
771                        builder,
772                        field,
773                        col,
774                        Float32Builder,
775                        f32,
776                        row,
777                        idx,
778                        just_return
779                    );
780                }
781                DataType::Float64 => {
782                    handle_primitive_type!(
783                        builder,
784                        field,
785                        col,
786                        Float64Builder,
787                        f64,
788                        row,
789                        idx,
790                        just_return
791                    );
792                }
793                DataType::Decimal128(_precision, scale) => {
794                    handle_primitive_type!(
795                        builder,
796                        field,
797                        col,
798                        Decimal128Builder,
799                        BigDecimalFromSql,
800                        row,
801                        idx,
802                        |v: BigDecimalFromSql| { v.to_i128_with_scale(*scale as i32) }
803                    );
804                }
805                DataType::Decimal256(_precision, scale) => {
806                    handle_primitive_type!(
807                        builder,
808                        field,
809                        col,
810                        Decimal256Builder,
811                        BigDecimalFromSql,
812                        row,
813                        idx,
814                        |v: BigDecimalFromSql| { v.to_i256_with_scale(*scale as i32) }
815                    );
816                }
817                DataType::Utf8 => {
818                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("xml") {
819                        let convert: for<'a> fn(XmlFromSql<'a>) -> DFResult<&'a str> =
820                            |v| Ok(v.xml);
821                        handle_primitive_type!(
822                            builder,
823                            field,
824                            col,
825                            StringBuilder,
826                            XmlFromSql,
827                            row,
828                            idx,
829                            convert
830                        );
831                    } else {
832                        handle_primitive_type!(
833                            builder,
834                            field,
835                            col,
836                            StringBuilder,
837                            &str,
838                            row,
839                            idx,
840                            just_return
841                        );
842                    }
843                }
844                DataType::LargeUtf8 => {
845                    if col.is_some() && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB) {
846                        handle_primitive_type!(
847                            builder,
848                            field,
849                            col,
850                            LargeStringBuilder,
851                            serde_json::value::Value,
852                            row,
853                            idx,
854                            |v: serde_json::value::Value| {
855                                Ok::<_, DataFusionError>(v.to_string())
856                            }
857                        );
858                    } else {
859                        handle_primitive_type!(
860                            builder,
861                            field,
862                            col,
863                            LargeStringBuilder,
864                            &str,
865                            row,
866                            idx,
867                            just_return
868                        );
869                    }
870                }
871                DataType::Binary => {
872                    if col.is_some() && col.unwrap().type_().name().eq_ignore_ascii_case("geometry")
873                    {
874                        let convert: for<'a> fn(GeometryFromSql<'a>) -> DFResult<&'a [u8]> =
875                            |v| Ok(v.wkb);
876                        handle_primitive_type!(
877                            builder,
878                            field,
879                            col,
880                            BinaryBuilder,
881                            GeometryFromSql,
882                            row,
883                            idx,
884                            convert
885                        );
886                    } else if col.is_some()
887                        && matches!(col.unwrap().type_(), &Type::JSON | &Type::JSONB)
888                    {
889                        handle_primitive_type!(
890                            builder,
891                            field,
892                            col,
893                            BinaryBuilder,
894                            serde_json::value::Value,
895                            row,
896                            idx,
897                            |v: serde_json::value::Value| {
898                                Ok::<_, DataFusionError>(v.to_string().into_bytes())
899                            }
900                        );
901                    } else {
902                        handle_primitive_type!(
903                            builder,
904                            field,
905                            col,
906                            BinaryBuilder,
907                            Vec<u8>,
908                            row,
909                            idx,
910                            just_return
911                        );
912                    }
913                }
914                DataType::FixedSizeBinary(_) => {
915                    let builder = builder
916                        .as_any_mut()
917                        .downcast_mut::<FixedSizeBinaryBuilder>()
918                        .unwrap_or_else(|| {
919                            panic!(
920                                "Failed to downcast builder to FixedSizeBinaryBuilder for {field:?}"
921                            )
922                        });
923                    let v = if col.is_some()
924                        && col.unwrap().type_().name().eq_ignore_ascii_case("uuid")
925                    {
926                        let v: Option<Uuid> = row.try_get(idx).map_err(|e| {
927                            DataFusionError::Execution(format!(
928                                "Failed to get Uuid value for field {:?}: {e:?}",
929                                field
930                            ))
931                        })?;
932                        v.map(|v| v.as_bytes().to_vec())
933                    } else {
934                        let v: Option<Vec<u8>> = row.try_get(idx).map_err(|e| {
935                            DataFusionError::Execution(format!(
936                                "Failed to get FixedSizeBinary value for field {:?}: {e:?}",
937                                field
938                            ))
939                        })?;
940                        v
941                    };
942
943                    match v {
944                        Some(v) => builder.append_value(v)?,
945                        None => builder.append_null(),
946                    }
947                }
948                DataType::Timestamp(TimeUnit::Microsecond, None) => {
949                    handle_primitive_type!(
950                        builder,
951                        field,
952                        col,
953                        TimestampMicrosecondBuilder,
954                        chrono::NaiveDateTime,
955                        row,
956                        idx,
957                        |v: chrono::NaiveDateTime| {
958                            let timestamp: i64 = v.and_utc().timestamp_micros();
959
960                            Ok::<i64, DataFusionError>(timestamp)
961                        }
962                    );
963                }
964                DataType::Timestamp(TimeUnit::Microsecond, Some(_tz)) => {
965                    handle_primitive_type!(
966                        builder,
967                        field,
968                        col,
969                        TimestampMicrosecondBuilder,
970                        chrono::DateTime<chrono::Utc>,
971                        row,
972                        idx,
973                        |v: chrono::DateTime<chrono::Utc>| {
974                            let timestamp: i64 = v.timestamp_micros();
975                            Ok::<_, DataFusionError>(timestamp)
976                        }
977                    );
978                }
979                DataType::Timestamp(TimeUnit::Nanosecond, None) => {
980                    handle_primitive_type!(
981                        builder,
982                        field,
983                        col,
984                        TimestampNanosecondBuilder,
985                        chrono::NaiveDateTime,
986                        row,
987                        idx,
988                        |v: chrono::NaiveDateTime| {
989                            let timestamp: i64 = v.and_utc().timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
990                            Ok::<i64, DataFusionError>(timestamp)
991                        }
992                    );
993                }
994                DataType::Timestamp(TimeUnit::Nanosecond, Some(_tz)) => {
995                    handle_primitive_type!(
996                        builder,
997                        field,
998                        col,
999                        TimestampNanosecondBuilder,
1000                        chrono::DateTime<chrono::Utc>,
1001                        row,
1002                        idx,
1003                        |v: chrono::DateTime<chrono::Utc>| {
1004                            let timestamp: i64 = v.timestamp_nanos_opt().unwrap_or_else(|| panic!("Failed to get timestamp in nanoseconds from {v} for {field:?} and {col:?}"));
1005                            Ok::<_, DataFusionError>(timestamp)
1006                        }
1007                    );
1008                }
1009                DataType::Time64(TimeUnit::Microsecond) => {
1010                    handle_primitive_type!(
1011                        builder,
1012                        field,
1013                        col,
1014                        Time64MicrosecondBuilder,
1015                        chrono::NaiveTime,
1016                        row,
1017                        idx,
1018                        |v: chrono::NaiveTime| {
1019                            let seconds = i64::from(v.num_seconds_from_midnight());
1020                            let microseconds = i64::from(v.nanosecond()) / 1000;
1021                            Ok::<_, DataFusionError>(seconds * 1_000_000 + microseconds)
1022                        }
1023                    );
1024                }
1025                DataType::Time64(TimeUnit::Nanosecond) => {
1026                    handle_primitive_type!(
1027                        builder,
1028                        field,
1029                        col,
1030                        Time64NanosecondBuilder,
1031                        chrono::NaiveTime,
1032                        row,
1033                        idx,
1034                        |v: chrono::NaiveTime| {
1035                            let timestamp: i64 = i64::from(v.num_seconds_from_midnight())
1036                                * 1_000_000_000
1037                                + i64::from(v.nanosecond());
1038                            Ok::<_, DataFusionError>(timestamp)
1039                        }
1040                    );
1041                }
1042                DataType::Date32 => {
1043                    handle_primitive_type!(
1044                        builder,
1045                        field,
1046                        col,
1047                        Date32Builder,
1048                        chrono::NaiveDate,
1049                        row,
1050                        idx,
1051                        |v| { Ok::<_, DataFusionError>(Date32Type::from_naive_date(v)) }
1052                    );
1053                }
1054                DataType::Interval(IntervalUnit::MonthDayNano) => {
1055                    handle_primitive_type!(
1056                        builder,
1057                        field,
1058                        col,
1059                        IntervalMonthDayNanoBuilder,
1060                        IntervalFromSql,
1061                        row,
1062                        idx,
1063                        |v: IntervalFromSql| {
1064                            let interval_month_day_nano = IntervalMonthDayNanoType::make_value(
1065                                v.month,
1066                                v.day,
1067                                v.time * 1_000,
1068                            );
1069                            Ok::<_, DataFusionError>(interval_month_day_nano)
1070                        }
1071                    );
1072                }
1073                DataType::Boolean => {
1074                    handle_primitive_type!(
1075                        builder,
1076                        field,
1077                        col,
1078                        BooleanBuilder,
1079                        bool,
1080                        row,
1081                        idx,
1082                        just_return
1083                    );
1084                }
1085                DataType::List(inner) => match inner.data_type() {
1086                    DataType::Int16 => {
1087                        handle_primitive_array_type!(
1088                            builder,
1089                            field,
1090                            col,
1091                            Int16Builder,
1092                            i16,
1093                            row,
1094                            idx
1095                        );
1096                    }
1097                    DataType::Int32 => {
1098                        handle_primitive_array_type!(
1099                            builder,
1100                            field,
1101                            col,
1102                            Int32Builder,
1103                            i32,
1104                            row,
1105                            idx
1106                        );
1107                    }
1108                    DataType::Int64 => {
1109                        handle_primitive_array_type!(
1110                            builder,
1111                            field,
1112                            col,
1113                            Int64Builder,
1114                            i64,
1115                            row,
1116                            idx
1117                        );
1118                    }
1119                    DataType::Float32 => {
1120                        handle_primitive_array_type!(
1121                            builder,
1122                            field,
1123                            col,
1124                            Float32Builder,
1125                            f32,
1126                            row,
1127                            idx
1128                        );
1129                    }
1130                    DataType::Float64 => {
1131                        handle_primitive_array_type!(
1132                            builder,
1133                            field,
1134                            col,
1135                            Float64Builder,
1136                            f64,
1137                            row,
1138                            idx
1139                        );
1140                    }
1141                    DataType::Utf8 => {
1142                        handle_primitive_array_type!(
1143                            builder,
1144                            field,
1145                            col,
1146                            StringBuilder,
1147                            &str,
1148                            row,
1149                            idx
1150                        );
1151                    }
1152                    DataType::Binary => {
1153                        handle_primitive_array_type!(
1154                            builder,
1155                            field,
1156                            col,
1157                            BinaryBuilder,
1158                            Vec<u8>,
1159                            row,
1160                            idx
1161                        );
1162                    }
1163                    DataType::Boolean => {
1164                        handle_primitive_array_type!(
1165                            builder,
1166                            field,
1167                            col,
1168                            BooleanBuilder,
1169                            bool,
1170                            row,
1171                            idx
1172                        );
1173                    }
1174                    _ => {
1175                        return Err(DataFusionError::NotImplemented(format!(
1176                            "Unsupported list data type {} for col: {:?}",
1177                            field.data_type(),
1178                            col
1179                        )));
1180                    }
1181                },
1182                _ => {
1183                    return Err(DataFusionError::NotImplemented(format!(
1184                        "Unsupported data type {} for col: {:?}",
1185                        field.data_type(),
1186                        col
1187                    )));
1188                }
1189            }
1190        }
1191    }
1192    let projected_columns = array_builders
1193        .into_iter()
1194        .enumerate()
1195        .filter(|(idx, _)| projections_contains(projection, *idx))
1196        .map(|(_, mut builder)| builder.finish())
1197        .collect::<Vec<ArrayRef>>();
1198    let options = RecordBatchOptions::new().with_row_count(Some(rows.len()));
1199    Ok(RecordBatch::try_new_with_options(
1200        projected_schema,
1201        projected_columns,
1202        &options,
1203    )?)
1204}