use ordered_float::OrderedFloat;
use std::convert::{From, TryFrom};
use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub(crate) enum GroupByScalar {
Float32(OrderedFloat<f32>),
Float64(OrderedFloat<f64>),
UInt8(u8),
UInt16(u16),
UInt32(u32),
UInt64(u64),
Int8(i8),
Int16(i16),
Int32(i32),
Int64(i64),
Utf8(Box<String>),
Boolean(bool),
TimeMillisecond(i64),
TimeMicrosecond(i64),
TimeNanosecond(i64),
Date32(i32),
}
impl TryFrom<&ScalarValue> for GroupByScalar {
type Error = DataFusionError;
fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
Ok(match scalar_value {
ScalarValue::Float32(Some(v)) => {
GroupByScalar::Float32(OrderedFloat::from(*v))
}
ScalarValue::Float64(Some(v)) => {
GroupByScalar::Float64(OrderedFloat::from(*v))
}
ScalarValue::Boolean(Some(v)) => GroupByScalar::Boolean(*v),
ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
ScalarValue::Int64(Some(v)) => GroupByScalar::Int64(*v),
ScalarValue::UInt8(Some(v)) => GroupByScalar::UInt8(*v),
ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v),
ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
ScalarValue::TimestampMillisecond(Some(v)) => {
GroupByScalar::TimeMillisecond(*v)
}
ScalarValue::TimestampMicrosecond(Some(v)) => {
GroupByScalar::TimeMicrosecond(*v)
}
ScalarValue::TimestampNanosecond(Some(v)) => {
GroupByScalar::TimeNanosecond(*v)
}
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Boolean(None)
| ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
| ScalarValue::UInt8(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt32(None)
| ScalarValue::UInt64(None)
| ScalarValue::Utf8(None) => {
return Err(DataFusionError::Internal(format!(
"Cannot convert a ScalarValue holding NULL ({:?})",
scalar_value
)));
}
v => {
return Err(DataFusionError::Internal(format!(
"Cannot convert a ScalarValue with associated DataType {:?}",
v.get_datatype()
)))
}
})
}
}
impl From<&GroupByScalar> for ScalarValue {
fn from(group_by_scalar: &GroupByScalar) -> Self {
match group_by_scalar {
GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)),
GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
GroupByScalar::Int64(v) => ScalarValue::Int64(Some(*v)),
GroupByScalar::UInt8(v) => ScalarValue::UInt8(Some(*v)),
GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)),
GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
GroupByScalar::TimeMillisecond(v) => {
ScalarValue::TimestampMillisecond(Some(*v))
}
GroupByScalar::TimeMicrosecond(v) => {
ScalarValue::TimestampMicrosecond(Some(*v))
}
GroupByScalar::TimeNanosecond(v) => {
ScalarValue::TimestampNanosecond(Some(*v))
}
GroupByScalar::Date32(v) => ScalarValue::Date32(Some(*v)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::DataFusionError;
macro_rules! scalar_eq_test {
($TYPE:expr, $VALUE:expr) => {{
let scalar_value = $TYPE($VALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();
let scalar_value = $TYPE($VALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();
assert_eq!(a, b);
}};
}
#[test]
fn test_scalar_ne_non_std() {
scalar_eq_test!(ScalarValue::Float32, Some(1.0));
scalar_eq_test!(ScalarValue::Float64, Some(1.0));
}
macro_rules! scalar_ne_test {
($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
let scalar_value = $TYPE($LVALUE);
let a = GroupByScalar::try_from(&scalar_value).unwrap();
let scalar_value = $TYPE($RVALUE);
let b = GroupByScalar::try_from(&scalar_value).unwrap();
assert_ne!(a, b);
}};
}
#[test]
fn test_scalar_eq_non_std() {
scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));
}
#[test]
fn from_scalar_holding_none() {
let scalar_value = ScalarValue::Int8(None);
let result = GroupByScalar::try_from(&scalar_value);
match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from("Cannot convert a ScalarValue holding NULL (Int8(NULL))")
),
_ => panic!("Unexpected result"),
}
}
#[test]
fn from_scalar_unsupported() {
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
let result = GroupByScalar::try_from(&scalar_value);
match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from(
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
)
),
_ => panic!("Unexpected result"),
}
}
#[test]
fn size_of_group_by_scalar() {
assert_eq!(std::mem::size_of::<GroupByScalar>(), 16);
}
}