use chrono::{DateTime, NaiveDate, NaiveDateTime, TimeZone, Utc};
use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use crate::schema::{ArrayType, DataType, PrimitiveType, StructField};
use crate::utils::require;
use crate::{DeltaResult, Error};
#[derive(Debug, Clone, PartialEq)]
pub struct StructData {
fields: Vec<StructField>,
values: Vec<Scalar>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct ArrayData {
tpe: ArrayType,
elements: Vec<Scalar>,
}
impl ArrayData {
#[cfg_attr(feature = "developer-visibility", visibility::make(pub))]
pub fn new(tpe: ArrayType, elements: impl IntoIterator<Item = impl Into<Scalar>>) -> Self {
let elements = elements.into_iter().map(Into::into).collect();
Self { tpe, elements }
}
pub fn array_type(&self) -> &ArrayType {
&self.tpe
}
#[deprecated(
note = "These fields will be removed eventually and are unstable. See https://github.com/delta-io/delta-kernel-rs/issues/291"
)]
pub fn array_elements(&self) -> &[Scalar] {
&self.elements
}
}
impl StructData {
pub fn try_new(fields: Vec<StructField>, values: Vec<Scalar>) -> DeltaResult<Self> {
require!(
fields.len() == values.len(),
Error::invalid_struct_data(format!(
"Incorrect number of values for Struct fields, expected {} got {}",
fields.len(),
values.len()
))
);
for (f, a) in fields.iter().zip(&values) {
require!(
f.data_type() == &a.data_type(),
Error::invalid_struct_data(format!(
"Incorrect datatype for Struct field {:?}, expected {} got {}",
f.name(),
f.data_type(),
a.data_type()
))
);
require!(
f.is_nullable() || !a.is_null(),
Error::invalid_struct_data(format!(
"Value for non-nullable field {:?} cannot be null, got {}",
f.name(),
a
))
);
}
Ok(Self { fields, values })
}
pub fn fields(&self) -> &[StructField] {
&self.fields
}
pub fn values(&self) -> &[Scalar] {
&self.values
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Scalar {
Integer(i32),
Long(i64),
Short(i16),
Byte(i8),
Float(f32),
Double(f64),
String(String),
Boolean(bool),
Timestamp(i64),
TimestampNtz(i64),
Date(i32),
Binary(Vec<u8>),
Decimal(i128, u8, u8),
Null(DataType),
Struct(StructData),
Array(ArrayData),
}
impl Scalar {
pub fn data_type(&self) -> DataType {
match self {
Self::Integer(_) => DataType::INTEGER,
Self::Long(_) => DataType::LONG,
Self::Short(_) => DataType::SHORT,
Self::Byte(_) => DataType::BYTE,
Self::Float(_) => DataType::FLOAT,
Self::Double(_) => DataType::DOUBLE,
Self::String(_) => DataType::STRING,
Self::Boolean(_) => DataType::BOOLEAN,
Self::Timestamp(_) => DataType::TIMESTAMP,
Self::TimestampNtz(_) => DataType::TIMESTAMP_NTZ,
Self::Date(_) => DataType::DATE,
Self::Binary(_) => DataType::BINARY,
Self::Decimal(_, precision, scale) => DataType::decimal_unchecked(*precision, *scale),
Self::Null(data_type) => data_type.clone(),
Self::Struct(data) => DataType::struct_type(data.fields.clone()),
Self::Array(data) => data.tpe.clone().into(),
}
}
pub fn is_null(&self) -> bool {
matches!(self, Self::Null(_))
}
pub(crate) fn timestamp_ntz_from_millis(millis: i64) -> DeltaResult<Self> {
let Some(timestamp) = DateTime::from_timestamp_millis(millis) else {
return Err(Error::generic(format!(
"Failed to create millisecond timestamp from {millis}"
)));
};
Ok(Self::TimestampNtz(timestamp.timestamp_micros()))
}
}
impl Display for Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Integer(i) => write!(f, "{}", i),
Self::Long(i) => write!(f, "{}", i),
Self::Short(i) => write!(f, "{}", i),
Self::Byte(i) => write!(f, "{}", i),
Self::Float(fl) => write!(f, "{}", fl),
Self::Double(fl) => write!(f, "{}", fl),
Self::String(s) => write!(f, "'{}'", s),
Self::Boolean(b) => write!(f, "{}", b),
Self::Timestamp(ts) => write!(f, "{}", ts),
Self::TimestampNtz(ts) => write!(f, "{}", ts),
Self::Date(d) => write!(f, "{}", d),
Self::Binary(b) => write!(f, "{:?}", b),
Self::Decimal(value, _, scale) => match scale.cmp(&0) {
Ordering::Equal => {
write!(f, "{}", value)
}
Ordering::Greater => {
let scalar_multiple = 10_i128.pow(*scale as u32);
write!(f, "{}", value / scalar_multiple)?;
write!(f, ".")?;
write!(
f,
"{:0>scale$}",
value % scalar_multiple,
scale = *scale as usize
)
}
Ordering::Less => {
write!(f, "{}", value)?;
for _ in 0..*scale {
write!(f, "0")?;
}
Ok(())
}
},
Self::Null(_) => write!(f, "null"),
Self::Struct(data) => {
write!(f, "{{")?;
let mut delim = "";
for (value, field) in data.values.iter().zip(data.fields.iter()) {
write!(f, "{delim}{}: {value}", field.name)?;
delim = ", ";
}
write!(f, "}}")
}
Self::Array(data) => {
write!(f, "(")?;
let mut delim = "";
for element in &data.elements {
write!(f, "{delim}{element}")?;
delim = ", ";
}
write!(f, ")")
}
}
}
}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use Scalar::*;
match (self, other) {
(Integer(a), Integer(b)) => a.partial_cmp(b),
(Integer(_), _) => None,
(Long(a), Long(b)) => a.partial_cmp(b),
(Long(_), _) => None,
(Short(a), Short(b)) => a.partial_cmp(b),
(Short(_), _) => None,
(Byte(a), Byte(b)) => a.partial_cmp(b),
(Byte(_), _) => None,
(Float(a), Float(b)) => a.partial_cmp(b),
(Float(_), _) => None,
(Double(a), Double(b)) => a.partial_cmp(b),
(Double(_), _) => None,
(String(a), String(b)) => a.partial_cmp(b),
(String(_), _) => None,
(Boolean(a), Boolean(b)) => a.partial_cmp(b),
(Boolean(_), _) => None,
(Timestamp(a), Timestamp(b)) => a.partial_cmp(b),
(Timestamp(_), _) => None,
(TimestampNtz(a), TimestampNtz(b)) => a.partial_cmp(b),
(TimestampNtz(_), _) => None,
(Date(a), Date(b)) => a.partial_cmp(b),
(Date(_), _) => None,
(Binary(a), Binary(b)) => a.partial_cmp(b),
(Binary(_), _) => None,
(Decimal(_, _, _), _) => None, (Null(_), _) => None, (Struct(_), _) => None, (Array(_), _) => None, }
}
}
impl From<i8> for Scalar {
fn from(i: i8) -> Self {
Self::Byte(i)
}
}
impl From<i16> for Scalar {
fn from(i: i16) -> Self {
Self::Short(i)
}
}
impl From<i32> for Scalar {
fn from(i: i32) -> Self {
Self::Integer(i)
}
}
impl From<i64> for Scalar {
fn from(i: i64) -> Self {
Self::Long(i)
}
}
impl From<f32> for Scalar {
fn from(i: f32) -> Self {
Self::Float(i)
}
}
impl From<f64> for Scalar {
fn from(i: f64) -> Self {
Self::Double(i)
}
}
impl From<bool> for Scalar {
fn from(b: bool) -> Self {
Self::Boolean(b)
}
}
impl From<&str> for Scalar {
fn from(s: &str) -> Self {
Self::String(s.into())
}
}
impl From<String> for Scalar {
fn from(value: String) -> Self {
Self::String(value)
}
}
impl<T: Into<Scalar> + Copy> From<&T> for Scalar {
fn from(t: &T) -> Self {
(*t).into()
}
}
impl From<&[u8]> for Scalar {
fn from(b: &[u8]) -> Self {
Self::Binary(b.into())
}
}
impl PrimitiveType {
pub fn check_decimal(precision: u8, scale: u8) -> DeltaResult<()> {
require!(
0 < precision && precision <= 38,
Error::invalid_decimal(format!(
"precision must in range 1..38 inclusive, found: {}.",
precision
))
);
require!(
scale <= precision,
Error::invalid_decimal(format!(
"scale must be in range 0..precision inclusive, found: {scale}."
))
);
Ok(())
}
fn data_type(&self) -> DataType {
DataType::Primitive(self.clone())
}
pub fn parse_scalar(&self, raw: &str) -> Result<Scalar, Error> {
use PrimitiveType::*;
if raw.is_empty() {
return Ok(Scalar::Null(self.data_type()));
}
match self {
String => Ok(Scalar::String(raw.to_string())),
Binary => Ok(Scalar::Binary(raw.to_string().into_bytes())),
Byte => self.parse_str_as_scalar(raw, Scalar::Byte),
Decimal(precision, scale) => self.parse_decimal(raw, *precision, *scale),
Short => self.parse_str_as_scalar(raw, Scalar::Short),
Integer => self.parse_str_as_scalar(raw, Scalar::Integer),
Long => self.parse_str_as_scalar(raw, Scalar::Long),
Float => self.parse_str_as_scalar(raw, Scalar::Float),
Double => self.parse_str_as_scalar(raw, Scalar::Double),
Boolean => {
if raw.eq_ignore_ascii_case("true") {
Ok(Scalar::Boolean(true))
} else if raw.eq_ignore_ascii_case("false") {
Ok(Scalar::Boolean(false))
} else {
Err(self.parse_error(raw))
}
}
Date => {
let date = NaiveDate::parse_from_str(raw, "%Y-%m-%d")
.map_err(|_| self.parse_error(raw))?
.and_hms_opt(0, 0, 0)
.ok_or(self.parse_error(raw))?;
let date = Utc.from_utc_datetime(&date);
let days = date.signed_duration_since(DateTime::UNIX_EPOCH).num_days() as i32;
Ok(Scalar::Date(days))
}
Timestamp | TimestampNtz => {
let timestamp = NaiveDateTime::parse_from_str(raw, "%Y-%m-%d %H:%M:%S%.f")
.map_err(|_| self.parse_error(raw))?;
let timestamp = Utc.from_utc_datetime(×tamp);
let micros = timestamp
.signed_duration_since(DateTime::UNIX_EPOCH)
.num_microseconds()
.ok_or(self.parse_error(raw))?;
match self {
Timestamp => Ok(Scalar::Timestamp(micros)),
TimestampNtz => Ok(Scalar::TimestampNtz(micros)),
_ => unreachable!(),
}
}
}
}
fn parse_error(&self, raw: &str) -> Error {
Error::ParseError(raw.to_string(), self.data_type())
}
fn parse_str_as_scalar<T: std::str::FromStr>(
&self,
raw: &str,
f: impl FnOnce(T) -> Scalar,
) -> Result<Scalar, Error> {
match raw.parse() {
Ok(val) => Ok(f(val)),
Err(..) => Err(self.parse_error(raw)),
}
}
fn parse_decimal(&self, raw: &str, precision: u8, expected_scale: u8) -> Result<Scalar, Error> {
let (base, exp): (&str, i128) = match raw.find(['e', 'E']) {
None => (raw, 0), Some(pos) => {
let (base, exp) = raw.split_at(pos);
(base, exp[1..].parse()?)
}
};
require!(!base.is_empty(), self.parse_error(raw));
let (int_part, frac_part, frac_digits) = match base.find('.') {
None => {
(base, None, 0)
}
Some(pos) if pos == base.len() - 1 => {
(&base[..pos], None, 0)
}
Some(pos) => {
let (int_part, frac_part) = (&base[..pos], &base[pos + 1..]);
(int_part, Some(frac_part), frac_part.len() as i128)
}
};
let scale = frac_digits - exp;
let scale: u8 = scale.try_into().map_err(|_| self.parse_error(raw))?;
require!(scale == expected_scale, self.parse_error(raw));
Self::check_decimal(precision, scale)?;
let int: i128 = match frac_part {
None => int_part.parse()?,
Some(frac_part) => format!("{}{}", int_part, frac_part).parse()?,
};
Ok(Scalar::Decimal(int, precision, scale))
}
}
#[cfg(test)]
mod tests {
use std::f32::consts::PI;
use crate::expressions::{column_expr, BinaryOperator};
use crate::Expression;
use super::*;
#[test]
fn test_decimal_display() {
let s = Scalar::Decimal(123456789, 9, 2);
assert_eq!(s.to_string(), "1234567.89");
let s = Scalar::Decimal(123456789, 9, 0);
assert_eq!(s.to_string(), "123456789");
let s = Scalar::Decimal(123456789, 9, 9);
assert_eq!(s.to_string(), "0.123456789");
}
fn assert_decimal(
raw: &str,
expect_int: i128,
expect_prec: u8,
expect_scale: u8,
) -> Result<(), Box<dyn std::error::Error>> {
let s = PrimitiveType::Decimal(expect_prec, expect_scale);
match s.parse_scalar(raw)? {
Scalar::Decimal(int, prec, scale) => {
assert_eq!(int, expect_int);
assert_eq!(prec, expect_prec);
assert_eq!(scale, expect_scale);
}
_ => panic!("Didn't parse as decimal"),
};
Ok(())
}
#[test]
fn test_parse_decimal() -> Result<(), Box<dyn std::error::Error>> {
assert_decimal("0", 0, 1, 0)?;
assert_decimal("0.00", 0, 3, 2)?;
assert_decimal("123", 123, 3, 0)?;
assert_decimal("-123", -123, 3, 0)?;
assert_decimal("-123.", -123, 3, 0)?;
assert_decimal("123000", 123000, 6, 0)?;
assert_decimal("12.0", 120, 3, 1)?;
assert_decimal("12.3", 123, 3, 1)?;
assert_decimal("0.00123", 123, 5, 5)?;
assert_decimal("1234.5E-4", 12345, 5, 5)?;
assert_decimal("-0", 0, 1, 0)?;
assert_decimal("12.000000000000000000", 12000000000000000000, 38, 18)?;
Ok(())
}
fn expect_fail_parse(raw: &str, prec: u8, scale: u8) {
let s = PrimitiveType::Decimal(prec, scale);
let res = s.parse_scalar(raw);
assert!(res.is_err(), "Fail on {raw}");
}
#[test]
fn test_parse_decimal_expect_fail() {
expect_fail_parse("iowjef", 0, 0);
expect_fail_parse("123Ef", 0, 0);
expect_fail_parse("1d2E3", 0, 0);
expect_fail_parse("a", 0, 0);
expect_fail_parse("2.a", 1, 1);
expect_fail_parse("E45", 0, 0);
expect_fail_parse("1.2.3", 0, 0);
expect_fail_parse("1.2E1.3", 0, 0);
expect_fail_parse("123.45", 5, 1);
expect_fail_parse(".45", 5, 1);
expect_fail_parse("+", 0, 0);
expect_fail_parse("-", 0, 0);
expect_fail_parse("0.-0", 2, 1);
expect_fail_parse("--1.0", 1, 1);
expect_fail_parse("+-1.0", 1, 1);
expect_fail_parse("-+1.0", 1, 1);
expect_fail_parse("++1.0", 1, 1);
expect_fail_parse("1.0E1+", 1, 1);
expect_fail_parse("0.999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 0, 0);
expect_fail_parse("0.E170141183460469231731687303715884105727", 0, 0);
}
#[test]
fn test_arrays() {
#[allow(deprecated)]
let array = Scalar::Array(ArrayData {
tpe: ArrayType::new(DataType::INTEGER, false),
elements: vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)],
});
let column = column_expr!("item");
let array_op = Expression::binary(BinaryOperator::In, 10, array.clone());
let array_not_op = Expression::binary(BinaryOperator::NotIn, 10, array);
let column_op = Expression::binary(BinaryOperator::In, PI, column.clone());
let column_not_op = Expression::binary(BinaryOperator::NotIn, "Cool", column);
assert_eq!(&format!("{}", array_op), "10 IN (1, 2, 3)");
assert_eq!(&format!("{}", array_not_op), "10 NOT IN (1, 2, 3)");
assert_eq!(&format!("{}", column_op), "3.1415927 IN Column(item)");
assert_eq!(&format!("{}", column_not_op), "'Cool' NOT IN Column(item)");
}
}