1use std::convert::Infallible;
9use std::ops::ControlFlow::{self, Continue};
10use std::sync::Arc;
11
12use rustc_hash::FxHashMap;
13use smallvec::SmallVec;
14#[cfg(test)]
15use vyre::ir::BufferAccess;
16use vyre::ir::{Expr, Node, Program};
17use vyre::visit::{visit_node_preorder, visit_preorder, ExprVisitor, NodeVisitor};
18use vyre::OpDef;
19use vyre_foundation::ir::model::expr::GeneratorRef;
20
21use vyre::Error;
22
23use crate::{oob::Buffer, value::Value};
24
25pub const MAX_WORKGROUP_BYTES: usize = 64 * 1024 * 1024;
27
28#[derive(Debug, Default)]
37pub struct BufferMap {
38 entries: SmallVec<[(Arc<str>, Buffer); 8]>,
39}
40
41impl BufferMap {
42 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 entries: SmallVec::new(),
47 }
48 }
49
50 #[must_use]
52 pub fn get(&self, name: &str) -> Option<&Buffer> {
53 self.entries
54 .iter()
55 .find(|(key, _)| key.as_ref() == name)
56 .map(|(_, buffer)| buffer)
57 }
58
59 pub fn get_mut(&mut self, name: &str) -> Option<&mut Buffer> {
61 self.entries
62 .iter_mut()
63 .find(|(key, _)| key.as_ref() == name)
64 .map(|(_, buffer)| buffer)
65 }
66
67 pub fn insert(&mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Option<Buffer> {
70 let name = name.into();
71 if let Some(entry) = self
72 .entries
73 .iter_mut()
74 .find(|(key, _)| key.as_ref() == name.as_ref())
75 {
76 return Some(std::mem::replace(&mut entry.1, buffer));
77 }
78 self.entries.push((name, buffer));
79 None
80 }
81
82 pub fn iter(&self) -> impl Iterator<Item = (&str, &Buffer)> {
84 self.entries
85 .iter()
86 .map(|(name, buffer)| (name.as_ref(), buffer))
87 }
88
89 pub fn into_iter_pairs(self) -> impl Iterator<Item = (Arc<str>, Buffer)> {
91 self.entries.into_iter()
92 }
93
94 #[must_use]
96 pub fn len(&self) -> usize {
97 self.entries.len()
98 }
99
100 #[must_use]
102 pub fn is_empty(&self) -> bool {
103 self.entries.is_empty()
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub struct InvocationIds {
110 pub global: [u32; 3],
112 pub workgroup: [u32; 3],
114 pub local: [u32; 3],
116}
117
118impl InvocationIds {
119 pub const ZERO: Self = Self {
121 global: [0, 0, 0],
122 workgroup: [0, 0, 0],
123 local: [0, 0, 0],
124 };
125}
126
127#[derive(Debug, Default)]
129pub struct Memory {
130 pub(crate) storage: BufferMap,
131 pub(crate) workgroup: BufferMap,
132}
133
134impl Memory {
135 #[must_use]
137 pub fn empty() -> Self {
138 Self::default()
139 }
140
141 #[must_use]
143 pub fn with_storage(mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Self {
144 self.storage.insert(name, buffer);
145 self
146 }
147
148 #[must_use]
150 pub fn with_workgroup(mut self, name: impl Into<Arc<str>>, buffer: Buffer) -> Self {
151 self.workgroup.insert(name, buffer);
152 self
153 }
154
155 #[must_use]
157 pub fn from_bytes(bytes: Vec<u8>) -> Self {
158 let mut storage = BufferMap::new();
159 storage.insert("__value", Buffer::new(bytes, vyre::ir::DataType::Bytes));
160 Self {
161 storage,
162 workgroup: BufferMap::new(),
163 }
164 }
165
166 #[must_use]
168 pub fn bytes(&self) -> Vec<u8> {
169 self.storage.get("__value").map_or_else(Vec::new, |buffer| {
170 buffer
171 .bytes
172 .read()
173 .unwrap_or_else(|error| error.into_inner())
174 .clone()
175 })
176 }
177
178 #[must_use]
180 pub fn into_bytes(self) -> Vec<u8> {
181 self.storage
182 .into_iter_pairs()
183 .find_map(|(name, buffer)| {
184 (name.as_ref() == "__value").then(|| {
185 std::sync::Arc::try_unwrap(buffer.bytes)
186 .map(|rw| rw.into_inner().unwrap_or_else(|error| error.into_inner()))
187 .unwrap_or_else(|a| {
188 a.read().unwrap_or_else(|error| error.into_inner()).clone()
189 })
190 })
191 })
192 .unwrap_or_default()
193 }
194}
195
196#[derive(Debug, Default)]
198pub struct LocalSlots {
199 names: rustc_hash::FxHashMap<Arc<str>, usize>,
200 slot_names: Vec<Arc<str>>,
201}
202
203impl LocalSlots {
204 #[must_use]
206 pub fn for_program(program: &Program) -> Self {
207 Self::for_nodes(program.entry())
208 }
209
210 #[must_use]
212 pub fn for_nodes(nodes: &[Node]) -> Self {
213 let mut slots = Self::default();
214 for node in nodes {
215 match visit_node_preorder(&mut slots, node) {
216 Continue(()) => {}
217 ControlFlow::Break(never) => match never {},
218 }
219 }
220 slots
221 }
222
223 fn slot(&self, name: &str) -> Option<usize> {
224 self.names.get(name).copied()
225 }
226
227 fn len(&self) -> usize {
228 self.slot_names.len()
229 }
230
231 fn intern(&mut self, name: &str) {
232 if self.names.contains_key(name) {
233 return;
234 }
235 let slot = self.slot_names.len();
236 let name: Arc<str> = Arc::from(name);
237 self.slot_names.push(Arc::clone(&name));
238 self.names.insert(name, slot);
239 }
240}
241
242impl ExprVisitor for LocalSlots {
243 type Break = Infallible;
244}
245
246impl NodeVisitor for LocalSlots {
247 type Break = Infallible;
248
249 fn visit_let(
250 &mut self,
251 _: &Node,
252 name: &vyre::ir::Ident,
253 value: &Expr,
254 ) -> ControlFlow<Self::Break> {
255 self.intern(name);
256 visit_preorder(self, value)
257 }
258
259 fn visit_assign(
260 &mut self,
261 _: &Node,
262 _: &vyre::ir::Ident,
263 value: &Expr,
264 ) -> ControlFlow<Self::Break> {
265 visit_preorder(self, value)
266 }
267
268 fn visit_store(
269 &mut self,
270 _: &Node,
271 _: &vyre::ir::Ident,
272 index: &Expr,
273 value: &Expr,
274 ) -> ControlFlow<Self::Break> {
275 visit_preorder(self, index)?;
276 visit_preorder(self, value)
277 }
278
279 fn visit_if(
280 &mut self,
281 _: &Node,
282 cond: &Expr,
283 _: &[Node],
284 _: &[Node],
285 ) -> ControlFlow<Self::Break> {
286 visit_preorder(self, cond)
287 }
288
289 fn visit_loop(
290 &mut self,
291 _: &Node,
292 var: &vyre::ir::Ident,
293 from: &Expr,
294 to: &Expr,
295 _: &[Node],
296 ) -> ControlFlow<Self::Break> {
297 self.intern(var);
298 visit_preorder(self, from)?;
299 visit_preorder(self, to)
300 }
301
302 fn visit_indirect_dispatch(
303 &mut self,
304 _: &Node,
305 _: &vyre::ir::Ident,
306 _: u64,
307 ) -> ControlFlow<Self::Break> {
308 Continue(())
309 }
310
311 fn visit_async_load(
312 &mut self,
313 _: &Node,
314 _: &vyre::ir::Ident,
315 _: &vyre::ir::Ident,
316 offset: &Expr,
317 size: &Expr,
318 _: &vyre::ir::Ident,
319 ) -> ControlFlow<Self::Break> {
320 visit_preorder(self, offset)?;
321 visit_preorder(self, size)
322 }
323
324 fn visit_async_store(
325 &mut self,
326 _: &Node,
327 _: &vyre::ir::Ident,
328 _: &vyre::ir::Ident,
329 offset: &Expr,
330 size: &Expr,
331 _: &vyre::ir::Ident,
332 ) -> ControlFlow<Self::Break> {
333 visit_preorder(self, offset)?;
334 visit_preorder(self, size)
335 }
336
337 fn visit_async_wait(&mut self, _: &Node, _: &vyre::ir::Ident) -> ControlFlow<Self::Break> {
338 Continue(())
339 }
340
341 fn visit_trap(
342 &mut self,
343 _: &Node,
344 address: &Expr,
345 _: &vyre::ir::Ident,
346 ) -> ControlFlow<Self::Break> {
347 visit_preorder(self, address)
348 }
349
350 fn visit_resume(&mut self, _: &Node, _: &vyre::ir::Ident) -> ControlFlow<Self::Break> {
351 Continue(())
352 }
353
354 fn visit_return(&mut self, _: &Node) -> ControlFlow<Self::Break> {
355 Continue(())
356 }
357
358 fn visit_barrier(&mut self, _: &Node) -> ControlFlow<Self::Break> {
359 Continue(())
360 }
361
362 fn visit_block(&mut self, _: &Node, _: &[Node]) -> ControlFlow<Self::Break> {
363 Continue(())
364 }
365
366 fn visit_region(
367 &mut self,
368 _: &Node,
369 _: &vyre::ir::Ident,
370 _: &Option<GeneratorRef>,
371 _: &[Node],
372 ) -> ControlFlow<Self::Break> {
373 Continue(())
374 }
375
376 fn visit_opaque_node(
377 &mut self,
378 _: &Node,
379 _: &dyn vyre::ir::NodeExtension,
380 ) -> ControlFlow<Self::Break> {
381 Continue(())
382 }
383}
384
385pub struct Invocation<'a> {
387 pub ids: InvocationIds,
389 slots: Arc<LocalSlots>,
390 locals: Vec<Option<Value>>,
391 immutable: Vec<bool>,
392 scopes: Vec<Vec<usize>>,
393 frames: Vec<Frame<'a>>,
394 pub returned: bool,
396 pub waiting_at_barrier: bool,
398 pub uniform_checks: Vec<(usize, bool)>,
400 pub(crate) pending_async: FxHashMap<Arc<str>, AsyncTransfer>,
403 pub(crate) op_cache: FxHashMap<*const Expr, ResolvedCall>,
404}
405
406#[derive(Debug, Clone, Copy)]
407pub(crate) struct ResolvedCall {
408 pub(crate) def: &'static OpDef,
409}
410
411#[non_exhaustive]
413pub enum Frame<'a> {
414 Nodes {
416 nodes: &'a [Node],
418 index: usize,
420 scoped: bool,
422 },
423 Loop {
425 var: &'a str,
427 next: u32,
429 to: u32,
431 body: &'a [Node],
433 },
434}
435
436impl<'a> Invocation<'a> {
437 pub fn new(ids: InvocationIds, entry: &'a [Node]) -> Self {
439 Self::with_slots(ids, entry, Arc::new(LocalSlots::for_nodes(entry)))
440 }
441
442 pub(crate) fn with_slots(
443 ids: InvocationIds,
444 entry: &'a [Node],
445 slots: Arc<LocalSlots>,
446 ) -> Self {
447 let slot_count = slots.len();
448 Self {
449 ids,
450 slots,
451 locals: vec![None; slot_count],
452 immutable: vec![false; slot_count],
453 scopes: vec![Vec::new()],
454 frames: vec![Frame::Nodes {
455 nodes: entry,
456 index: 0,
457 scoped: false,
458 }],
459 returned: false,
460 waiting_at_barrier: false,
461 uniform_checks: Vec::new(),
462 pending_async: FxHashMap::default(),
463 op_cache: FxHashMap::default(),
464 }
465 }
466
467 pub fn done(&self) -> bool {
469 self.returned || self.frames.is_empty()
470 }
471
472 pub fn push_scope(&mut self) {
481 self.scopes.push(Vec::new());
482 }
483
484 pub fn pop_scope(&mut self) {
493 if let Some(names) = self.scopes.pop() {
494 for slot in names {
495 self.locals[slot] = None;
496 self.immutable[slot] = false;
497 }
498 }
499 }
500
501 pub(crate) fn begin_async(&mut self, tag: &str, transfer: AsyncTransfer) -> Result<(), Error> {
502 let tag: Arc<str> = Arc::from(tag);
503 if self.pending_async.insert(tag.clone(), transfer).is_some() {
504 return Err(Error::interp(format!(
505 "async tag `{}` was started more than once before a matching wait. \
506 Fix: reuse the tag only after AsyncWait completes.",
507 tag
508 )));
509 }
510 Ok(())
511 }
512
513 pub(crate) fn finish_async(&mut self, tag: &str) -> Result<AsyncTransfer, Error> {
514 self.pending_async.remove(tag).ok_or_else(|| Error::interp(format!(
515 "async wait for tag `{tag}` has no matching async load. Fix: emit AsyncLoad before AsyncWait."
516 )))
517 }
518
519 pub fn local(&self, name: &str) -> Option<&Value> {
521 self.slots
522 .slot(name)
523 .and_then(|slot| self.locals.get(slot))
524 .and_then(Option::as_ref)
525 }
526
527 pub fn bind(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
539 let slot = self.slots.slot(name).ok_or_else(|| {
540 Error::interp(format!(
541 "local binding `{name}` has no preassigned slot. Fix: rebuild the local slot layout from the full Program before interpretation."
542 ))
543 })?;
544 if self.locals[slot].is_some() {
545 return Err(Error::interp(format!(
546 "duplicate local binding `{name}`. Fix: choose a unique local name; shadowing is not allowed."
547 )));
548 }
549 self.locals[slot] = Some(value);
550 if let Some(scope) = self.scopes.last_mut() {
551 scope.push(slot);
552 }
553 Ok(())
554 }
555
556 pub fn bind_loop_var(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
568 self.bind(name, value)?;
569 let slot = self.slots.slot(name).ok_or_else(|| {
570 Error::interp(format!(
571 "local binding `{name}` disappeared after bind. Fix: keep local slot layout immutable during interpretation."
572 ))
573 })?;
574 self.immutable[slot] = true;
575 Ok(())
576 }
577
578 pub fn assign(&mut self, name: &str, value: Value) -> Result<(), vyre::Error> {
580 let slot = self.slots.slot(name).ok_or_else(|| {
581 Error::interp(format!(
582 "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
583 ))
584 })?;
585 if self.immutable[slot] {
586 return Err(Error::interp(format!(
587 "assignment to loop variable `{name}`. Fix: loop variables are immutable."
588 )));
589 }
590 let Some(local) = self.locals.get_mut(slot).and_then(Option::as_mut) else {
591 return Err(Error::interp(format!(
592 "assignment to undeclared variable `{name}`. Fix: add a Let before assigning it."
593 )));
594 };
595 *local = value;
596 Ok(())
597 }
598
599 pub(crate) fn frames_mut(&mut self) -> &mut Vec<Frame<'a>> {
600 &mut self.frames
601 }
602}
603
604pub(crate) enum AsyncTransfer {
606 Copy {
608 destination: Arc<str>,
609 start: usize,
610 payload: Vec<u8>,
611 },
612}
613
614#[cfg(test)]
615#[allow(dead_code)]
616pub(crate) fn create_invocations(
617 program: &Program,
618 workgroup: [u32; 3],
619 slots: Arc<LocalSlots>,
620) -> Result<Vec<Invocation<'_>>, vyre::Error> {
621 let global_dim = |wgid: u32, size: u32, local: u32| {
622 wgid
623 .checked_mul(size)
624 .and_then(|base| base.checked_add(local))
625 .ok_or_else(|| Error::interp(
626 "workgroup * dispatch dimensions overflow u32 global id. Fix: reduce workgroup id or workgroup size so each global_invocation_id component fits in u32.",
627 ))
628 };
629 let [sx, sy, sz] = program.workgroup_size();
630 let invocation_count = sx
631 .checked_mul(sy)
632 .and_then(|count| count.checked_mul(sz))
633 .ok_or_else(|| {
634 Error::interp(
635 "workgroup invocation count overflows u32. Fix: reduce workgroup dimensions before reference execution.",
636 )
637 })?;
638 let mut invocations = Vec::with_capacity(usize::try_from(invocation_count).map_err(|_| {
639 Error::interp(
640 "workgroup invocation count exceeds host usize. Fix: reduce workgroup dimensions before reference execution.",
641 )
642 })?);
643 for z in 0..sz {
644 for y in 0..sy {
645 for x in 0..sx {
646 let local = [x, y, z];
647 let global = [
648 global_dim(workgroup[0], sx, x)?,
649 global_dim(workgroup[1], sy, y)?,
650 global_dim(workgroup[2], sz, z)?,
651 ];
652 invocations.push(Invocation::with_slots(
653 InvocationIds {
654 global,
655 workgroup,
656 local,
657 },
658 program.entry(),
659 Arc::clone(&slots),
660 ));
661 }
662 }
663 }
664 Ok(invocations)
665}
666
667#[cfg(test)]
668#[allow(dead_code)]
669pub(crate) fn workgroup_memory(program: &Program) -> Result<BufferMap, vyre::Error> {
670 let mut workgroup = BufferMap::new();
671 let mut allocated = 0usize;
672 for decl in program
673 .buffers()
674 .iter()
675 .filter(|decl| decl.access() == BufferAccess::Workgroup)
676 {
677 let element_size = decl.element().min_bytes();
678 let len = (decl.count() as usize)
679 .checked_mul(element_size)
680 .ok_or_else(|| Error::interp(format!(
681 "workgroup buffer `{}` byte size overflows usize. Fix: reduce count or element size.",
682 decl.name()
683 )))?;
684 allocated = allocated
685 .checked_add(len)
686 .ok_or_else(|| Error::interp(
687 "total workgroup memory byte size overflows usize. Fix: reduce workgroup buffer declarations.",
688 ))?;
689 if allocated > MAX_WORKGROUP_BYTES {
690 return Err(Error::interp(format!(
691 "workgroup memory requires {allocated} bytes, exceeding the {MAX_WORKGROUP_BYTES}-byte reference budget. Fix: reduce workgroup buffer counts."
692 )));
693 }
694 workgroup.insert(decl.name(), Buffer::new(vec![0; len], decl.element()));
695 }
696 Ok(workgroup)
697}