1use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
2use arrow::error::ArrowError;
3use serde_json::json;
4use serde_json::Value;
5use std::sync::Arc;
6
7use crate::UnsupportedTypeAction;
8
9#[derive(Debug, Clone)]
10pub(crate) struct ParseContext {
11 pub(crate) unsupported_type_action: UnsupportedTypeAction,
12 pub(crate) type_details: Option<serde_json::Value>,
13}
14
15impl ParseContext {
16 pub(crate) fn new() -> Self {
17 Self {
18 unsupported_type_action: UnsupportedTypeAction::Error,
19 type_details: None,
20 }
21 }
22
23 pub(crate) fn with_unsupported_type_action(
24 mut self,
25 unsupported_type_action: UnsupportedTypeAction,
26 ) -> Self {
27 self.unsupported_type_action = unsupported_type_action;
28 self
29 }
30
31 pub(crate) fn with_type_details(mut self, type_details: serde_json::Value) -> Self {
32 self.type_details = Some(type_details);
33 self
34 }
35}
36
37impl Default for ParseContext {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43pub(crate) fn pg_data_type_to_arrow_type(
44 pg_type: &str,
45 context: &ParseContext,
46) -> Result<DataType, ArrowError> {
47 let base_type = pg_type.split('(').next().unwrap_or(pg_type).trim();
48
49 match base_type {
50 "smallint" => Ok(DataType::Int16),
51 "integer" | "int" | "int4" => Ok(DataType::Int32),
52 "bigint" | "int8" | "money" => Ok(DataType::Int64),
53 "oid" | "xid" | "regproc" => Ok(DataType::UInt32),
54 "numeric" | "decimal" => {
55 let (precision, scale) = parse_numeric_type(pg_type)?;
56 Ok(DataType::Decimal128(precision, scale))
57 }
58 "real" | "float4" => Ok(DataType::Float32),
59 "double precision" | "float8" => Ok(DataType::Float64),
60 "\"char\"" => Ok(DataType::Int8),
61 "character" | "char" | "character varying" | "varchar" | "text" | "bpchar" | "uuid"
62 | "name" => Ok(DataType::Utf8),
63 "bytea" => Ok(DataType::Binary),
64 "date" => Ok(DataType::Date32),
65 "time" | "time without time zone" => Ok(DataType::Time64(TimeUnit::Nanosecond)),
66 "timestamp" | "timestamp without time zone" => {
67 Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
68 }
69 "timestamp with time zone" | "timestamptz" => Ok(DataType::Timestamp(
70 TimeUnit::Nanosecond,
71 Some("UTC".into()),
72 )),
73 "interval" => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
74 "boolean" => Ok(DataType::Boolean),
75 "enum" => Ok(DataType::Dictionary(
76 Box::new(DataType::Int8),
77 Box::new(DataType::Utf8),
78 )),
79 "point" => Ok(DataType::FixedSizeList(
80 Arc::new(Field::new("item", DataType::Float64, true)),
81 2,
82 )),
83 "line" | "lseg" | "box" | "path" | "polygon" | "circle" => Ok(DataType::Binary),
84 "inet" | "cidr" | "macaddr" => Ok(DataType::Utf8),
85 "bit" | "bit varying" => Ok(DataType::Binary),
86 "tsvector" | "tsquery" => Ok(DataType::LargeUtf8),
87 "xml" | "json" => Ok(DataType::Utf8),
88 "aclitem" | "pg_node_tree" => Ok(DataType::Utf8),
89 "array" => parse_array_type(context),
90 "anyarray" => Ok(DataType::List(Arc::new(Field::new(
91 "item",
92 DataType::Binary,
93 true,
94 )))),
95 "int4range" => Ok(DataType::Struct(Fields::from(vec![
96 Field::new("lower", DataType::Int32, true),
97 Field::new("upper", DataType::Int32, true),
98 ]))),
99 "composite" => parse_composite_type(context),
100 "geometry" | "geography" => Ok(DataType::Binary),
101
102 "jsonb" if context.unsupported_type_action == UnsupportedTypeAction::String => {
104 Ok(DataType::Utf8)
105 }
106 _ => Err(ArrowError::ParseError(format!(
107 "Unsupported PostgreSQL type: {}",
108 pg_type
109 ))),
110 }
111}
112
113fn parse_array_type(context: &ParseContext) -> Result<DataType, ArrowError> {
114 let details = context
115 .type_details
116 .as_ref()
117 .ok_or_else(|| ArrowError::ParseError("Missing type details for array type".to_string()))?;
118 let details = details
119 .as_object()
120 .ok_or_else(|| ArrowError::ParseError("Invalid array type details format".to_string()))?;
121 let element_type = details
122 .get("element_type")
123 .and_then(Value::as_str)
124 .ok_or_else(|| {
125 ArrowError::ParseError("Missing or invalid element_type for array".to_string())
126 })?;
127
128 let inner_type = if element_type.ends_with("[]") {
129 let inner_context = context.clone().with_type_details(json!({
130 "type": "array",
131 "element_type": element_type.trim_end_matches("[]"),
132 }));
133 parse_array_type(&inner_context)?
134 } else {
135 pg_data_type_to_arrow_type(element_type, context)?
136 };
137
138 Ok(DataType::List(Arc::new(Field::new(
139 "item", inner_type, true,
140 ))))
141}
142
143fn parse_composite_type(context: &ParseContext) -> Result<DataType, ArrowError> {
144 let details = context.type_details.as_ref().ok_or_else(|| {
145 ArrowError::ParseError("Missing type details for composite type".to_string())
146 })?;
147 let details = details.as_object().ok_or_else(|| {
148 ArrowError::ParseError("Invalid composite type details format".to_string())
149 })?;
150 let attributes = details
151 .get("attributes")
152 .and_then(Value::as_array)
153 .ok_or_else(|| {
154 ArrowError::ParseError("Missing or invalid attributes for composite type".to_string())
155 })?;
156
157 let fields: Result<Vec<Field>, ArrowError> = attributes
158 .iter()
159 .map(|attr| {
160 let attr_obj = attr.as_object().ok_or_else(|| {
161 ArrowError::ParseError("Invalid attribute format in composite type".to_string())
162 })?;
163 let name = attr_obj
164 .get("name")
165 .and_then(Value::as_str)
166 .ok_or_else(|| {
167 ArrowError::ParseError(
168 "Missing or invalid name in composite type attribute".to_string(),
169 )
170 })?;
171 let attr_type = attr_obj
172 .get("type")
173 .and_then(Value::as_str)
174 .ok_or_else(|| {
175 ArrowError::ParseError(
176 "Missing or invalid type in composite type attribute".to_string(),
177 )
178 })?;
179 let field_type = if attr_type == "composite" {
180 let inner_context = context.clone().with_type_details(attr.clone());
181 parse_composite_type(&inner_context)?
182 } else {
183 pg_data_type_to_arrow_type(attr_type, context)?
184 };
185 Ok(Field::new(name, field_type, true))
186 })
187 .collect();
188
189 Ok(DataType::Struct(Fields::from(fields?)))
190}
191
192fn parse_numeric_type(pg_type: &str) -> Result<(u8, i8), ArrowError> {
193 let type_str = pg_type
194 .trim_start_matches("numeric")
195 .trim_start_matches("decimal")
196 .trim();
197
198 if type_str.is_empty() || type_str == "()" {
199 return Ok((38, 20)); }
201
202 let parts: Vec<&str> = type_str
203 .trim_start_matches('(')
204 .trim_end_matches(')')
205 .split(',')
206 .collect();
207
208 match parts.len() {
209 1 => {
210 let precision = parts[0]
211 .trim()
212 .parse::<u8>()
213 .map_err(|_| ArrowError::ParseError("Invalid numeric precision".to_string()))?;
214 Ok((precision, 0))
215 }
216 2 => {
217 let precision = parts[0]
218 .trim()
219 .parse::<u8>()
220 .map_err(|_| ArrowError::ParseError("Invalid numeric precision".to_string()))?;
221 let scale = parts[1]
222 .trim()
223 .parse::<i8>()
224 .map_err(|_| ArrowError::ParseError("Invalid numeric scale".to_string()))?;
225 Ok((precision, scale))
226 }
227 _ => Err(ArrowError::ParseError(
228 "Invalid numeric type format".to_string(),
229 )),
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_pg_data_type_to_arrow_type() {
239 let context = ParseContext::new();
240 assert_eq!(
242 pg_data_type_to_arrow_type("smallint", &context).expect("Failed to convert smallint"),
243 DataType::Int16
244 );
245 assert_eq!(
246 pg_data_type_to_arrow_type("integer", &context).expect("Failed to convert integer"),
247 DataType::Int32
248 );
249 assert_eq!(
250 pg_data_type_to_arrow_type("bigint", &context).expect("Failed to convert bigint"),
251 DataType::Int64
252 );
253 assert_eq!(
254 pg_data_type_to_arrow_type("real", &context).expect("Failed to convert real"),
255 DataType::Float32
256 );
257 assert_eq!(
258 pg_data_type_to_arrow_type("double precision", &context)
259 .expect("Failed to convert double precision"),
260 DataType::Float64
261 );
262 assert_eq!(
263 pg_data_type_to_arrow_type("boolean", &context).expect("Failed to convert boolean"),
264 DataType::Boolean
265 );
266 assert_eq!(
267 pg_data_type_to_arrow_type("\"char\"", &context)
268 .expect("Failed to convert single character"),
269 DataType::Int8
270 );
271
272 assert_eq!(
274 pg_data_type_to_arrow_type("character", &context).expect("Failed to convert character"),
275 DataType::Utf8
276 );
277 assert_eq!(
278 pg_data_type_to_arrow_type("character varying", &context)
279 .expect("Failed to convert character varying"),
280 DataType::Utf8
281 );
282 assert_eq!(
283 pg_data_type_to_arrow_type("name", &context).expect("Failed to convert name"),
284 DataType::Utf8
285 );
286 assert_eq!(
287 pg_data_type_to_arrow_type("text", &context).expect("Failed to convert text"),
288 DataType::Utf8
289 );
290
291 assert_eq!(
293 pg_data_type_to_arrow_type("date", &context).expect("Failed to convert date"),
294 DataType::Date32
295 );
296 assert_eq!(
297 pg_data_type_to_arrow_type("time without time zone", &context)
298 .expect("Failed to convert time without time zone"),
299 DataType::Time64(TimeUnit::Nanosecond)
300 );
301 assert_eq!(
302 pg_data_type_to_arrow_type("timestamp without time zone", &context)
303 .expect("Failed to convert timestamp without time zone"),
304 DataType::Timestamp(TimeUnit::Nanosecond, None)
305 );
306 assert_eq!(
307 pg_data_type_to_arrow_type("timestamp with time zone", &context)
308 .expect("Failed to convert timestamp with time zone"),
309 DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into()))
310 );
311 assert_eq!(
312 pg_data_type_to_arrow_type("interval", &context).expect("Failed to convert interval"),
313 DataType::Interval(IntervalUnit::MonthDayNano)
314 );
315
316 assert_eq!(
318 pg_data_type_to_arrow_type("numeric", &context).expect("Failed to convert numeric"),
319 DataType::Decimal128(38, 20)
320 );
321 assert_eq!(
322 pg_data_type_to_arrow_type("numeric()", &context).expect("Failed to convert numeric()"),
323 DataType::Decimal128(38, 20)
324 );
325 assert_eq!(
326 pg_data_type_to_arrow_type("numeric(10,2)", &context)
327 .expect("Failed to convert numeric(10,2)"),
328 DataType::Decimal128(10, 2)
329 );
330
331 let array_type_context = context.clone().with_type_details(json!({
333 "type": "array",
334 "element_type": "integer",
335 }));
336 assert_eq!(
337 pg_data_type_to_arrow_type("array", &array_type_context)
338 .expect("Failed to convert array"),
339 DataType::List(Arc::new(Field::new("item", DataType::Int32, true)))
340 );
341
342 let composite_type_context = context.clone().with_type_details(json!({
344 "type": "composite",
345 "attributes": [
346 {"name": "x", "type": "integer"},
347 {"name": "y", "type": "text"}
348 ]
349 }));
350 assert_eq!(
351 pg_data_type_to_arrow_type("composite", &composite_type_context)
352 .expect("Failed to convert composite"),
353 DataType::Struct(Fields::from(vec![
354 Field::new("x", DataType::Int32, true),
355 Field::new("y", DataType::Utf8, true)
356 ]))
357 );
358
359 assert!(pg_data_type_to_arrow_type("unsupported_type", &context).is_err());
361 }
362
363 #[test]
364 fn test_parse_numeric_type() {
365 assert_eq!(
366 parse_numeric_type("numeric").expect("Failed to parse numeric"),
367 (38, 20)
368 );
369 assert_eq!(
370 parse_numeric_type("numeric()").expect("Failed to parse numeric()"),
371 (38, 20)
372 );
373 assert_eq!(
374 parse_numeric_type("numeric(10)").expect("Failed to parse numeric(10)"),
375 (10, 0)
376 );
377 assert_eq!(
378 parse_numeric_type("numeric(10,2)").expect("Failed to parse numeric(10,2)"),
379 (10, 2)
380 );
381 assert_eq!(
382 parse_numeric_type("decimal").expect("Failed to parse decimal"),
383 (38, 20)
384 );
385 assert_eq!(
386 parse_numeric_type("decimal()").expect("Failed to parse decimal()"),
387 (38, 20)
388 );
389 assert_eq!(
390 parse_numeric_type("decimal(15)").expect("Failed to parse decimal(15)"),
391 (15, 0)
392 );
393 assert_eq!(
394 parse_numeric_type("decimal(15,5)").expect("Failed to parse decimal(15,5)"),
395 (15, 5)
396 );
397
398 assert!(parse_numeric_type("numeric(invalid)").is_err());
400 assert!(parse_numeric_type("numeric(10,2,3)").is_err());
401 assert!(parse_numeric_type("numeric(,)").is_err());
402 }
403
404 #[test]
405 fn test_pg_data_type_to_arrow_type_with_size() {
406 let context = ParseContext::new();
407 assert_eq!(
408 pg_data_type_to_arrow_type("character(10)", &context)
409 .expect("Failed to convert character(10)"),
410 DataType::Utf8
411 );
412 assert_eq!(
413 pg_data_type_to_arrow_type("character varying(255)", &context)
414 .expect("Failed to convert character varying(255)"),
415 DataType::Utf8
416 );
417 assert_eq!(
418 pg_data_type_to_arrow_type("bit(8)", &context).expect("Failed to convert bit(8)"),
419 DataType::Binary
420 );
421 assert_eq!(
422 pg_data_type_to_arrow_type("bit varying(64)", &context)
423 .expect("Failed to convert bit varying(64)"),
424 DataType::Binary
425 );
426 assert_eq!(
427 pg_data_type_to_arrow_type("numeric(10,2)", &context)
428 .expect("Failed to convert numeric(10,2)"),
429 DataType::Decimal128(10, 2)
430 );
431 }
432
433 #[test]
434 fn test_pg_data_type_to_arrow_type_extended() {
435 let context = ParseContext::new();
436 assert_eq!(
438 pg_data_type_to_arrow_type("numeric(38,10)", &context)
439 .expect("Failed to convert numeric(38,10)"),
440 DataType::Decimal128(38, 10)
441 );
442 assert_eq!(
443 pg_data_type_to_arrow_type("decimal(5,0)", &context)
444 .expect("Failed to convert decimal(5,0)"),
445 DataType::Decimal128(5, 0)
446 );
447
448 assert_eq!(
450 pg_data_type_to_arrow_type("time(6) without time zone", &context)
451 .expect("Failed to convert time(6) without time zone"),
452 DataType::Time64(TimeUnit::Nanosecond)
453 );
454
455 let nested_array_type_details = context.clone().with_type_details(json!({
457 "type": "array",
458 "element_type": "integer[]",
459 }));
460 assert_eq!(
461 pg_data_type_to_arrow_type("array", &nested_array_type_details)
462 .expect("Failed to convert nested array"),
463 DataType::List(Arc::new(Field::new(
464 "item",
465 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
466 true
467 )))
468 );
469
470 let enum_type_details = context.clone().with_type_details(json!({
472 "type": "enum",
473 "values": ["small", "medium", "large"]
474 }));
475 assert_eq!(
476 pg_data_type_to_arrow_type("enum", &enum_type_details).expect("Failed to convert enum"),
477 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8))
478 );
479
480 assert_eq!(
482 pg_data_type_to_arrow_type("point", &context).expect("Failed to convert point"),
483 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2)
484 );
485 assert_eq!(
486 pg_data_type_to_arrow_type("line", &context).expect("Failed to convert line"),
487 DataType::Binary
488 );
489
490 assert_eq!(
492 pg_data_type_to_arrow_type("inet", &context).expect("Failed to convert inet"),
493 DataType::Utf8
494 );
495 assert_eq!(
496 pg_data_type_to_arrow_type("cidr", &context).expect("Failed to convert cidr"),
497 DataType::Utf8
498 );
499
500 assert_eq!(
502 pg_data_type_to_arrow_type("int4range", &context).expect("Failed to convert int4range"),
503 DataType::Struct(Fields::from(vec![
504 Field::new("lower", DataType::Int32, true),
505 Field::new("upper", DataType::Int32, true),
506 ]))
507 );
508
509 assert_eq!(
511 pg_data_type_to_arrow_type("json", &context).expect("Failed to convert json"),
512 DataType::Utf8
513 );
514
515 let jsonb_context = context
516 .clone()
517 .with_unsupported_type_action(UnsupportedTypeAction::String);
518 assert_eq!(
519 pg_data_type_to_arrow_type("jsonb", &jsonb_context).expect("Failed to convert jsonb"),
520 DataType::Utf8
521 );
522
523 assert_eq!(
525 pg_data_type_to_arrow_type("uuid", &context).expect("Failed to convert uuid"),
526 DataType::Utf8
527 );
528
529 assert_eq!(
531 pg_data_type_to_arrow_type("tsvector", &context).expect("Failed to convert tsvector"),
532 DataType::LargeUtf8
533 );
534 assert_eq!(
535 pg_data_type_to_arrow_type("tsquery", &context).expect("Failed to convert tsquery"),
536 DataType::LargeUtf8
537 );
538
539 assert_eq!(
541 pg_data_type_to_arrow_type("bpchar", &context).expect("Failed to convert bpchar"),
542 DataType::Utf8
543 );
544
545 assert_eq!(
547 pg_data_type_to_arrow_type("bpchar(10)", &context)
548 .expect("Failed to convert bpchar(10)"),
549 DataType::Utf8
550 );
551 }
552
553 #[test]
554 fn test_parse_array_type_extended() {
555 let context = ParseContext::new();
556 let single_dim_array = context.clone().with_type_details(json!({
557 "type": "array",
558 "element_type": "integer",
559 }));
560 assert_eq!(
561 parse_array_type(&single_dim_array).expect("Failed to parse single dimension array"),
562 DataType::List(Arc::new(Field::new("item", DataType::Int32, true)))
563 );
564
565 let multi_dim_array = context.clone().with_type_details(json!({
566 "type": "array",
567 "element_type": "text[]",
568 }));
569 assert_eq!(
570 parse_array_type(&multi_dim_array).expect("Failed to parse multi-dimension array"),
571 DataType::List(Arc::new(Field::new(
572 "item",
573 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
574 true
575 )))
576 );
577
578 let invalid_array = context.clone().with_type_details(json!({"type": "array"}));
579 assert!(parse_array_type(&invalid_array).is_err());
580 }
581
582 #[test]
583 fn test_parse_composite_type_extended() {
584 let context = ParseContext::new();
585 let simple_composite = context.clone().with_type_details(json!({
586 "type": "composite",
587 "attributes": [
588 {"name": "id", "type": "integer"},
589 {"name": "name", "type": "text"},
590 {"name": "active", "type": "boolean"}
591 ]
592 }));
593 assert_eq!(
594 parse_composite_type(&simple_composite).expect("Failed to parse simple composite type"),
595 DataType::Struct(Fields::from(vec![
596 Field::new("id", DataType::Int32, true),
597 Field::new("name", DataType::Utf8, true),
598 Field::new("active", DataType::Boolean, true),
599 ]))
600 );
601
602 let nested_composite = context.clone().with_type_details(json!({
603 "type": "composite",
604 "attributes": [
605 {"name": "id", "type": "integer"},
606 {"name": "details", "type": "composite", "attributes": [
607 {"name": "x", "type": "float8"},
608 {"name": "y", "type": "float8"}
609 ]}
610 ]
611 }));
612 assert_eq!(
613 parse_composite_type(&nested_composite).expect("Failed to parse nested composite type"),
614 DataType::Struct(Fields::from(vec![
615 Field::new("id", DataType::Int32, true),
616 Field::new(
617 "details",
618 DataType::Struct(Fields::from(vec![
619 Field::new("x", DataType::Float64, true),
620 Field::new("y", DataType::Float64, true),
621 ])),
622 true
623 ),
624 ]))
625 );
626
627 let invalid_composite = context.clone().with_type_details(json!({
628 "type": "composite",
629 }));
630 assert!(parse_composite_type(&invalid_composite).is_err());
631 }
632}