Skip to main content

vyre_reference/
workgroup.rs

1//! Workgroup simulation — the parity engine's model of invocation coordination.
2//!
3//! GPU backends must reproduce the exact barrier synchronization, shared-memory
4//! layout, and invocation-ID arithmetic that this module defines. The conform gate
5//! compares GPU dispatch output against this deterministic CPU simulation; any
6//! divergence in control flow uniformity or workgroup memory semantics is a bug.
7
8use std::convert::Infallible;
9use std::ops::ControlFlow::{self, Continue};
10use std::sync::Arc;
11
12use rustc_hash::FxHashMap;
13use smallvec::SmallVec;
14#[cfg(test)]
15use vyre::ir::BufferAccess;
16use vyre::ir::{Expr, Node, Program};
17use vyre::visit::{visit_node_preorder, visit_preorder, ExprVisitor, NodeVisitor};
18use vyre::OpDef;
19use vyre_foundation::ir::model::expr::GeneratorRef;
20
21use vyre::Error;
22
23use crate::{oob::Buffer, value::Value};
24
25/// Maximum per-workgroup shared memory the reference interpreter will allocate.
26pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;
27
28/// Small-N buffer lookup keyed by interned `Arc<str>` names.
29///
30/// Typical reference interpreter programs have ≤ 8 declared buffers. A
31/// linear scan over 8 entries is branch-predicted and hits L1 cache; hashing
32/// each access (as `HashMap<String, Buffer>` did) burned a SipHash-1-3 on
33/// every load/store in the inner interpreter loop. This struct preserves
34/// the public `get` / `get_mut` / `insert` shape consumers depend on while
35/// eliminating the per-lookup hash + heap traffic.
36#[derive(Debug, Default)]
37pub struct BufferMap {
38    entries: SmallVec<[(Arc<str>, Buffer); 8]>,
39}
40
41impl BufferMap {
42    /// Construct an empty map.
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            entries: SmallVec::new(),
47        }
48    }
49
50    /// Look up a buffer by name.
51    #[must_use]
52    pub fn get(&self, name: &str) -> Option<&Buffer> {
53        self.entries
54            .iter()
55            .find(|(key, _)| key.as_ref() == name)
56            .map(|(_, buffer)| buffer)
57    }
58
59    /// Look up a mutable buffer by name.
60    pub fn get_mut(&mut self, name: &str) -> Option<&mut Buffer> {
61        self.entries
62            .iter_mut()
63            .find(|(key, _)| key.as_ref() == name)
64            .map(|(_, buffer)| buffer)
65    }
66
67    /// Insert or overwrite a buffer. Returns the previous value when the
68    /// key already existed.
69    pub fn insert(&mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Option<Buffer> {
70        let name = name.into();
71        if let Some(entry) = self
72            .entries
73            .iter_mut()
74            .find(|(key, _)| key.as_ref() == name.as_ref())
75        {
76            return Some(std::mem::replace(&mut entry.1, buffer));
77        }
78        self.entries.push((name, buffer));
79        None
80    }
81
82    /// Iterate `(name, buffer)` pairs in insertion order.
83    pub fn iter(&self) -> impl Iterator<Item = (&str, &Buffer)> {
84        self.entries
85            .iter()
86            .map(|(name, buffer)| (name.as_ref(), buffer))
87    }
88
89    /// Move-iterate `(name, buffer)` pairs.
90    pub fn into_iter_pairs(self) -> impl Iterator<Item = (Arc<str>, Buffer)> {
91        self.entries.into_iter()
92    }
93
94    /// Number of entries.
95    #[must_use]
96    pub fn len(&self) -> usize {
97        self.entries.len()
98    }
99
100    /// True when empty.
101    #[must_use]
102    pub fn is_empty(&self) -> bool {
103        self.entries.is_empty()
104    }
105}
106
107/// Identity of one compute invocation.
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub struct InvocationIds {
110    /// Global invocation id.
111    pub global: [u32; 3],
112    /// Workgroup id.
113    pub workgroup: [u32; 3],
114    /// Local invocation id.
115    pub local: [u32; 3],
116}
117
118impl InvocationIds {
119    /// Zero-valued invocation ids for examples and unit tests.
120    pub const ZERO: Self = Self {
121        global: [0, 0, 0],
122        workgroup: [0, 0, 0],
123        local: [0, 0, 0],
124    };
125}
126
127/// Shared execution memory for storage and current workgroup buffers.
128#[derive(Debug, Default)]
129pub struct Memory {
130    pub(crate) storage: BufferMap,
131    pub(crate) workgroup: BufferMap,
132}
133
134impl Memory {
135    /// Create empty memory for test fixtures.
136    #[must_use]
137    pub fn empty() -> Self {
138        Self::default()
139    }
140
141    /// Add a storage buffer.
142    #[must_use]
143    pub fn with_storage(mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Self {
144        self.storage.insert(name, buffer);
145        self
146    }
147
148    /// Add a workgroup buffer.
149    #[must_use]
150    pub fn with_workgroup(mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Self {
151        self.workgroup.insert(name, buffer);
152        self
153    }
154
155    /// Build a single byte payload memory used by canonical primitive evaluators.
156    #[must_use]
157    pub fn from_bytes(bytes: Vec<u8>) -> Self {
158        let mut storage = BufferMap::new();
159        storage.insert("__value", Buffer::new(bytes, vyre::ir::DataType::Bytes));
160        Self {
161            storage,
162            workgroup: BufferMap::new(),
163        }
164    }
165
166    /// Return the byte payload for canonical primitive evaluators.
167    #[must_use]
168    pub fn bytes(&self) -> Vec<u8> {
169        self.storage.get("__value").map_or_else(Vec::new, |buffer| {
170            buffer
171                .bytes
172                .read()
173                .unwrap_or_else(|error| error.into_inner())
174                .clone()
175        })
176    }
177
178    /// Consume this memory and return the byte payload for canonical primitives.
179    #[must_use]
180    pub fn into_bytes(self) -> Vec<u8> {
181        self.storage
182            .into_iter_pairs()
183            .find_map(|(name, buffer)| {
184                (name.as_ref() == "__value").then(|| {
185                    std::sync::Arc::try_unwrap(buffer.bytes)
186                        .map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
187                        .unwrap_or_else(|a| {
188                            a.read().unwrap_or_else(|error| error.into_inner()).clone()
189                        })
190                })
191            })
192            .unwrap_or_default()
193    }
194}
195
196/// Shared slot layout for all locals in one program.
197#[derive(Debug, Default)]
198pub struct LocalSlots {
199    names: rustc_hash::FxHashMap<Arc<str>, usize>,
200    slot_names: Vec<Arc<str>>,
201}
202
203impl LocalSlots {
204    /// Build a slot layout from every binding site in a program.
205    #[must_use]
206    pub fn for_program(program: &Program) -> Self {
207        Self::for_nodes(program.entry())
208    }
209
210    /// Build a slot layout from a node slice.
211    #[must_use]
212    pub fn for_nodes(nodes: &[Node]) -> Self {
213        let mut slots = Self::default();
214        for node in nodes {
215            match visit_node_preorder(&mut slots, node) {
216                Continue(()) => {}
217                ControlFlow::Break(never) => match never {},
218            }
219        }
220        slots
221    }
222
223    fn slot(&self, name: &str) -> Option<usize> {
224        self.names.get(name).copied()
225    }
226
227    fn len(&self) -> usize {
228        self.slot_names.len()
229    }
230
231    fn intern(&mut self, name: &str) {
232        if self.names.contains_key(name) {
233            return;
234        }
235        let slot = self.slot_names.len();
236        let name: Arc<str> = Arc::from(name);
237        self.slot_names.push(Arc::clone(&name));
238        self.names.insert(name, slot);
239    }
240}
241
242impl ExprVisitor for LocalSlots {
243    type Break = Infallible;
244}
245
246impl NodeVisitor for LocalSlots {
247    type Break = Infallible;
248
249    fn visit_let(
250        &mut self,
251        _: &Node,
252        name: &vyre::ir::Ident,
253        value: &Expr,
254    ) -> ControlFlow<Self::Break> {
255        self.intern(name);
256        visit_preorder(self, value)
257    }
258
259    fn visit_assign(
260        &mut self,
261        _: &Node,
262        _: &vyre::ir::Ident,
263        value: &Expr,
264    ) -> ControlFlow<Self::Break> {
265        visit_preorder(self, value)
266    }
267
268    fn visit_store(
269        &mut self,
270        _: &Node,
271        _: &vyre::ir::Ident,
272        index: &Expr,
273        value: &Expr,
274    ) -> ControlFlow<Self::Break> {
275        visit_preorder(self, index)?;
276        visit_preorder(self, value)
277    }
278
279    fn visit_if(
280        &mut self,
281        _: &Node,
282        cond: &Expr,
283        _: &[Node],
284        _: &[Node],
285    ) -> ControlFlow<Self::Break> {
286        visit_preorder(self, cond)
287    }
288
289    fn visit_loop(
290        &mut self,
291        _: &Node,
292        var: &vyre::ir::Ident,
293        from: &Expr,
294        to: &Expr,
295        _: &[Node],
296    ) -> ControlFlow<Self::Break> {
297        self.intern(var);
298        visit_preorder(self, from)?;
299        visit_preorder(self, to)
300    }
301
302    fn visit_indirect_dispatch(
303        &mut self,
304        _: &Node,
305        _: &vyre::ir::Ident,
306        _: u64,
307    ) -> ControlFlow<Self::Break> {
308        Continue(())
309    }
310
311    fn visit_async_load(
312        &mut self,
313        _: &Node,
314        _: &vyre::ir::Ident,
315        _: &vyre::ir::Ident,
316        offset: &Expr,
317        size: &Expr,
318        _: &vyre::ir::Ident,
319    ) -> ControlFlow<Self::Break> {
320        visit_preorder(self, offset)?;
321        visit_preorder(self, size)
322    }
323
324    fn visit_async_store(
325        &mut self,
326        _: &Node,
327        _: &vyre::ir::Ident,
328        _: &vyre::ir::Ident,
329        offset: &Expr,
330        size: &Expr,
331        _: &vyre::ir::Ident,
332    ) -> ControlFlow<Self::Break> {
333        visit_preorder(self, offset)?;
334        visit_preorder(self, size)
335    }
336
337    fn visit_async_wait(&mut self, _: &Node, _: &vyre::ir::Ident) -> ControlFlow<Self::Break> {
338        Continue(())
339    }
340
341    fn visit_trap(
342        &mut self,
343        _: &Node,
344        address: &Expr,
345        _: &vyre::ir::Ident,
346    ) -> ControlFlow<Self::Break> {
347        visit_preorder(self, address)
348    }
349
350    fn visit_resume(&mut self, _: &Node, _: &vyre::ir::Ident) -> ControlFlow<Self::Break> {
351        Continue(())
352    }
353
354    fn visit_return(&mut self, _: &Node) -> ControlFlow<Self::Break> {
355        Continue(())
356    }
357
358    fn visit_barrier(&mut self, _: &Node) -> ControlFlow<Self::Break> {
359        Continue(())
360    }
361
362    fn visit_block(&mut self, _: &Node, _: &[Node]) -> ControlFlow<Self::Break> {
363        Continue(())
364    }
365
366    fn visit_region(
367        &mut self,
368        _: &Node,
369        _: &vyre::ir::Ident,
370        _: &Option<GeneratorRef>,
371        _: &[Node],
372    ) -> ControlFlow<Self::Break> {
373        Continue(())
374    }
375
376    fn visit_opaque_node(
377        &mut self,
378        _: &Node,
379        _: &dyn vyre::ir::NodeExtension,
380    ) -> ControlFlow<Self::Break> {
381        Continue(())
382    }
383}
384
385/// One paused or running invocation.
386pub struct Invocation<'a> {
387    /// Builtin ids for this invocation.
388    pub ids: InvocationIds,
389    slots: Arc<LocalSlots>,
390    locals: Vec<Option<Value>>,
391    immutable: Vec<bool>,
392    scopes: Vec<Vec<usize>>,
393    frames: Vec<Frame<'a>>,
394    /// True after `return`.
395    pub returned: bool,
396    /// True when paused at a barrier.
397    pub waiting_at_barrier: bool,
398    /// Uniform-if observations for branches that contain a barrier.
399    pub uniform_checks: Vec<(usize, bool)>,
400    /// Async transfers started by `AsyncLoad`/`AsyncStore` and pending
401    /// observation by `AsyncWait`.
402    pub(crate) pending_async: FxHashMap<Arc<str>, AsyncTransfer>,
403    pub(crate) op_cache: FxHashMap<*const Expr, ResolvedCall>,
404}
405
406#[derive(Debug, Clone, Copy)]
407pub(crate) struct ResolvedCall {
408    pub(crate) def: &'static OpDef,
409}
410
411/// Interpreter continuation stack.
412#[non_exhaustive]
413pub enum Frame<'a> {
414    /// Sequence of nodes.
415    Nodes {
416        /// Nodes being executed.
417        nodes: &'a [Node],
418        /// Next node index.
419        index: usize,
420        /// Whether completion pops a lexical scope.
421        scoped: bool,
422    },
423    /// Bounded `u32` loop.
424    Loop {
425        /// Loop variable name.
426        var: &'a str,
427        /// Next induction value.
428        next: u32,
429        /// Exclusive upper bound.
430        to: u32,
431        /// Loop body.
432        body: &'a [Node],
433    },
434}
435
436impl<'a> Invocation<'a> {
437    /// Create an invocation at the start of the entry point.
438    pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
439        Self::with_slots(ids, entry, Arc::new(LocalSlots::for_nodes(entry)))
440    }
441
442    pub(crate) fn with_slots(
443        ids: InvocationIds,
444        entry: &'a [Node],
445        slots: Arc<LocalSlots>,
446    ) -> Self {
447        let slot_count = slots.len();
448        Self {
449            ids,
450            slots,
451            locals: vec![None; slot_count],
452            immutable: vec![false; slot_count],
453            scopes: vec![Vec::new()],
454            frames: vec![Frame::Nodes {
455                nodes: entry,
456                index: 0,
457                scoped: false,
458            }],
459            returned: false,
460            waiting_at_barrier: false,
461            uniform_checks: Vec::new(),
462            pending_async: FxHashMap::default(),
463            op_cache: FxHashMap::default(),
464        }
465    }
466
467    /// Return true when no further execution can occur.
468    pub fn done(&self) -> bool {
469        self.returned || self.frames.is_empty()
470    }
471
472    /// Push a lexical scope.
473    ///
474    ///
475    /// ```rust,no_run
476    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
477    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
478    /// invocation.push_scope();
479    /// ```
480    pub fn push_scope(&mut self) {
481        self.scopes.push(Vec::new());
482    }
483
484    /// Pop a lexical scope and remove bindings declared in it.
485    ///
486    ///
487    /// ```rust,no_run
488    /// use vyre_reference::workgroup::{Invocation, InvocationIds};
489    /// let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
490    /// invocation.pop_scope();
491    /// ```
492    pub fn pop_scope(&mut self) {
493        if let Some(names) = self.scopes.pop() {
494            for slot in names {
495                self.locals[slot] = None;
496                self.immutable[slot] = false;
497            }
498        }
499    }
500
501    pub(crate) fn begin_async(&mut self, tag: &str, transfer: AsyncTransfer) -> Result<(), Error> {
502        let tag: Arc<str> = Arc::from(tag);
503        if self.pending_async.insert(tag.clone(), transfer).is_some() {
504            return Err(Error::interp(format!(
505                "async tag `{}` was started more than once before a matching wait. \
506                 Fix: reuse the tag only after AsyncWait completes.",
507                tag
508            )));
509        }
510        Ok(())
511    }
512
513    pub(crate) fn finish_async(&mut self, tag: &str) -> Result<AsyncTransfer, Error> {
514        self.pending_async.remove(tag).ok_or_else(|| Error::interp(format!(
515            "async wait for tag `{tag}` has no matching async load. Fix: emit AsyncLoad before AsyncWait."
516        )))
517    }
518
519    /// Look up an active local by name.
520    pub fn local(&self, name: &str) -> Option<&Value> {
521        self.slots
522            .slot(name)
523            .and_then(|slot| self.locals.get(slot))
524            .and_then(Option::as_ref)
525    }
526
527    /// Bind a mutable local.
528    ///
529    ///
530    /// ```rust,no_run
531    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
532    /// fn main() -> Result<(), vyre_foundation::Error> {
533    ///     let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
534    ///     invocation.bind("example", Value::U32(1))?;
535    ///     Ok(())
536    /// }
537    /// ```
538    pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
539        let slot = self.slots.slot(name).ok_or_else(|| {
540            Error::interp(format!(
541                "local binding `{name}` has no preassigned slot. Fix: rebuild the local slot layout from the full Program before interpretation."
542            ))
543        })?;
544        if self.locals[slot].is_some() {
545            return Err(Error::interp(format!(
546                "duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
547            )));
548        }
549        self.locals[slot] = Some(value);
550        if let Some(scope) = self.scopes.last_mut() {
551            scope.push(slot);
552        }
553        Ok(())
554    }
555
556    /// Bind an immutable loop variable.
557    ///
558    ///
559    /// ```rust,no_run
560    /// use vyre_reference::{value::Value, workgroup::{Invocation, InvocationIds}};
561    /// fn main() -> Result<(), vyre_foundation::Error> {
562    ///     let mut invocation = Invocation::new(InvocationIds::ZERO, &[]);
563    ///     invocation.bind_loop_var("example", Value::U32(1))?;
564    ///     Ok(())
565    /// }
566    /// ```
567    pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
568        self.bind(name, value)?;
569        let slot = self.slots.slot(name).ok_or_else(|| {
570            Error::interp(format!(
571                "local binding `{name}` disappeared after bind. Fix: keep local slot layout immutable during interpretation."
572            ))
573        })?;
574        self.immutable[slot] = true;
575        Ok(())
576    }
577
578    /// Assign an existing mutable local.
579    pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
580        let slot = self.slots.slot(name).ok_or_else(|| {
581            Error::interp(format!(
582                "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
583            ))
584        })?;
585        if self.immutable[slot] {
586            return Err(Error::interp(format!(
587                "assignment to loop variable `{name}`. Fix: loop variables are immutable."
588            )));
589        }
590        let Some(local) = self.locals.get_mut(slot).and_then(Option::as_mut) else {
591            return Err(Error::interp(format!(
592                "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
593            )));
594        };
595        *local = value;
596        Ok(())
597    }
598
599    pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
600        &mut self.frames
601    }
602}
603
604/// Deferred byte-copy transfer for the workgroup reference scheduler.
605pub(crate) enum AsyncTransfer {
606    /// Copy `payload` into `destination` starting at byte offset `start`.
607    Copy {
608        destination: Arc<str>,
609        start: usize,
610        payload: Vec<u8>,
611    },
612}
613
614#[cfg(test)]
615#[allow(dead_code)]
616pub(crate) fn create_invocations(
617    program: &Program,
618    workgroup: [u32; 3],
619    slots: Arc<LocalSlots>,
620) -> Result<Vec<Invocation<'_>>, vyre::Error> {
621    let global_dim = |wgid: u32, size: u32, local: u32| {
622        wgid
623            .checked_mul(size)
624            .and_then(|base| base.checked_add(local))
625            .ok_or_else(|| Error::interp(
626                "workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
627            ))
628    };
629    let [sx, sy, sz] = program.workgroup_size();
630    let invocation_count = sx
631        .checked_mul(sy)
632        .and_then(|count| count.checked_mul(sz))
633        .ok_or_else(|| {
634            Error::interp(
635                "workgroup invocation count overflows u32. Fix: reduce workgroup dimensions before reference execution.",
636            )
637        })?;
638    let mut invocations = Vec::with_capacity(usize::try_from(invocation_count).map_err(|_| {
639        Error::interp(
640            "workgroup invocation count exceeds host usize. Fix: reduce workgroup dimensions before reference execution.",
641        )
642    })?);
643    for z in 0..sz {
644        for y in 0..sy {
645            for x in 0..sx {
646                let local = [x, y, z];
647                let global = [
648                    global_dim(workgroup[0], sx, x)?,
649                    global_dim(workgroup[1], sy, y)?,
650                    global_dim(workgroup[2], sz, z)?,
651                ];
652                invocations.push(Invocation::with_slots(
653                    InvocationIds {
654                        global,
655                        workgroup,
656                        local,
657                    },
658                    program.entry(),
659                    Arc::clone(&slots),
660                ));
661            }
662        }
663    }
664    Ok(invocations)
665}
666
667#[cfg(test)]
668#[allow(dead_code)]
669pub(crate) fn workgroup_memory(program: &Program) -> Result<BufferMap, vyre::Error> {
670    let mut workgroup = BufferMap::new();
671    let mut allocated = 0usize;
672    for decl in program
673        .buffers()
674        .iter()
675        .filter(|decl| decl.access() == BufferAccess::Workgroup)
676    {
677        let element_size = decl.element().min_bytes();
678        let len = (decl.count() as usize)
679            .checked_mul(element_size)
680            .ok_or_else(|| Error::interp(format!(
681                    "workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
682                    decl.name()
683            )))?;
684        allocated = allocated
685            .checked_add(len)
686            .ok_or_else(|| Error::interp(
687                "total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
688            ))?;
689        if allocated > MAX_WORKGROUP_BYTES {
690            return Err(Error::interp(format!(
691                "workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
692            )));
693        }
694        workgroup.insert(decl.name(), Buffer::new(vec![0; len], decl.element()));
695    }
696    Ok(workgroup)
697}