vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::types::AtomicOp;
use crate::lower::wgsl::Error;
use std::collections::HashSet;
use std::fmt::Write as _;

/// Emit WGSL helper functions for every atomic operation used by a buffer.
#[inline]
pub fn emit_atomic_helpers(
    out: &mut String,
    name: &str,
    ops: &HashSet<AtomicOp>,
) -> Result<(), Error> {
    for op in [
        AtomicOp::Add,
        AtomicOp::Or,
        AtomicOp::And,
        AtomicOp::Xor,
        AtomicOp::Min,
        AtomicOp::Max,
        AtomicOp::Exchange,
        AtomicOp::CompareExchange,
    ] {
        if ops.contains(&op) {
            emit_atomic_helper(out, name, op)?;
        }
    }
    Ok(())
}

fn emit_atomic_helper(out: &mut String, name: &str, op: AtomicOp) -> Result<(), Error> {
    let helper = match op {
        AtomicOp::Add => (
            "add",
            "idx: u32, value: u32",
            format!("atomicAdd(&{name}.data[idx], value)"),
        ),
        AtomicOp::Or => (
            "or",
            "idx: u32, value: u32",
            format!("atomicOr(&{name}.data[idx], value)"),
        ),
        AtomicOp::And => (
            "and",
            "idx: u32, value: u32",
            format!("atomicAnd(&{name}.data[idx], value)"),
        ),
        AtomicOp::Xor => (
            "xor",
            "idx: u32, value: u32",
            format!("atomicXor(&{name}.data[idx], value)"),
        ),
        AtomicOp::Min => (
            "min",
            "idx: u32, value: u32",
            format!("atomicMin(&{name}.data[idx], value)"),
        ),
        AtomicOp::Max => (
            "max",
            "idx: u32, value: u32",
            format!("atomicMax(&{name}.data[idx], value)"),
        ),
        AtomicOp::Exchange => (
            "exchange",
            "idx: u32, value: u32",
            format!("atomicExchange(&{name}.data[idx], value)"),
        ),
        AtomicOp::CompareExchange => {
            let _ = write!(
                out,
                "fn _vyre_atomic_compare_exchange_{name}(idx: u32, expected: u32, value: u32) -> u32 {{\n\
                 if (idx >= arrayLength(&{name}.data)) {{ return 0u; }}\n\
                 let result = atomicCompareExchangeWeak(&{name}.data[idx], expected, value);\n\
                 return result.old_value;\n\
               }}\n\n"
            );
            return Ok(());
        }
    };
    let _ = write!(
        out,
        "fn _vyre_atomic_{suffix}_{name}({args}) -> u32 {{\n\
         if (idx >= arrayLength(&{name}.data)) {{ return 0u; }}\n\
         return {body};\n\
       }}\n\n",
        suffix = helper.0,
        args = helper.1,
        body = helper.2,
    );
    Ok(())
}