use std::cmp::Ordering;
use std::hash::Hash;
use std::hash::Hasher;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure_eq;
use vortex_error::vortex_panic;
use crate::dtype::DType;
use crate::dtype::NativeDType;
use crate::dtype::PType;
use crate::scalar::Scalar;
use crate::scalar::ScalarValue;
impl Scalar {
pub fn null(dtype: DType) -> Self {
assert!(
dtype.is_nullable(),
"Cannot create null scalar with non-nullable dtype {dtype}"
);
Self { dtype, value: None }
}
pub fn null_native<T: NativeDType>() -> Self {
Self {
dtype: T::dtype().as_nullable(),
value: None,
}
}
#[cfg(test)]
pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
use vortex_error::VortexExpect;
Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
}
pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
Self::validate(&dtype, value.as_ref())?;
Ok(Self { dtype, value })
}
pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
#[cfg(debug_assertions)]
{
use vortex_error::VortexExpect;
Self::validate(&dtype, value.as_ref())
.vortex_expect("Scalar::new_unchecked called with incompatible dtype and value");
}
Self { dtype, value }
}
pub fn default_value(dtype: &DType) -> Self {
let value = ScalarValue::default_value(dtype);
unsafe { Self::new_unchecked(dtype.clone(), value) }
}
pub fn zero_value(dtype: &DType) -> Self {
let value = ScalarValue::zero_value(dtype);
unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
}
pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
}
pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
(self.dtype, self.value)
}
pub fn dtype(&self) -> &DType {
&self.dtype
}
pub fn value(&self) -> Option<&ScalarValue> {
self.value.as_ref()
}
pub fn into_value(self) -> Option<ScalarValue> {
self.value
}
pub fn is_valid(&self) -> bool {
self.value.is_some()
}
pub fn is_null(&self) -> bool {
self.value.is_none()
}
pub fn is_zero(&self) -> Option<bool> {
let value = self.value()?;
let is_zero = match self.dtype() {
DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
DType::Bool(_) => !value.as_bool(),
DType::Primitive(..) => value.as_primitive().is_zero(),
DType::Decimal(..) => value.as_decimal().is_zero(),
DType::Utf8(_) => value.as_utf8().is_empty(),
DType::Binary(_) => value.as_binary().is_empty(),
DType::List(..) => value.as_list().is_empty(),
DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize,
DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(),
DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
DType::Variant(_) => self.as_variant().is_zero()?,
};
Some(is_zero)
}
pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
let primitive = self.as_primitive();
if primitive.ptype() == ptype {
return Ok(self.clone());
}
vortex_ensure_eq!(
primitive.ptype().byte_width(),
ptype.byte_width(),
"can't reinterpret cast between integers of two different widths"
);
Scalar::try_new(
DType::Primitive(ptype, self.dtype().nullability()),
primitive
.pvalue()
.map(|p| p.reinterpret_cast(ptype))
.map(ScalarValue::Primitive),
)
}
pub fn approx_nbytes(&self) -> usize {
use crate::dtype::NativeDecimalType;
use crate::dtype::i256;
match self.dtype() {
DType::Null => 0,
DType::Bool(_) => 1,
DType::Primitive(ptype, _) => ptype.byte_width(),
DType::Decimal(dt, _) => {
if dt.precision() <= i128::MAX_PRECISION {
size_of::<i128>()
} else {
size_of::<i256>()
}
}
DType::Utf8(_) => self
.value()
.map_or_else(|| 0, |value| value.as_utf8().len()),
DType::Binary(_) => self
.value()
.map_or_else(|| 0, |value| value.as_binary().len()),
DType::Struct(..) => self
.as_struct()
.fields_iter()
.map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
.unwrap_or_default(),
DType::List(..) | DType::FixedSizeList(..) => self
.as_list()
.elements()
.map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
.unwrap_or_default(),
DType::Extension(_) => self.as_extension().to_storage_scalar().approx_nbytes(),
DType::Variant(_) => self.as_variant().value().map_or(0, Scalar::approx_nbytes),
}
}
}
impl PartialEq for Scalar {
fn eq(&self, other: &Self) -> bool {
self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
}
}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if !self.dtype().eq_ignore_nullability(other.dtype()) {
return None;
}
self.value().partial_cmp(&other.value())
}
}
impl Hash for Scalar {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dtype.as_nonnullable().hash(state);
self.value.hash(state);
}
}