use cudarc::driver::{CudaView, DeviceSlice};
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::CudaKernelProvider;
#[derive(Debug, Clone, Copy)]
pub struct NeuralFastPathConfig {
pub eps: f64,
pub min_p: f64,
}
impl Default for NeuralFastPathConfig {
fn default() -> Self {
Self {
eps: 1e-7,
min_p: 1e-12,
}
}
}
pub struct GpuWeightSlots {
group_offsets_host: Vec<u32>,
group_offsets: TrackedCudaSlice<u32>, slot_cnf_var: TrackedCudaSlice<u32>, }
impl GpuWeightSlots {
pub fn upload(provider: &CudaKernelProvider, groups: &[Vec<u32>]) -> Result<Self> {
let mut offsets: Vec<u32> = Vec::with_capacity(groups.len().saturating_add(1));
offsets.push(0);
let mut flat: Vec<u32> = Vec::new();
for g in groups {
flat.extend_from_slice(g);
offsets.push(flat.len() as u32);
}
let memory = provider.memory().clone();
let device = provider.device().inner();
let mut d_offsets = memory.alloc::<u32>(offsets.len())?;
device
.htod_sync_copy_into(&offsets, &mut d_offsets)
.map_err(|e| {
XlogError::Kernel(format!("Failed to upload weight slot offsets: {}", e))
})?;
let mut d_vars = memory.alloc::<u32>(flat.len())?;
device
.htod_sync_copy_into(&flat, &mut d_vars)
.map_err(|e| XlogError::Kernel(format!("Failed to upload weight slot vars: {}", e)))?;
Ok(Self {
group_offsets_host: offsets,
group_offsets: d_offsets,
slot_cnf_var: d_vars,
})
}
pub fn num_groups(&self) -> u32 {
self.group_offsets_host
.len()
.saturating_sub(1)
.try_into()
.unwrap_or(0)
}
pub fn total_slots(&self) -> u32 {
self.group_offsets_host.last().copied().unwrap_or(0)
}
pub fn group_offsets(&self) -> &TrackedCudaSlice<u32> {
&self.group_offsets
}
pub fn slot_cnf_var(&self) -> &TrackedCudaSlice<u32> {
&self.slot_cnf_var
}
pub fn group_slot_cnf_var(&self, group_idx: usize) -> Result<CudaView<'_, u32>> {
let start = *self
.group_offsets_host
.get(group_idx)
.ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
as usize;
let end = *self
.group_offsets_host
.get(group_idx + 1)
.ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
as usize;
if end < start || end > self.slot_cnf_var.len() {
return Err(XlogError::Compilation(
"Invalid group slot range in GpuWeightSlots".to_string(),
));
}
Ok(self.slot_cnf_var.slice(start..end))
}
}