1use super::scalar_and_i256_conversions::{convert_i256_to_scalar, convert_scalar_to_i256};
16use crate::base::{
17 database::{OwnedColumn, OwnedTable, OwnedTableError},
18 map::IndexMap,
19 math::decimal::Precision,
20 scalar::Scalar,
21};
22use alloc::sync::Arc;
23use arrow::{
24 array::{
25 ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, Int16Array,
26 Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray,
27 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
28 },
29 datatypes::{i256, DataType, Schema, SchemaRef, TimeUnit as ArrowTimeUnit},
30 error::ArrowError,
31 record_batch::RecordBatch,
32};
33use proof_of_sql_parser::{
34 posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestampError},
35 ParseError,
36};
37use snafu::Snafu;
38use sqlparser::ast::Ident;
39
40#[derive(Snafu, Debug)]
41#[non_exhaustive]
42pub enum OwnedArrowConversionError {
44 #[snafu(display(
46 "unsupported type: attempted conversion from ArrayRef of type {datatype} to OwnedColumn"
47 ))]
48 UnsupportedType {
49 datatype: DataType,
51 },
52 #[snafu(display("conversion resulted in duplicate idents"))]
54 DuplicateIdents,
55 #[snafu(transparent)]
57 FieldParseFail {
58 source: ParseError,
60 },
61 #[snafu(transparent)]
63 InvalidTable {
64 source: OwnedTableError,
66 },
67 #[snafu(display("null values are not supported in OwnedColumn yet"))]
69 NullNotSupportedYet,
70 #[snafu(transparent)]
72 TimestampConversionError {
73 source: PoSQLTimestampError,
75 },
76}
77
78impl<S: Scalar> From<OwnedColumn<S>> for ArrayRef {
84 fn from(value: OwnedColumn<S>) -> Self {
85 match value {
86 OwnedColumn::Boolean(col) => Arc::new(BooleanArray::from(col)),
87 OwnedColumn::Uint8(col) => Arc::new(UInt8Array::from(col)),
88 OwnedColumn::TinyInt(col) => Arc::new(Int8Array::from(col)),
89 OwnedColumn::SmallInt(col) => Arc::new(Int16Array::from(col)),
90 OwnedColumn::Int(col) => Arc::new(Int32Array::from(col)),
91 OwnedColumn::BigInt(col) => Arc::new(Int64Array::from(col)),
92 OwnedColumn::Int128(col) => Arc::new(
93 Decimal128Array::from(col)
94 .with_precision_and_scale(38, 0)
95 .unwrap(),
96 ),
97 OwnedColumn::Decimal75(precision, scale, col) => {
98 let converted_col: Vec<i256> = col.iter().map(convert_scalar_to_i256).collect();
99
100 Arc::new(
101 Decimal256Array::from(converted_col)
102 .with_precision_and_scale(precision.value(), scale)
103 .unwrap(),
104 )
105 }
106 OwnedColumn::Scalar(_) => unimplemented!("Cannot convert Scalar type to arrow type"),
107 OwnedColumn::VarChar(col) => Arc::new(StringArray::from(col)),
108 OwnedColumn::VarBinary(col) => {
109 Arc::new(BinaryArray::from_iter_values(col.iter().map(Vec::as_slice)))
110 }
111 OwnedColumn::TimestampTZ(time_unit, _, col) => match time_unit {
112 PoSQLTimeUnit::Second => Arc::new(TimestampSecondArray::from(col)),
113 PoSQLTimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(col)),
114 PoSQLTimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(col)),
115 PoSQLTimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(col)),
116 },
117 }
118 }
119}
120
121impl<S: Scalar> TryFrom<OwnedTable<S>> for RecordBatch {
122 type Error = ArrowError;
123 fn try_from(value: OwnedTable<S>) -> Result<Self, Self::Error> {
124 if value.is_empty() {
125 Ok(RecordBatch::new_empty(SchemaRef::new(Schema::empty())))
126 } else {
127 RecordBatch::try_from_iter(
128 value
129 .into_inner()
130 .into_iter()
131 .map(|(identifier, owned_column)| {
132 (identifier.value, ArrayRef::from(owned_column))
133 }),
134 )
135 }
136 }
137}
138
139impl<S: Scalar> TryFrom<ArrayRef> for OwnedColumn<S> {
140 type Error = OwnedArrowConversionError;
141 fn try_from(value: ArrayRef) -> Result<Self, Self::Error> {
142 Self::try_from(&value)
143 }
144}
145impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
146 type Error = OwnedArrowConversionError;
147
148 #[allow(clippy::too_many_lines)]
149 fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
160 match &value.data_type() {
161 DataType::Boolean => Ok(Self::Boolean(
164 value
165 .as_any()
166 .downcast_ref::<BooleanArray>()
167 .unwrap()
168 .iter()
169 .collect::<Option<Vec<bool>>>()
170 .ok_or(OwnedArrowConversionError::NullNotSupportedYet)?,
171 )),
172 DataType::UInt8 => Ok(Self::Uint8(
173 value
174 .as_any()
175 .downcast_ref::<UInt8Array>()
176 .unwrap()
177 .values()
178 .to_vec(),
179 )),
180 DataType::Int8 => Ok(Self::TinyInt(
181 value
182 .as_any()
183 .downcast_ref::<Int8Array>()
184 .unwrap()
185 .values()
186 .to_vec(),
187 )),
188 DataType::Int16 => Ok(Self::SmallInt(
189 value
190 .as_any()
191 .downcast_ref::<Int16Array>()
192 .unwrap()
193 .values()
194 .to_vec(),
195 )),
196 DataType::Int32 => Ok(Self::Int(
197 value
198 .as_any()
199 .downcast_ref::<Int32Array>()
200 .unwrap()
201 .values()
202 .to_vec(),
203 )),
204 DataType::Int64 => Ok(Self::BigInt(
205 value
206 .as_any()
207 .downcast_ref::<Int64Array>()
208 .unwrap()
209 .values()
210 .to_vec(),
211 )),
212 DataType::Decimal128(38, 0) => Ok(Self::Int128(
213 value
214 .as_any()
215 .downcast_ref::<Decimal128Array>()
216 .unwrap()
217 .values()
218 .to_vec(),
219 )),
220 DataType::Decimal256(precision, scale) if *precision <= 75 => Ok(Self::Decimal75(
221 Precision::new(*precision).expect("precision is less than 76"),
222 *scale,
223 value
224 .as_any()
225 .downcast_ref::<Decimal256Array>()
226 .unwrap()
227 .values()
228 .iter()
229 .map(convert_i256_to_scalar)
230 .map(Option::unwrap)
231 .collect(),
232 )),
233 DataType::Utf8 => Ok(Self::VarChar(
234 value
235 .as_any()
236 .downcast_ref::<StringArray>()
237 .unwrap()
238 .iter()
239 .map(|s| s.unwrap().to_string())
240 .collect(),
241 )),
242 DataType::Binary => Ok(Self::VarBinary(
243 value
244 .as_any()
245 .downcast_ref::<BinaryArray>()
246 .unwrap()
247 .iter()
248 .map(|s| s.map(<[u8]>::to_vec).unwrap())
249 .collect(),
250 )),
251 DataType::Timestamp(time_unit, timezone) => match time_unit {
252 ArrowTimeUnit::Second => {
253 let array = value
254 .as_any()
255 .downcast_ref::<TimestampSecondArray>()
256 .expect(
257 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
258 );
259 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
260 Ok(OwnedColumn::TimestampTZ(
261 PoSQLTimeUnit::Second,
262 PoSQLTimeZone::try_from(timezone)?,
263 timestamps,
264 ))
265 }
266 ArrowTimeUnit::Millisecond => {
267 let array = value
268 .as_any()
269 .downcast_ref::<TimestampMillisecondArray>()
270 .expect(
271 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
272 );
273 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
274 Ok(OwnedColumn::TimestampTZ(
275 PoSQLTimeUnit::Millisecond,
276 PoSQLTimeZone::try_from(timezone)?,
277 timestamps,
278 ))
279 }
280 ArrowTimeUnit::Microsecond => {
281 let array = value
282 .as_any()
283 .downcast_ref::<TimestampMicrosecondArray>()
284 .expect(
285 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
286 );
287 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
288 Ok(OwnedColumn::TimestampTZ(
289 PoSQLTimeUnit::Microsecond,
290 PoSQLTimeZone::try_from(timezone)?,
291 timestamps,
292 ))
293 }
294 ArrowTimeUnit::Nanosecond => {
295 let array = value
296 .as_any()
297 .downcast_ref::<TimestampNanosecondArray>()
298 .expect(
299 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
300 );
301 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
302 Ok(OwnedColumn::TimestampTZ(
303 PoSQLTimeUnit::Nanosecond,
304 PoSQLTimeZone::try_from(timezone)?,
305 timestamps,
306 ))
307 }
308 },
309 &data_type => Err(OwnedArrowConversionError::UnsupportedType {
310 datatype: data_type.clone(),
311 }),
312 }
313 }
314}
315
316impl<S: Scalar> TryFrom<RecordBatch> for OwnedTable<S> {
317 type Error = OwnedArrowConversionError;
318 fn try_from(value: RecordBatch) -> Result<Self, Self::Error> {
319 let num_columns = value.num_columns();
320 let table: Result<IndexMap<_, _>, Self::Error> = value
321 .schema()
322 .fields()
323 .iter()
324 .zip(value.columns())
325 .map(|(field, array_ref)| {
326 let owned_column = OwnedColumn::try_from(array_ref)?;
327 let identifier = Ident::new(field.name());
328 Ok((identifier, owned_column))
329 })
330 .collect();
331 let owned_table = Self::try_new(table?)?;
332 if num_columns == owned_table.num_columns() {
333 Ok(owned_table)
334 } else {
335 Err(OwnedArrowConversionError::DuplicateIdents)
336 }
337 }
338}