vyre-foundation 0.6.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
//! Buffer-target collectors for load/store/atomic walks.
//!
//! `VYRE_IR_HOTSPOTS` HIGH: `fuse_programs_multi` previously called three
//! independent walks (`collect_atomic_targets_from_node`,
//! `collect_load_targets_from_node`, `collect_store_targets_from_node`)
//! per arm  -  three full traversals of the same IR tree. The unified
//! [`collect_buffer_targets`] helper does it in one walk with three
//! mutable target sets. This is the canonical collector API for fusion:
//! adding single-target wrappers would reintroduce duplicate IR walks and
//! make the fusion boundary ambiguous for contributors.

use rustc_hash::FxHashSet;

use crate::ir::{Expr, Ident, Node};

/// One-pass collector: walk `node` once and fan out load / store /
/// atomic buffer targets into the three caller-supplied sets.
pub(super) fn collect_buffer_targets(
    node: &Node,
    loads: &mut FxHashSet<Ident>,
    stores: &mut FxHashSet<Ident>,
    atomics: &mut FxHashSet<Ident>,
) {
    match node {
        Node::Let { value, .. } | Node::Assign { value, .. } => {
            collect_buffer_targets_from_expr(value, loads, atomics);
        }
        Node::Store {
            buffer,
            index,
            value,
        } => {
            stores.insert(Ident::from(buffer));
            collect_buffer_targets_from_expr(index, loads, atomics);
            collect_buffer_targets_from_expr(value, loads, atomics);
        }
        Node::If {
            cond,
            then,
            otherwise,
        } => {
            collect_buffer_targets_from_expr(cond, loads, atomics);
            for n in then.iter().chain(otherwise.iter()) {
                collect_buffer_targets(n, loads, stores, atomics);
            }
        }
        Node::Loop { from, to, body, .. } => {
            collect_buffer_targets_from_expr(from, loads, atomics);
            collect_buffer_targets_from_expr(to, loads, atomics);
            for n in body {
                collect_buffer_targets(n, loads, stores, atomics);
            }
        }
        Node::Block(body) => {
            for n in body {
                collect_buffer_targets(n, loads, stores, atomics);
            }
        }
        Node::Region { body, .. } => {
            for n in body.iter() {
                collect_buffer_targets(n, loads, stores, atomics);
            }
        }
        Node::AllReduce { buffer, .. } | Node::Broadcast { buffer, .. } => {
            loads.insert(buffer.clone());
            stores.insert(buffer.clone());
        }
        Node::AllGather { input, output, .. } | Node::ReduceScatter { input, output, .. } => {
            loads.insert(input.clone());
            stores.insert(output.clone());
        }
        Node::IndirectDispatch { .. }
        | Node::Return
        | Node::Barrier { .. }
        | Node::AsyncLoad { .. }
        | Node::AsyncStore { .. }
        | Node::AsyncWait { .. }
        | Node::Trap { .. }
        | Node::Resume { .. }
        | Node::Opaque(_) => {}
    }
}

fn collect_buffer_targets_from_expr(
    expr: &Expr,
    loads: &mut FxHashSet<Ident>,
    atomics: &mut FxHashSet<Ident>,
) {
    match expr {
        Expr::Load { buffer, index } => {
            loads.insert(Ident::from(buffer));
            collect_buffer_targets_from_expr(index, loads, atomics);
        }
        Expr::Atomic {
            buffer,
            index,
            expected,
            value,
            ..
        } => {
            atomics.insert(Ident::from(buffer));
            collect_buffer_targets_from_expr(index, loads, atomics);
            if let Some(expected) = expected {
                collect_buffer_targets_from_expr(expected, loads, atomics);
            }
            collect_buffer_targets_from_expr(value, loads, atomics);
        }
        Expr::BinOp { left, right, .. } => {
            collect_buffer_targets_from_expr(left, loads, atomics);
            collect_buffer_targets_from_expr(right, loads, atomics);
        }
        Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
            collect_buffer_targets_from_expr(operand, loads, atomics);
        }
        Expr::Fma { a, b, c } => {
            collect_buffer_targets_from_expr(a, loads, atomics);
            collect_buffer_targets_from_expr(b, loads, atomics);
            collect_buffer_targets_from_expr(c, loads, atomics);
        }
        Expr::Call { args, .. } => {
            for arg in args {
                collect_buffer_targets_from_expr(arg, loads, atomics);
            }
        }
        Expr::Select {
            cond,
            true_val,
            false_val,
        } => {
            collect_buffer_targets_from_expr(cond, loads, atomics);
            collect_buffer_targets_from_expr(true_val, loads, atomics);
            collect_buffer_targets_from_expr(false_val, loads, atomics);
        }
        Expr::SubgroupBallot { cond } => collect_buffer_targets_from_expr(cond, loads, atomics),
        Expr::SubgroupShuffle { value, lane } => {
            collect_buffer_targets_from_expr(value, loads, atomics);
            collect_buffer_targets_from_expr(lane, loads, atomics);
        }
        Expr::SubgroupAdd { value } => collect_buffer_targets_from_expr(value, loads, atomics),
        _ => {}
    }
}