xlog-prob 0.9.2

Probabilistic inference engines for XLOG
use std::sync::Arc;

use xlog_core::MemoryBudget;
use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
use xlog_prob::compilation::gpu_cache::{GpuCircuitCache, GpuCircuitCacheConfig};
use xlog_prob::gpu::GpuXgcf;
use xlog_prob::xgcf::{Xgcf, XgcfNodeType};

#[test]
fn gpu_eval_device_only_matches_cached_eval() {
    let device = match CudaDevice::new(0) {
        Ok(d) => Arc::new(d),
        Err(e) => {
            eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
            return;
        }
    };
    let memory = Arc::new(GpuMemoryManager::new(
        device.clone(),
        MemoryBudget::with_limit(1 << 30),
    ));
    let provider = Arc::new(CudaKernelProvider::new(device, memory).expect("provider"));

    let circuit = Xgcf {
        node_type: vec![XgcfNodeType::Lit],
        child_offsets: vec![0, 0],
        child_indices: vec![],
        lit: vec![1],
        decision_var: vec![0],
        decision_child_false: vec![0],
        decision_child_true: vec![0],
        roots: vec![0],
        level_offsets: vec![0, 1],
        level_nodes: vec![0],
    };
    let mut direct = GpuXgcf::upload(&provider, &circuit).expect("upload");

    let weights_len = direct.max_var() as usize + 1;
    let weights = vec![(0.0f64, 0.0f64); weights_len];
    direct
        .set_base_weights(&provider, &weights)
        .expect("set weights");

    let config = {
        let mut config = GpuCircuitCacheConfig::default();
        config.num_slots = 1;
        config.table_size = 4;
        config.node_cap = 16;
        config.edge_cap = 32;
        config.level_cap = 16;
        config.var_cap = 16;
        config
    };
    let mut cache = GpuCircuitCache::new(&provider, config).expect("cache");

    let mut handle = cache.claim_slot(0xabcdefu64).expect("claim");
    cache.store_from_xgcf(&mut handle, &direct).expect("store");
    cache
        .store_weights(&handle, direct.var_log_true(), direct.var_log_false())
        .expect("store weights");

    let mut out_device = provider.memory().alloc::<f64>(1).unwrap();
    let mut out_cached = provider.memory().alloc::<f64>(1).unwrap();

    cache
        .eval_log_wmc_device_only(&handle, &mut out_device)
        .expect("device-only eval");
    cache
        .eval_log_wmc_device_inplace(&handle, &mut out_cached)
        .expect("cached eval");

    let mut hd = vec![0.0f64; 1];
    let mut hc = vec![0.0f64; 1];
    provider
        .device()
        .inner()
        .dtoh_sync_copy_into(&out_device, &mut hd)
        .unwrap();
    provider
        .device()
        .inner()
        .dtoh_sync_copy_into(&out_cached, &mut hc)
        .unwrap();

    assert_eq!(hd[0].to_bits(), hc[0].to_bits());
}