use std::sync::Arc;
use smallvec::SmallVec;
use vyre_foundation::ir::{Ident, Node, Program};
use vyre_foundation::memory_model::MemoryOrdering;
use crate::backend::{BackendError, DispatchConfig, OutputBuffers, VyreBackend};
fn entry_sequence(program: &Program) -> &[Node] {
let entry = program.entry();
if entry.len() == 1 {
if let Node::Region { body, .. } = &entry[0] {
return body.as_slice();
}
}
entry
}
#[must_use]
pub fn contains_grid_sync(program: &Program) -> bool {
entry_sequence(program).iter().any(|node| {
matches!(
node,
Node::Barrier {
ordering: MemoryOrdering::GridSync,
..
}
)
})
}
#[must_use]
pub fn split_on_grid_sync(program: &Program) -> Vec<Program> {
let inner = entry_sequence(program);
let split_count = inner
.iter()
.filter(|node| {
matches!(
node,
Node::Barrier {
ordering: MemoryOrdering::GridSync,
..
}
)
})
.count();
if split_count == 0 {
return vec![program.clone()];
}
let outer_generator: Option<Ident> = if let [Node::Region { generator, .. }] = program.entry() {
Some(generator.clone())
} else {
None
};
let segment_count = split_count + 1;
let executable_nodes = inner.len().saturating_sub(split_count);
let segment_capacity = executable_nodes.div_ceil(segment_count);
let mut segments = Vec::with_capacity(segment_count);
let mut current = Vec::with_capacity(segment_capacity);
for node in inner {
match node {
Node::Barrier {
ordering: MemoryOrdering::GridSync,
..
} => {
let entry = std::mem::replace(&mut current, Vec::with_capacity(segment_capacity));
segments.push(wrap_split_segment(program, outer_generator.as_ref(), entry));
}
other => {
current.push(other.clone());
}
}
}
segments.push(wrap_split_segment(
program,
outer_generator.as_ref(),
current,
));
segments
}
fn wrap_split_segment(
program: &Program,
outer_generator: Option<&Ident>,
entry: Vec<Node>,
) -> Program {
let wrapped_entry = match outer_generator {
Some(generator) => vec![Node::Region {
generator: generator.clone(),
source_region: None,
body: Arc::new(entry),
}],
None => entry,
};
program.with_rewritten_entry(wrapped_entry)
}
pub fn dispatch_with_grid_sync_split(
backend: &dyn VyreBackend,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
let mut outputs = Vec::new();
dispatch_with_grid_sync_split_into(backend, program, inputs, config, &mut outputs)?;
Ok(outputs)
}
pub fn dispatch_with_grid_sync_split_into(
backend: &dyn VyreBackend,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
outputs: &mut OutputBuffers,
) -> Result<(), BackendError> {
if !contains_grid_sync(program) || backend.supports_grid_sync() {
return backend.dispatch_borrowed_into(program, inputs, config, outputs);
}
let segments = split_on_grid_sync(program);
if segments.is_empty() {
return Err(BackendError::InvalidProgram {
fix: "Fix: program contains GridSync barrier but split_on_grid_sync produced 0 \
segments. This is a grid_sync invariant bug — split_on_grid_sync must \
always return at least one segment."
.to_string(),
});
}
outputs.clear();
let mut current_inputs: Vec<GridSyncInput<'_>> = inputs
.iter()
.copied()
.map(GridSyncInput::Borrowed)
.collect();
let mut segment_outputs = Vec::new();
for (segment_idx, segment) in segments.iter().enumerate() {
let borrowed: SmallVec<[&[u8]; 8]> =
current_inputs.iter().map(GridSyncInput::as_slice).collect();
if segment_idx + 1 == segments.len() {
return backend
.dispatch_borrowed_into(segment, borrowed.as_slice(), config, outputs)
.map_err(|error| grid_sync_segment_error(error, segment_idx, segments.len()));
}
backend
.dispatch_borrowed_into(segment, borrowed.as_slice(), config, &mut segment_outputs)
.map_err(|error| grid_sync_segment_error(error, segment_idx, segments.len()))?;
drop(borrowed);
refresh_readwrite_inputs(segment, &mut segment_outputs, &mut current_inputs);
}
Ok(())
}
fn grid_sync_segment_error(
error: BackendError,
segment_idx: usize,
segment_count: usize,
) -> BackendError {
match error {
BackendError::InvalidProgram { fix } => BackendError::InvalidProgram {
fix: format!(
"Fix: grid-sync split segment {segment_idx} of {segment_count} dispatch failed: {fix}"
),
},
other => other,
}
}
enum GridSyncInput<'a> {
Borrowed(&'a [u8]),
Owned(Vec<u8>),
}
impl GridSyncInput<'_> {
fn as_slice(&self) -> &[u8] {
match self {
Self::Borrowed(bytes) => bytes,
Self::Owned(bytes) => bytes.as_slice(),
}
}
}
fn refresh_readwrite_inputs(
segment: &Program,
outputs: &mut Vec<Vec<u8>>,
inputs: &mut [GridSyncInput<'_>],
) {
use vyre_foundation::ir::BufferAccess;
let mut input_idx = 0usize;
let mut output_idx = 0usize;
for buffer in segment.buffers() {
if matches!(buffer.access(), BufferAccess::Workgroup) {
continue;
}
let is_output_buffer = buffer.is_output();
let is_readwrite = matches!(buffer.access(), BufferAccess::ReadWrite);
if is_readwrite && !is_output_buffer {
if let (Some(slot), Some(bytes)) =
(inputs.get_mut(input_idx), outputs.get_mut(output_idx))
{
*slot = GridSyncInput::Owned(std::mem::take(bytes));
}
}
if !is_output_buffer {
input_idx += 1;
}
if is_readwrite {
output_idx += 1;
}
}
outputs.clear();
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr};
fn buffer() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn region(generator: &str, body: Vec<Node>) -> Node {
Node::Region {
generator: Ident::from(generator),
source_region: None,
body: Arc::new(body),
}
}
fn inner_len(program: &Program) -> usize {
entry_sequence(program).len()
}
#[test]
fn no_grid_sync_returns_single_segment() {
let program = Program::wrapped(
vec![buffer()],
[1, 1, 1],
vec![region(
"a",
vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
)],
);
assert!(!contains_grid_sync(&program));
let segments = split_on_grid_sync(&program);
assert_eq!(segments.len(), 1);
assert_eq!(inner_len(&segments[0]), 1);
}
#[test]
fn one_grid_sync_splits_into_two() {
let program = Program::wrapped(
vec![buffer()],
[1, 1, 1],
vec![
region("a", vec![Node::store("buf", Expr::u32(0), Expr::u32(1))]),
Node::barrier_with_ordering(MemoryOrdering::GridSync),
region("b", vec![Node::store("buf", Expr::u32(1), Expr::u32(2))]),
],
);
assert!(contains_grid_sync(&program));
let segments = split_on_grid_sync(&program);
assert_eq!(segments.len(), 2);
assert_eq!(inner_len(&segments[0]), 1);
assert_eq!(inner_len(&segments[1]), 1);
}
#[test]
fn three_grid_syncs_split_into_four() {
let program = Program::wrapped(
vec![buffer()],
[1, 1, 1],
vec![
region("a", vec![Node::Return]),
Node::barrier_with_ordering(MemoryOrdering::GridSync),
region("b", vec![Node::Return]),
Node::barrier_with_ordering(MemoryOrdering::GridSync),
region("c", vec![Node::Return]),
Node::barrier_with_ordering(MemoryOrdering::GridSync),
region("d", vec![Node::Return]),
],
);
let segments = split_on_grid_sync(&program);
assert_eq!(segments.len(), 4);
}
#[test]
fn workgroup_barrier_does_not_split() {
let program = Program::wrapped(
vec![buffer()],
[1, 1, 1],
vec![
region("a", vec![Node::Return]),
Node::barrier_with_ordering(MemoryOrdering::SeqCst),
region("b", vec![Node::Return]),
],
);
assert!(!contains_grid_sync(&program));
let segments = split_on_grid_sync(&program);
assert_eq!(segments.len(), 1);
assert_eq!(inner_len(&segments[0]), 3);
}
#[test]
fn buffers_and_workgroup_size_propagate_to_each_segment() {
let program = Program::wrapped(
vec![buffer()],
[256, 1, 1],
vec![
region("a", vec![Node::Return]),
Node::barrier_with_ordering(MemoryOrdering::GridSync),
region("b", vec![Node::Return]),
],
);
let segments = split_on_grid_sync(&program);
for seg in &segments {
assert_eq!(seg.workgroup_size(), [256, 1, 1]);
assert_eq!(seg.buffers().len(), 1);
assert_eq!(seg.buffers()[0].name(), "buf");
}
}
}