proof_of_sql/base/arrow/
column_arrow_conversions.rs

1use crate::base::{
2    database::{ColumnField, ColumnType},
3    math::decimal::Precision,
4    posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
5};
6use alloc::sync::Arc;
7use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit};
8
9/// Convert [`ColumnType`] values to some arrow [`DataType`]
10impl From<&ColumnType> for DataType {
11    fn from(column_type: &ColumnType) -> Self {
12        match column_type {
13            ColumnType::Boolean => DataType::Boolean,
14            ColumnType::Uint8 => DataType::UInt8,
15            ColumnType::TinyInt => DataType::Int8,
16            ColumnType::SmallInt => DataType::Int16,
17            ColumnType::Int => DataType::Int32,
18            ColumnType::BigInt => DataType::Int64,
19            ColumnType::Int128 => DataType::Decimal128(38, 0),
20            ColumnType::Decimal75(precision, scale) => {
21                DataType::Decimal256(precision.value(), *scale)
22            }
23            ColumnType::VarChar => DataType::Utf8,
24            ColumnType::VarBinary => DataType::Binary,
25            ColumnType::Scalar => unimplemented!("Cannot convert Scalar type to arrow type"),
26            ColumnType::TimestampTZ(timeunit, timezone) => {
27                let arrow_timezone = Some(Arc::from(timezone.to_string()));
28                let arrow_timeunit = match timeunit {
29                    PoSQLTimeUnit::Second => ArrowTimeUnit::Second,
30                    PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond,
31                    PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond,
32                    PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond,
33                };
34                DataType::Timestamp(arrow_timeunit, arrow_timezone)
35            }
36        }
37    }
38}
39
40/// Convert arrow [`DataType`] values to some [`ColumnType`]
41impl TryFrom<DataType> for ColumnType {
42    type Error = String;
43
44    fn try_from(data_type: DataType) -> Result<Self, Self::Error> {
45        match data_type {
46            DataType::Boolean => Ok(ColumnType::Boolean),
47            DataType::UInt8 => Ok(ColumnType::Uint8),
48            DataType::Int8 => Ok(ColumnType::TinyInt),
49            DataType::Int16 => Ok(ColumnType::SmallInt),
50            DataType::Int32 => Ok(ColumnType::Int),
51            DataType::Int64 => Ok(ColumnType::BigInt),
52            DataType::Decimal128(38, 0) => Ok(ColumnType::Int128),
53            DataType::Decimal256(precision, scale) if precision <= 75 => {
54                Ok(ColumnType::Decimal75(Precision::new(precision)?, scale))
55            }
56            DataType::Timestamp(time_unit, timezone_option) => {
57                let posql_time_unit = match time_unit {
58                    ArrowTimeUnit::Second => PoSQLTimeUnit::Second,
59                    ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond,
60                    ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond,
61                    ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond,
62                };
63                Ok(ColumnType::TimestampTZ(
64                    posql_time_unit,
65                    PoSQLTimeZone::try_from(&timezone_option)?,
66                ))
67            }
68            DataType::Utf8 => Ok(ColumnType::VarChar),
69            DataType::Binary => Ok(ColumnType::VarBinary),
70            _ => Err(format!("Unsupported arrow data type {data_type:?}")),
71        }
72    }
73}
74/// Convert [`ColumnField`] values to arrow Field
75impl From<&ColumnField> for Field {
76    fn from(column_field: &ColumnField) -> Self {
77        Field::new(
78            column_field.name().value.as_str(),
79            (&column_field.data_type()).into(),
80            false,
81        )
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use proptest::prelude::*;
89
90    proptest! {
91        #[test]
92        fn we_can_roundtrip_arbitrary_column_type(column_type: ColumnType) {
93            let arrow = DataType::from(&column_type);
94            let actual = ColumnType::try_from(arrow).unwrap();
95
96            prop_assert_eq!(actual, column_type);
97        }
98    }
99}