xlog-gpu 0.9.2

High-level Rust API for running XLOG programs on NVIDIA GPUs
#![allow(clippy::arc_with_non_send_sync)]

use std::collections::HashMap;
use std::sync::Arc;

use xlog_core::{MemoryBudget, Result};
use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};

fn create_test_provider() -> Option<Arc<CudaKernelProvider>> {
    let device = Arc::new(CudaDevice::new(0).ok()?);
    let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
    let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
    Some(Arc::new(CudaKernelProvider::new(device, memory).ok()?))
}

fn read_u32_col(
    provider: &CudaKernelProvider,
    buffer: &xlog_cuda::CudaBuffer,
    col: usize,
) -> Vec<u32> {
    provider
        .download_column::<u32>(buffer, col)
        .unwrap_or_default()
}

fn read_pairs(provider: &CudaKernelProvider, buffer: &xlog_cuda::CudaBuffer) -> Vec<(u32, u32)> {
    let c0 = read_u32_col(provider, buffer, 0);
    let c1 = read_u32_col(provider, buffer, 1);
    let mut rows: Vec<_> = c0.into_iter().zip(c1).collect();
    rows.sort_unstable();
    rows.dedup();
    rows
}

fn run_query(
    source: &str,
) -> Result<Option<(Arc<CudaKernelProvider>, xlog_gpu::logic::LogicEvalResult)>> {
    let Some(provider) = create_test_provider() else {
        eprintln!("Skipping: no CUDA device");
        return Ok(None);
    };
    let program = xlog_gpu::logic::LogicProgram::compile(source)?;
    let result = program.evaluate(provider.clone(), HashMap::new())?;
    Ok(Some((provider, result)))
}

#[test]
fn v085_list_builtins_execute_through_gpu_relations() -> Result<()> {
    let source = r#"
        pred bag(id: u32, xs: list<u32>).
        bag(1, [10, 20, 10]).
        bag(2, [30]).

        out_member(Id, X) :- bag(Id, L), member(X, L).
        out_memberchk(Id, X) :- bag(Id, L), memberchk(X, L).
        out_length(Id, N) :- bag(Id, L), length(L, N).
        out_nth(Id, X) :- bag(Id, L), nth(1, L, X).
        out_head(Id, H) :- bag(Id, [H | T]).
        out_append_len(N) :- append([10], [20], L), length(L, N).
        out_sort_member(X) :- sort([20, 10, 10], L), member(X, L).
        out_msort_len(N) :- msort([20, 10, 10], L), length(L, N).
        out_set_len(N) :- list_to_set([20, 10, 20], L), length(L, N).
        out_is_list(1) :- is_list([10, 20]).

        ?- out_member(Id, X).
        ?- out_memberchk(Id, X).
        ?- out_length(Id, N).
        ?- out_nth(Id, X).
        ?- out_head(Id, H).
        ?- out_append_len(N).
        ?- out_sort_member(X).
        ?- out_msort_len(N).
        ?- out_set_len(N).
        ?- out_is_list(Flag).
    "#;

    let Some((provider, result)) = run_query(source)? else {
        return Ok(());
    };

    assert_eq!(result.queries.len(), 10);

    assert_eq!(
        read_pairs(&provider, &result.queries[0].buffer),
        vec![(1, 10), (1, 20), (2, 30)]
    );
    assert_eq!(
        read_pairs(&provider, &result.queries[1].buffer),
        vec![(1, 10), (1, 20), (2, 30)]
    );
    assert_eq!(
        read_pairs(&provider, &result.queries[2].buffer),
        vec![(1, 3), (2, 1)]
    );
    assert_eq!(
        read_pairs(&provider, &result.queries[3].buffer),
        vec![(1, 20)]
    );
    assert_eq!(
        read_pairs(&provider, &result.queries[4].buffer),
        vec![(1, 10), (2, 30)]
    );
    assert_eq!(
        read_u32_col(&provider, &result.queries[5].buffer, 0),
        vec![2]
    );
    assert_eq!(
        read_u32_col(&provider, &result.queries[6].buffer, 0),
        vec![10, 20]
    );
    assert_eq!(
        read_u32_col(&provider, &result.queries[7].buffer, 0),
        vec![3]
    );
    assert_eq!(
        read_u32_col(&provider, &result.queries[8].buffer, 0),
        vec![2]
    );
    assert_eq!(
        read_u32_col(&provider, &result.queries[9].buffer, 0),
        vec![1]
    );

    Ok(())
}