use cudarc::driver::{CudaView, DeviceSlice};
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::CudaKernelProvider;
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct NeuralFastPathConfig {
pub eps: f64,
pub min_p: f64,
}
impl Default for NeuralFastPathConfig {
fn default() -> Self {
Self {
eps: 1e-7,
min_p: 1e-12,
}
}
}
pub struct GpuWeightSlots {
group_offsets_host: Vec<u32>,
group_offsets: TrackedCudaSlice<u32>, slot_cnf_var: TrackedCudaSlice<u32>, }
impl GpuWeightSlots {
pub fn upload(provider: &CudaKernelProvider, groups: &[Vec<u32>]) -> Result<Self> {
if groups.len() > u32::MAX as usize {
return Err(XlogError::Compilation(
"Neural fast-path group count exceeds GPU u32 index space".to_string(),
));
}
let offset_count = groups.len().checked_add(1).ok_or_else(|| {
XlogError::Compilation("Neural fast-path group offset count overflow".to_string())
})?;
let total_slots = groups.iter().try_fold(0usize, |acc, group| {
acc.checked_add(group.len()).ok_or_else(|| {
XlogError::Compilation("Neural fast-path slot count overflow".to_string())
})
})?;
if total_slots > u32::MAX as usize {
return Err(XlogError::Compilation(
"Neural fast-path slot count exceeds GPU u32 index space".to_string(),
));
}
let mut offsets: Vec<u32> = Vec::with_capacity(offset_count);
offsets.push(0);
let mut flat: Vec<u32> = Vec::with_capacity(total_slots);
for g in groups {
flat.extend_from_slice(g);
let next_offset = u32::try_from(flat.len()).map_err(|_| {
XlogError::Compilation(
"Neural fast-path slot offset exceeds GPU u32 index space".to_string(),
)
})?;
offsets.push(next_offset);
}
let memory = provider.memory().clone();
let mut d_offsets = memory.alloc::<u32>(offsets.len())?;
provider
.htod_sync_copy_into_tracked(&offsets, &mut d_offsets)
.map_err(|e| {
XlogError::Kernel(format!("Failed to upload weight slot offsets: {}", e))
})?;
let mut d_vars = memory.alloc::<u32>(flat.len())?;
provider
.htod_sync_copy_into_tracked(&flat, &mut d_vars)
.map_err(|e| XlogError::Kernel(format!("Failed to upload weight slot vars: {}", e)))?;
Ok(Self {
group_offsets_host: offsets,
group_offsets: d_offsets,
slot_cnf_var: d_vars,
})
}
pub fn num_groups(&self) -> u32 {
debug_assert!(!self.group_offsets_host.is_empty());
(self.group_offsets_host.len() - 1) as u32
}
pub fn num_groups_usize(&self) -> usize {
debug_assert!(!self.group_offsets_host.is_empty());
self.group_offsets_host.len() - 1
}
pub fn total_slots(&self) -> u32 {
debug_assert!(!self.group_offsets_host.is_empty());
self.group_offsets_host[self.group_offsets_host.len() - 1]
}
pub fn group_offsets(&self) -> &TrackedCudaSlice<u32> {
&self.group_offsets
}
pub fn slot_cnf_var(&self) -> &TrackedCudaSlice<u32> {
&self.slot_cnf_var
}
pub fn group_slot_cnf_var(&self, group_idx: usize) -> Result<CudaView<'_, u32>> {
let start = *self
.group_offsets_host
.get(group_idx)
.ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
as usize;
let end = *self
.group_offsets_host
.get(group_idx + 1)
.ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
as usize;
if end < start || end > self.slot_cnf_var.len() {
return Err(XlogError::Compilation(
"Invalid group slot range in GpuWeightSlots".to_string(),
));
}
Ok(self.slot_cnf_var.slice(start..end))
}
}