datafusion_table_providers/sql/arrow_sql_gen/postgres/
schema.rs

1use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
2use arrow::error::ArrowError;
3use serde_json::json;
4use serde_json::Value;
5use std::sync::Arc;
6
7use crate::UnsupportedTypeAction;
8
9#[derive(Debug, Clone)]
10pub(crate) struct ParseContext {
11    pub(crate) unsupported_type_action: UnsupportedTypeAction,
12    pub(crate) type_details: Option<serde_json::Value>,
13}
14
15impl ParseContext {
16    pub(crate) fn new() -> Self {
17        Self {
18            unsupported_type_action: UnsupportedTypeAction::Error,
19            type_details: None,
20        }
21    }
22
23    pub(crate) fn with_unsupported_type_action(
24        mut self,
25        unsupported_type_action: UnsupportedTypeAction,
26    ) -> Self {
27        self.unsupported_type_action = unsupported_type_action;
28        self
29    }
30
31    pub(crate) fn with_type_details(mut self, type_details: serde_json::Value) -> Self {
32        self.type_details = Some(type_details);
33        self
34    }
35}
36
37impl Default for ParseContext {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43pub(crate) fn pg_data_type_to_arrow_type(
44    pg_type: &str,
45    context: &ParseContext,
46) -> Result<DataType, ArrowError> {
47    let base_type = pg_type.split('(').next().unwrap_or(pg_type).trim();
48
49    match base_type {
50        "smallint" => Ok(DataType::Int16),
51        "integer" | "int" | "int4" => Ok(DataType::Int32),
52        "bigint" | "int8" | "money" => Ok(DataType::Int64),
53        "oid" | "xid" | "regproc" => Ok(DataType::UInt32),
54        "numeric" | "decimal" => {
55            let (precision, scale) = parse_numeric_type(pg_type)?;
56            Ok(DataType::Decimal128(precision, scale))
57        }
58        "real" | "float4" => Ok(DataType::Float32),
59        "double precision" | "float8" => Ok(DataType::Float64),
60        "\"char\"" => Ok(DataType::Int8),
61        "character" | "char" | "character varying" | "varchar" | "text" | "bpchar" | "uuid"
62        | "name" => Ok(DataType::Utf8),
63        "bytea" => Ok(DataType::Binary),
64        "date" => Ok(DataType::Date32),
65        "time" | "time without time zone" => Ok(DataType::Time64(TimeUnit::Nanosecond)),
66        "timestamp" | "timestamp without time zone" => {
67            Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
68        }
69        "timestamp with time zone" | "timestamptz" => Ok(DataType::Timestamp(
70            TimeUnit::Nanosecond,
71            Some("UTC".into()),
72        )),
73        "interval" => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
74        "boolean" => Ok(DataType::Boolean),
75        "enum" => Ok(DataType::Dictionary(
76            Box::new(DataType::Int8),
77            Box::new(DataType::Utf8),
78        )),
79        "point" => Ok(DataType::FixedSizeList(
80            Arc::new(Field::new("item", DataType::Float64, true)),
81            2,
82        )),
83        "line" | "lseg" | "box" | "path" | "polygon" | "circle" => Ok(DataType::Binary),
84        "inet" | "cidr" | "macaddr" => Ok(DataType::Utf8),
85        "bit" | "bit varying" => Ok(DataType::Binary),
86        "tsvector" | "tsquery" => Ok(DataType::LargeUtf8),
87        "xml" | "json" => Ok(DataType::Utf8),
88        "aclitem" | "pg_node_tree" => Ok(DataType::Utf8),
89        "array" => parse_array_type(context),
90        "anyarray" => Ok(DataType::List(Arc::new(Field::new(
91            "item",
92            DataType::Binary,
93            true,
94        )))),
95        "int4range" => Ok(DataType::Struct(Fields::from(vec![
96            Field::new("lower", DataType::Int32, true),
97            Field::new("upper", DataType::Int32, true),
98        ]))),
99        "composite" => parse_composite_type(context),
100        "geometry" | "geography" => Ok(DataType::Binary),
101
102        // `jsonb` is currently not supported, but if the user has set the `UnsupportedTypeAction` to `String` we'll return `Utf8`.
103        "jsonb" if context.unsupported_type_action == UnsupportedTypeAction::String => {
104            Ok(DataType::Utf8)
105        }
106        _ => Err(ArrowError::ParseError(format!(
107            "Unsupported PostgreSQL type: {}",
108            pg_type
109        ))),
110    }
111}
112
113fn parse_array_type(context: &ParseContext) -> Result<DataType, ArrowError> {
114    let details = context
115        .type_details
116        .as_ref()
117        .ok_or_else(|| ArrowError::ParseError("Missing type details for array type".to_string()))?;
118    let details = details
119        .as_object()
120        .ok_or_else(|| ArrowError::ParseError("Invalid array type details format".to_string()))?;
121    let element_type = details
122        .get("element_type")
123        .and_then(Value::as_str)
124        .ok_or_else(|| {
125            ArrowError::ParseError("Missing or invalid element_type for array".to_string())
126        })?;
127
128    let inner_type = if element_type.ends_with("[]") {
129        let inner_context = context.clone().with_type_details(json!({
130            "type": "array",
131            "element_type": element_type.trim_end_matches("[]"),
132        }));
133        parse_array_type(&inner_context)?
134    } else {
135        pg_data_type_to_arrow_type(element_type, context)?
136    };
137
138    Ok(DataType::List(Arc::new(Field::new(
139        "item", inner_type, true,
140    ))))
141}
142
143fn parse_composite_type(context: &ParseContext) -> Result<DataType, ArrowError> {
144    let details = context.type_details.as_ref().ok_or_else(|| {
145        ArrowError::ParseError("Missing type details for composite type".to_string())
146    })?;
147    let details = details.as_object().ok_or_else(|| {
148        ArrowError::ParseError("Invalid composite type details format".to_string())
149    })?;
150    let attributes = details
151        .get("attributes")
152        .and_then(Value::as_array)
153        .ok_or_else(|| {
154            ArrowError::ParseError("Missing or invalid attributes for composite type".to_string())
155        })?;
156
157    let fields: Result<Vec<Field>, ArrowError> = attributes
158        .iter()
159        .map(|attr| {
160            let attr_obj = attr.as_object().ok_or_else(|| {
161                ArrowError::ParseError("Invalid attribute format in composite type".to_string())
162            })?;
163            let name = attr_obj
164                .get("name")
165                .and_then(Value::as_str)
166                .ok_or_else(|| {
167                    ArrowError::ParseError(
168                        "Missing or invalid name in composite type attribute".to_string(),
169                    )
170                })?;
171            let attr_type = attr_obj
172                .get("type")
173                .and_then(Value::as_str)
174                .ok_or_else(|| {
175                    ArrowError::ParseError(
176                        "Missing or invalid type in composite type attribute".to_string(),
177                    )
178                })?;
179            let field_type = if attr_type == "composite" {
180                let inner_context = context.clone().with_type_details(attr.clone());
181                parse_composite_type(&inner_context)?
182            } else {
183                pg_data_type_to_arrow_type(attr_type, context)?
184            };
185            Ok(Field::new(name, field_type, true))
186        })
187        .collect();
188
189    Ok(DataType::Struct(Fields::from(fields?)))
190}
191
192fn parse_numeric_type(pg_type: &str) -> Result<(u8, i8), ArrowError> {
193    let type_str = pg_type
194        .trim_start_matches("numeric")
195        .trim_start_matches("decimal")
196        .trim();
197
198    if type_str.is_empty() || type_str == "()" {
199        return Ok((38, 20)); // Default precision and scale if not specified
200    }
201
202    let parts: Vec<&str> = type_str
203        .trim_start_matches('(')
204        .trim_end_matches(')')
205        .split(',')
206        .collect();
207
208    match parts.len() {
209        1 => {
210            let precision = parts[0]
211                .trim()
212                .parse::<u8>()
213                .map_err(|_| ArrowError::ParseError("Invalid numeric precision".to_string()))?;
214            Ok((precision, 0))
215        }
216        2 => {
217            let precision = parts[0]
218                .trim()
219                .parse::<u8>()
220                .map_err(|_| ArrowError::ParseError("Invalid numeric precision".to_string()))?;
221            let scale = parts[1]
222                .trim()
223                .parse::<i8>()
224                .map_err(|_| ArrowError::ParseError("Invalid numeric scale".to_string()))?;
225            Ok((precision, scale))
226        }
227        _ => Err(ArrowError::ParseError(
228            "Invalid numeric type format".to_string(),
229        )),
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_pg_data_type_to_arrow_type() {
239        let context = ParseContext::new();
240        // Test basic types
241        assert_eq!(
242            pg_data_type_to_arrow_type("smallint", &context).expect("Failed to convert smallint"),
243            DataType::Int16
244        );
245        assert_eq!(
246            pg_data_type_to_arrow_type("integer", &context).expect("Failed to convert integer"),
247            DataType::Int32
248        );
249        assert_eq!(
250            pg_data_type_to_arrow_type("bigint", &context).expect("Failed to convert bigint"),
251            DataType::Int64
252        );
253        assert_eq!(
254            pg_data_type_to_arrow_type("real", &context).expect("Failed to convert real"),
255            DataType::Float32
256        );
257        assert_eq!(
258            pg_data_type_to_arrow_type("double precision", &context)
259                .expect("Failed to convert double precision"),
260            DataType::Float64
261        );
262        assert_eq!(
263            pg_data_type_to_arrow_type("boolean", &context).expect("Failed to convert boolean"),
264            DataType::Boolean
265        );
266        assert_eq!(
267            pg_data_type_to_arrow_type("\"char\"", &context)
268                .expect("Failed to convert single character"),
269            DataType::Int8
270        );
271
272        // Test string types
273        assert_eq!(
274            pg_data_type_to_arrow_type("character", &context).expect("Failed to convert character"),
275            DataType::Utf8
276        );
277        assert_eq!(
278            pg_data_type_to_arrow_type("character varying", &context)
279                .expect("Failed to convert character varying"),
280            DataType::Utf8
281        );
282        assert_eq!(
283            pg_data_type_to_arrow_type("name", &context).expect("Failed to convert name"),
284            DataType::Utf8
285        );
286        assert_eq!(
287            pg_data_type_to_arrow_type("text", &context).expect("Failed to convert text"),
288            DataType::Utf8
289        );
290
291        // Test date/time types
292        assert_eq!(
293            pg_data_type_to_arrow_type("date", &context).expect("Failed to convert date"),
294            DataType::Date32
295        );
296        assert_eq!(
297            pg_data_type_to_arrow_type("time without time zone", &context)
298                .expect("Failed to convert time without time zone"),
299            DataType::Time64(TimeUnit::Nanosecond)
300        );
301        assert_eq!(
302            pg_data_type_to_arrow_type("timestamp without time zone", &context)
303                .expect("Failed to convert timestamp without time zone"),
304            DataType::Timestamp(TimeUnit::Nanosecond, None)
305        );
306        assert_eq!(
307            pg_data_type_to_arrow_type("timestamp with time zone", &context)
308                .expect("Failed to convert timestamp with time zone"),
309            DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
310        );
311        assert_eq!(
312            pg_data_type_to_arrow_type("interval", &context).expect("Failed to convert interval"),
313            DataType::Interval(IntervalUnit::MonthDayNano)
314        );
315
316        // Test numeric types
317        assert_eq!(
318            pg_data_type_to_arrow_type("numeric", &context).expect("Failed to convert numeric"),
319            DataType::Decimal128(38, 20)
320        );
321        assert_eq!(
322            pg_data_type_to_arrow_type("numeric()", &context).expect("Failed to convert numeric()"),
323            DataType::Decimal128(38, 20)
324        );
325        assert_eq!(
326            pg_data_type_to_arrow_type("numeric(10,2)", &context)
327                .expect("Failed to convert numeric(10,2)"),
328            DataType::Decimal128(10, 2)
329        );
330
331        // Test array type
332        let array_type_context = context.clone().with_type_details(json!({
333            "type": "array",
334            "element_type": "integer",
335        }));
336        assert_eq!(
337            pg_data_type_to_arrow_type("array", &array_type_context)
338                .expect("Failed to convert array"),
339            DataType::List(Arc::new(Field::new("item", DataType::Int32, true)))
340        );
341
342        // Test composite type
343        let composite_type_context = context.clone().with_type_details(json!({
344            "type": "composite",
345            "attributes": [
346                {"name": "x", "type": "integer"},
347                {"name": "y", "type": "text"}
348            ]
349        }));
350        assert_eq!(
351            pg_data_type_to_arrow_type("composite", &composite_type_context)
352                .expect("Failed to convert composite"),
353            DataType::Struct(Fields::from(vec![
354                Field::new("x", DataType::Int32, true),
355                Field::new("y", DataType::Utf8, true)
356            ]))
357        );
358
359        // Test unsupported type
360        assert!(pg_data_type_to_arrow_type("unsupported_type", &context).is_err());
361    }
362
363    #[test]
364    fn test_parse_numeric_type() {
365        assert_eq!(
366            parse_numeric_type("numeric").expect("Failed to parse numeric"),
367            (38, 20)
368        );
369        assert_eq!(
370            parse_numeric_type("numeric()").expect("Failed to parse numeric()"),
371            (38, 20)
372        );
373        assert_eq!(
374            parse_numeric_type("numeric(10)").expect("Failed to parse numeric(10)"),
375            (10, 0)
376        );
377        assert_eq!(
378            parse_numeric_type("numeric(10,2)").expect("Failed to parse numeric(10,2)"),
379            (10, 2)
380        );
381        assert_eq!(
382            parse_numeric_type("decimal").expect("Failed to parse decimal"),
383            (38, 20)
384        );
385        assert_eq!(
386            parse_numeric_type("decimal()").expect("Failed to parse decimal()"),
387            (38, 20)
388        );
389        assert_eq!(
390            parse_numeric_type("decimal(15)").expect("Failed to parse decimal(15)"),
391            (15, 0)
392        );
393        assert_eq!(
394            parse_numeric_type("decimal(15,5)").expect("Failed to parse decimal(15,5)"),
395            (15, 5)
396        );
397
398        // Test invalid formats
399        assert!(parse_numeric_type("numeric(invalid)").is_err());
400        assert!(parse_numeric_type("numeric(10,2,3)").is_err());
401        assert!(parse_numeric_type("numeric(,)").is_err());
402    }
403
404    #[test]
405    fn test_pg_data_type_to_arrow_type_with_size() {
406        let context = ParseContext::new();
407        assert_eq!(
408            pg_data_type_to_arrow_type("character(10)", &context)
409                .expect("Failed to convert character(10)"),
410            DataType::Utf8
411        );
412        assert_eq!(
413            pg_data_type_to_arrow_type("character varying(255)", &context)
414                .expect("Failed to convert character varying(255)"),
415            DataType::Utf8
416        );
417        assert_eq!(
418            pg_data_type_to_arrow_type("bit(8)", &context).expect("Failed to convert bit(8)"),
419            DataType::Binary
420        );
421        assert_eq!(
422            pg_data_type_to_arrow_type("bit varying(64)", &context)
423                .expect("Failed to convert bit varying(64)"),
424            DataType::Binary
425        );
426        assert_eq!(
427            pg_data_type_to_arrow_type("numeric(10,2)", &context)
428                .expect("Failed to convert numeric(10,2)"),
429            DataType::Decimal128(10, 2)
430        );
431    }
432
433    #[test]
434    fn test_pg_data_type_to_arrow_type_extended() {
435        let context = ParseContext::new();
436        // Test additional numeric types
437        assert_eq!(
438            pg_data_type_to_arrow_type("numeric(38,10)", &context)
439                .expect("Failed to convert numeric(38,10)"),
440            DataType::Decimal128(38, 10)
441        );
442        assert_eq!(
443            pg_data_type_to_arrow_type("decimal(5,0)", &context)
444                .expect("Failed to convert decimal(5,0)"),
445            DataType::Decimal128(5, 0)
446        );
447
448        // Test time types with precision
449        assert_eq!(
450            pg_data_type_to_arrow_type("time(6) without time zone", &context)
451                .expect("Failed to convert time(6) without time zone"),
452            DataType::Time64(TimeUnit::Nanosecond)
453        );
454
455        // Test array types
456        let nested_array_type_details = context.clone().with_type_details(json!({
457            "type": "array",
458            "element_type": "integer[]",
459        }));
460        assert_eq!(
461            pg_data_type_to_arrow_type("array", &nested_array_type_details)
462                .expect("Failed to convert nested array"),
463            DataType::List(Arc::new(Field::new(
464                "item",
465                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
466                true
467            )))
468        );
469
470        // Test enum type
471        let enum_type_details = context.clone().with_type_details(json!({
472            "type": "enum",
473            "values": ["small", "medium", "large"]
474        }));
475        assert_eq!(
476            pg_data_type_to_arrow_type("enum", &enum_type_details).expect("Failed to convert enum"),
477            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8))
478        );
479
480        // Test geometric types
481        assert_eq!(
482            pg_data_type_to_arrow_type("point", &context).expect("Failed to convert point"),
483            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2)
484        );
485        assert_eq!(
486            pg_data_type_to_arrow_type("line", &context).expect("Failed to convert line"),
487            DataType::Binary
488        );
489
490        // Test network address types
491        assert_eq!(
492            pg_data_type_to_arrow_type("inet", &context).expect("Failed to convert inet"),
493            DataType::Utf8
494        );
495        assert_eq!(
496            pg_data_type_to_arrow_type("cidr", &context).expect("Failed to convert cidr"),
497            DataType::Utf8
498        );
499
500        // Test range types
501        assert_eq!(
502            pg_data_type_to_arrow_type("int4range", &context).expect("Failed to convert int4range"),
503            DataType::Struct(Fields::from(vec![
504                Field::new("lower", DataType::Int32, true),
505                Field::new("upper", DataType::Int32, true),
506            ]))
507        );
508
509        // Test JSON types
510        assert_eq!(
511            pg_data_type_to_arrow_type("json", &context).expect("Failed to convert json"),
512            DataType::Utf8
513        );
514
515        let jsonb_context = context
516            .clone()
517            .with_unsupported_type_action(UnsupportedTypeAction::String);
518        assert_eq!(
519            pg_data_type_to_arrow_type("jsonb", &jsonb_context).expect("Failed to convert jsonb"),
520            DataType::Utf8
521        );
522
523        // Test UUID type
524        assert_eq!(
525            pg_data_type_to_arrow_type("uuid", &context).expect("Failed to convert uuid"),
526            DataType::Utf8
527        );
528
529        // Test text search types
530        assert_eq!(
531            pg_data_type_to_arrow_type("tsvector", &context).expect("Failed to convert tsvector"),
532            DataType::LargeUtf8
533        );
534        assert_eq!(
535            pg_data_type_to_arrow_type("tsquery", &context).expect("Failed to convert tsquery"),
536            DataType::LargeUtf8
537        );
538
539        // Test bpchar type
540        assert_eq!(
541            pg_data_type_to_arrow_type("bpchar", &context).expect("Failed to convert bpchar"),
542            DataType::Utf8
543        );
544
545        // Test bpchar with length specification
546        assert_eq!(
547            pg_data_type_to_arrow_type("bpchar(10)", &context)
548                .expect("Failed to convert bpchar(10)"),
549            DataType::Utf8
550        );
551    }
552
553    #[test]
554    fn test_parse_array_type_extended() {
555        let context = ParseContext::new();
556        let single_dim_array = context.clone().with_type_details(json!({
557            "type": "array",
558            "element_type": "integer",
559        }));
560        assert_eq!(
561            parse_array_type(&single_dim_array).expect("Failed to parse single dimension array"),
562            DataType::List(Arc::new(Field::new("item", DataType::Int32, true)))
563        );
564
565        let multi_dim_array = context.clone().with_type_details(json!({
566            "type": "array",
567            "element_type": "text[]",
568        }));
569        assert_eq!(
570            parse_array_type(&multi_dim_array).expect("Failed to parse multi-dimension array"),
571            DataType::List(Arc::new(Field::new(
572                "item",
573                DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
574                true
575            )))
576        );
577
578        let invalid_array = context.clone().with_type_details(json!({"type": "array"}));
579        assert!(parse_array_type(&invalid_array).is_err());
580    }
581
582    #[test]
583    fn test_parse_composite_type_extended() {
584        let context = ParseContext::new();
585        let simple_composite = context.clone().with_type_details(json!({
586            "type": "composite",
587            "attributes": [
588                {"name": "id", "type": "integer"},
589                {"name": "name", "type": "text"},
590                {"name": "active", "type": "boolean"}
591            ]
592        }));
593        assert_eq!(
594            parse_composite_type(&simple_composite).expect("Failed to parse simple composite type"),
595            DataType::Struct(Fields::from(vec![
596                Field::new("id", DataType::Int32, true),
597                Field::new("name", DataType::Utf8, true),
598                Field::new("active", DataType::Boolean, true),
599            ]))
600        );
601
602        let nested_composite = context.clone().with_type_details(json!({
603            "type": "composite",
604            "attributes": [
605                {"name": "id", "type": "integer"},
606                {"name": "details", "type": "composite", "attributes": [
607                    {"name": "x", "type": "float8"},
608                    {"name": "y", "type": "float8"}
609                ]}
610            ]
611        }));
612        assert_eq!(
613            parse_composite_type(&nested_composite).expect("Failed to parse nested composite type"),
614            DataType::Struct(Fields::from(vec![
615                Field::new("id", DataType::Int32, true),
616                Field::new(
617                    "details",
618                    DataType::Struct(Fields::from(vec![
619                        Field::new("x", DataType::Float64, true),
620                        Field::new("y", DataType::Float64, true),
621                    ])),
622                    true
623                ),
624            ]))
625        );
626
627        let invalid_composite = context.clone().with_type_details(json!({
628            "type": "composite",
629        }));
630        assert!(parse_composite_type(&invalid_composite).is_err());
631    }
632}