proof_of_sql/base/arrow/
owned_and_arrow_conversions.rs

1//! This module provides `From` and `TryFrom` implementations to go between arrow and owned types
2//! The mapping is as follows:
3//! `OwnedType` <-> `Array/ArrayRef`
4//! `OwnedTable` <-> `RecordBatch`
5//! `Boolean` <-> `Boolean`
6//! `BigInt` <-> `Int64`
7//! `VarChar` <-> `Utf8/String`
8//! `Int128` <-> `Decimal128(38,0)`
9//! `Decimal75` <-> `S`
10//!
11//! Note: this converts `Int128` values to `Decimal128(38,0)`, which are backed by `i128`.
12//! This is because there is no `Int128` type in Arrow.
13//! This does not check that the values are less than 39 digits.
14//! However, the actual arrow backing `i128` is the correct value.
15use super::scalar_and_i256_conversions::{convert_i256_to_scalar, convert_scalar_to_i256};
16use crate::base::{
17    database::{OwnedColumn, OwnedTable, OwnedTableError},
18    map::IndexMap,
19    math::decimal::Precision,
20    posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestampError},
21    scalar::Scalar,
22};
23use alloc::sync::Arc;
24use arrow::{
25    array::{
26        ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, Int16Array,
27        Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray,
28        TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
29    },
30    datatypes::{i256, DataType, Schema, SchemaRef, TimeUnit as ArrowTimeUnit},
31    error::ArrowError,
32    record_batch::RecordBatch,
33};
34use snafu::Snafu;
35use sqlparser::ast::Ident;
36
37#[derive(Snafu, Debug)]
38#[non_exhaustive]
39/// Errors caused by conversions between Arrow and owned types.
40pub enum OwnedArrowConversionError {
41    /// This error occurs when trying to convert from an unsupported arrow type.
42    #[snafu(display(
43        "unsupported type: attempted conversion from ArrayRef of type {datatype} to OwnedColumn"
44    ))]
45    UnsupportedType {
46        /// The unsupported datatype
47        datatype: DataType,
48    },
49    /// This error occurs when trying to convert from a record batch with duplicate idents(e.g. `"a"` and `"A"`).
50    #[snafu(display("conversion resulted in duplicate idents"))]
51    DuplicateIdents,
52    /// This error occurs when creating an owned table fails, which should only occur when there are zero columns.
53    #[snafu(transparent)]
54    InvalidTable {
55        /// The underlying source error
56        source: OwnedTableError,
57    },
58    /// This error occurs when trying to convert from an Arrow array with nulls.
59    #[snafu(display("null values are not supported in OwnedColumn yet"))]
60    NullNotSupportedYet,
61    /// Using `TimeError` to handle all time-related errors
62    #[snafu(transparent)]
63    TimestampConversionError {
64        /// The underlying source error
65        source: PoSQLTimestampError,
66    },
67}
68
69/// # Panics
70///
71/// Will panic if setting precision and scale fails when converting `OwnedColumn::Int128`.
72/// Will panic if setting precision and scale fails when converting `OwnedColumn::Decimal75`.
73/// Will panic if trying to convert `OwnedColumn::Scalar`, as this conversion is not implemented
74impl<S: Scalar> From<OwnedColumn<S>> for ArrayRef {
75    fn from(value: OwnedColumn<S>) -> Self {
76        match value {
77            OwnedColumn::Boolean(col) => Arc::new(BooleanArray::from(col)),
78            OwnedColumn::Uint8(col) => Arc::new(UInt8Array::from(col)),
79            OwnedColumn::TinyInt(col) => Arc::new(Int8Array::from(col)),
80            OwnedColumn::SmallInt(col) => Arc::new(Int16Array::from(col)),
81            OwnedColumn::Int(col) => Arc::new(Int32Array::from(col)),
82            OwnedColumn::BigInt(col) => Arc::new(Int64Array::from(col)),
83            OwnedColumn::Int128(col) => Arc::new(
84                Decimal128Array::from(col)
85                    .with_precision_and_scale(38, 0)
86                    .unwrap(),
87            ),
88            OwnedColumn::Decimal75(precision, scale, col) => {
89                let converted_col: Vec<i256> = col.iter().map(convert_scalar_to_i256).collect();
90
91                Arc::new(
92                    Decimal256Array::from(converted_col)
93                        .with_precision_and_scale(precision.value(), scale)
94                        .unwrap(),
95                )
96            }
97            OwnedColumn::Scalar(_) => unimplemented!("Cannot convert Scalar type to arrow type"),
98            OwnedColumn::VarChar(col) => Arc::new(StringArray::from(col)),
99            OwnedColumn::VarBinary(col) => {
100                Arc::new(BinaryArray::from_iter_values(col.iter().map(Vec::as_slice)))
101            }
102            OwnedColumn::TimestampTZ(time_unit, _, col) => match time_unit {
103                PoSQLTimeUnit::Second => Arc::new(TimestampSecondArray::from(col)),
104                PoSQLTimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(col)),
105                PoSQLTimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(col)),
106                PoSQLTimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(col)),
107            },
108        }
109    }
110}
111
112impl<S: Scalar> TryFrom<OwnedTable<S>> for RecordBatch {
113    type Error = ArrowError;
114    fn try_from(value: OwnedTable<S>) -> Result<Self, Self::Error> {
115        if value.is_empty() {
116            Ok(RecordBatch::new_empty(SchemaRef::new(Schema::empty())))
117        } else {
118            RecordBatch::try_from_iter(
119                value
120                    .into_inner()
121                    .into_iter()
122                    .map(|(identifier, owned_column)| {
123                        (identifier.value, ArrayRef::from(owned_column))
124                    }),
125            )
126        }
127    }
128}
129
130impl<S: Scalar> TryFrom<ArrayRef> for OwnedColumn<S> {
131    type Error = OwnedArrowConversionError;
132    fn try_from(value: ArrayRef) -> Result<Self, Self::Error> {
133        Self::try_from(&value)
134    }
135}
136impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
137    type Error = OwnedArrowConversionError;
138
139    #[expect(clippy::too_many_lines)]
140    /// # Panics
141    ///
142    /// Will panic if downcasting fails for the following types:
143    /// - `BooleanArray` when converting from `DataType::Boolean`.
144    /// - `Int16Array` when converting from `DataType::Int16`.
145    /// - `Int32Array` when converting from `DataType::Int32`.
146    /// - `Int64Array` when converting from `DataType::Int64`.
147    /// - `Decimal128Array` when converting from `DataType::Decimal128(38, 0)`.
148    /// - `Decimal256Array` when converting from `DataType::Decimal256` if precision is less than or equal to 75.
149    /// - `StringArray` when converting from `DataType::Utf8`.
150    fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
151        match &value.data_type() {
152            // Arrow uses a bit-packed representation for booleans.
153            // Hence we need to unpack the bits to get the actual boolean values.
154            DataType::Boolean => Ok(Self::Boolean(
155                value
156                    .as_any()
157                    .downcast_ref::<BooleanArray>()
158                    .unwrap()
159                    .iter()
160                    .collect::<Option<Vec<bool>>>()
161                    .ok_or(OwnedArrowConversionError::NullNotSupportedYet)?,
162            )),
163            DataType::UInt8 => Ok(Self::Uint8(
164                value
165                    .as_any()
166                    .downcast_ref::<UInt8Array>()
167                    .unwrap()
168                    .values()
169                    .to_vec(),
170            )),
171            DataType::Int8 => Ok(Self::TinyInt(
172                value
173                    .as_any()
174                    .downcast_ref::<Int8Array>()
175                    .unwrap()
176                    .values()
177                    .to_vec(),
178            )),
179            DataType::Int16 => Ok(Self::SmallInt(
180                value
181                    .as_any()
182                    .downcast_ref::<Int16Array>()
183                    .unwrap()
184                    .values()
185                    .to_vec(),
186            )),
187            DataType::Int32 => Ok(Self::Int(
188                value
189                    .as_any()
190                    .downcast_ref::<Int32Array>()
191                    .unwrap()
192                    .values()
193                    .to_vec(),
194            )),
195            DataType::Int64 => Ok(Self::BigInt(
196                value
197                    .as_any()
198                    .downcast_ref::<Int64Array>()
199                    .unwrap()
200                    .values()
201                    .to_vec(),
202            )),
203            DataType::Decimal128(38, 0) => Ok(Self::Int128(
204                value
205                    .as_any()
206                    .downcast_ref::<Decimal128Array>()
207                    .unwrap()
208                    .values()
209                    .to_vec(),
210            )),
211            DataType::Decimal256(precision, scale) if *precision <= 75 => Ok(Self::Decimal75(
212                Precision::new(*precision).expect("precision is less than 76"),
213                *scale,
214                value
215                    .as_any()
216                    .downcast_ref::<Decimal256Array>()
217                    .unwrap()
218                    .values()
219                    .iter()
220                    .map(convert_i256_to_scalar)
221                    .map(Option::unwrap)
222                    .collect(),
223            )),
224            DataType::Utf8 => Ok(Self::VarChar(
225                value
226                    .as_any()
227                    .downcast_ref::<StringArray>()
228                    .unwrap()
229                    .iter()
230                    .map(|s| s.unwrap().to_string())
231                    .collect(),
232            )),
233            DataType::Binary => Ok(Self::VarBinary(
234                value
235                    .as_any()
236                    .downcast_ref::<BinaryArray>()
237                    .unwrap()
238                    .iter()
239                    .map(|s| s.map(<[u8]>::to_vec).unwrap())
240                    .collect(),
241            )),
242            DataType::Timestamp(time_unit, timezone) => match time_unit {
243                ArrowTimeUnit::Second => {
244                    let array = value
245                        .as_any()
246                        .downcast_ref::<TimestampSecondArray>()
247                        .expect(
248                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
249                        );
250                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
251                    Ok(OwnedColumn::TimestampTZ(
252                        PoSQLTimeUnit::Second,
253                        PoSQLTimeZone::try_from(timezone)?,
254                        timestamps,
255                    ))
256                }
257                ArrowTimeUnit::Millisecond => {
258                    let array = value
259                        .as_any()
260                        .downcast_ref::<TimestampMillisecondArray>()
261                        .expect(
262                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
263                        );
264                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
265                    Ok(OwnedColumn::TimestampTZ(
266                        PoSQLTimeUnit::Millisecond,
267                        PoSQLTimeZone::try_from(timezone)?,
268                        timestamps,
269                    ))
270                }
271                ArrowTimeUnit::Microsecond => {
272                    let array = value
273                        .as_any()
274                        .downcast_ref::<TimestampMicrosecondArray>()
275                        .expect(
276                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
277                        );
278                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
279                    Ok(OwnedColumn::TimestampTZ(
280                        PoSQLTimeUnit::Microsecond,
281                        PoSQLTimeZone::try_from(timezone)?,
282                        timestamps,
283                    ))
284                }
285                ArrowTimeUnit::Nanosecond => {
286                    let array = value
287                        .as_any()
288                        .downcast_ref::<TimestampNanosecondArray>()
289                        .expect(
290                            "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
291                        );
292                    let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
293                    Ok(OwnedColumn::TimestampTZ(
294                        PoSQLTimeUnit::Nanosecond,
295                        PoSQLTimeZone::try_from(timezone)?,
296                        timestamps,
297                    ))
298                }
299            },
300            &data_type => Err(OwnedArrowConversionError::UnsupportedType {
301                datatype: data_type.clone(),
302            }),
303        }
304    }
305}
306
307impl<S: Scalar> TryFrom<RecordBatch> for OwnedTable<S> {
308    type Error = OwnedArrowConversionError;
309    fn try_from(value: RecordBatch) -> Result<Self, Self::Error> {
310        let num_columns = value.num_columns();
311        let table: Result<IndexMap<_, _>, Self::Error> = value
312            .schema()
313            .fields()
314            .iter()
315            .zip(value.columns())
316            .map(|(field, array_ref)| {
317                let owned_column = OwnedColumn::try_from(array_ref)?;
318                let identifier = Ident::new(field.name());
319                Ok((identifier, owned_column))
320            })
321            .collect();
322        let owned_table = Self::try_new(table?)?;
323        if num_columns == owned_table.num_columns() {
324            Ok(owned_table)
325        } else {
326            Err(OwnedArrowConversionError::DuplicateIdents)
327        }
328    }
329}