1use super::scalar_and_i256_conversions::{convert_i256_to_scalar, convert_scalar_to_i256};
16use crate::base::{
17 database::{OwnedColumn, OwnedTable, OwnedTableError},
18 map::IndexMap,
19 math::decimal::Precision,
20 posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestampError},
21 scalar::Scalar,
22};
23use alloc::sync::Arc;
24use arrow::{
25 array::{
26 ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, Int16Array,
27 Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray,
28 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
29 },
30 datatypes::{i256, DataType, Schema, SchemaRef, TimeUnit as ArrowTimeUnit},
31 error::ArrowError,
32 record_batch::RecordBatch,
33};
34use snafu::Snafu;
35use sqlparser::ast::Ident;
36
37#[derive(Snafu, Debug)]
38#[non_exhaustive]
39pub enum OwnedArrowConversionError {
41 #[snafu(display(
43 "unsupported type: attempted conversion from ArrayRef of type {datatype} to OwnedColumn"
44 ))]
45 UnsupportedType {
46 datatype: DataType,
48 },
49 #[snafu(display("conversion resulted in duplicate idents"))]
51 DuplicateIdents,
52 #[snafu(transparent)]
54 InvalidTable {
55 source: OwnedTableError,
57 },
58 #[snafu(display("null values are not supported in OwnedColumn yet"))]
60 NullNotSupportedYet,
61 #[snafu(transparent)]
63 TimestampConversionError {
64 source: PoSQLTimestampError,
66 },
67}
68
69impl<S: Scalar> From<OwnedColumn<S>> for ArrayRef {
75 fn from(value: OwnedColumn<S>) -> Self {
76 match value {
77 OwnedColumn::Boolean(col) => Arc::new(BooleanArray::from(col)),
78 OwnedColumn::Uint8(col) => Arc::new(UInt8Array::from(col)),
79 OwnedColumn::TinyInt(col) => Arc::new(Int8Array::from(col)),
80 OwnedColumn::SmallInt(col) => Arc::new(Int16Array::from(col)),
81 OwnedColumn::Int(col) => Arc::new(Int32Array::from(col)),
82 OwnedColumn::BigInt(col) => Arc::new(Int64Array::from(col)),
83 OwnedColumn::Int128(col) => Arc::new(
84 Decimal128Array::from(col)
85 .with_precision_and_scale(38, 0)
86 .unwrap(),
87 ),
88 OwnedColumn::Decimal75(precision, scale, col) => {
89 let converted_col: Vec<i256> = col.iter().map(convert_scalar_to_i256).collect();
90
91 Arc::new(
92 Decimal256Array::from(converted_col)
93 .with_precision_and_scale(precision.value(), scale)
94 .unwrap(),
95 )
96 }
97 OwnedColumn::Scalar(_) => unimplemented!("Cannot convert Scalar type to arrow type"),
98 OwnedColumn::VarChar(col) => Arc::new(StringArray::from(col)),
99 OwnedColumn::VarBinary(col) => {
100 Arc::new(BinaryArray::from_iter_values(col.iter().map(Vec::as_slice)))
101 }
102 OwnedColumn::TimestampTZ(time_unit, _, col) => match time_unit {
103 PoSQLTimeUnit::Second => Arc::new(TimestampSecondArray::from(col)),
104 PoSQLTimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(col)),
105 PoSQLTimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(col)),
106 PoSQLTimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(col)),
107 },
108 }
109 }
110}
111
112impl<S: Scalar> TryFrom<OwnedTable<S>> for RecordBatch {
113 type Error = ArrowError;
114 fn try_from(value: OwnedTable<S>) -> Result<Self, Self::Error> {
115 if value.is_empty() {
116 Ok(RecordBatch::new_empty(SchemaRef::new(Schema::empty())))
117 } else {
118 RecordBatch::try_from_iter(
119 value
120 .into_inner()
121 .into_iter()
122 .map(|(identifier, owned_column)| {
123 (identifier.value, ArrayRef::from(owned_column))
124 }),
125 )
126 }
127 }
128}
129
130impl<S: Scalar> TryFrom<ArrayRef> for OwnedColumn<S> {
131 type Error = OwnedArrowConversionError;
132 fn try_from(value: ArrayRef) -> Result<Self, Self::Error> {
133 Self::try_from(&value)
134 }
135}
136impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
137 type Error = OwnedArrowConversionError;
138
139 #[expect(clippy::too_many_lines)]
140 fn try_from(value: &ArrayRef) -> Result<Self, Self::Error> {
151 match &value.data_type() {
152 DataType::Boolean => Ok(Self::Boolean(
155 value
156 .as_any()
157 .downcast_ref::<BooleanArray>()
158 .unwrap()
159 .iter()
160 .collect::<Option<Vec<bool>>>()
161 .ok_or(OwnedArrowConversionError::NullNotSupportedYet)?,
162 )),
163 DataType::UInt8 => Ok(Self::Uint8(
164 value
165 .as_any()
166 .downcast_ref::<UInt8Array>()
167 .unwrap()
168 .values()
169 .to_vec(),
170 )),
171 DataType::Int8 => Ok(Self::TinyInt(
172 value
173 .as_any()
174 .downcast_ref::<Int8Array>()
175 .unwrap()
176 .values()
177 .to_vec(),
178 )),
179 DataType::Int16 => Ok(Self::SmallInt(
180 value
181 .as_any()
182 .downcast_ref::<Int16Array>()
183 .unwrap()
184 .values()
185 .to_vec(),
186 )),
187 DataType::Int32 => Ok(Self::Int(
188 value
189 .as_any()
190 .downcast_ref::<Int32Array>()
191 .unwrap()
192 .values()
193 .to_vec(),
194 )),
195 DataType::Int64 => Ok(Self::BigInt(
196 value
197 .as_any()
198 .downcast_ref::<Int64Array>()
199 .unwrap()
200 .values()
201 .to_vec(),
202 )),
203 DataType::Decimal128(38, 0) => Ok(Self::Int128(
204 value
205 .as_any()
206 .downcast_ref::<Decimal128Array>()
207 .unwrap()
208 .values()
209 .to_vec(),
210 )),
211 DataType::Decimal256(precision, scale) if *precision <= 75 => Ok(Self::Decimal75(
212 Precision::new(*precision).expect("precision is less than 76"),
213 *scale,
214 value
215 .as_any()
216 .downcast_ref::<Decimal256Array>()
217 .unwrap()
218 .values()
219 .iter()
220 .map(convert_i256_to_scalar)
221 .map(Option::unwrap)
222 .collect(),
223 )),
224 DataType::Utf8 => Ok(Self::VarChar(
225 value
226 .as_any()
227 .downcast_ref::<StringArray>()
228 .unwrap()
229 .iter()
230 .map(|s| s.unwrap().to_string())
231 .collect(),
232 )),
233 DataType::Binary => Ok(Self::VarBinary(
234 value
235 .as_any()
236 .downcast_ref::<BinaryArray>()
237 .unwrap()
238 .iter()
239 .map(|s| s.map(<[u8]>::to_vec).unwrap())
240 .collect(),
241 )),
242 DataType::Timestamp(time_unit, timezone) => match time_unit {
243 ArrowTimeUnit::Second => {
244 let array = value
245 .as_any()
246 .downcast_ref::<TimestampSecondArray>()
247 .expect(
248 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
249 );
250 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
251 Ok(OwnedColumn::TimestampTZ(
252 PoSQLTimeUnit::Second,
253 PoSQLTimeZone::try_from(timezone)?,
254 timestamps,
255 ))
256 }
257 ArrowTimeUnit::Millisecond => {
258 let array = value
259 .as_any()
260 .downcast_ref::<TimestampMillisecondArray>()
261 .expect(
262 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
263 );
264 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
265 Ok(OwnedColumn::TimestampTZ(
266 PoSQLTimeUnit::Millisecond,
267 PoSQLTimeZone::try_from(timezone)?,
268 timestamps,
269 ))
270 }
271 ArrowTimeUnit::Microsecond => {
272 let array = value
273 .as_any()
274 .downcast_ref::<TimestampMicrosecondArray>()
275 .expect(
276 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
277 );
278 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
279 Ok(OwnedColumn::TimestampTZ(
280 PoSQLTimeUnit::Microsecond,
281 PoSQLTimeZone::try_from(timezone)?,
282 timestamps,
283 ))
284 }
285 ArrowTimeUnit::Nanosecond => {
286 let array = value
287 .as_any()
288 .downcast_ref::<TimestampNanosecondArray>()
289 .expect(
290 "This cannot fail, all Arrow TimeUnits are mapped to PoSQL TimeUnits",
291 );
292 let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
293 Ok(OwnedColumn::TimestampTZ(
294 PoSQLTimeUnit::Nanosecond,
295 PoSQLTimeZone::try_from(timezone)?,
296 timestamps,
297 ))
298 }
299 },
300 &data_type => Err(OwnedArrowConversionError::UnsupportedType {
301 datatype: data_type.clone(),
302 }),
303 }
304 }
305}
306
307impl<S: Scalar> TryFrom<RecordBatch> for OwnedTable<S> {
308 type Error = OwnedArrowConversionError;
309 fn try_from(value: RecordBatch) -> Result<Self, Self::Error> {
310 let num_columns = value.num_columns();
311 let table: Result<IndexMap<_, _>, Self::Error> = value
312 .schema()
313 .fields()
314 .iter()
315 .zip(value.columns())
316 .map(|(field, array_ref)| {
317 let owned_column = OwnedColumn::try_from(array_ref)?;
318 let identifier = Ident::new(field.name());
319 Ok((identifier, owned_column))
320 })
321 .collect();
322 let owned_table = Self::try_new(table?)?;
323 if num_columns == owned_table.num_columns() {
324 Ok(owned_table)
325 } else {
326 Err(OwnedArrowConversionError::DuplicateIdents)
327 }
328 }
329}