use super::scalar_and_i256_conversions::convert_scalar_to_i256;
use crate::base::{
database::{
scalar_and_i256_conversions::convert_i256_to_scalar, OwnedColumn, OwnedTable,
OwnedTableError,
},
math::decimal::Precision,
scalar::Scalar,
};
use arrow::{
array::{
ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array,
Int64Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{i256, DataType, Schema, SchemaRef, TimeUnit as ArrowTimeUnit},
error::ArrowError,
record_batch::RecordBatch,
};
use indexmap::IndexMap;
use proof_of_sql_parser::{
posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestampError},
Identifier, ParseError,
};
use std::sync::Arc;
use thiserror::Error;
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum OwnedArrowConversionError {
#[error("unsupported type: attempted conversion from ArrayRef of type {0} to OwnedColumn")]
UnsupportedType(DataType),
#[error("conversion resulted in duplicate identifiers")]
DuplicateIdentifiers,
#[error(transparent)]
FieldParseFail(#[from] ParseError),
#[error(transparent)]
InvalidTable(#[from] OwnedTableError),
#[error("null values are not supported in OwnedColumn yet")]
NullNotSupportedYet,
#[error(transparent)]
TimestampConversionError(#[from] PoSQLTimestampError),
}
impl<S: Scalar> From<OwnedColumn<S>> for ArrayRef {
fn from(value: OwnedColumn<S>) -> Self {
match value {
OwnedColumn::Boolean(col) => Arc::new(BooleanArray::from(col)),
OwnedColumn::SmallInt(col) => Arc::new(Int16Array::from(col)),
OwnedColumn::Int(col) => Arc::new(Int32Array::from(col)),
OwnedColumn::BigInt(col) => Arc::new(Int64Array::from(col)),
OwnedColumn::Int128(col) => Arc::new(
Decimal128Array::from(col)
.with_precision_and_scale(38, 0)
.unwrap(),
),
OwnedColumn::Decimal75(precision, scale, col) => {
let converted_col: Vec<i256> = col.iter().map(convert_scalar_to_i256).collect();
Arc::new(
Decimal256Array::from(converted_col)
.with_precision_and_scale(precision.value(), scale)
.unwrap(),
)
}
OwnedColumn::Scalar(_) => unimplemented!("Cannot convert Scalar type to arrow type"),
OwnedColumn::VarChar(col) => Arc::new(StringArray::from(col)),
OwnedColumn::TimestampTZ(time_unit, _, col) => match time_unit {
PoSQLTimeUnit::Second => Arc::new(TimestampSecondArray::from(col)),
PoSQLTimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(col)),
PoSQLTimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(col)),
PoSQLTimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(col)),
},
}
}
}
impl<S: Scalar> TryFrom<OwnedTable<S>> for RecordBatch {
type Error = ArrowError;
fn try_from(value: OwnedTable<S>) -> Result<Self, Self::Error> {
if value.is_empty() {
Ok(RecordBatch::new_empty(SchemaRef::new(Schema::empty())))
} else {
RecordBatch::try_from_iter(
value
.into_inner()
.into_iter()
.map(|(identifier, owned_column)| (identifier, ArrayRef::from(owned_column))),
)
}
}
}
impl<S: Scalar> TryFrom<ArrayRef> for OwnedColumn<S> {
type Error = OwnedArrowConversionError;
fn try_from(value: ArrayRef) -> Result<Self, Self::Error> {
Self::try_from(&value)
}
}
impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
type Error = OwnedArrowConversionError;
fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
match &value.data_type() {
DataType::Boolean => Ok(Self::Boolean(
value
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.iter()
.collect::<Option<Vec<bool>>>()
.ok_or(OwnedArrowConversionError::NullNotSupportedYet)?,
)),
DataType::Int16 => Ok(Self::SmallInt(
value
.as_any()
.downcast_ref::<Int16Array>()
.unwrap()
.values()
.to_vec(),
)),
DataType::Int32 => Ok(Self::Int(
value
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.values()
.to_vec(),
)),
DataType::Int64 => Ok(Self::BigInt(
value
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.values()
.to_vec(),
)),
DataType::Decimal128(38, 0) => Ok(Self::Int128(
value
.as_any()
.downcast_ref::<Decimal128Array>()
.unwrap()
.values()
.to_vec(),
)),
DataType::Decimal256(precision, scale) if *precision <= 75 => Ok(Self::Decimal75(
Precision::new(*precision).expect("precision is less than 76"),
*scale,
value
.as_any()
.downcast_ref::<Decimal256Array>()
.unwrap()
.values()
.iter()
.map(convert_i256_to_scalar)
.map(Option::unwrap)
.collect(),
)),
DataType::Utf8 => Ok(Self::VarChar(
value
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.map(|s| s.unwrap().to_string())
.collect(),
)),
DataType::Timestamp(time_unit, timezone) => match time_unit {
ArrowTimeUnit::Second => {
let array = value
.as_any()
.downcast_ref::<TimestampSecondArray>()
.expect(
"This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
);
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(timezone)?,
timestamps,
))
}
ArrowTimeUnit::Millisecond => {
let array = value
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect(
"This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
);
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(timezone)?,
timestamps,
))
}
ArrowTimeUnit::Microsecond => {
let array = value
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.expect(
"This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
);
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(timezone)?,
timestamps,
))
}
ArrowTimeUnit::Nanosecond => {
let array = value
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.expect(
"This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
);
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(timezone)?,
timestamps,
))
}
},
&data_type => Err(OwnedArrowConversionError::UnsupportedType(
data_type.clone(),
)),
}
}
}
impl<S: Scalar> TryFrom<RecordBatch> for OwnedTable<S> {
type Error = OwnedArrowConversionError;
fn try_from(value: RecordBatch) -> Result<Self, Self::Error> {
let num_columns = value.num_columns();
let table: Result<IndexMap<_, _>, Self::Error> = value
.schema()
.fields()
.iter()
.zip(value.columns())
.map(|(field, array_ref)| {
let owned_column = OwnedColumn::try_from(array_ref)?;
let identifier = Identifier::try_new(field.name())?; Ok((identifier, owned_column))
})
.collect();
let owned_table = Self::try_new(table?)?;
if num_columns == owned_table.num_columns() {
Ok(owned_table)
} else {
Err(OwnedArrowConversionError::DuplicateIdentifiers)
}
}
}