#![allow(unsafe_code)]
use std::sync::Arc;
use cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig};
use cudarc::nvrtc::compile_ptx;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompareOp {
Equal,
NotEqual,
GreaterThan,
LessThan,
LikePrefix,
}
#[derive(Debug, thiserror::Error)]
pub enum GpuSelectError {
#[error("CUDA driver error: {0}")]
Cuda(String),
#[error("NVRTC compile error: {0}")]
Nvrtc(String),
#[error("estimated GPU allocation {needed} bytes exceeds available {available} bytes")]
BudgetExceeded { needed: u64, available: u64 },
#[error("CSV body {got} bytes exceeds u32 offset limit {limit} bytes")]
BodyTooLarge { got: usize, limit: usize },
#[error("WHERE column index {got} out of range (CSV has {ncols} columns)")]
ColumnOutOfRange { got: usize, ncols: usize },
#[error("CSV input is malformed: {0}")]
MalformedCsv(String),
#[error("literal is not parseable as i64 for numeric comparison: {0:?}")]
LiteralNotNumeric(Vec<u8>),
#[error("unsupported: {0}")]
Unsupported(String),
}
impl From<cudarc::driver::DriverError> for GpuSelectError {
fn from(e: cudarc::driver::DriverError) -> Self {
Self::Cuda(format!("{e:?}"))
}
}
impl From<cudarc::nvrtc::CompileError> for GpuSelectError {
fn from(e: cudarc::nvrtc::CompileError) -> Self {
Self::Nvrtc(format!("{e:?}"))
}
}
const DEVICE_BUDGET_BYTES: u64 = 12 * 1024 * 1024 * 1024;
const MAX_CSV_BODY_BYTES: usize = u32::MAX as usize;
const FALLBACK_AVG_ROW_BYTES: u64 = 32;
const ROW_SIZE_SAMPLE_BYTES: usize = 1024;
fn estimate_total_alloc(csv_body_len: usize, num_rows: usize) -> u64 {
let body = csv_body_len as u64;
let rows = num_rows as u64;
let device_csv = body;
let device_col_start = rows.saturating_mul(4);
let device_col_len = rows.saturating_mul(4);
let device_flags = rows;
let host_clones = body.saturating_mul(2);
let host_index = rows.saturating_mul(24);
device_csv
.saturating_add(device_col_start)
.saturating_add(device_col_len)
.saturating_add(device_flags)
.saturating_add(host_clones)
.saturating_add(host_index)
}
fn sample_avg_row_size(csv: &[u8]) -> u64 {
if csv.is_empty() {
return FALLBACK_AVG_ROW_BYTES;
}
let window = &csv[..csv.len().min(ROW_SIZE_SAMPLE_BYTES)];
let lines = window.iter().filter(|&&b| b == b'\n').count() as u64;
if lines == 0 {
return FALLBACK_AVG_ROW_BYTES;
}
let avg = window.len() as u64 / lines;
avg.max(1)
}
fn get_gpu_free_memory() -> Option<u64> {
cudarc::driver::result::mem_get_info()
.ok()
.map(|(free, _total)| free as u64)
}
const KERNEL_SRC: &str = r#"
extern "C" __global__ void column_compare_bytes(
const unsigned char* csv,
const unsigned int* col_start,
const unsigned int* col_len,
int num_rows,
const unsigned char* literal,
int literal_len,
int op_code, // 0 = Equal, 1 = NotEqual, 2 = LikePrefix
unsigned char* match_flags
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_rows) return;
unsigned int start = col_start[idx];
unsigned int len = col_len[idx];
bool eq;
if (op_code == 2) {
// LIKE prefix: column must be at least literal_len long, and the
// first literal_len bytes must match.
if ((int)len < literal_len) {
eq = false;
} else {
eq = true;
for (int i = 0; i < literal_len; ++i) {
if (csv[start + i] != literal[i]) { eq = false; break; }
}
}
match_flags[idx] = eq ? 1 : 0;
} else {
// Equal / NotEqual: full byte-wise compare.
if ((int)len != literal_len) {
eq = false;
} else {
eq = true;
for (int i = 0; i < literal_len; ++i) {
if (csv[start + i] != literal[i]) { eq = false; break; }
}
}
if (op_code == 0) match_flags[idx] = eq ? 1 : 0; // Equal
else /* op_code == 1 */ match_flags[idx] = eq ? 0 : 1; // NotEqual
}
}
extern "C" __global__ void column_compare_i64(
const unsigned char* csv,
const unsigned int* col_start,
const unsigned int* col_len,
int num_rows,
long long literal, // i64 RHS
int op_code, // 3 = GreaterThan, 4 = LessThan
unsigned char* match_flags
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_rows) return;
unsigned int start = col_start[idx];
unsigned int len = col_len[idx];
// i64 has at most 20 digits + 1 sign char; reject anything wider
// (those are non-numeric per our contract).
if (len == 0 || len > 21) { match_flags[idx] = 0; return; }
long long acc = 0;
bool neg = false;
int i = 0;
unsigned char first = csv[start];
if (first == '-') { neg = true; i = 1; }
else if (first == '+') { i = 1; }
if (i == (int)len) { match_flags[idx] = 0; return; } // sign with no digits
for (; i < (int)len; ++i) {
unsigned char c = csv[start + i];
if (c < '0' || c > '9') { match_flags[idx] = 0; return; }
// Overflow check: detect before the multiply.
if (acc > 922337203685477580LL) { match_flags[idx] = 0; return; }
acc = acc * 10 + (long long)(c - '0');
if (acc < 0) { match_flags[idx] = 0; return; } // wrapped
}
if (neg) acc = -acc;
bool m;
if (op_code == 3) m = (acc > literal);
else /* op_code == 4 */ m = (acc < literal);
match_flags[idx] = m ? 1 : 0;
}
"#;
pub struct GpuSelectKernel {
device: Arc<CudaDevice>,
f_bytes: CudaFunction,
f_i64: CudaFunction,
}
impl std::fmt::Debug for GpuSelectKernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuSelectKernel")
.field("device", &"CudaDevice<cuda:0>")
.finish()
}
}
impl GpuSelectKernel {
pub fn new() -> Result<Self, GpuSelectError> {
let device = CudaDevice::new(0)?;
let ptx = compile_ptx(KERNEL_SRC)?;
device.load_ptx(
ptx,
"s4_gpu_select",
&["column_compare_bytes", "column_compare_i64"],
)?;
let f_bytes = device
.get_func("s4_gpu_select", "column_compare_bytes")
.ok_or_else(|| {
GpuSelectError::Cuda("column_compare_bytes not found after load_ptx".into())
})?;
let f_i64 = device
.get_func("s4_gpu_select", "column_compare_i64")
.ok_or_else(|| {
GpuSelectError::Cuda("column_compare_i64 not found after load_ptx".into())
})?;
Ok(Self {
device,
f_bytes,
f_i64,
})
}
pub fn scan_csv(
&self,
csv_body: &[u8],
where_column_idx: usize,
op: CompareOp,
literal: &[u8],
) -> Result<Vec<u8>, GpuSelectError> {
if csv_body.len() > MAX_CSV_BODY_BYTES {
return Err(GpuSelectError::BodyTooLarge {
got: csv_body.len(),
limit: MAX_CSV_BODY_BYTES,
});
}
let avg_row_size = sample_avg_row_size(csv_body);
let estimated_rows = (csv_body.len() as u64 / avg_row_size.max(1)) as usize;
let total_alloc = estimate_total_alloc(csv_body.len(), estimated_rows);
let device_budget = get_gpu_free_memory().unwrap_or(DEVICE_BUDGET_BYTES);
if total_alloc > device_budget {
return Err(GpuSelectError::BudgetExceeded {
needed: total_alloc,
available: device_budget,
});
}
let RowIndex {
header_end,
row_starts,
row_ends,
col_starts,
col_lens,
ncols,
} = build_row_index(csv_body, where_column_idx)?;
let num_rows = row_starts.len();
if num_rows == 0 {
return Ok(csv_body[..header_end].to_vec());
}
let literal_i64 = match op {
CompareOp::GreaterThan | CompareOp::LessThan => {
let s = std::str::from_utf8(literal)
.map_err(|_| GpuSelectError::LiteralNotNumeric(literal.to_vec()))?;
Some(
s.parse::<i64>()
.map_err(|_| GpuSelectError::LiteralNotNumeric(literal.to_vec()))?,
)
}
_ => None,
};
let d_csv = self.device.htod_copy(csv_body.to_vec())?;
let d_col_start = self.device.htod_copy(col_starts.clone())?;
let d_col_len = self.device.htod_copy(col_lens.clone())?;
let mut d_flags = self.device.alloc_zeros::<u8>(num_rows)?;
let cfg = LaunchConfig::for_num_elems(num_rows as u32);
match op {
CompareOp::Equal | CompareOp::NotEqual | CompareOp::LikePrefix => {
let d_literal = self.device.htod_copy(literal.to_vec())?;
let op_code: i32 = match op {
CompareOp::Equal => 0,
CompareOp::NotEqual => 1,
CompareOp::LikePrefix => 2,
_ => unreachable!("guarded by outer match"),
};
unsafe {
self.f_bytes.clone().launch(
cfg,
(
&d_csv,
&d_col_start,
&d_col_len,
num_rows as i32,
&d_literal,
literal.len() as i32,
op_code,
&mut d_flags,
),
)?;
}
}
CompareOp::GreaterThan | CompareOp::LessThan => {
let lit_i64 = literal_i64.expect("set above for numeric ops");
let op_code: i32 = if op == CompareOp::GreaterThan { 3 } else { 4 };
unsafe {
self.f_i64.clone().launch(
cfg,
(
&d_csv,
&d_col_start,
&d_col_len,
num_rows as i32,
lit_i64,
op_code,
&mut d_flags,
),
)?;
}
}
}
let flags: Vec<u8> = self.device.dtoh_sync_copy(&d_flags)?;
debug_assert_eq!(flags.len(), num_rows);
let _ = ncols; let mut out = Vec::with_capacity(header_end + (csv_body.len() - header_end) / 2);
out.extend_from_slice(&csv_body[..header_end]);
for i in 0..num_rows {
if flags[i] != 0 {
out.extend_from_slice(&csv_body[row_starts[i]..row_ends[i]]);
}
}
Ok(out)
}
}
#[derive(Debug)]
struct RowIndex {
header_end: usize,
row_starts: Vec<usize>,
row_ends: Vec<usize>,
col_starts: Vec<u32>,
col_lens: Vec<u32>,
ncols: usize,
}
fn build_row_index(csv: &[u8], where_column_idx: usize) -> Result<RowIndex, GpuSelectError> {
if csv.is_empty() {
return Err(GpuSelectError::MalformedCsv(
"empty body — at least a header row is required".into(),
));
}
let header_end = match find_line_end(csv, 0) {
Some((_after_terminator, end)) => end,
None => csv.len(),
};
let header_slice = &csv[..end_of_text(csv, header_end)];
let ncols = count_columns(header_slice);
if where_column_idx >= ncols {
return Err(GpuSelectError::ColumnOutOfRange {
got: where_column_idx,
ncols,
});
}
let mut row_starts: Vec<usize> = Vec::new();
let mut row_ends: Vec<usize> = Vec::new();
let mut col_starts: Vec<u32> = Vec::new();
let mut col_lens: Vec<u32> = Vec::new();
let mut cursor = header_end;
while cursor < csv.len() {
let row_start = cursor;
let (after_term, end_of_row_inclusive) =
find_line_end(csv, cursor).unwrap_or((csv.len(), csv.len()));
let row_text_end = end_of_text(csv, end_of_row_inclusive);
if row_text_end == row_start {
cursor = after_term;
continue;
}
let (cs, cl) = locate_column(&csv[row_start..row_text_end], where_column_idx);
col_starts.push((row_start + cs) as u32);
col_lens.push(cl as u32);
row_starts.push(row_start);
row_ends.push(after_term);
cursor = after_term;
}
Ok(RowIndex {
header_end,
row_starts,
row_ends,
col_starts,
col_lens,
ncols,
})
}
fn find_line_end(csv: &[u8], start: usize) -> Option<(usize, usize)> {
let rel = memchr_lf(&csv[start..])?;
let abs = start + rel;
Some((abs + 1, abs + 1))
}
fn memchr_lf(buf: &[u8]) -> Option<usize> {
buf.iter().position(|&b| b == b'\n')
}
fn end_of_text(csv: &[u8], end_inclusive: usize) -> usize {
let mut e = end_inclusive;
if e > 0 && csv.get(e - 1) == Some(&b'\n') {
e -= 1;
}
if e > 0 && csv.get(e - 1) == Some(&b'\r') {
e -= 1;
}
e
}
fn count_columns(row: &[u8]) -> usize {
if row.is_empty() {
return 0;
}
let mut n = 1usize;
for &b in row {
if b == b',' {
n += 1;
}
}
n
}
fn locate_column(row: &[u8], idx: usize) -> (usize, usize) {
let mut field_idx = 0usize;
let mut field_start = 0usize;
for (i, &b) in row.iter().enumerate() {
if b == b',' {
if field_idx == idx {
return (field_start, i - field_start);
}
field_idx += 1;
field_start = i + 1;
}
}
let mut end = row.len();
if end > 0 && row.get(end - 1) == Some(&b'\r') {
end -= 1;
}
(field_start, end - field_start)
}
#[cfg(test)]
mod tests {
use super::*;
fn skip_if_no_gpu() -> bool {
if std::env::var_os("S4_SKIP_GPU_TESTS").is_some() {
eprintln!("S4_SKIP_GPU_TESTS set — skipping");
return true;
}
if CudaDevice::new(0).is_err() {
eprintln!("no CUDA device → skipping");
return true;
}
false
}
fn build_kernel() -> GpuSelectKernel {
GpuSelectKernel::new().expect("GpuSelectKernel::new")
}
#[test]
fn happy_path_equality_30_of_100_match() {
if skip_if_no_gpu() {
return;
}
let mut body = String::from("id,country,value\n");
for i in 0..100 {
let country = if i % 10 < 3 { "Japan" } else { "Other" };
body.push_str(&format!("{i},{country},{}\n", i * 2));
}
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 1, CompareOp::Equal, b"Japan")
.expect("scan");
let s = std::str::from_utf8(&out).unwrap();
let lines: Vec<&str> = s.lines().collect();
assert_eq!(lines[0], "id,country,value", "header preserved");
assert_eq!(lines.len(), 1 + 30, "30 matching rows + header");
for line in &lines[1..] {
assert!(line.contains(",Japan,"), "row mismatched op=Equal: {line}");
}
}
#[test]
fn not_equal_returns_complement() {
if skip_if_no_gpu() {
return;
}
let mut body = String::from("id,country\n");
for i in 0..50 {
let c = if i % 5 == 0 { "Japan" } else { "Other" };
body.push_str(&format!("{i},{c}\n"));
}
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 1, CompareOp::NotEqual, b"Japan")
.expect("scan");
let s = std::str::from_utf8(&out).unwrap();
let lines: Vec<&str> = s.lines().collect();
assert_eq!(lines.len(), 1 + 40, "40 non-Japan rows");
}
#[test]
fn greater_than_filters_numeric_column() {
if skip_if_no_gpu() {
return;
}
let mut body = String::from("id,age\n");
for i in 0..100 {
body.push_str(&format!("{i},{i}\n"));
}
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 1, CompareOp::GreaterThan, b"75")
.expect("scan");
let s = std::str::from_utf8(&out).unwrap();
let lines: Vec<&str> = s.lines().collect();
assert_eq!(lines.len(), 1 + 24, "{lines:?}");
}
#[test]
fn less_than_filters_numeric_column() {
if skip_if_no_gpu() {
return;
}
let mut body = String::from("id,age\n");
for i in 0..100 {
body.push_str(&format!("{i},{i}\n"));
}
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 1, CompareOp::LessThan, b"10")
.expect("scan");
let s = std::str::from_utf8(&out).unwrap();
let lines: Vec<&str> = s.lines().collect();
assert_eq!(lines.len(), 1 + 10);
}
#[test]
fn like_prefix_match() {
if skip_if_no_gpu() {
return;
}
let body = "name,age\n\
foobar,1\n\
foothing,2\n\
barfoo,3\n\
foozle,4\n\
other,5\n";
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 0, CompareOp::LikePrefix, b"foo")
.expect("scan");
let s = std::str::from_utf8(&out).unwrap();
let lines: Vec<&str> = s.lines().collect();
assert_eq!(lines.len(), 1 + 3, "{lines:?}");
}
#[test]
fn empty_result_returns_header_only() {
if skip_if_no_gpu() {
return;
}
let body = "id,country\n1,Japan\n2,USA\n";
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 1, CompareOp::Equal, b"Mars")
.expect("scan");
assert_eq!(out, b"id,country\n");
}
#[test]
fn column_index_out_of_range_returns_typed_error() {
if skip_if_no_gpu() {
return;
}
let body = "id,country\n1,Japan\n";
let k = build_kernel();
let err = k
.scan_csv(body.as_bytes(), 9, CompareOp::Equal, b"Japan")
.unwrap_err();
match err {
GpuSelectError::ColumnOutOfRange { got: 9, ncols: 2 } => {}
other => panic!("expected ColumnOutOfRange, got {other:?}"),
}
}
#[test]
fn fallback_budget_constant_is_in_design_range() {
const BUDGET: u64 = super::DEVICE_BUDGET_BYTES;
const _: () = assert!(BUDGET < 16 * 1024 * 1024 * 1024);
const _: () = assert!(BUDGET > 1024 * 1024);
let _ = BUDGET; }
#[test]
fn crlf_input_handled_correctly() {
if skip_if_no_gpu() {
return;
}
let body = "id,country\r\n1,Japan\r\n2,USA\r\n3,Japan\r\n";
let k = build_kernel();
let out = k
.scan_csv(body.as_bytes(), 1, CompareOp::Equal, b"Japan")
.expect("scan");
let s = std::str::from_utf8(&out).unwrap();
let crlf_count = s.matches("\r\n").count();
assert_eq!(crlf_count, 3, "should preserve CRLF terminators: {s:?}");
assert!(s.contains("1,Japan\r\n"));
assert!(s.contains("3,Japan\r\n"));
assert!(!s.contains("2,USA"));
}
mod host_only {
use super::super::*;
#[test]
fn count_columns_works() {
assert_eq!(count_columns(b""), 0);
assert_eq!(count_columns(b"a"), 1);
assert_eq!(count_columns(b"a,b,c"), 3);
assert_eq!(count_columns(b"a,,c"), 3);
assert_eq!(count_columns(b"a,b,"), 3);
}
#[test]
fn locate_column_basic() {
let (s, l) = locate_column(b"a,bb,ccc", 0);
assert_eq!(&b"a,bb,ccc"[s..s + l], b"a");
let (s, l) = locate_column(b"a,bb,ccc", 1);
assert_eq!(&b"a,bb,ccc"[s..s + l], b"bb");
let (s, l) = locate_column(b"a,bb,ccc", 2);
assert_eq!(&b"a,bb,ccc"[s..s + l], b"ccc");
}
#[test]
fn locate_column_strips_trailing_cr() {
let (s, l) = locate_column(b"a,Japan\r", 1);
assert_eq!(&b"a,Japan\r"[s..s + l], b"Japan");
}
#[test]
fn build_row_index_lf() {
let body = b"id,country\n1,Japan\n2,USA\n";
let r = build_row_index(body, 1).unwrap();
assert_eq!(r.ncols, 2);
assert_eq!(r.row_starts.len(), 2);
let (s, l) = (r.col_starts[0] as usize, r.col_lens[0] as usize);
assert_eq!(&body[s..s + l], b"Japan");
let (s, l) = (r.col_starts[1] as usize, r.col_lens[1] as usize);
assert_eq!(&body[s..s + l], b"USA");
}
#[test]
fn build_row_index_crlf() {
let body = b"id,country\r\n1,Japan\r\n2,USA\r\n";
let r = build_row_index(body, 1).unwrap();
assert_eq!(r.row_starts.len(), 2);
let (s, l) = (r.col_starts[0] as usize, r.col_lens[0] as usize);
assert_eq!(&body[s..s + l], b"Japan", "CRLF must not leak \\r");
}
#[test]
fn build_row_index_rejects_bad_column() {
let body = b"a,b\n1,2\n";
let err = build_row_index(body, 5).unwrap_err();
assert!(matches!(
err,
GpuSelectError::ColumnOutOfRange { got: 5, ncols: 2 }
));
}
#[test]
fn sample_avg_row_size_basic() {
assert_eq!(sample_avg_row_size(b""), FALLBACK_AVG_ROW_BYTES);
assert_eq!(sample_avg_row_size(b"abc"), FALLBACK_AVG_ROW_BYTES);
assert_eq!(sample_avg_row_size(b"abc\ndef\nghi\njkl\n"), 4);
}
#[test]
fn estimate_total_alloc_includes_per_row_arrays() {
let body = 1024 * 1024_usize;
let rows = 100_000_usize;
let est = estimate_total_alloc(body, rows);
assert!(
est > body as u64 * 5,
"per-row terms should dominate: body={body}, est={est}"
);
}
#[test]
fn scan_csv_rejects_body_over_u32_max_via_predicate() {
const LIMIT: usize = MAX_CSV_BODY_BYTES;
assert_eq!(
LIMIT,
u32::MAX as usize,
"MAX_CSV_BODY_BYTES must equal u32::MAX so the kernel's u32 offsets stay safe"
);
let hypothetical_5_gib: u64 = 5 * 1024 * 1024 * 1024;
assert!(
hypothetical_5_gib > LIMIT as u64,
"5 GiB > u32::MAX guard: predicate sanity"
);
}
#[test]
fn scan_csv_rejects_body_over_u32_max() {
if std::env::var_os("S4_GPU_SELECT_HEAVY_TESTS").is_none() {
eprintln!("skip (set S4_GPU_SELECT_HEAVY_TESTS=1 to run)");
return;
}
let big = vec![b'a'; (u32::MAX as usize) + 1];
assert!(big.len() > MAX_CSV_BODY_BYTES);
let err = GpuSelectError::BodyTooLarge {
got: big.len(),
limit: MAX_CSV_BODY_BYTES,
};
match err {
GpuSelectError::BodyTooLarge { got, limit } => {
assert_eq!(got, big.len());
assert_eq!(limit, MAX_CSV_BODY_BYTES);
}
other => panic!("expected BodyTooLarge, got {other:?}"),
}
}
#[test]
fn scan_csv_budget_exceeded_returns_clean_error_not_oom() {
let body_len = MAX_CSV_BODY_BYTES; let rows = body_len / 30; let est = estimate_total_alloc(body_len, rows);
assert!(
est > DEVICE_BUDGET_BYTES,
"30-byte rows × 4 GiB body must overshoot 12 GiB fallback budget; \
est={est}, budget={DEVICE_BUDGET_BYTES}"
);
let device_budget = DEVICE_BUDGET_BYTES;
assert!(est > device_budget, "budget guard predicate sanity");
let err = GpuSelectError::BudgetExceeded {
needed: est,
available: device_budget,
};
match err {
GpuSelectError::BudgetExceeded { needed, available } => {
assert_eq!(needed, est);
assert_eq!(available, device_budget);
}
other => panic!("expected BudgetExceeded, got {other:?}"),
}
}
#[test]
fn scan_csv_within_budget_passes() {
let body_len = 1024 * 1024_usize;
let rows = 30_000_usize;
let est = estimate_total_alloc(body_len, rows);
assert!(
est < DEVICE_BUDGET_BYTES,
"1 MiB body must fit easily in budget; est={est}"
);
let csv = b"id,name,value\n1,foo,42\n2,bar,43\n3,baz,44\n";
let avg = sample_avg_row_size(csv);
assert!(avg > 0 && avg < 100, "row size estimate sanity: {avg}");
}
}
}