vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
use crate::ir::model::expr::Expr;
use crate::ir::model::node::Node;
use crate::ir::model::program::Program;
use crate::ir::model::types::AtomicOp;
use std::collections::{HashMap, HashSet};

#[derive(Debug, Default)]
pub(crate) struct HelperUsage {
    buffers: HashMap<String, BufferHelperUsage>,
}

#[derive(Debug, Default)]
pub(crate) struct BufferHelperUsage {
    pub(crate) load: bool,
    pub(crate) store: bool,
    pub(crate) atomic_ops: HashSet<AtomicOp>,
}

impl HelperUsage {
    pub(crate) fn collect(program: &Program) -> Self {
        let mut usage = Self::default();
        for node in program.entry() {
            usage.record_node(node);
        }
        usage
    }

    pub(crate) fn buffer(&self, name: &str) -> BufferHelperUsageView<'_> {
        match self.buffers.get(name) {
            Some(usage) => BufferHelperUsageView {
                load: usage.load,
                store: usage.store,
                atomic_ops: Some(&usage.atomic_ops),
            },
            None => BufferHelperUsageView::default(),
        }
    }

    fn record_node(&mut self, node: &Node) {
        match node {
            Node::Let { value, .. } | Node::Assign { value, .. } => self.record_expr(value),
            Node::Store {
                buffer,
                index,
                value,
            } => {
                self.buffers.entry(buffer.clone()).or_default().store = true;
                self.record_expr(index);
                self.record_expr(value);
            }
            Node::If {
                cond,
                then,
                otherwise,
            } => {
                self.record_expr(cond);
                for node in then.iter().chain(otherwise) {
                    self.record_node(node);
                }
            }
            Node::Loop { from, to, body, .. } => {
                self.record_expr(from);
                self.record_expr(to);
                for node in body {
                    self.record_node(node);
                }
            }
            Node::Block(nodes) => {
                for node in nodes {
                    self.record_node(node);
                }
            }
            Node::Return | Node::Barrier => {}
        }
    }

    fn record_expr(&mut self, expr: &Expr) {
        match expr {
            Expr::Load { buffer, index } => {
                self.buffers.entry(buffer.to_string()).or_default().load = true;
                self.record_expr(index);
            }
            Expr::Atomic {
                op,
                buffer,
                index,
                expected,
                value,
            } => {
                self.buffers
                    .entry(buffer.to_string())
                    .or_default()
                    .atomic_ops
                    .insert(op.clone());
                self.record_expr(index);
                if let Some(expected) = expected {
                    self.record_expr(expected);
                }
                self.record_expr(value);
            }
            Expr::BinOp { left, right, .. } => {
                self.record_expr(left);
                self.record_expr(right);
            }
            Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
                self.record_expr(operand);
            }
            Expr::Call { args, .. } => {
                for arg in args {
                    self.record_expr(arg);
                }
            }
            Expr::Select {
                cond,
                true_val,
                false_val,
            } => {
                self.record_expr(cond);
                self.record_expr(true_val);
                self.record_expr(false_val);
            }
            Expr::Fma { a, b, c } => {
                self.record_expr(a);
                self.record_expr(b);
                self.record_expr(c);
            }
            Expr::BufLen { .. }
            | Expr::InvocationId { .. }
            | Expr::WorkgroupId { .. }
            | Expr::LocalId { .. }
            | Expr::LitU32(_)
            | Expr::LitI32(_)
            | Expr::LitF32(_)
            | Expr::LitBool(_)
            | Expr::Var(_) => {}
        }
    }
}

#[derive(Clone, Copy, Default)]
pub(crate) struct BufferHelperUsageView<'a> {
    pub(crate) load: bool,
    pub(crate) store: bool,
    pub(crate) atomic_ops: Option<&'a HashSet<AtomicOp>>,
}

impl BufferHelperUsageView<'_> {
    pub(crate) fn uses_atomics(self) -> bool {
        self.atomic_ops.is_some_and(|ops| !ops.is_empty())
    }
}