use std::sync::Arc;
use cudarc::driver::DeviceSlice;
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::CudaKernelProvider;
use crate::pir::{PirGraph, PirNode, PirNodeId};
pub(crate) const PIR_CONST: u8 = 0;
pub const PIR_LIT: u8 = 1;
pub const PIR_NEG_LIT: u8 = 2;
pub const PIR_AND: u8 = 3;
pub const PIR_OR: u8 = 4;
pub(crate) const PIR_DECISION: u8 = 5;
pub struct GpuPirGraph {
pub node_type: TrackedCudaSlice<u8>,
pub child_offsets: TrackedCudaSlice<u32>,
pub children: TrackedCudaSlice<u32>,
pub leaf_id: TrackedCudaSlice<u32>,
pub decision_var: TrackedCudaSlice<u32>,
pub decision_child_false: TrackedCudaSlice<u32>,
pub decision_child_true: TrackedCudaSlice<u32>,
}
pub struct GpuPirRoots {
pub roots: TrackedCudaSlice<u32>,
}
impl GpuPirGraph {
pub fn from_host(pir: &PirGraph, provider: &Arc<CudaKernelProvider>) -> Result<Self> {
let num_nodes = pir.len();
let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
XlogError::Compilation("GpuPirGraph::from_host: node count overflow".to_string())
})?;
let mut node_type: Vec<u8> = Vec::with_capacity(num_nodes);
let mut child_offsets: Vec<u32> = Vec::with_capacity(num_nodes + 1);
let mut children: Vec<u32> = Vec::new();
let mut leaf_id: Vec<u32> = Vec::with_capacity(num_nodes);
let mut decision_var: Vec<u32> = Vec::with_capacity(num_nodes);
let mut decision_child_false: Vec<u32> = Vec::with_capacity(num_nodes);
let mut decision_child_true: Vec<u32> = Vec::with_capacity(num_nodes);
child_offsets.push(0);
for (idx, node) in pir.nodes().iter().enumerate() {
let node_id = u32::try_from(idx).map_err(|_| {
XlogError::Compilation("GpuPirGraph::from_host: node id overflow".to_string())
})?;
match node {
PirNode::Const(value) => {
node_type.push(PIR_CONST);
leaf_id.push(u32::from(*value));
decision_var.push(0);
decision_child_false.push(0);
decision_child_true.push(0);
}
PirNode::Lit { leaf } => {
node_type.push(PIR_LIT);
leaf_id.push(leaf.as_u32());
decision_var.push(0);
decision_child_false.push(0);
decision_child_true.push(0);
}
PirNode::NegLit { leaf } => {
node_type.push(PIR_NEG_LIT);
leaf_id.push(leaf.as_u32());
decision_var.push(0);
decision_child_false.push(0);
decision_child_true.push(0);
}
PirNode::And { children: kids } => {
validate_children_sorted(node_id, kids, num_nodes_u32)?;
node_type.push(PIR_AND);
leaf_id.push(0);
decision_var.push(0);
decision_child_false.push(0);
decision_child_true.push(0);
for &child in kids {
children.push(child.as_u32());
}
}
PirNode::Or { children: kids } => {
validate_children_sorted(node_id, kids, num_nodes_u32)?;
node_type.push(PIR_OR);
leaf_id.push(0);
decision_var.push(0);
decision_child_false.push(0);
decision_child_true.push(0);
for &child in kids {
children.push(child.as_u32());
}
}
PirNode::Decision {
var,
child_false,
child_true,
} => {
validate_child_id(node_id, *child_false, num_nodes_u32)?;
validate_child_id(node_id, *child_true, num_nodes_u32)?;
node_type.push(PIR_DECISION);
leaf_id.push(0);
decision_var.push(var.as_u32());
decision_child_false.push(child_false.as_u32());
decision_child_true.push(child_true.as_u32());
}
}
let next_off = u32::try_from(children.len()).map_err(|_| {
XlogError::Compilation(
"GpuPirGraph::from_host: children count exceeds u32".to_string(),
)
})?;
child_offsets.push(next_off);
}
if child_offsets.len() != num_nodes + 1 {
return Err(XlogError::Compilation(
"GpuPirGraph::from_host: child_offsets length mismatch".to_string(),
));
}
let memory = provider.memory();
let mut d_node_type = memory.alloc::<u8>(node_type.len())?;
let mut d_child_offsets = memory.alloc::<u32>(child_offsets.len())?;
let mut d_children = memory.alloc::<u32>(children.len())?;
let mut d_leaf_id = memory.alloc::<u32>(leaf_id.len())?;
let mut d_decision_var = memory.alloc::<u32>(decision_var.len())?;
let mut d_decision_child_false = memory.alloc::<u32>(decision_child_false.len())?;
let mut d_decision_child_true = memory.alloc::<u32>(decision_child_true.len())?;
provider
.htod_sync_copy_into_tracked(&node_type, &mut d_node_type)
.map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload node_type: {}", e)))?;
provider
.htod_sync_copy_into_tracked(&child_offsets, &mut d_child_offsets)
.map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload child_offsets: {}", e)))?;
provider
.htod_sync_copy_into_tracked(&children, &mut d_children)
.map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload children: {}", e)))?;
provider
.htod_sync_copy_into_tracked(&leaf_id, &mut d_leaf_id)
.map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload leaf_id: {}", e)))?;
provider
.htod_sync_copy_into_tracked(&decision_var, &mut d_decision_var)
.map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload decision_var: {}", e)))?;
provider
.htod_sync_copy_into_tracked(&decision_child_false, &mut d_decision_child_false)
.map_err(|e| {
XlogError::Kernel(format!("GpuPirGraph upload decision_child_false: {}", e))
})?;
provider
.htod_sync_copy_into_tracked(&decision_child_true, &mut d_decision_child_true)
.map_err(|e| {
XlogError::Kernel(format!("GpuPirGraph upload decision_child_true: {}", e))
})?;
Ok(Self {
node_type: d_node_type,
child_offsets: d_child_offsets,
children: d_children,
leaf_id: d_leaf_id,
decision_var: d_decision_var,
decision_child_false: d_decision_child_false,
decision_child_true: d_decision_child_true,
})
}
pub fn num_nodes(&self) -> usize {
self.node_type.len()
}
}
impl GpuPirRoots {
pub fn from_host(roots: &[PirNodeId], provider: &Arc<CudaKernelProvider>) -> Result<Self> {
let mut host: Vec<u32> = Vec::with_capacity(roots.len());
for &r in roots {
host.push(r.as_u32());
}
let memory = provider.memory();
let mut d_roots = memory.alloc::<u32>(host.len())?;
provider
.htod_sync_copy_into_tracked(&host, &mut d_roots)
.map_err(|e| XlogError::Kernel(format!("GpuPirRoots upload: {}", e)))?;
Ok(Self { roots: d_roots })
}
}
fn validate_child_id(parent: u32, child: PirNodeId, num_nodes: u32) -> Result<()> {
let id = child.as_u32();
if id >= num_nodes {
return Err(XlogError::Compilation(format!(
"GpuPirGraph::from_host: child {:?} out of bounds for parent {}",
child, parent
)));
}
Ok(())
}
fn validate_children_sorted(parent: u32, children: &[PirNodeId], num_nodes: u32) -> Result<()> {
let mut prev: Option<u32> = None;
for &child in children {
let id = child.as_u32();
if id >= num_nodes {
return Err(XlogError::Compilation(format!(
"GpuPirGraph::from_host: child {:?} out of bounds for parent {}",
child, parent
)));
}
if let Some(p) = prev {
if id <= p {
return Err(XlogError::Compilation(format!(
"GpuPirGraph::from_host: children of {} must be sorted and unique",
parent
)));
}
}
prev = Some(id);
}
Ok(())
}