use cxx::UniquePtr;
use crate::column::CudfType;
use crate::error::{CudfError, Result};
use crate::types::{DataType, TypeId};
pub struct Scalar {
pub(crate) inner: UniquePtr<cudf_cxx::scalar::ffi::OwnedScalar>,
}
unsafe impl Send for Scalar {}
macro_rules! scalar_set_dispatch {
($type_id:expr, $value:expr, $inner:expr, $($variant:ident => $ty:ty, $set_fn:path);+ $(;)?) => {
match $type_id {
$(TypeId::$variant => {
debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<$ty>());
let v: $ty = unsafe { std::mem::transmute_copy(&$value) };
$set_fn($inner.pin_mut(), v).map_err(CudfError::from_cxx)?;
})+
_ => {
return Err(CudfError::InvalidArgument(format!(
"Scalar::new does not support {:?}",
$type_id
)));
}
}
};
}
macro_rules! scalar_get_dispatch {
($self_:expr, $($variant:ident => $ty:ty, $get_fn:path);+ $(;)?) => {
match T::TYPE_ID {
$(TypeId::$variant => {
let v = $get_fn(&$self_.inner).map_err(CudfError::from_cxx)?;
debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<$ty>());
Ok(unsafe { std::mem::transmute_copy(&v) })
})+
_ => Err(CudfError::InvalidArgument(format!(
"Scalar::value does not yet support {:?}",
T::TYPE_ID
))),
}
};
}
impl Scalar {
pub fn new<T: CudfType>(value: T) -> Result<Self> {
let type_id = T::TYPE_ID;
let mut inner = cudf_cxx::scalar::ffi::make_numeric_scalar(type_id as i32)
.map_err(CudfError::from_cxx)?;
scalar_set_dispatch! {
type_id, value, inner,
Int8 => i8, cudf_cxx::scalar::ffi::scalar_set_i8;
Int16 => i16, cudf_cxx::scalar::ffi::scalar_set_i16;
Int32 => i32, cudf_cxx::scalar::ffi::scalar_set_i32;
Int64 => i64, cudf_cxx::scalar::ffi::scalar_set_i64;
Uint8 => u8, cudf_cxx::scalar::ffi::scalar_set_u8;
Uint16 => u16, cudf_cxx::scalar::ffi::scalar_set_u16;
Uint32 => u32, cudf_cxx::scalar::ffi::scalar_set_u32;
Uint64 => u64, cudf_cxx::scalar::ffi::scalar_set_u64;
Float32 => f32, cudf_cxx::scalar::ffi::scalar_set_f32;
Float64 => f64, cudf_cxx::scalar::ffi::scalar_set_f64;
Bool8 => bool, cudf_cxx::scalar::ffi::scalar_set_bool;
}
Ok(Self { inner })
}
pub fn null(dtype: DataType) -> Result<Self> {
if !dtype.id().is_numeric() && dtype.id() != TypeId::Bool8 {
return Err(CudfError::InvalidArgument(format!(
"Scalar::null() only supports numeric and boolean types, got {:?}",
dtype.id()
)));
}
let inner = cudf_cxx::scalar::ffi::make_numeric_scalar(dtype.id() as i32)
.map_err(CudfError::from_cxx)?;
Ok(Self { inner })
}
pub fn value<T: CudfType>(&self) -> Result<T> {
let actual = self.data_type().id();
if actual != T::TYPE_ID {
return Err(CudfError::TypeMismatch {
expected: format!("{:?}", T::TYPE_ID),
actual: format!("{:?}", actual),
});
}
if !self.is_valid() {
return Err(CudfError::InvalidArgument(
"cannot get value of null scalar".to_string(),
));
}
scalar_get_dispatch! {
self,
Int8 => i8, cudf_cxx::scalar::ffi::scalar_get_i8;
Int16 => i16, cudf_cxx::scalar::ffi::scalar_get_i16;
Int32 => i32, cudf_cxx::scalar::ffi::scalar_get_i32;
Int64 => i64, cudf_cxx::scalar::ffi::scalar_get_i64;
Uint8 => u8, cudf_cxx::scalar::ffi::scalar_get_u8;
Uint16 => u16, cudf_cxx::scalar::ffi::scalar_get_u16;
Uint32 => u32, cudf_cxx::scalar::ffi::scalar_get_u32;
Uint64 => u64, cudf_cxx::scalar::ffi::scalar_get_u64;
Float32 => f32, cudf_cxx::scalar::ffi::scalar_get_f32;
Float64 => f64, cudf_cxx::scalar::ffi::scalar_get_f64;
Bool8 => bool, cudf_cxx::scalar::ffi::scalar_get_bool;
}
}
pub fn is_valid(&self) -> bool {
self.inner.is_valid()
}
pub fn set_valid(&mut self, valid: bool) -> Result<()> {
cudf_cxx::scalar::ffi::scalar_set_valid(self.inner.pin_mut(), valid)
.map_err(CudfError::from_cxx)
}
pub fn data_type(&self) -> DataType {
let raw = self.inner.type_id();
let id = TypeId::from_raw(raw).unwrap_or_else(|| {
panic!(
"cudf: unrecognized type_id {} from FFI — possible libcudf version mismatch",
raw
)
});
DataType::new(id)
}
}
impl TryFrom<i32> for Scalar {
type Error = CudfError;
fn try_from(v: i32) -> Result<Self> {
Self::new(v)
}
}
impl TryFrom<i64> for Scalar {
type Error = CudfError;
fn try_from(v: i64) -> Result<Self> {
Self::new(v)
}
}
impl TryFrom<f32> for Scalar {
type Error = CudfError;
fn try_from(v: f32) -> Result<Self> {
Self::new(v)
}
}
impl TryFrom<f64> for Scalar {
type Error = CudfError;
fn try_from(v: f64) -> Result<Self> {
Self::new(v)
}
}
impl std::fmt::Debug for Scalar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}
impl std::fmt::Display for Scalar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Scalar({}, valid={})", self.data_type(), self.is_valid())
}
}