use std::sync::Arc;
use xlog_core::MemoryBudget;
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
use xlog_prob::compilation::gpu_cnf::GpuCnfVarTables;
use xlog_prob::compilation::gpu_weights::{build_evidence_by_var_gpu, build_weights_gpu};
fn try_provider() -> Option<Arc<CudaKernelProvider>> {
let device = match CudaDevice::new(0) {
Ok(d) => Arc::new(d),
Err(e) => {
eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
return None;
}
};
let memory = Arc::new(GpuMemoryManager::new(
device.clone(),
MemoryBudget::with_limit(1 << 30),
));
match CudaKernelProvider::new(device, memory) {
Ok(p) => Some(Arc::new(p)),
Err(e) => {
eprintln!(
"Skipping test: failed to create CUDA kernel provider: {}",
e
);
None
}
}
}
fn upload_u32(provider: &Arc<CudaKernelProvider>, host: &[u32]) -> TrackedCudaSlice<u32> {
let memory = provider.memory();
let mut buf = memory.alloc::<u32>(host.len()).expect("alloc u32");
provider
.device()
.inner()
.htod_sync_copy_into(host, &mut buf)
.expect("upload u32");
buf
}
fn upload_f64(provider: &Arc<CudaKernelProvider>, host: &[f64]) -> TrackedCudaSlice<f64> {
let memory = provider.memory();
let mut buf = memory.alloc::<f64>(host.len()).expect("alloc f64");
provider
.device()
.inner()
.htod_sync_copy_into(host, &mut buf)
.expect("upload f64");
buf
}
fn upload_u8(provider: &Arc<CudaKernelProvider>, host: &[u8]) -> TrackedCudaSlice<u8> {
let memory = provider.memory();
let mut buf = memory.alloc::<u8>(host.len()).expect("alloc u8");
provider
.device()
.inner()
.htod_sync_copy_into(host, &mut buf)
.expect("upload u8");
buf
}
fn ln_prob(p: f64) -> f64 {
if p == 0.0 {
f64::NEG_INFINITY
} else {
p.ln()
}
}
#[test]
fn gpu_weights_builds_log_tables_and_evidence() {
let Some(provider) = try_provider() else {
return;
};
let node_var = upload_u32(&provider, &[4, 0, 0]); let leaf_var = upload_u32(&provider, &[2, 5, 0]); let choice_var = upload_u32(&provider, &[3, 6]);
let vars = GpuCnfVarTables {
node_var,
leaf_var,
choice_var,
max_var: 6,
};
let leaf_probs = upload_f64(&provider, &[0.2, 0.7]);
let choice_true = upload_f64(&provider, &[0.1, 0.6]);
let choice_false = upload_f64(&provider, &[0.9, 0.4]);
let evidence_nodes = upload_u32(&provider, &[0]);
let evidence_vals = upload_u8(&provider, &[1]);
let evidence_by_var = build_evidence_by_var_gpu(
&vars.node_var,
&evidence_nodes,
&evidence_vals,
vars.max_var,
&provider,
)
.expect("evidence map");
let weights = build_weights_gpu(
&vars,
&leaf_probs,
&choice_true,
&choice_false,
&evidence_by_var,
&provider,
)
.expect("weights");
let mut host_true = vec![0.0f64; (vars.max_var as usize) + 1];
let mut host_false = vec![0.0f64; (vars.max_var as usize) + 1];
provider
.device()
.inner()
.dtoh_sync_copy_into(&weights.log_true, &mut host_true)
.expect("read log_true");
provider
.device()
.inner()
.dtoh_sync_copy_into(&weights.log_false, &mut host_false)
.expect("read log_false");
assert!((host_true[2] - ln_prob(0.2)).abs() < 1e-9);
assert!((host_false[2] - ln_prob(0.8)).abs() < 1e-9);
assert!((host_true[5] - ln_prob(0.7)).abs() < 1e-9);
assert!((host_false[5] - ln_prob(0.3)).abs() < 1e-9);
assert!((host_true[3] - ln_prob(0.1)).abs() < 1e-9);
assert!((host_false[3] - ln_prob(0.9)).abs() < 1e-9);
assert!((host_true[6] - ln_prob(0.6)).abs() < 1e-9);
assert!((host_false[6] - ln_prob(0.4)).abs() < 1e-9);
assert!(host_false[4].is_infinite() && host_false[4].is_sign_negative());
}