proof_of_sql/base/arrow/
owned_and_arrow_conversions.rs

1//! This module provides `From` and `TryFrom` implementations to go between arrow and owned types
2//! The mapping is as follows:
3//! `OwnedType` <-> `Array/ArrayRef`
4//! `OwnedTable` <-> `RecordBatch`
5//! `Boolean` <-> `Boolean`
6//! `BigInt` <-> `Int64`
7//! `VarChar` <-> `Utf8/String`
8//! `Int128` <-> `Decimal128(38,0)`
9//! `Decimal75` <-> `S`
10//!
11//! Note: this converts `Int128` values to `Decimal128(38,0)`, which are backed by `i128`.
12//! This is because there is no `Int128` type in Arrow.
13//! This does not check that the values are less than 39 digits.
14//! However, the actual arrow backing `i128` is the correct value.
15use super::scalar_and_i256_conversions::{convert_i256_to_scalar, convert_scalar_to_i256};
16use crate::base::{
17    database::{OwnedColumn, OwnedTable, OwnedTableError},
18    map::IndexMap,
19    math::decimal::Precision,
20    scalar::Scalar,
21};
22use alloc::sync::Arc;
23use arrow::{
24    array::{
25        ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, Int16Array,
26        Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray,
27        TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
28    },
29    datatypes::{i256, DataType, Schema, SchemaRef, TimeUnit as ArrowTimeUnit},
30    error::ArrowError,
31    record_batch::RecordBatch,
32};
33use proof_of_sql_parser::{
34    posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestampError},
35    ParseError,
36};
37use snafu::Snafu;
38use sqlparser::ast::Ident;
39
40#[derive(Snafu, Debug)]
41#[non_exhaustive]
42/// Errors caused by conversions between Arrow and owned types.
43pub enum OwnedArrowConversionError {
44    /// This error occurs when trying to convert from an unsupported arrow type.
45    #[snafu(display(
46        "unsupported type: attempted conversion from ArrayRef of type {datatype} to OwnedColumn"
47    ))]
48    UnsupportedType {
49        /// The unsupported datatype
50        datatype: DataType,
51    },
52    /// This error occurs when trying to convert from a record batch with duplicate idents(e.g. `"a"` and `"A"`).
53    #[snafu(display("conversion resulted in duplicate idents"))]
54    DuplicateIdents,
55    /// This error occurs when converting from a record batch name to an idents fails. (Which may be impossible.)
56    #[snafu(transparent)]
57    FieldParseFail {
58        /// The underlying source error
59        source: ParseError,
60    },
61    /// This error occurs when creating an owned table fails, which should only occur when there are zero columns.
62    #[snafu(transparent)]
63    InvalidTable {
64        /// The underlying source error
65        source: OwnedTableError,
66    },
67    /// This error occurs when trying to convert from an Arrow array with nulls.
68    #[snafu(display("null values are not supported in OwnedColumn yet"))]
69    NullNotSupportedYet,
70    /// Using `TimeError` to handle all time-related errors
71    #[snafu(transparent)]
72    TimestampConversionError {
73        /// The underlying source error
74        source: PoSQLTimestampError,
75    },
76}
77
78/// # Panics
79///
80/// Will panic if setting precision and scale fails when converting `OwnedColumn::Int128`.
81/// Will panic if setting precision and scale fails when converting `OwnedColumn::Decimal75`.
82/// Will panic if trying to convert `OwnedColumn::Scalar`, as this conversion is not implemented
83impl<S: Scalar> From<OwnedColumn<S>> for ArrayRef {
84    fn from(value: OwnedColumn<S>) -> Self {
85        match value {
86            OwnedColumn::Boolean(col) => Arc::new(BooleanArray::from(col)),
87            OwnedColumn::Uint8(col) => Arc::new(UInt8Array::from(col)),
88            OwnedColumn::TinyInt(col) => Arc::new(Int8Array::from(col)),
89            OwnedColumn::SmallInt(col) => Arc::new(Int16Array::from(col)),
90            OwnedColumn::Int(col) => Arc::new(Int32Array::from(col)),
91            OwnedColumn::BigInt(col) => Arc::new(Int64Array::from(col)),
92            OwnedColumn::Int128(col) => Arc::new(
93                Decimal128Array::from(col)
94                    .with_precision_and_scale(38, 0)
95                    .unwrap(),
96            ),
97            OwnedColumn::Decimal75(precision, scale, col) => {
98                let converted_col: Vec<i256> = col.iter().map(convert_scalar_to_i256).collect();
99
100                Arc::new(
101                    Decimal256Array::from(converted_col)
102                        .with_precision_and_scale(precision.value(), scale)
103                        .unwrap(),
104                )
105            }
106            OwnedColumn::Scalar(_) => unimplemented!("Cannot convert Scalar type to arrow type"),
107            OwnedColumn::VarChar(col) => Arc::new(StringArray::from(col)),
108            OwnedColumn::VarBinary(col) => {
109                Arc::new(BinaryArray::from_iter_values(col.iter().map(Vec::as_slice)))
110            }
111            OwnedColumn::TimestampTZ(time_unit, _, col) => match time_unit {
112                PoSQLTimeUnit::Second => Arc::new(TimestampSecondArray::from(col)),
113                PoSQLTimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(col)),
114                PoSQLTimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(col)),
115                PoSQLTimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(col)),
116            },
117        }
118    }
119}
120
121impl<S: Scalar> TryFrom<OwnedTable<S>> for RecordBatch {
122    type Error = ArrowError;
123    fn try_from(value: OwnedTable<S>) -> Result<Self, Self::Error> {
124        if value.is_empty() {
125            Ok(RecordBatch::new_empty(SchemaRef::new(Schema::empty())))
126        } else {
127            RecordBatch::try_from_iter(
128                value
129                    .into_inner()
130                    .into_iter()
131                    .map(|(identifier, owned_column)| {
132                        (identifier.value, ArrayRef::from(owned_column))
133                    }),
134            )
135        }
136    }
137}
138
139impl<S: Scalar> TryFrom<ArrayRef> for OwnedColumn<S> {
140    type Error = OwnedArrowConversionError;
141    fn try_from(value: ArrayRef) -> Result<Self, Self::Error> {
142        Self::try_from(&value)
143    }
144}
145impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
146    type Error = OwnedArrowConversionError;
147
148    #[allow(clippy::too_many_lines)]
149    /// # Panics
150    ///
151    /// Will panic if downcasting fails for the following types:
152    /// - `BooleanArray` when converting from `DataType::Boolean`.
153    /// - `Int16Array` when converting from `DataType::Int16`.
154    /// - `Int32Array` when converting from `DataType::Int32`.
155    /// - `Int64Array` when converting from `DataType::Int64`.
156    /// - `Decimal128Array` when converting from `DataType::Decimal128(38, 0)`.
157    /// - `Decimal256Array` when converting from `DataType::Decimal256` if precision is less than or equal to 75.
158    /// - `StringArray` when converting from `DataType::Utf8`.
159    fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
160        match &value.data_type() {
161            // Arrow uses a bit-packed representation for booleans.
162            // Hence we need to unpack the bits to get the actual boolean values.
163            DataType::Boolean => Ok(Self::Boolean(
164                value
165                    .as_any()
166                    .downcast_ref::<BooleanArray>()
167                    .unwrap()
168                    .iter()
169                    .collect::<Option<Vec<bool>>>()
170                    .ok_or(OwnedArrowConversionError::NullNotSupportedYet)?,
171            )),
172            DataType::UInt8 => Ok(Self::Uint8(
173                value
174                    .as_any()
175                    .downcast_ref::<UInt8Array>()
176                    .unwrap()
177                    .values()
178                    .to_vec(),
179            )),
180            DataType::Int8 => Ok(Self::TinyInt(
181                value
182                    .as_any()
183                    .downcast_ref::<Int8Array>()
184                    .unwrap()
185                    .values()
186                    .to_vec(),
187            )),
188            DataType::Int16 => Ok(Self::SmallInt(
189                value
190                    .as_any()
191                    .downcast_ref::<Int16Array>()
192                    .unwrap()
193                    .values()
194                    .to_vec(),
195            )),
196            DataType::Int32 => Ok(Self::Int(
197                value
198                    .as_any()
199                    .downcast_ref::<Int32Array>()
200                    .unwrap()
201                    .values()
202                    .to_vec(),
203            )),
204            DataType::Int64 => Ok(Self::BigInt(
205                value
206                    .as_any()
207                    .downcast_ref::<Int64Array>()
208                    .unwrap()
209                    .values()
210                    .to_vec(),
211            )),
212            DataType::Decimal128(38, 0) => Ok(Self::Int128(
213                value
214                    .as_any()
215                    .downcast_ref::<Decimal128Array>()
216                    .unwrap()
217                    .values()
218                    .to_vec(),
219            )),
220            DataType::Decimal256(precision, scale) if *precision <= 75 => Ok(Self::Decimal75(
221                Precision::new(*precision).expect("precision is less than 76"),
222                *scale,
223                value
224                    .as_any()
225                    .downcast_ref::<Decimal256Array>()
226                    .unwrap()
227                    .values()
228                    .iter()
229                    .map(convert_i256_to_scalar)
230                    .map(Option::unwrap)
231                    .collect(),
232            )),
233            DataType::Utf8 => Ok(Self::VarChar(
234                value
235                    .as_any()
236                    .downcast_ref::<StringArray>()
237                    .unwrap()
238                    .iter()
239                    .map(|s| s.unwrap().to_string())
240                    .collect(),
241            )),
242            DataType::Binary => Ok(Self::VarBinary(
243                value
244                    .as_any()
245                    .downcast_ref::<BinaryArray>()
246                    .unwrap()
247                    .iter()
248                    .map(|s| s.map(<[u8]>::to_vec).unwrap())
249                    .collect(),
250            )),
251            DataType::Timestamp(time_unit, timezone) => match time_unit {
252                ArrowTimeUnit::Second => {
253                    let array = value
254                        .as_any()
255                        .downcast_ref::<TimestampSecondArray>()
256                        .expect(
257                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
258                        );
259                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
260                    Ok(OwnedColumn::TimestampTZ(
261                        PoSQLTimeUnit::Second,
262                        PoSQLTimeZone::try_from(timezone)?,
263                        timestamps,
264                    ))
265                }
266                ArrowTimeUnit::Millisecond => {
267                    let array = value
268                        .as_any()
269                        .downcast_ref::<TimestampMillisecondArray>()
270                        .expect(
271                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
272                        );
273                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
274                    Ok(OwnedColumn::TimestampTZ(
275                        PoSQLTimeUnit::Millisecond,
276                        PoSQLTimeZone::try_from(timezone)?,
277                        timestamps,
278                    ))
279                }
280                ArrowTimeUnit::Microsecond => {
281                    let array = value
282                        .as_any()
283                        .downcast_ref::<TimestampMicrosecondArray>()
284                        .expect(
285                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
286                        );
287                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
288                    Ok(OwnedColumn::TimestampTZ(
289                        PoSQLTimeUnit::Microsecond,
290                        PoSQLTimeZone::try_from(timezone)?,
291                        timestamps,
292                    ))
293                }
294                ArrowTimeUnit::Nanosecond => {
295                    let array = value
296                        .as_any()
297                        .downcast_ref::<TimestampNanosecondArray>()
298                        .expect(
299                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
300                        );
301                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
302                    Ok(OwnedColumn::TimestampTZ(
303                        PoSQLTimeUnit::Nanosecond,
304                        PoSQLTimeZone::try_from(timezone)?,
305                        timestamps,
306                    ))
307                }
308            },
309            &data_type => Err(OwnedArrowConversionError::UnsupportedType {
310                datatype: data_type.clone(),
311            }),
312        }
313    }
314}
315
316impl<S: Scalar> TryFrom<RecordBatch> for OwnedTable<S> {
317    type Error = OwnedArrowConversionError;
318    fn try_from(value: RecordBatch) -> Result<Self, Self::Error> {
319        let num_columns = value.num_columns();
320        let table: Result<IndexMap<_, _>, Self::Error> = value
321            .schema()
322            .fields()
323            .iter()
324            .zip(value.columns())
325            .map(|(field, array_ref)| {
326                let owned_column = OwnedColumn::try_from(array_ref)?;
327                let identifier = Ident::new(field.name());
328                Ok((identifier, owned_column))
329            })
330            .collect();
331        let owned_table = Self::try_new(table?)?;
332        if num_columns == owned_table.num_columns() {
333            Ok(owned_table)
334        } else {
335            Err(OwnedArrowConversionError::DuplicateIdents)
336        }
337    }
338}