use crate::{LaunchAsync, LaunchConfig};
use xlog_core::{AggOp, Result, ScalarType, Schema, XlogError};
use super::{
arith_kernels, groupby_kernels, pack_kernels, scan_kernels, ARITH_MODULE, GROUPBY_MODULE,
PACK_MODULE, SCAN_MODULE,
};
use crate::memory::{CudaColumn, TrackedCudaSlice};
use crate::CudaBuffer;
impl super::CudaKernelProvider {
pub fn groupby_agg(
&self,
input: &CudaBuffer,
key_cols: &[usize],
agg: AggOp,
value_col: usize,
) -> Result<CudaBuffer> {
self.groupby_multi_agg(input, key_cols, &[(value_col, agg)])
}
pub fn groupby_multi_agg(
&self,
buffer: &CudaBuffer,
key_cols: &[usize],
aggs: &[(usize, AggOp)],
) -> Result<CudaBuffer> {
let num_rows = self.device_row_count(buffer)?;
if num_rows == 0 {
let result_schema =
self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
return self.create_empty_buffer(result_schema);
}
if num_rows > u32::MAX as usize {
return Err(XlogError::Kernel(format!(
"GroupBy supports at most {} rows, got {}",
u32::MAX,
num_rows
)));
}
if key_cols.is_empty() {
return Err(XlogError::Kernel(
"GroupBy requires at least one key column".to_string(),
));
}
if aggs.is_empty() {
return Err(XlogError::Kernel(
"GroupBy requires at least one aggregation".to_string(),
));
}
for &key_col in key_cols {
if key_col >= buffer.arity() {
return Err(XlogError::Kernel(format!(
"Key column {} out of bounds (arity {})",
key_col,
buffer.arity()
)));
}
}
for &(value_col, agg_op) in aggs {
if value_col >= buffer.arity() {
return Err(XlogError::Kernel(format!(
"Value column {} out of bounds (arity {})",
value_col,
buffer.arity()
)));
}
let value_ty = buffer
.schema()
.column_type(value_col)
.ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
match agg_op {
AggOp::Count => {}
AggOp::Sum | AggOp::Min | AggOp::Max => {
if value_ty != ScalarType::U32 {
return Err(XlogError::Kernel(format!(
"{:?} currently requires U32 values, got {:?}",
agg_op, value_ty
)));
}
}
AggOp::LogSumExp => {
if value_ty != ScalarType::F64 {
return Err(XlogError::Kernel(format!(
"LogSumExp requires F64 values, got {:?}",
value_ty
)));
}
}
}
}
let num_rows = num_rows as u32;
let sorted = self.sort(buffer, key_cols)?;
let boundary_func = self
.device
.inner()
.get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES)
.ok_or_else(|| {
XlogError::Kernel("detect_group_boundaries kernel not found".to_string())
})?;
let boundaries = self.memory.alloc::<u8>(num_rows as usize)?;
let packed = self.compute_hashes_and_pack_keys(&sorted, key_cols)?;
if packed.key_bytes == 0 || packed.key_bytes % 4 != 0 {
return Err(XlogError::Kernel(format!(
"GroupBy key packing produced {} bytes per row (expected multiple of 4); Bool keys are not supported",
packed.key_bytes
)));
}
let segments_per_row = (packed.key_bytes / 4) as usize;
let total_segments = (num_rows as usize) * segments_per_row;
let packed_u32 = self.bytes_as_u32_view(&packed.packed_keys, total_segments)?;
let block_size = 256u32;
let grid_size = (num_rows + block_size - 1) / block_size;
let config = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
boundary_func.clone().launch(
config,
(
&packed_u32,
num_rows,
segments_per_row as u32,
segments_per_row as u32,
&boundaries,
),
)
}
.map_err(|e| XlogError::Kernel(format!("detect_group_boundaries failed: {}", e)))?;
self.device.synchronize()?;
let device = self.device.inner();
let num_blocks = grid_size;
let d_boundary_pos = self.memory.alloc::<u32>(num_rows as usize)?;
let mut d_block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
let phase1_fn = device
.get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE1)
.ok_or_else(|| {
XlogError::Kernel("Failed to get multiblock_scan_phase1 kernel".to_string())
})?;
unsafe {
phase1_fn.clone().launch(
LaunchConfig {
grid_dim: (num_blocks, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
},
(&boundaries, &d_boundary_pos, &d_block_sums, num_rows),
)
}
.map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase1 failed: {}", e)))?;
if num_blocks > 1 {
self.multiblock_scan_u32_inplace(&mut d_block_sums, num_blocks)?;
let phase3_fn = device
.get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
.ok_or_else(|| {
XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
})?;
unsafe {
phase3_fn.clone().launch(
LaunchConfig {
grid_dim: (num_blocks, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
},
(&d_boundary_pos, &d_block_sums, num_rows),
)
}
.map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
}
self.device.synchronize()?;
let d_num_groups = self.capture_num_groups(&d_boundary_pos, &boundaries, num_rows)?;
let row_cap = num_rows as u64;
let row_cap_usize = num_rows as usize;
let row_cap_u32 = num_rows;
let mut group_ids = self.memory.alloc::<u32>(num_rows as usize)?;
let mut group_first_idx = self.memory.alloc::<u32>(row_cap_usize)?;
let group_ids_fn = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUP_IDS_FROM_BOUNDARIES)
.ok_or_else(|| {
XlogError::Kernel("group_ids_from_boundaries kernel not found".to_string())
})?;
let group_start_fn = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUP_START_INDICES)
.ok_or_else(|| XlogError::Kernel("group_start_indices kernel not found".to_string()))?;
unsafe {
group_ids_fn.clone().launch(
config,
(&boundaries, &d_boundary_pos, num_rows, &mut group_ids),
)
}
.map_err(|e| XlogError::Kernel(format!("group_ids_from_boundaries failed: {}", e)))?;
unsafe {
group_start_fn.clone().launch(
config,
(&boundaries, &d_boundary_pos, num_rows, &mut group_first_idx),
)
}
.map_err(|e| XlogError::Kernel(format!("group_start_indices failed: {}", e)))?;
self.device.synchronize()?;
let mut agg_columns: Vec<CudaColumn> = Vec::with_capacity(aggs.len());
for &(value_col, agg_op) in aggs {
let values = sorted
.column(value_col)
.ok_or_else(|| XlogError::Kernel("Value column not found".to_string()))?;
match agg_op {
AggOp::Count => {
let output_bytes = row_cap_usize
.checked_mul(std::mem::size_of::<u64>())
.ok_or_else(|| {
XlogError::Kernel("Count output size overflow".to_string())
})?;
let mut output = self.memory.alloc::<u8>(output_bytes)?;
device.memset_zeros(&mut output).map_err(|e| {
XlogError::Kernel(format!("Failed to zero count output: {}", e))
})?;
let count_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT)
.ok_or_else(|| {
XlogError::Kernel("groupby_count kernel not found".to_string())
})?;
unsafe {
count_func
.clone()
.launch(config, (&boundaries, &group_ids, num_rows, &output))
}
.map_err(|e| XlogError::Kernel(format!("groupby_count failed: {}", e)))?;
self.device.synchronize()?;
agg_columns.push(output.into());
}
AggOp::Sum => {
let values_view = self.column_as_u32_view(values, num_rows as usize)?;
let output_bytes = row_cap_usize
.checked_mul(std::mem::size_of::<u64>())
.ok_or_else(|| XlogError::Kernel("Sum output size overflow".to_string()))?;
let mut output = self.memory.alloc::<u8>(output_bytes)?;
device.memset_zeros(&mut output).map_err(|e| {
XlogError::Kernel(format!("Failed to zero sum output: {}", e))
})?;
let sum_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM)
.ok_or_else(|| {
XlogError::Kernel("groupby_sum kernel not found".to_string())
})?;
unsafe {
sum_func
.clone()
.launch(config, (&values_view, &group_ids, num_rows, &output))
}
.map_err(|e| XlogError::Kernel(format!("groupby_sum failed: {}", e)))?;
self.device.synchronize()?;
agg_columns.push(output.into());
}
AggOp::Min => {
let values_view = self.column_as_u32_view(values, num_rows as usize)?;
let output_bytes = row_cap_usize
.checked_mul(std::mem::size_of::<u32>())
.ok_or_else(|| XlogError::Kernel("Min output size overflow".to_string()))?;
let mut output = self.memory.alloc::<u8>(output_bytes)?;
let fill_fn = device
.get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U32)
.ok_or_else(|| {
XlogError::Kernel("arith_fill_const_u32 not found".to_string())
})?;
let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
unsafe {
fill_fn
.clone()
.launch(fill_config, (u32::MAX, row_cap_u32, &mut output))
}
.map_err(|e| XlogError::Kernel(format!("Failed to init min output: {}", e)))?;
let min_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN)
.ok_or_else(|| {
XlogError::Kernel("groupby_min kernel not found".to_string())
})?;
unsafe {
min_func
.clone()
.launch(config, (&values_view, &group_ids, num_rows, &output))
}
.map_err(|e| XlogError::Kernel(format!("groupby_min failed: {}", e)))?;
self.device.synchronize()?;
agg_columns.push(output.into());
}
AggOp::Max => {
let values_view = self.column_as_u32_view(values, num_rows as usize)?;
let output_bytes = row_cap_usize
.checked_mul(std::mem::size_of::<u32>())
.ok_or_else(|| XlogError::Kernel("Max output size overflow".to_string()))?;
let mut output = self.memory.alloc::<u8>(output_bytes)?;
device.memset_zeros(&mut output).map_err(|e| {
XlogError::Kernel(format!("Failed to zero max output: {}", e))
})?;
let max_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX)
.ok_or_else(|| {
XlogError::Kernel("groupby_max kernel not found".to_string())
})?;
unsafe {
max_func
.clone()
.launch(config, (&values_view, &group_ids, num_rows, &output))
}
.map_err(|e| XlogError::Kernel(format!("groupby_max failed: {}", e)))?;
self.device.synchronize()?;
agg_columns.push(output.into());
}
AggOp::LogSumExp => {
let values_f64 = self.column_as_f64_view(values, num_rows as usize)?;
let output_bytes = row_cap_usize
.checked_mul(std::mem::size_of::<f64>())
.ok_or_else(|| {
XlogError::Kernel("LogSumExp output size overflow".to_string())
})?;
let mut maxs = self.memory.alloc::<u8>(output_bytes)?;
let mut sumexps = self.memory.alloc::<u8>(output_bytes)?;
let results = self.memory.alloc::<u8>(output_bytes)?;
let fill_f64 = device
.get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_F64)
.ok_or_else(|| {
XlogError::Kernel("arith_fill_const_f64 not found".to_string())
})?;
let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
unsafe {
fill_f64
.clone()
.launch(fill_config, (f64::NEG_INFINITY, row_cap_u32, &mut maxs))
}
.map_err(|e| XlogError::Kernel(format!("Failed to init maxs: {}", e)))?;
device
.memset_zeros(&mut sumexps)
.map_err(|e| XlogError::Kernel(format!("Failed to init sumexps: {}", e)))?;
let max_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_MAX)
.ok_or_else(|| {
XlogError::Kernel("groupby_logsumexp_max kernel not found".to_string())
})?;
unsafe {
max_func
.clone()
.launch(config, (&values_f64, &group_ids, num_rows, &maxs))
}
.map_err(|e| {
XlogError::Kernel(format!("groupby_logsumexp_max failed: {}", e))
})?;
self.device.synchronize()?;
let sumexp_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_SUMEXP)
.ok_or_else(|| {
XlogError::Kernel(
"groupby_logsumexp_sumexp kernel not found".to_string(),
)
})?;
unsafe {
sumexp_func
.clone()
.launch(config, (&values_f64, &group_ids, &maxs, num_rows, &sumexps))
}
.map_err(|e| {
XlogError::Kernel(format!("groupby_logsumexp_sumexp failed: {}", e))
})?;
self.device.synchronize()?;
let final_config = LaunchConfig::for_num_elems(row_cap_u32);
let final_func = device
.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_FINAL)
.ok_or_else(|| {
XlogError::Kernel(
"groupby_logsumexp_final kernel not found".to_string(),
)
})?;
unsafe {
final_func.clone().launch(
final_config,
(&maxs, &sumexps, &d_num_groups, row_cap_u32, &results),
)
}
.map_err(|e| {
XlogError::Kernel(format!("groupby_logsumexp_final failed: {}", e))
})?;
self.device.synchronize()?;
agg_columns.push(results.into());
}
}
}
let mut result_columns: Vec<CudaColumn> = Vec::with_capacity(key_cols.len() + aggs.len());
let group_packed_bytes = row_cap_usize
.checked_mul(packed.key_bytes as usize)
.ok_or_else(|| XlogError::Kernel("GroupBy packed size overflow".to_string()))?;
let mut group_packed = self.memory.alloc::<u8>(group_packed_bytes)?;
let gather_fn = device
.get_func(PACK_MODULE, pack_kernels::GATHER_PACKED_ROWS_COUNTED)
.ok_or_else(|| {
XlogError::Kernel("gather_packed_rows_counted kernel not found".to_string())
})?;
let gather_config = LaunchConfig::for_num_elems(row_cap_u32);
unsafe {
gather_fn.clone().launch(
gather_config,
(
&packed.packed_keys,
packed.key_bytes,
&group_first_idx,
&d_num_groups,
row_cap_u32,
&mut group_packed,
),
)
}
.map_err(|e| XlogError::Kernel(format!("gather_packed_rows failed: {}", e)))?;
let mut col_offsets: Vec<u32> = Vec::with_capacity(key_cols.len());
let mut col_sizes: Vec<u32> = Vec::with_capacity(key_cols.len());
let mut offset = 0u32;
for &key_col in key_cols {
let size = buffer
.schema()
.column_type(key_col)
.map(|t| t.size_bytes() as u32)
.unwrap_or(4);
col_offsets.push(offset);
col_sizes.push(size);
offset = offset
.checked_add(size)
.ok_or_else(|| XlogError::Kernel("GroupBy key size overflow".to_string()))?;
}
let unpack_fn = device
.get_func(PACK_MODULE, pack_kernels::UNPACK_COLUMN_COUNTED)
.ok_or_else(|| {
XlogError::Kernel("unpack_column_counted kernel not found".to_string())
})?;
let unpack_config = LaunchConfig::for_num_elems(row_cap_u32);
for idx in 0..key_cols.len() {
let col_size = col_sizes[idx];
let col_offset = col_offsets[idx];
let out_bytes = row_cap_usize
.checked_mul(col_size as usize)
.ok_or_else(|| XlogError::Kernel("GroupBy key column overflow".to_string()))?;
let mut out_col = self.memory.alloc::<u8>(out_bytes)?;
unsafe {
unpack_fn.clone().launch(
unpack_config,
(
&group_packed,
packed.key_bytes,
col_offset,
col_size,
&d_num_groups,
row_cap_u32,
&mut out_col,
),
)
}
.map_err(|e| XlogError::Kernel(format!("unpack_column failed: {}", e)))?;
result_columns.push(out_col.into());
}
result_columns.extend(agg_columns);
let result_schema = self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
Ok(CudaBuffer::from_columns(
result_columns,
row_cap,
d_num_groups,
result_schema,
))
}
fn capture_num_groups(
&self,
boundary_pos: &TrackedCudaSlice<u32>,
boundaries: &TrackedCudaSlice<u8>,
num_rows: u32,
) -> Result<TrackedCudaSlice<u32>> {
let mut d_num_groups = self.memory.alloc::<u32>(1)?;
let capture_fn = self
.device
.inner()
.get_func(GROUPBY_MODULE, groupby_kernels::CAPTURE_NUM_GROUPS)
.ok_or_else(|| XlogError::Kernel("capture_num_groups kernel not found".to_string()))?;
unsafe {
capture_fn.clone().launch(
LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
},
(boundary_pos, boundaries, num_rows, &mut d_num_groups),
)
}
.map_err(|e| XlogError::Kernel(format!("capture_num_groups failed: {}", e)))?;
Ok(d_num_groups)
}
pub(crate) fn groupby_multi_agg_result_schema(
&self,
input: &Schema,
key_cols: &[usize],
aggs: &[(usize, AggOp)],
) -> Schema {
let mut columns: Vec<(String, ScalarType)> = key_cols
.iter()
.filter_map(|&i| input.columns.get(i).cloned())
.collect();
for (i, &(_value_col, agg_op)) in aggs.iter().enumerate() {
let agg_name = match agg_op {
AggOp::Count => format!("count_{}", i),
AggOp::Sum => format!("sum_{}", i),
AggOp::Min => format!("min_{}", i),
AggOp::Max => format!("max_{}", i),
AggOp::LogSumExp => format!("logsumexp_{}", i),
};
let agg_type = match agg_op {
AggOp::Count => ScalarType::U64,
AggOp::Sum => ScalarType::U64,
AggOp::Min | AggOp::Max => ScalarType::U32,
AggOp::LogSumExp => ScalarType::F64,
};
columns.push((agg_name, agg_type));
}
Schema::new(columns)
}
}