proof_of_sql/base/arrow/
column_arrow_conversions.rs1use crate::base::{
2 database::{ColumnField, ColumnType},
3 math::decimal::Precision,
4 posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
5};
6use alloc::sync::Arc;
7use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
8
9impl From<&ColumnType> for DataType {
11 fn from(column_type: &ColumnType) -> Self {
12 match column_type {
13 ColumnType::Boolean => DataType::Boolean,
14 ColumnType::Uint8 => DataType::UInt8,
15 ColumnType::TinyInt => DataType::Int8,
16 ColumnType::SmallInt => DataType::Int16,
17 ColumnType::Int => DataType::Int32,
18 ColumnType::BigInt => DataType::Int64,
19 ColumnType::Int128 => DataType::Decimal128(38, 0),
20 ColumnType::Decimal75(precision, scale) => {
21 DataType::Decimal256(precision.value(), *scale)
22 }
23 ColumnType::VarChar => DataType::Utf8,
24 ColumnType::VarBinary => DataType::Binary,
25 ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"),
26 ColumnType::TimestampTZ(timeunit, timezone) => {
27 let arrow_timezone = Some(Arc::from(timezone.to_string()));
28 let arrow_timeunit = match timeunit {
29 PoSQLTimeUnit::Second => ArrowTimeUnit::Second,
30 PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
31 PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
32 PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
33 };
34 DataType::Timestamp(arrow_timeunit, arrow_timezone)
35 }
36 }
37 }
38}
39
40impl TryFrom<DataType> for ColumnType {
42 type Error = String;
43
44 fn try_from(data_type: DataType) -> Result<Self, Self::Error> {
45 match data_type {
46 DataType::Boolean => Ok(ColumnType::Boolean),
47 DataType::UInt8 => Ok(ColumnType::Uint8),
48 DataType::Int8 => Ok(ColumnType::TinyInt),
49 DataType::Int16 => Ok(ColumnType::SmallInt),
50 DataType::Int32 => Ok(ColumnType::Int),
51 DataType::Int64 => Ok(ColumnType::BigInt),
52 DataType::Decimal128(38, 0) => Ok(ColumnType::Int128),
53 DataType::Decimal256(precision, scale) if precision <= 75 => {
54 Ok(ColumnType::Decimal75(Precision::new(precision)?, scale))
55 }
56 DataType::Timestamp(time_unit, timezone_option) => {
57 let posql_time_unit = match time_unit {
58 ArrowTimeUnit::Second => PoSQLTimeUnit::Second,
59 ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond,
60 ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond,
61 ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond,
62 };
63 Ok(ColumnType::TimestampTZ(
64 posql_time_unit,
65 PoSQLTimeZone::try_from(&timezone_option)?,
66 ))
67 }
68 DataType::Utf8 => Ok(ColumnType::VarChar),
69 DataType::Binary => Ok(ColumnType::VarBinary),
70 _ => Err(format!("Unsupported arrow data type {data_type:?}")),
71 }
72 }
73}
74impl From<&ColumnField> for Field {
76 fn from(column_field: &ColumnField) -> Self {
77 Field::new(
78 column_field.name().value.as_str(),
79 (&column_field.data_type()).into(),
80 false,
81 )
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use proptest::prelude::*;
89
90 proptest! {
91 #[test]
92 fn we_can_roundtrip_arbitrary_column_type(column_type: ColumnType) {
93 let arrow = DataType::from(&column_type);
94 let actual = ColumnType::try_from(arrow).unwrap();
95
96 prop_assert_eq!(actual, column_type);
97 }
98 }
99}