xlog-prob 0.9.2

Probabilistic inference engines for XLOG
use std::sync::Arc;

use xlog_core::MemoryBudget;
use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
use xlog_prob::compilation::{apply_query_vars_device, restore_query_vars_device};

fn try_provider() -> Option<Arc<CudaKernelProvider>> {
    let device = match CudaDevice::new(0) {
        Ok(d) => Arc::new(d),
        Err(e) => {
            eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
            return None;
        }
    };
    let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
    let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
    match CudaKernelProvider::new(device, memory) {
        Ok(p) => Some(Arc::new(p)),
        Err(e) => {
            eprintln!("Skipping test: failed to create provider: {}", e);
            None
        }
    }
}

#[test]
fn query_vars_apply_and_restore_on_device() {
    let Some(provider) = try_provider() else {
        return;
    };

    let host_log_false = vec![0.0f64, -1.0, -2.0, -3.0, -4.0];
    let var_cap = (host_log_false.len() - 1) as u32;

    let mut log_false = provider
        .memory()
        .alloc::<f64>(host_log_false.len())
        .unwrap();
    provider
        .device()
        .inner()
        .htod_sync_copy_into(&host_log_false, &mut log_false)
        .unwrap();

    let query_vars_host = vec![0u32, 2u32, 4u32];
    let mut query_vars = provider
        .memory()
        .alloc::<u32>(query_vars_host.len())
        .unwrap();
    provider
        .device()
        .inner()
        .htod_sync_copy_into(&query_vars_host, &mut query_vars)
        .unwrap();

    let mut saved = provider
        .memory()
        .alloc::<f64>(query_vars_host.len())
        .unwrap();

    apply_query_vars_device(&provider, &query_vars, var_cap, &mut log_false, &mut saved)
        .expect("apply query vars");

    let mut log_false_host = vec![0.0f64; host_log_false.len()];
    let mut saved_host = vec![0.0f64; query_vars_host.len()];
    provider
        .device()
        .inner()
        .dtoh_sync_copy_into(&log_false, &mut log_false_host)
        .unwrap();
    provider
        .device()
        .inner()
        .dtoh_sync_copy_into(&saved, &mut saved_host)
        .unwrap();

    assert_eq!(saved_host[0], 0.0);
    assert_eq!(saved_host[1], host_log_false[2]);
    assert_eq!(saved_host[2], host_log_false[4]);
    assert_eq!(log_false_host[2], f64::NEG_INFINITY);
    assert_eq!(log_false_host[4], f64::NEG_INFINITY);

    restore_query_vars_device(&provider, &query_vars, var_cap, &mut log_false, &saved)
        .expect("restore query vars");

    let mut restored = vec![0.0f64; host_log_false.len()];
    provider
        .device()
        .inner()
        .dtoh_sync_copy_into(&log_false, &mut restored)
        .unwrap();
    assert_eq!(restored, host_log_false);
}