mod sealed {
pub trait Sealed {}
impl Sealed for u8 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for bool {}
}
pub trait GpuScalar:
sealed::Sealed + crate::cuda_compat::KernelScalar + Copy + Send + 'static
{
const BYTE_WIDTH: usize;
fn from_le_bytes(bytes: &[u8]) -> Self;
fn to_le_bytes_into(self, buf: &mut [u8]);
fn filter_compare_kernel() -> &'static str;
fn compare_col_kernel() -> &'static str;
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType];
fn filter_scan_phase1_kernel() -> Option<&'static str> {
None
}
}
impl GpuScalar for u8 {
const BYTE_WIDTH: usize = 1;
fn from_le_bytes(bytes: &[u8]) -> Self {
bytes[0]
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf[0] = self;
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_u8"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_u8_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::Bool]
}
}
impl GpuScalar for u32 {
const BYTE_WIDTH: usize = 4;
fn from_le_bytes(bytes: &[u8]) -> Self {
u32::from_le_bytes(bytes.try_into().unwrap())
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf.copy_from_slice(&self.to_le_bytes());
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_u32"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_u32_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::U32, xlog_core::ScalarType::Symbol]
}
fn filter_scan_phase1_kernel() -> Option<&'static str> {
Some("filter_compare_u32_scan_phase1")
}
}
impl GpuScalar for u64 {
const BYTE_WIDTH: usize = 8;
fn from_le_bytes(bytes: &[u8]) -> Self {
u64::from_le_bytes(bytes.try_into().unwrap())
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf.copy_from_slice(&self.to_le_bytes());
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_u64"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_u64_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::U64]
}
}
impl GpuScalar for i32 {
const BYTE_WIDTH: usize = 4;
fn from_le_bytes(bytes: &[u8]) -> Self {
i32::from_le_bytes(bytes.try_into().unwrap())
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf.copy_from_slice(&self.to_le_bytes());
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_i32"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_i32_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::I32]
}
}
impl GpuScalar for i64 {
const BYTE_WIDTH: usize = 8;
fn from_le_bytes(bytes: &[u8]) -> Self {
i64::from_le_bytes(bytes.try_into().unwrap())
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf.copy_from_slice(&self.to_le_bytes());
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_i64"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_i64_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::I64]
}
}
impl GpuScalar for f32 {
const BYTE_WIDTH: usize = 4;
fn from_le_bytes(bytes: &[u8]) -> Self {
f32::from_le_bytes(bytes.try_into().unwrap())
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf.copy_from_slice(&self.to_le_bytes());
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_f32"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_f32_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::F32]
}
}
impl GpuScalar for f64 {
const BYTE_WIDTH: usize = 8;
fn from_le_bytes(bytes: &[u8]) -> Self {
f64::from_le_bytes(bytes.try_into().unwrap())
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf.copy_from_slice(&self.to_le_bytes());
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_f64"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_f64_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::F64]
}
fn filter_scan_phase1_kernel() -> Option<&'static str> {
Some("filter_compare_f64_scan_phase1")
}
}
impl GpuScalar for bool {
const BYTE_WIDTH: usize = 1;
fn from_le_bytes(bytes: &[u8]) -> Self {
bytes[0] != 0
}
fn to_le_bytes_into(self, buf: &mut [u8]) {
buf[0] = if self { 1 } else { 0 };
}
fn filter_compare_kernel() -> &'static str {
"filter_compare_u8"
}
fn compare_col_kernel() -> &'static str {
"filter_compare_u8_col"
}
fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
&[xlog_core::ScalarType::Bool]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip<T: GpuScalar + PartialEq + std::fmt::Debug>(val: T) {
let mut buf = vec![0u8; T::BYTE_WIDTH];
val.to_le_bytes_into(&mut buf);
let recovered = T::from_le_bytes(&buf);
assert_eq!(recovered, val);
}
#[test]
fn test_gpu_scalar_roundtrip_u8() {
roundtrip(42u8);
roundtrip(0u8);
roundtrip(255u8);
}
#[test]
fn test_gpu_scalar_roundtrip_u32() {
roundtrip(0u32);
roundtrip(42u32);
roundtrip(u32::MAX);
}
#[test]
fn test_gpu_scalar_roundtrip_u64() {
roundtrip(0u64);
roundtrip(42u64);
roundtrip(u64::MAX);
}
#[test]
fn test_gpu_scalar_roundtrip_i32() {
roundtrip(0i32);
roundtrip(-1i32);
roundtrip(i32::MAX);
}
#[test]
fn test_gpu_scalar_roundtrip_i64() {
roundtrip(0i64);
roundtrip(-1i64);
roundtrip(i64::MAX);
}
#[test]
fn test_gpu_scalar_roundtrip_f32() {
roundtrip(0.0f32);
roundtrip(-1.5f32);
roundtrip(f32::INFINITY);
}
#[test]
fn test_gpu_scalar_roundtrip_f64() {
roundtrip(0.0f64);
roundtrip(-1.5f64);
roundtrip(f64::INFINITY);
}
#[test]
fn test_gpu_scalar_roundtrip_bool() {
roundtrip(true);
roundtrip(false);
}
#[test]
fn test_bool_canonical_write() {
let mut buf = [0xFFu8];
false.to_le_bytes_into(&mut buf);
assert_eq!(buf[0], 0x00, "false must write canonical 0x00");
true.to_le_bytes_into(&mut buf);
assert_eq!(buf[0], 0x01, "true must write canonical 0x01");
}
#[test]
fn test_bool_lenient_read() {
assert!(!bool::from_le_bytes(&[0x00]));
assert!(bool::from_le_bytes(&[0x01]));
assert!(bool::from_le_bytes(&[0x02]));
assert!(bool::from_le_bytes(&[0xFF]));
}
#[test]
fn test_byte_width_consistency() {
assert_eq!(u8::BYTE_WIDTH, std::mem::size_of::<u8>());
assert_eq!(u32::BYTE_WIDTH, std::mem::size_of::<u32>());
assert_eq!(u64::BYTE_WIDTH, std::mem::size_of::<u64>());
assert_eq!(i32::BYTE_WIDTH, std::mem::size_of::<i32>());
assert_eq!(i64::BYTE_WIDTH, std::mem::size_of::<i64>());
assert_eq!(f32::BYTE_WIDTH, std::mem::size_of::<f32>());
assert_eq!(f64::BYTE_WIDTH, std::mem::size_of::<f64>());
assert_eq!(bool::BYTE_WIDTH, std::mem::size_of::<bool>());
}
#[test]
fn test_filter_kernel_names_non_empty() {
assert!(!u8::filter_compare_kernel().is_empty());
assert!(!u8::compare_col_kernel().is_empty());
assert!(!u32::filter_compare_kernel().is_empty());
assert!(!u32::compare_col_kernel().is_empty());
assert!(!u64::filter_compare_kernel().is_empty());
assert!(!u64::compare_col_kernel().is_empty());
assert!(!i32::filter_compare_kernel().is_empty());
assert!(!i32::compare_col_kernel().is_empty());
assert!(!i64::filter_compare_kernel().is_empty());
assert!(!i64::compare_col_kernel().is_empty());
assert!(!f32::filter_compare_kernel().is_empty());
assert!(!f32::compare_col_kernel().is_empty());
assert!(!f64::filter_compare_kernel().is_empty());
assert!(!f64::compare_col_kernel().is_empty());
assert!(!bool::filter_compare_kernel().is_empty());
assert!(!bool::compare_col_kernel().is_empty());
}
#[test]
fn test_allowed_scalar_types_non_empty() {
assert!(!u8::allowed_scalar_types().is_empty());
assert!(!u32::allowed_scalar_types().is_empty());
assert!(!u64::allowed_scalar_types().is_empty());
assert!(!i32::allowed_scalar_types().is_empty());
assert!(!i64::allowed_scalar_types().is_empty());
assert!(!f32::allowed_scalar_types().is_empty());
assert!(!f64::allowed_scalar_types().is_empty());
assert!(!bool::allowed_scalar_types().is_empty());
}
#[test]
fn test_fused_scan_kernel_only_u32_and_f64() {
assert!(u32::filter_scan_phase1_kernel().is_some());
assert!(f64::filter_scan_phase1_kernel().is_some());
assert!(u8::filter_scan_phase1_kernel().is_none());
assert!(u64::filter_scan_phase1_kernel().is_none());
assert!(i32::filter_scan_phase1_kernel().is_none());
assert!(i64::filter_scan_phase1_kernel().is_none());
assert!(f32::filter_scan_phase1_kernel().is_none());
assert!(bool::filter_scan_phase1_kernel().is_none());
}
#[test]
fn test_bool_and_u8_share_gpu_kernels() {
assert_eq!(u8::filter_compare_kernel(), bool::filter_compare_kernel());
assert_eq!(u8::compare_col_kernel(), bool::compare_col_kernel());
}
#[test]
fn test_u32_allowed_includes_symbol() {
let allowed = u32::allowed_scalar_types();
assert!(allowed.contains(&xlog_core::ScalarType::U32));
assert!(allowed.contains(&xlog_core::ScalarType::Symbol));
}
}