use super::{Type, Value};
use crate::types::{FromSqlError, FromSqlResult, OrderedMap};
use crate::Row;
use rust_decimal::prelude::*;
use arrow::{
array::{
Array, ArrayRef, DictionaryArray, FixedSizeListArray, LargeListArray, ListArray, MapArray, StringArray,
StructArray, UnionArray,
},
datatypes::{UInt8Type, UInt16Type, UInt32Type},
};
#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum TimeUnit {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
impl TimeUnit {
pub fn to_micros(&self, value: i64) -> i64 {
match self {
Self::Second => value * 1_000_000,
Self::Millisecond => value * 1000,
Self::Microsecond => value,
Self::Nanosecond => value / 1000,
}
}
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum ValueRef<'a> {
Null,
Boolean(bool),
TinyInt(i8),
SmallInt(i16),
Int(i32),
BigInt(i64),
HugeInt(i128),
UTinyInt(u8),
USmallInt(u16),
UInt(u32),
UBigInt(u64),
Float(f32),
Double(f64),
Decimal(Decimal),
Timestamp(TimeUnit, i64),
Text(&'a [u8]),
Blob(&'a [u8]),
Date32(i32),
Time64(TimeUnit, i64),
Interval {
months: i32,
days: i32,
nanos: i64,
},
List(ListType<'a>, usize),
Enum(EnumType<'a>, usize),
Struct(&'a StructArray, usize),
Array(&'a FixedSizeListArray, usize),
Map(&'a MapArray, usize),
Union(&'a ArrayRef, usize),
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ListType<'a> {
Regular(&'a ListArray),
Large(&'a LargeListArray),
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum EnumType<'a> {
UInt8(&'a DictionaryArray<UInt8Type>),
UInt16(&'a DictionaryArray<UInt16Type>),
UInt32(&'a DictionaryArray<UInt32Type>),
}
impl ValueRef<'_> {
#[inline]
pub fn data_type(&self) -> Type {
match *self {
ValueRef::Null => Type::Null,
ValueRef::Boolean(_) => Type::Boolean,
ValueRef::TinyInt(_) => Type::TinyInt,
ValueRef::SmallInt(_) => Type::SmallInt,
ValueRef::Int(_) => Type::Int,
ValueRef::BigInt(_) => Type::BigInt,
ValueRef::HugeInt(_) => Type::HugeInt,
ValueRef::UTinyInt(_) => Type::UTinyInt,
ValueRef::USmallInt(_) => Type::USmallInt,
ValueRef::UInt(_) => Type::UInt,
ValueRef::UBigInt(_) => Type::UBigInt,
ValueRef::Float(_) => Type::Float,
ValueRef::Double(_) => Type::Double,
ValueRef::Decimal(_) => Type::Decimal,
ValueRef::Timestamp(..) => Type::Timestamp,
ValueRef::Text(_) => Type::Text,
ValueRef::Blob(_) => Type::Blob,
ValueRef::Date32(_) => Type::Date32,
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
ValueRef::Struct(arr, _) => arr.data_type().into(),
ValueRef::Map(arr, _) => arr.data_type().into(),
ValueRef::Array(arr, _) => arr.data_type().into(),
ValueRef::List(arr, _) => match arr {
ListType::Large(arr) => arr.data_type().into(),
ListType::Regular(arr) => arr.data_type().into(),
},
ValueRef::Enum(..) => Type::Enum,
ValueRef::Union(arr, _) => arr.data_type().into(),
}
}
pub fn to_owned(&self) -> Value {
(*self).into()
}
}
impl<'a> ValueRef<'a> {
#[inline]
pub fn as_str(&self) -> FromSqlResult<&'a str> {
match *self {
ValueRef::Text(t) => std::str::from_utf8(t).map_err(|e| FromSqlError::Other(Box::new(e))),
ValueRef::Enum(ref enum_type, idx) => {
let (values, key) = match enum_type {
EnumType::UInt8(arr) => (arr.values(), arr.key(idx)),
EnumType::UInt16(arr) => (arr.values(), arr.key(idx)),
EnumType::UInt32(arr) => (arr.values(), arr.key(idx)),
};
let dict_key = key.ok_or(FromSqlError::InvalidType)?;
let string_array = values
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| FromSqlError::Other("enum dictionary values are not strings".into()))?;
if dict_key >= string_array.len() {
return Err(FromSqlError::Other(
format!("enum key {} out of bounds (len {})", dict_key, string_array.len()).into(),
));
}
Ok(string_array.value(dict_key))
}
_ => Err(FromSqlError::InvalidType),
}
}
#[inline]
pub fn as_blob(&self) -> FromSqlResult<&'a [u8]> {
match *self {
ValueRef::Blob(b) => Ok(b),
ValueRef::Text(t) => Ok(t),
_ => Err(FromSqlError::InvalidType),
}
}
}
impl From<ValueRef<'_>> for Value {
#[inline]
fn from(borrowed: ValueRef<'_>) -> Self {
match borrowed {
ValueRef::Null => Self::Null,
ValueRef::Boolean(i) => Self::Boolean(i),
ValueRef::TinyInt(i) => Self::TinyInt(i),
ValueRef::SmallInt(i) => Self::SmallInt(i),
ValueRef::Int(i) => Self::Int(i),
ValueRef::BigInt(i) => Self::BigInt(i),
ValueRef::HugeInt(i) => Self::HugeInt(i),
ValueRef::UTinyInt(i) => Self::UTinyInt(i),
ValueRef::USmallInt(i) => Self::USmallInt(i),
ValueRef::UInt(i) => Self::UInt(i),
ValueRef::UBigInt(i) => Self::UBigInt(i),
ValueRef::Float(i) => Self::Float(i),
ValueRef::Double(i) => Self::Double(i),
ValueRef::Decimal(i) => Self::Decimal(i),
ValueRef::Timestamp(tu, t) => Self::Timestamp(tu, t),
ValueRef::Text(s) => {
let s = std::str::from_utf8(s).expect("invalid UTF-8");
Self::Text(s.to_string())
}
ValueRef::Blob(b) => Self::Blob(b.to_vec()),
ValueRef::Date32(d) => Self::Date32(d),
ValueRef::Time64(t, d) => Self::Time64(t, d),
ValueRef::Interval { months, days, nanos } => Self::Interval { months, days, nanos },
ValueRef::List(items, idx) => match items {
ListType::Regular(items) => {
let offsets = items.offsets();
from_list(
offsets[idx].try_into().unwrap(),
offsets[idx + 1].try_into().unwrap(),
idx,
items.values(),
)
}
ListType::Large(items) => {
let offsets = items.offsets();
from_list(
offsets[idx].try_into().unwrap(),
offsets[idx + 1].try_into().unwrap(),
idx,
items.values(),
)
}
},
ValueRef::Enum(items, idx) => {
let dict_values = match items {
EnumType::UInt8(res) => res.values(),
EnumType::UInt16(res) => res.values(),
EnumType::UInt32(res) => res.values(),
}
.as_any()
.downcast_ref::<StringArray>()
.expect("Enum value is not a string");
let dict_key = match items {
EnumType::UInt8(res) => res.key(idx),
EnumType::UInt16(res) => res.key(idx),
EnumType::UInt32(res) => res.key(idx),
}
.unwrap();
Self::Enum(dict_values.value(dict_key).to_string())
}
ValueRef::Struct(items, idx) => {
let capacity = items.columns().len();
let mut value = Vec::with_capacity(capacity);
value.extend(
items
.columns()
.iter()
.zip(items.fields().iter().map(|f| f.name().to_owned()))
.map(|(column, name)| -> (String, Self) {
(name, Row::value_ref_internal(idx, 0, column).to_owned())
}),
);
Self::Struct(OrderedMap::from(value))
}
ValueRef::Map(arr, idx) => {
let keys = arr.keys();
let values = arr.values();
let offsets = arr.offsets();
let range = offsets[idx]..offsets[idx + 1];
let capacity = range.len();
let mut map_vec = Vec::with_capacity(capacity);
map_vec.extend(range.map(|row| {
let row = row.try_into().unwrap();
let key = Row::value_ref_internal(row, idx, keys).to_owned();
let value = Row::value_ref_internal(row, idx, values).to_owned();
(key, value)
}));
Self::Map(OrderedMap::from(map_vec))
}
ValueRef::Array(items, idx) => {
let value_length = usize::try_from(items.value_length()).unwrap();
let range = (idx * value_length)..((idx + 1) * value_length);
let capacity = value_length;
let mut array_vec = Vec::with_capacity(capacity);
array_vec.extend(range.map(|row| Row::value_ref_internal(row, idx, items.values()).to_owned()));
Self::Array(array_vec)
}
ValueRef::Union(column, idx) => {
let column = column.as_any().downcast_ref::<UnionArray>().unwrap();
let type_id = column.type_id(idx);
let value_offset = column.value_offset(idx);
let tag = Row::value_ref_internal(idx, value_offset, column.child(type_id));
Self::Union(Box::new(tag.to_owned()))
}
}
}
}
fn from_list(start: usize, end: usize, idx: usize, values: &ArrayRef) -> Value {
let capacity = end - start;
let mut list_vec = Vec::with_capacity(capacity);
list_vec.extend((start..end).map(|row| Row::value_ref_internal(row, idx, values).to_owned()));
Value::List(list_vec)
}
impl<'a> From<&'a str> for ValueRef<'a> {
#[inline]
fn from(s: &str) -> ValueRef<'_> {
ValueRef::Text(s.as_bytes())
}
}
impl<'a> From<&'a [u8]> for ValueRef<'a> {
#[inline]
fn from(s: &[u8]) -> ValueRef<'_> {
ValueRef::Blob(s)
}
}
impl<'a> From<&'a Value> for ValueRef<'a> {
#[inline]
fn from(value: &'a Value) -> Self {
match *value {
Value::Null => ValueRef::Null,
Value::Boolean(i) => ValueRef::Boolean(i),
Value::TinyInt(i) => ValueRef::TinyInt(i),
Value::SmallInt(i) => ValueRef::SmallInt(i),
Value::Int(i) => ValueRef::Int(i),
Value::BigInt(i) => ValueRef::BigInt(i),
Value::HugeInt(i) => ValueRef::HugeInt(i),
Value::UTinyInt(i) => ValueRef::UTinyInt(i),
Value::USmallInt(i) => ValueRef::USmallInt(i),
Value::UInt(i) => ValueRef::UInt(i),
Value::UBigInt(i) => ValueRef::UBigInt(i),
Value::Float(i) => ValueRef::Float(i),
Value::Double(i) => ValueRef::Double(i),
Value::Decimal(i) => ValueRef::Decimal(i),
Value::Timestamp(tu, t) => ValueRef::Timestamp(tu, t),
Value::Text(ref s) => ValueRef::Text(s.as_bytes()),
Value::Blob(ref b) => ValueRef::Blob(b),
Value::Date32(d) => ValueRef::Date32(d),
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
Value::Enum(..) => todo!(),
Value::List(..) | Value::Struct(..) | Value::Map(..) | Value::Array(..) | Value::Union(..) => {
unimplemented!()
}
}
}
}
impl<'a, T> From<Option<T>> for ValueRef<'a>
where
T: Into<Self>,
{
#[inline]
fn from(s: Option<T>) -> Self {
match s {
Some(x) => x.into(),
None => ValueRef::Null,
}
}
}
#[cfg(test)]
mod tests {
use crate::types::Type;
use crate::{Connection, Result};
#[test]
fn test_list_types() -> Result<()> {
let conn = Connection::open_in_memory()?;
conn.execute(
"CREATE TABLE test_table (float_list FLOAT[], double_list DOUBLE[], int_list INT[])",
[],
)?;
conn.execute("INSERT INTO test_table VALUES ([1.5, 2.5], [3.5, 4.5], [1, 2])", [])?;
let mut stmt = conn.prepare("SELECT float_list, double_list, int_list FROM test_table")?;
let mut rows = stmt.query([])?;
let row = rows.next()?.unwrap();
let float_list = row.get_ref_unwrap(0);
assert!(
matches!(float_list.data_type(), Type::List(ref inner_type) if **inner_type == Type::Float),
"Expected Type::List(Type::Float), got {:?}",
float_list.data_type()
);
let double_list = row.get_ref_unwrap(1);
assert!(
matches!(double_list.data_type(), Type::List(ref inner_type) if **inner_type == Type::Double),
"Expected Type::List(Type::Double), got {:?}",
double_list.data_type()
);
let int_list = row.get_ref_unwrap(2);
assert!(
matches!(int_list.data_type(), Type::List(ref inner_type) if **inner_type == Type::Int),
"Expected Type::List(Type::Int), got {:?}",
int_list.data_type()
);
Ok(())
}
}