vyre-primitives 0.4.1

Compositional primitives for vyre — marker types (always on) + Tier 2.5 LEGO substrate (feature-gated per domain).
Documentation
use std::sync::Arc;

use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};

/// Canonical op id.
pub const OP_ID: &str = "vyre-primitives::math::bellman_shortest_path";

/// Build a fused Bellman-Ford shortest-path Program: relax edges
/// until convergence, all inside ONE GPU dispatch.
///
/// Composes `persistent_fixpoint` over an edge list to perform
/// graph distances without host round-trips.
///
/// Invalid dimensions lower to an explicit trap program.
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn bellman_shortest_path(
    src: &str,
    dst: &str,
    weight: &str,
    dist: &str,
    next_dist: &str,
    changed: &str,
    n_nodes: u32,
    n_edges: u32,
    max_iterations: u32,
) -> Program {
    if n_nodes == 0 {
        return crate::invalid_output_program(
            OP_ID,
            dist,
            DataType::U32,
            format!("Fix: bellman_shortest_path requires n_nodes > 0, got {n_nodes}."),
        );
    }
    if max_iterations == 0 {
        return crate::invalid_output_program(
            OP_ID,
            dist,
            DataType::U32,
            format!(
                "Fix: bellman_shortest_path requires max_iterations > 0, got {max_iterations}."
            ),
        );
    }

    let t = Expr::InvocationId { axis: 0 };

    let transfer_body = vec![Node::if_then(
        Expr::lt(t.clone(), Expr::u32(n_edges)),
        vec![
            Node::let_bind("u", Expr::load(src, t.clone())),
            Node::let_bind("v", Expr::load(dst, t.clone())),
            Node::let_bind("w", Expr::load(weight, t.clone())),
            Node::let_bind("du", Expr::load(dist, Expr::var("u"))),
            Node::if_then(
                Expr::ne(Expr::var("du"), Expr::u32(u32::MAX)),
                vec![
                    Node::let_bind("alt", Expr::add(Expr::var("du"), Expr::var("w"))),
                    Node::let_bind(
                        "_relax",
                        Expr::atomic_min(next_dist, Expr::var("v"), Expr::var("alt")),
                    ),
                ],
            ),
        ],
    )];

    let inner = crate::fixpoint::persistent_fixpoint::persistent_fixpoint(
        transfer_body,
        dist,
        next_dist,
        changed,
        n_nodes,
        max_iterations,
    );

    let entry: Vec<Node> = vec![Node::Region {
        generator: Ident::from(OP_ID),
        source_region: None,
        body: Arc::new(inner.entry().to_vec()),
    }];

    Program::wrapped(
        vec![
            BufferDecl::storage(dist, 0, BufferAccess::ReadWrite, DataType::U32)
                .with_count(n_nodes),
            BufferDecl::storage(next_dist, 1, BufferAccess::ReadWrite, DataType::U32)
                .with_count(n_nodes),
            BufferDecl::storage(changed, 2, BufferAccess::ReadWrite, DataType::U32).with_count(1),
            BufferDecl::storage(src, 3, BufferAccess::ReadOnly, DataType::U32).with_count(n_edges),
            BufferDecl::storage(dst, 4, BufferAccess::ReadOnly, DataType::U32).with_count(n_edges),
            BufferDecl::storage(weight, 5, BufferAccess::ReadOnly, DataType::U32)
                .with_count(n_edges),
        ],
        [256, 1, 1],
        entry,
    )
}

/// CPU reference.
#[must_use]
pub fn cpu_ref(
    src: &[u32],
    dst: &[u32],
    weight: &[u32],
    dist: &[u32],
    n_nodes: u32,
    max_iterations: u32,
) -> (Vec<u32>, u32) {
    let n = n_nodes as usize;
    let edge_count = src.len().min(dst.len()).min(weight.len());
    let mut current = vec![u32::MAX; n];
    for (out, &value) in current.iter_mut().zip(dist.iter()) {
        *out = value;
    }
    let mut next = current.clone();
    for iter in 0..max_iterations {
        for i in 0..edge_count {
            let u = src[i] as usize;
            let v = dst[i] as usize;
            if u >= n || v >= n {
                continue;
            }
            let w = weight[i];
            let du = current[u];
            if du != u32::MAX {
                let alt = du.saturating_add(w);
                next[v] = next[v].min(alt);
            }
        }
        if next == current {
            return (current, iter);
        }
        current.copy_from_slice(&next);
    }
    (current, max_iterations)
}

#[cfg(feature = "inventory-registry")]
inventory::submit! {
    crate::harness::OpEntry::new(
        OP_ID,
        || bellman_shortest_path("src", "dst", "weight", "dist", "next_dist", "changed", 4, 4, 10),
        Some(|| {
            let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
            vec![vec![
                to_bytes(&[0, u32::MAX, u32::MAX, u32::MAX]), // dist
                to_bytes(&[0, u32::MAX, u32::MAX, u32::MAX]), // next_dist
                to_bytes(&[0]), // changed
                to_bytes(&[0, 1, 2, 0]), // src
                to_bytes(&[1, 2, 3, 3]), // dst
                to_bytes(&[10, 20, 30, 100]), // weight
            ]]
        }),
        Some(|| {
            let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
            vec![vec![
                to_bytes(&[0, 10, 30, 60]), // dist
                to_bytes(&[0, 10, 30, 60]), // next_dist
                to_bytes(&[0]),             // changed
            ]]
        }),
    )
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cpu_ref_trivial() {
        let src = vec![0];
        let dst = vec![1];
        let weight = vec![5];
        let dist = vec![0, u32::MAX];
        let (final_dist, iters) = cpu_ref(&src, &dst, &weight, &dist, 2, 10);
        assert_eq!(final_dist, vec![0, 5]);
        assert_eq!(iters, 1);
    }

    #[test]
    fn test_cpu_ref_single_node() {
        let dist = vec![0];
        let (final_dist, iters) = cpu_ref(&[], &[], &[], &dist, 1, 10);
        assert_eq!(final_dist, vec![0]);
        assert_eq!(iters, 0);
    }

    #[test]
    fn test_cpu_ref_cycle() {
        let src = vec![0, 1, 2];
        let dst = vec![1, 2, 0];
        let weight = vec![10, 10, 10];
        let dist = vec![0, u32::MAX, u32::MAX];
        let (final_dist, _) = cpu_ref(&src, &dst, &weight, &dist, 3, 10);
        assert_eq!(final_dist, vec![0, 10, 20]);
    }

    #[test]
    fn test_cpu_ref_large_line() {
        let n = 50;
        let mut src = Vec::new();
        let mut dst = Vec::new();
        let mut weight = Vec::new();
        for i in 0..n - 1 {
            src.push(i as u32);
            dst.push((i + 1) as u32);
            weight.push(1);
        }
        let mut dist = vec![u32::MAX; n];
        dist[0] = 0;
        let (final_dist, iters) = cpu_ref(&src, &dst, &weight, &dist, n as u32, n as u32 * 2);
        assert_eq!(final_dist[n - 1], (n - 1) as u32);
        assert_eq!(iters, (n - 1) as u32);
    }

    #[test]
    fn test_cpu_ref_asymmetric() {
        let src = vec![0, 0, 1, 2];
        let dst = vec![1, 3, 3, 3];
        let weight = vec![10, 100, 20, 5];
        let dist = vec![0, u32::MAX, u32::MAX, u32::MAX];
        // 0->3 is 100
        // 0->1->3 is 10+20=30
        let (final_dist, _) = cpu_ref(&src, &dst, &weight, &dist, 4, 10);
        assert_eq!(final_dist[3], 30);
    }

    #[test]
    fn test_cpu_ref_ignores_malformed_edges_and_pads_distances() {
        let src = vec![0, 9, 1];
        let dst = vec![1, 2];
        let weight = vec![5, 99, 7];
        let (final_dist, _) = cpu_ref(&src, &dst, &weight, &[0], 3, 10);
        assert_eq!(final_dist, vec![0, 5, u32::MAX]);
    }

    #[test]
    fn test_parity_small_graph() {
        let src = vec![0, 1, 2, 0];
        let dst = vec![1, 2, 3, 3];
        let weight = vec![10, 20, 30, 100];
        let dist_init = vec![0, u32::MAX, u32::MAX, u32::MAX];

        let p = bellman_shortest_path(
            "src",
            "dst",
            "weight",
            "dist",
            "next_dist",
            "changed",
            4,
            4,
            10,
        );

        let (expected_dist, _) = cpu_ref(&src, &dst, &weight, &dist_init, 4, 10);

        use vyre_reference::reference_eval;
        use vyre_reference::value::Value;

        let to_value = |data: &[u32]| {
            let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
            Value::Bytes(Arc::from(bytes))
        };

        let inputs = vec![
            to_value(&dist_init),
            to_value(&dist_init),
            to_value(&[0]),
            to_value(&src),
            to_value(&dst),
            to_value(&weight),
        ];

        let results = reference_eval(&p, &inputs).expect("Fix: interpreter failed");
        let actual_bytes = results[0].to_bytes();
        let actual_dist: Vec<u32> = actual_bytes
            .chunks_exact(4)
            .map(|c| u32::from_le_bytes(c.try_into().unwrap()))
            .collect();

        assert_eq!(actual_dist, expected_dist);
    }

    #[test]
    fn program_declares_six_buffers() {
        let p = bellman_shortest_path("s", "d", "w", "di", "nd", "c", 4, 4, 10);
        assert_eq!(p.buffers().len(), 6);
    }

    #[test]
    fn rejects_zero_nodes_with_trap() {
        let p = bellman_shortest_path("s", "d", "w", "di", "nd", "c", 0, 4, 10);
        assert!(p.stats().trap());
    }

    #[test]
    fn rejects_zero_max_iterations_with_trap() {
        let p = bellman_shortest_path("s", "d", "w", "di", "nd", "c", 4, 4, 0);
        assert!(p.stats().trap());
    }
}