use std::sync::Arc;
#[cfg(feature = "gpu")]
use cudarc::driver::{CudaContext, CudaStream};
use crate::error::{DbxError, DbxResult};
#[cfg(feature = "gpu")]
pub struct PersistentKernelConfig {
pub max_tasks: usize,
pub timeout_ms: u64,
pub threads_per_block: u32,
pub num_blocks: u32,
}
#[cfg(feature = "gpu")]
impl Default for PersistentKernelConfig {
fn default() -> Self {
Self {
max_tasks: 1000,
timeout_ms: 100,
threads_per_block: 256,
num_blocks: 1,
}
}
}
#[cfg(feature = "gpu")]
const PERSISTENT_KERNEL_SRC: &str = r#"
extern "C" __global__ void persistent_scan_kernel(
const float* __restrict__ input,
float* __restrict__ output,
const int* __restrict__ work_queue,
volatile int* __restrict__ control,
int data_size
) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
// Persistent loop: keep running until host signals shutdown
while (atomicAdd((int*)control, 0) != 0) {
// Read current task from work queue
int task_id = atomicAdd((int*)&work_queue[0], 0);
if (task_id < 0) {
// No work available, spin-wait
continue;
}
// Process: parallel scan/filter over input data
for (int i = tid; i < data_size; i += stride) {
output[i] = input[i];
}
__threadfence();
// Signal task completion (first thread only)
if (tid == 0) {
atomicExch((int*)&work_queue[0], -1);
}
__syncthreads();
}
}
"#;
#[cfg(feature = "gpu")]
pub struct PersistentKernelManager {
device: Arc<CudaContext>,
config: PersistentKernelConfig,
module: Option<Arc<cudarc::driver::CudaModule>>,
}
#[cfg(feature = "gpu")]
impl PersistentKernelManager {
pub fn new(device: Arc<CudaContext>, config: PersistentKernelConfig) -> Self {
Self {
device,
config,
module: None,
}
}
pub fn device(&self) -> &Arc<CudaContext> {
&self.device
}
pub fn config(&self) -> &PersistentKernelConfig {
&self.config
}
pub fn compile_kernel(&mut self) -> DbxResult<()> {
use cudarc::nvrtc::Ptx;
let ptx = Ptx::compile_source(PERSISTENT_KERNEL_SRC)
.map_err(|e| DbxError::Gpu(format!("NVRTC compilation failed: {:?}", e)))?;
let module = self
.device
.load_module(ptx)
.map_err(|e| DbxError::Gpu(format!("Module load failed: {:?}", e)))?;
self.module = Some(module);
Ok(())
}
pub fn is_ready(&self) -> bool {
self.module.is_some()
}
pub fn get_kernel_function(&self) -> DbxResult<Option<Arc<cudarc::driver::CudaFunction>>> {
match &self.module {
Some(module) => {
let func = module
.load_function("persistent_scan_kernel")
.map_err(|e| {
DbxError::Gpu(format!("Failed to load kernel function: {:?}", e))
})?;
Ok(Some(func))
}
None => Ok(None),
}
}
pub fn launch_config(&self) -> (u32, u32) {
(self.config.num_blocks, self.config.threads_per_block)
}
}
#[cfg(not(feature = "gpu"))]
pub struct PersistentKernelManager;
#[cfg(not(feature = "gpu"))]
pub struct PersistentKernelConfig;
#[cfg(not(feature = "gpu"))]
impl Default for PersistentKernelConfig {
fn default() -> Self {
Self
}
}
#[cfg(not(feature = "gpu"))]
impl PersistentKernelManager {
pub fn new(_device: (), _config: PersistentKernelConfig) -> Self {
Self
}
}