1use vyre::ir::{Expr, Node, Program};
10
11use crate::{
12 execution::expr as eval_expr,
13 oob,
14 workgroup::{AsyncTransfer, Frame, Invocation, Memory},
15};
16use vyre::Error;
17
18pub fn step<'a>(
25 invocation: &mut Invocation<'a>,
26 memory: &mut Memory,
27 program: &'a Program,
28) -> Result<(), vyre::Error> {
29 if invocation.done() || invocation.waiting_at_barrier {
30 return Ok(());
31 }
32
33 loop {
34 let Some(frame) = invocation.frames_mut().pop() else {
35 return Ok(());
36 };
37 match frame {
38 Frame::Nodes {
39 nodes,
40 index,
41 scoped,
42 } => {
43 if step_nodes_frame(invocation, memory, program, nodes, index, scoped)? {
44 return Ok(());
45 }
46 }
47 Frame::Loop {
48 var,
49 next,
50 to,
51 body,
52 } => step_loop_frame(invocation, var, next, to, body)?,
53 }
54 }
55}
56
57fn step_nodes_frame<'a>(
58 invocation: &mut Invocation<'a>,
59 memory: &mut Memory,
60 program: &'a Program,
61 nodes: &'a [Node],
62 index: usize,
63 scoped: bool,
64) -> Result<bool, vyre::Error> {
65 if index >= nodes.len() {
66 if scoped {
67 invocation.pop_scope();
68 }
69 return Ok(false);
70 }
71
72 invocation.frames_mut().push(Frame::Nodes {
73 nodes,
74 index: index + 1,
75 scoped,
76 });
77 execute_node(&nodes[index], invocation, memory, program)?;
78 Ok(true)
79}
80
81fn step_loop_frame<'a>(
82 invocation: &mut Invocation<'a>,
83 var: &'a str,
84 next: u32,
85 to: u32,
86 body: &'a [Node],
87) -> Result<(), vyre::Error> {
88 if next >= to {
89 return Ok(());
90 }
91 invocation.frames_mut().push(Frame::Loop {
92 var,
93 next: next.wrapping_add(1),
94 to,
95 body,
96 });
97 invocation.push_scope();
98 invocation.bind_loop_var(var, crate::value::Value::U32(next))?;
99 invocation.frames_mut().push(Frame::Nodes {
100 nodes: body,
101 index: 0,
102 scoped: true,
103 });
104 Ok(())
105}
106
107fn execute_node<'a>(
108 node: &'a Node,
109 invocation: &mut Invocation<'a>,
110 memory: &mut Memory,
111 program: &'a Program,
112) -> Result<(), vyre::Error> {
113 match node {
114 Node::Let { name, value } => eval_let(name, value, invocation, memory, program),
115 Node::Assign { name, value } => eval_assign(name, value, invocation, memory, program),
116 Node::Store {
117 buffer,
118 index,
119 value,
120 } => eval_store(buffer, index, value, invocation, memory, program),
121 Node::If {
122 cond,
123 then,
124 otherwise,
125 } => eval_if(cond, then, otherwise, node, invocation, memory, program),
126 Node::Loop {
127 var,
128 from,
129 to,
130 body,
131 } => eval_loop(var, from, to, body, invocation, memory, program),
132 Node::Return => eval_return(invocation),
133 Node::Block(nodes) => eval_block(nodes, invocation),
134 Node::Barrier { .. } => eval_barrier(invocation),
135 Node::IndirectDispatch {
136 count_buffer,
137 count_offset,
138 } => eval_indirect_dispatch(count_buffer, *count_offset, memory, program),
139 Node::AsyncLoad {
140 source,
141 destination,
142 offset,
143 size,
144 tag,
145 } => eval_async_load(
146 AsyncLoadEval {
147 source,
148 destination,
149 offset,
150 size,
151 tag,
152 },
153 invocation,
154 memory,
155 program,
156 ),
157 Node::AsyncStore {
158 source,
159 destination,
160 offset,
161 size,
162 tag,
163 } => eval_async_store(
164 AsyncStoreEval {
165 source,
166 destination,
167 offset,
168 size,
169 tag,
170 },
171 invocation,
172 memory,
173 program,
174 ),
175 Node::AsyncWait { tag } => eval_async_wait(tag, invocation, memory, program),
176 Node::Trap { address, tag } => {
177 let address = eval_expr::eval(address, invocation, memory, program)?
178 .try_as_u32()
179 .ok_or_else(|| {
180 Error::interp(format!(
181 "reference trap `{tag}` address is not a u32. Fix: pass a scalar u32 trap address."
182 ))
183 })?;
184 Err(vyre::Error::interp(format!(
185 "reference dispatch trapped: address={address}, tag=`{tag}`. Fix: handle the trap condition or route this Program through a backend/runtime with replay support."
186 )))
187 }
188 Node::Resume { tag } => Err(vyre::Error::interp(format!(
189 "reference dispatch reached Resume `{tag}` without a replay runtime. Fix: lower Resume through a runtime-owned replay path before reference execution."
190 ))),
191 Node::Region { body, .. } => eval_block(body, invocation),
192 Node::Opaque(extension) => Err(vyre::Error::interp(format!(
193 "reference interpreter does not support opaque node extension `{}`/`{}`. Fix: provide a reference evaluator for this NodeExtension or lower it to core Node variants before evaluation.",
194 extension.extension_kind(),
195 extension.debug_identity()
196 ))),
197 _ => Err(vyre::Error::interp(
198 "reference interpreter encountered an unknown Node variant. Fix: update vyre-reference before executing this IR.",
199 )),
200 }
201}
202
203fn eval_let(
204 name: &str,
205 value: &Expr,
206 invocation: &mut Invocation<'_>,
207 memory: &mut Memory,
208 program: &Program,
209) -> Result<(), vyre::Error> {
210 let value = eval_expr::eval(value, invocation, memory, program)?;
211 invocation.bind(name, value)
212}
213
214fn eval_assign(
215 name: &str,
216 value: &Expr,
217 invocation: &mut Invocation<'_>,
218 memory: &mut Memory,
219 program: &Program,
220) -> Result<(), vyre::Error> {
221 let value = eval_expr::eval(value, invocation, memory, program)?;
222 invocation.assign(name, value)
223}
224
225fn eval_store(
226 buffer: &str,
227 index: &Expr,
228 value: &Expr,
229 invocation: &mut Invocation<'_>,
230 memory: &mut Memory,
231 program: &Program,
232) -> Result<(), vyre::Error> {
233 let index = eval_expr::eval(index, invocation, memory, program)?;
234 let index = index
235 .try_as_u32()
236 .ok_or_else(|| Error::interp(format!(
237 "store index {index:?} cannot be represented as u32. Fix: use a non-negative scalar index within u32."
238 )))?;
239 let value = eval_expr::eval(value, invocation, memory, program)?;
240 let target = eval_expr::buffer_mut(memory, program, buffer)?;
241 oob::store(target, index, &value);
242 Ok(())
243}
244
245fn eval_indirect_dispatch(
246 count_buffer: &str,
247 count_offset: u64,
248 memory: &Memory,
249 program: &Program,
250) -> Result<(), vyre::Error> {
251 if count_offset % 4 != 0 {
252 return Err(Error::interp(format!(
253 "indirect dispatch offset {count_offset} is not 4-byte aligned. Fix: use a u32-aligned dispatch tuple."
254 )));
255 }
256 let decl = program.buffer(count_buffer).ok_or_else(|| {
257 Error::interp(format!(
258 "indirect dispatch references unknown buffer `{count_buffer}`. Fix: declare the count buffer before execution."
259 ))
260 })?;
261 let buffer = if decl.access() == vyre::ir::BufferAccess::Workgroup {
262 memory.workgroup.get(count_buffer)
263 } else {
264 memory.storage.get(count_buffer)
265 }
266 .ok_or_else(|| {
267 Error::interp(format!(
268 "indirect dispatch buffer `{count_buffer}` is missing. Fix: initialize the count buffer before execution."
269 ))
270 })?;
271 let required_end = count_offset.checked_add(12).ok_or_else(|| {
272 Error::interp(
273 "indirect dispatch byte range overflowed u64. Fix: shrink the count offset."
274 .to_string(),
275 )
276 })?;
277 let byte_len = buffer
278 .bytes
279 .read()
280 .map_err(|_| {
281 Error::interp(format!(
282 "indirect dispatch buffer `{count_buffer}` lock is poisoned. Fix: rebuild the interpreter memory state before execution."
283 ))
284 })?
285 .len();
286 if u64::try_from(byte_len).unwrap_or(u64::MAX) < required_end {
287 return Err(Error::interp(format!(
288 "indirect dispatch buffer `{count_buffer}` is too short for a 3-word dispatch tuple at byte offset {count_offset}. Fix: provide 12 readable bytes starting at that offset."
289 )));
290 }
291 Ok(())
292}
293
294struct AsyncLoadEval<'a> {
295 source: &'a str,
296 destination: &'a str,
297 offset: &'a Expr,
298 size: &'a Expr,
299 tag: &'a str,
300}
301
302struct AsyncStoreEval<'a> {
303 source: &'a str,
304 destination: &'a str,
305 offset: &'a Expr,
306 size: &'a Expr,
307 tag: &'a str,
308}
309
310fn eval_async_load(
311 request: AsyncLoadEval<'_>,
312 invocation: &mut Invocation<'_>,
313 memory: &mut Memory,
314 program: &Program,
315) -> Result<(), vyre::Error> {
316 let start = eval_byte_count(
317 request.offset,
318 "async load source offset",
319 invocation,
320 memory,
321 program,
322 )?;
323 let byte_count = eval_byte_count(request.size, "async load size", invocation, memory, program)?;
324 let payload = read_bytes(memory, program, request.source, start, byte_count)?;
325 ensure_writable_buffer(memory, program, request.destination)?;
326 invocation.begin_async(
327 request.tag,
328 AsyncTransfer::Copy {
329 destination: request.destination.into(),
330 start: 0,
331 payload,
332 },
333 )
334}
335
336fn eval_async_store(
337 request: AsyncStoreEval<'_>,
338 invocation: &mut Invocation<'_>,
339 memory: &mut Memory,
340 program: &Program,
341) -> Result<(), vyre::Error> {
342 let start = eval_byte_count(
343 request.offset,
344 "async store destination offset",
345 invocation,
346 memory,
347 program,
348 )?;
349 let byte_count = eval_byte_count(
350 request.size,
351 "async store size",
352 invocation,
353 memory,
354 program,
355 )?;
356 let payload = read_bytes(memory, program, request.source, 0, byte_count)?;
357 ensure_writable_buffer(memory, program, request.destination)?;
358 invocation.begin_async(
359 request.tag,
360 AsyncTransfer::Copy {
361 destination: request.destination.into(),
362 start,
363 payload,
364 },
365 )
366}
367
368fn eval_async_wait(
369 tag: &str,
370 invocation: &mut Invocation<'_>,
371 memory: &mut Memory,
372 program: &Program,
373) -> Result<(), vyre::Error> {
374 apply_async_transfer(invocation.finish_async(tag)?, memory, program)
375}
376
377fn eval_byte_count(
378 expr: &Expr,
379 label: &str,
380 invocation: &mut Invocation<'_>,
381 memory: &mut Memory,
382 program: &Program,
383) -> Result<usize, Error> {
384 let value = eval_expr::eval(expr, invocation, memory, program)?;
385 usize::try_from(value.try_as_u64().ok_or_else(|| {
386 Error::interp(format!(
387 "{label} cannot be represented as u64. Fix: use an in-range non-negative byte count."
388 ))
389 })?)
390 .map_err(|_| {
391 Error::interp(format!(
392 "{label} exceeds host usize. Fix: reduce the async transfer span."
393 ))
394 })
395}
396
397fn read_bytes(
398 memory: &Memory,
399 program: &Program,
400 source: &str,
401 start: usize,
402 byte_count: usize,
403) -> Result<Vec<u8>, Error> {
404 let buffer = resolve_buffer(memory, program, source)?;
405 let bytes = buffer
406 .bytes
407 .read()
408 .unwrap_or_else(|error| error.into_inner());
409 let mut payload = vec![0; byte_count];
410 if start < bytes.len() {
411 let available = (bytes.len() - start).min(byte_count);
412 payload[..available].copy_from_slice(&bytes[start..start + available]);
413 }
414 Ok(payload)
415}
416
417fn ensure_writable_buffer(memory: &mut Memory, program: &Program, name: &str) -> Result<(), Error> {
418 eval_expr::buffer_mut(memory, program, name).map(|_| ())
419}
420
421fn apply_async_transfer(
422 transfer: AsyncTransfer,
423 memory: &mut Memory,
424 program: &Program,
425) -> Result<(), Error> {
426 match transfer {
427 AsyncTransfer::Copy {
428 destination,
429 start,
430 payload,
431 } => {
432 let buffer = eval_expr::buffer_mut(memory, program, &destination)?;
433 let mut bytes = buffer
434 .bytes
435 .write()
436 .unwrap_or_else(|error| error.into_inner());
437 if start >= bytes.len() {
438 return Ok(());
439 }
440 let write_len = payload.len().min(bytes.len() - start);
441 bytes[start..start + write_len].copy_from_slice(&payload[..write_len]);
442 Ok(())
443 }
444 }
445}
446
447fn resolve_buffer<'a>(
448 memory: &'a Memory,
449 program: &Program,
450 name: &str,
451) -> Result<&'a oob::Buffer, Error> {
452 let decl = program.buffer(name).ok_or_else(|| {
453 Error::interp(format!(
454 "missing buffer declaration `{name}`. Fix: declare every async transfer buffer."
455 ))
456 })?;
457 if decl.access() == vyre::ir::BufferAccess::Workgroup {
458 memory.workgroup.get(name)
459 } else {
460 memory.storage.get(name)
461 }
462 .ok_or_else(|| {
463 Error::interp(format!(
464 "missing buffer `{name}`. Fix: initialize every declared async transfer buffer."
465 ))
466 })
467}
468
469fn eval_if<'a>(
470 cond: &Expr,
471 then: &'a [Node],
472 otherwise: &'a [Node],
473 node: &Node,
474 invocation: &mut Invocation<'a>,
475 memory: &mut Memory,
476 program: &Program,
477) -> Result<(), vyre::Error> {
478 let cond_value = eval_expr::eval(cond, invocation, memory, program)?.truthy();
479 if contains_barrier(then) || contains_barrier(otherwise) {
480 invocation.uniform_checks.push((node_id(node), cond_value));
481 }
482 let branch = if cond_value { then } else { otherwise };
483 invocation.push_scope();
484 invocation.frames_mut().push(Frame::Nodes {
485 nodes: branch,
486 index: 0,
487 scoped: true,
488 });
489 Ok(())
490}
491
492fn eval_loop<'a>(
493 var: &'a str,
494 from: &Expr,
495 to: &Expr,
496 body: &'a [Node],
497 invocation: &mut Invocation<'a>,
498 memory: &mut Memory,
499 program: &Program,
500) -> Result<(), vyre::Error> {
501 let from_value = eval_expr::eval(from, invocation, memory, program)?;
502 let to_value = eval_expr::eval(to, invocation, memory, program)?;
503 let from = from_value.try_as_u32().ok_or_else(|| {
504 Error::interp(format!(
505 "loop lower bound {from_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
506 ))
507 })?;
508 let to = to_value.try_as_u32().ok_or_else(|| Error::interp(format!(
509 "loop upper bound {to_value:?} cannot be represented as u32. Fix: use an in-range unsigned loop bound."
510 )))?;
511 invocation.frames_mut().push(Frame::Loop {
512 var,
513 next: from,
514 to,
515 body,
516 });
517 Ok(())
518}
519
520fn eval_return(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
521 invocation.frames_mut().clear();
522 invocation.returned = true;
523 Ok(())
524}
525
526fn eval_block<'a>(nodes: &'a [Node], invocation: &mut Invocation<'a>) -> Result<(), vyre::Error> {
527 invocation.push_scope();
528 invocation.frames_mut().push(Frame::Nodes {
529 nodes,
530 index: 0,
531 scoped: true,
532 });
533 Ok(())
534}
535
536fn eval_barrier(invocation: &mut Invocation<'_>) -> Result<(), vyre::Error> {
537 invocation.waiting_at_barrier = true;
538 Ok(())
539}
540
541fn contains_barrier(nodes: &[Node]) -> bool {
544 nodes.iter().any(node_contains_barrier)
545}
546
547fn node_contains_barrier(node: &Node) -> bool {
548 match node {
549 Node::Barrier { .. } => true,
550 Node::Let { .. }
551 | Node::Assign { .. }
552 | Node::Store { .. }
553 | Node::Return
554 | Node::IndirectDispatch { .. }
555 | Node::AsyncLoad { .. }
556 | Node::AsyncStore { .. }
557 | Node::AsyncWait { .. }
558 | Node::Trap { .. }
559 | Node::Resume { .. }
560 | Node::Opaque(_) => false,
561 Node::If {
562 then, otherwise, ..
563 } => contains_barrier(then) || contains_barrier(otherwise),
564 Node::Loop { body, .. } => contains_barrier(body),
565 Node::Block(body) => contains_barrier(body),
566 _ => false,
567 }
568}
569
570fn node_id(node: &Node) -> usize {
571 std::ptr::from_ref(node).addr()
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use crate::oob::Buffer;
578 use crate::workgroup::InvocationIds;
579 use vyre::ir::{BufferDecl, DataType};
580
581 fn run_program(program: &Program, memory: &mut Memory) -> Result<(), vyre::Error> {
582 let mut invocation = Invocation::new(InvocationIds::ZERO, program.entry());
583 while !invocation.done() {
584 step(&mut invocation, memory, program)?;
585 }
586 Ok(())
587 }
588
589 fn bytes(memory: &Memory, name: &str) -> Vec<u8> {
590 memory
591 .storage
592 .get(name)
593 .expect("test buffer exists")
594 .bytes
595 .read()
596 .unwrap_or_else(|error| error.into_inner())
597 .clone()
598 }
599
600 #[test]
601 fn async_load_wait_copies_payload_into_destination() {
602 let program = Program::wrapped(
603 vec![
604 BufferDecl::read("src", 0, DataType::Bytes).with_count(8),
605 BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
606 ],
607 [1, 1, 1],
608 vec![
609 Node::async_load_ext("src", "dst", Expr::u32(2), Expr::u32(4), "copy"),
610 Node::AsyncWait { tag: "copy".into() },
611 ],
612 );
613 let mut memory = Memory::empty()
614 .with_storage(
615 "src",
616 Buffer::new(vec![10, 11, 12, 13, 14, 15, 16, 17], DataType::Bytes),
617 )
618 .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
619
620 run_program(&program, &mut memory).unwrap();
621
622 assert_eq!(bytes(&memory, "dst"), vec![12, 13, 14, 15, 0, 0, 0, 0]);
623 }
624
625 #[test]
626 fn async_store_wait_copies_payload_at_destination_offset() {
627 let program = Program::wrapped(
628 vec![
629 BufferDecl::read("src", 0, DataType::Bytes).with_count(4),
630 BufferDecl::output("dst", 1, DataType::Bytes).with_count(8),
631 ],
632 [1, 1, 1],
633 vec![
634 Node::async_store("src", "dst", Expr::u32(3), Expr::u32(4), "store"),
635 Node::AsyncWait {
636 tag: "store".into(),
637 },
638 ],
639 );
640 let mut memory = Memory::empty()
641 .with_storage("src", Buffer::new(vec![21, 22, 23, 24], DataType::Bytes))
642 .with_storage("dst", Buffer::new(vec![0; 8], DataType::Bytes));
643
644 run_program(&program, &mut memory).unwrap();
645
646 assert_eq!(bytes(&memory, "dst"), vec![0, 0, 0, 21, 22, 23, 24, 0]);
647 }
648}