vyre-foundation 0.4.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};

#[test]
fn derive_use_counts_simple() {
    let program = Program::wrapped(
        vec![BufferDecl::read_write("out", 0, DataType::U32)],
        [1, 1, 1],
        vec![
            Node::let_bind("x", Expr::u32(1)),
            Node::let_bind("y", Expr::add(Expr::var("x"), Expr::var("x"))),
            Node::store("out", Expr::u32(0), Expr::var("y")),
        ],
    );
    let substrate = FactSubstrate::derive(&program);
    assert_eq!(substrate.use_count_of(&Ident::from("x")), 2);
    assert_eq!(substrate.use_count_of(&Ident::from("y")), 1);
    assert_eq!(substrate.use_count_of(&Ident::from("z")), 0);
}

#[test]
fn derive_use_counts_async_operands() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::U32).with_count(4),
            BufferDecl::read_write("out", 1, DataType::U32).with_count(4),
        ],
        [1, 1, 1],
        vec![
            Node::let_bind("offset", Expr::u32(1)),
            Node::let_bind("size", Expr::u32(2)),
            Node::async_load_ext(
                Ident::from("input"),
                Ident::from("out"),
                Expr::var("offset"),
                Expr::var("size"),
                Ident::from("copy"),
            ),
        ],
    );
    let substrate = FactSubstrate::derive(&program);
    assert_eq!(substrate.use_count_of(&Ident::from("offset")), 1);
    assert_eq!(substrate.use_count_of(&Ident::from("size")), 1);
}

#[test]
fn derive_use_facts_records_buffer_accesses_and_index_axes() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::U32).with_count(64),
            BufferDecl::read_write("out", 1, DataType::U32).with_count(64),
        ],
        [8, 8, 1],
        vec![Node::store(
            "out",
            Expr::gid_y(),
            Expr::load("input", Expr::gid_x()),
        )],
    );

    let substrate = FactSubstrate::derive_use_only(&program);
    assert!(substrate.has_fresh_use_facts_for(&program));
    assert!(!substrate.is_fresh_for(&program));
    let facts = substrate.use_facts().unwrap();
    assert_eq!(facts.buffer_reads.get(&Ident::from("input")), Some(&1));
    assert_eq!(facts.buffer_writes.get(&Ident::from("out")), Some(&1));
    assert_eq!(facts.dominant_index_axis(&Ident::from("input")), Some(0));
    assert_eq!(facts.dominant_index_axis(&Ident::from("out")), Some(1));
}

#[test]
fn derive_use_facts_records_scalar_mediated_buffer_dependencies() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::U32).with_count(1),
            BufferDecl::read_write("scratch", 1, DataType::U32).with_count(1),
            BufferDecl::output("out", 2, DataType::U32).with_count(1),
        ],
        [1, 1, 1],
        vec![
            Node::let_bind("x", Expr::load("input", Expr::u32(0))),
            Node::store("scratch", Expr::u32(0), Expr::var("x")),
            Node::store("out", Expr::u32(0), Expr::load("scratch", Expr::u32(0))),
        ],
    );

    let substrate = FactSubstrate::derive_use_only(&program);
    let facts = substrate.use_facts().unwrap();
    assert!(facts
        .var_buffer_deps
        .get(&Ident::from("x"))
        .is_some_and(|deps| deps.contains(&Ident::from("input"))));
    assert!(facts
        .buffer_write_deps
        .get(&Ident::from("scratch"))
        .is_some_and(|deps| deps.contains(&Ident::from("input"))));
    assert!(facts
        .buffer_write_deps
        .get(&Ident::from("out"))
        .is_some_and(|deps| deps.contains(&Ident::from("scratch"))));
}

#[test]
fn derive_use_facts_records_indirect_dispatch_count_buffers() {
    let program = Program::wrapped(
        vec![BufferDecl::read("counts", 0, DataType::U32).with_count(1)],
        [1, 1, 1],
        vec![Node::indirect_dispatch("counts", 0)],
    );

    let substrate = FactSubstrate::derive_use_only(&program);
    let facts = substrate.use_facts().unwrap();
    assert!(facts
        .indirect_dispatch_buffers
        .contains(&Ident::from("counts")));
    assert_eq!(facts.buffer_reads.get(&Ident::from("counts")), Some(&1));
}

#[test]
fn derive_type_facts_float_propagation() {
    let program = Program::wrapped(
        vec![BufferDecl::read_write("out", 0, DataType::U32)],
        [1, 1, 1],
        vec![
            Node::let_bind("a", Expr::f32(1.0)),
            Node::let_bind("b", Expr::add(Expr::var("a"), Expr::f32(2.0))),
        ],
    );
    let substrate = FactSubstrate::derive(&program);
    let types = substrate.type_map.as_ref().unwrap();
    assert_eq!(types.var_types.get(&Ident::from("a")), Some(&DataType::F32));
    assert_eq!(types.var_types.get(&Ident::from("b")), Some(&DataType::F32));
}

#[test]
fn derive_type_facts_records_loads_and_expression_types() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::F32).with_count(1),
            BufferDecl::read_write("out", 1, DataType::F32).with_count(1),
        ],
        [1, 1, 1],
        vec![
            Node::let_bind("x", Expr::load("input", Expr::u32(0))),
            Node::store("out", Expr::u32(0), Expr::var("x")),
        ],
    );

    let substrate = FactSubstrate::derive(&program);
    let types = substrate.type_map.as_ref().unwrap();
    assert_eq!(types.var_types.get(&Ident::from("x")), Some(&DataType::F32));
    assert!(
        !types.expr_types.is_empty(),
        "FactSubstrate::TypeFacts promises expression type facts; derive() must populate them"
    );
}

#[test]
fn invalidate_clears_all() {
    let program = Program::wrapped(
        vec![BufferDecl::read_write("out", 0, DataType::U32)],
        [1, 1, 1],
        vec![Node::store("out", Expr::u32(0), Expr::u32(1))],
    );
    let mut substrate = FactSubstrate::derive(&program);
    assert!(substrate.is_fresh_for(&program));
    substrate.invalidate();
    assert!(!substrate.is_fresh_for(&program));
    assert!(substrate.shape.is_none());
}

#[test]
fn derive_use_counts_handles_large_blocks_in_one_pass() {
    let block = Node::block(
        (0..4096)
            .map(|index| Node::let_bind(format!("sink_{index}"), Expr::var("x")))
            .collect(),
    );
    let program = Program::wrapped(
        vec![BufferDecl::read_write("out", 0, DataType::U32)],
        [1, 1, 1],
        vec![Node::let_bind("x", Expr::u32(1)), block],
    );
    let substrate = FactSubstrate::derive(&program);
    assert_eq!(substrate.use_count_of(&Ident::from("x")), 4096);
}