use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType};
pub const BINDING_NODES: u32 = 0;
pub const BINDING_EDGE_OFFSETS: u32 = 1;
pub const BINDING_EDGE_TARGETS: u32 = 2;
pub const BINDING_EDGE_KIND_MASK: u32 = 3;
pub const BINDING_NODE_TAGS: u32 = 4;
pub const BINDING_PRIMITIVE_START: u32 = 5;
pub const NAME_NODES: &str = "pg_nodes";
pub const NAME_EDGE_OFFSETS: &str = "pg_edge_offsets";
pub const NAME_EDGE_TARGETS: &str = "pg_edge_targets";
pub const NAME_EDGE_KIND_MASK: &str = "pg_edge_kind_mask";
pub const NAME_NODE_TAGS: &str = "pg_node_tags";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ProgramGraphShape {
pub node_count: u32,
pub edge_count: u32,
}
impl ProgramGraphShape {
#[must_use]
pub fn new(node_count: u32, edge_count: u32) -> Self {
Self {
node_count,
edge_count,
}
}
#[must_use]
pub fn read_only_buffers(&self) -> Vec<BufferDecl> {
self.try_read_only_buffers()
.unwrap_or_else(|error| panic!("{error}"))
}
pub fn try_read_only_buffers(&self) -> Result<Vec<BufferDecl>, String> {
let edge_offset_count = self.node_count.checked_add(1).ok_or_else(|| {
format!(
"ProgramGraphShape node_count={} overflows edge-offset buffer count. Fix: shard the graph before GPU dispatch.",
self.node_count
)
})?;
Ok(read_only_buffers_with_counts(
self.node_count,
edge_offset_count,
self.edge_count,
self.node_count,
))
}
}
fn read_only_buffers_with_counts(
node_count: u32,
edge_offset_count: u32,
edge_count: u32,
node_tag_count: u32,
) -> Vec<BufferDecl> {
vec![
BufferDecl::storage(
NAME_NODES,
BINDING_NODES,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(node_count),
BufferDecl::storage(
NAME_EDGE_OFFSETS,
BINDING_EDGE_OFFSETS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(edge_offset_count),
BufferDecl::storage(
NAME_EDGE_TARGETS,
BINDING_EDGE_TARGETS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(edge_count.max(1)),
BufferDecl::storage(
NAME_EDGE_KIND_MASK,
BINDING_EDGE_KIND_MASK,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(edge_count.max(1)),
BufferDecl::storage(
NAME_NODE_TAGS,
BINDING_NODE_TAGS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(node_tag_count),
]
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GraphValidationError {
EdgeOffsetsLen {
expected: usize,
got: usize,
},
EdgeTargetsLen {
expected: usize,
got: usize,
},
EdgeKindMaskLen {
expected: usize,
got: usize,
},
NodeTagsLen {
expected: usize,
got: usize,
},
NodesLen {
expected: usize,
got: usize,
},
EdgeOutOfRange {
index: usize,
target: u32,
node_count: u32,
},
NonMonotonicOffsets {
index: usize,
},
EdgeCountMismatch {
expected: usize,
got: usize,
},
}
pub fn validate_program_graph(
shape: ProgramGraphShape,
nodes: &[u32],
edge_offsets: &[u32],
edge_targets: &[u32],
edge_kind_mask: &[u32],
node_tags: &[u32],
) -> Result<(), GraphValidationError> {
let n = shape.node_count as usize;
let e = shape.edge_count as usize;
if nodes.len() != n {
return Err(GraphValidationError::NodesLen {
expected: n,
got: nodes.len(),
});
}
if edge_offsets.len() != n + 1 {
return Err(GraphValidationError::EdgeOffsetsLen {
expected: n + 1,
got: edge_offsets.len(),
});
}
let expected_edge_len = e.max(1);
if edge_targets.len() != expected_edge_len {
return Err(GraphValidationError::EdgeTargetsLen {
expected: expected_edge_len,
got: edge_targets.len(),
});
}
if edge_kind_mask.len() != expected_edge_len {
return Err(GraphValidationError::EdgeKindMaskLen {
expected: expected_edge_len,
got: edge_kind_mask.len(),
});
}
if node_tags.len() != n {
return Err(GraphValidationError::NodeTagsLen {
expected: n,
got: node_tags.len(),
});
}
if let Some(&first) = edge_offsets.first() {
if first != 0 {
return Err(GraphValidationError::NonMonotonicOffsets { index: 0 });
}
}
for window in edge_offsets.windows(2).enumerate() {
let (index, pair) = window;
if pair[1] < pair[0] {
return Err(GraphValidationError::NonMonotonicOffsets { index });
}
}
let final_offset = edge_offsets.last().copied().unwrap_or_default() as usize;
if final_offset != e {
return Err(GraphValidationError::EdgeCountMismatch {
expected: e,
got: final_offset,
});
}
for (index, &target) in edge_targets.iter().take(e).enumerate() {
if target >= shape.node_count {
return Err(GraphValidationError::EdgeOutOfRange {
index,
target,
node_count: shape.node_count,
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_only_buffers_has_canonical_layout() {
let bufs = ProgramGraphShape::new(4, 6).read_only_buffers();
assert_eq!(bufs.len(), 5);
assert_eq!(bufs[0].name(), NAME_NODES);
assert_eq!(bufs[1].name(), NAME_EDGE_OFFSETS);
assert_eq!(bufs[2].name(), NAME_EDGE_TARGETS);
assert_eq!(bufs[3].name(), NAME_EDGE_KIND_MASK);
assert_eq!(bufs[4].name(), NAME_NODE_TAGS);
assert_eq!(bufs[1].count(), 5); assert_eq!(bufs[2].count(), 6); }
#[test]
fn checked_read_only_buffers_rejects_edge_offset_overflow() {
let error = ProgramGraphShape::new(u32::MAX, 0)
.try_read_only_buffers()
.expect_err("checked ProgramGraphShape buffers must reject offset overflow");
assert!(
error.contains("overflows edge-offset buffer count"),
"error should describe the graph shape overflow: {error}"
);
}
#[test]
fn legacy_read_only_buffers_fail_fast_on_edge_offset_overflow() {
let panic = std::panic::catch_unwind(|| {
let _ = ProgramGraphShape::new(u32::MAX, 0).read_only_buffers();
})
.expect_err("legacy read_only_buffers must fail fast on edge-offset overflow");
let message = panic_payload_message(panic);
assert!(
message.contains("overflows edge-offset buffer count"),
"error should describe the graph shape overflow: {message}"
);
}
fn panic_payload_message(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(message) = payload.downcast_ref::<&str>() {
message.to_string()
} else if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else {
format!("{payload:?}")
}
}
#[test]
fn program_graph_shape_source_has_checked_buffers_without_panics() {
let source = include_str!("program_graph.rs");
let production = source
.split("/// Error kinds surfaced")
.next()
.expect("Fix: ProgramGraphShape source must precede validation errors");
assert!(
production.contains("pub fn try_read_only_buffers(")
&& !production.contains("inert_")
&& !production.contains("Err(_) =>"),
"Fix: ProgramGraphShape buffer ABI must expose checked sizing and must not emit inert placeholder buffers."
);
}
#[test]
fn validate_rejects_oob_edge_target() {
let err = validate_program_graph(
ProgramGraphShape::new(3, 2),
&[0, 0, 0],
&[0, 1, 2, 2],
&[1, 5],
&[0, 0],
&[0, 0, 0],
)
.unwrap_err();
assert!(matches!(
err,
GraphValidationError::EdgeOutOfRange { target: 5, .. }
));
}
#[test]
fn validate_rejects_non_monotonic_offsets() {
let err = validate_program_graph(
ProgramGraphShape::new(2, 1),
&[0, 0],
&[2, 1, 1], &[0],
&[0],
&[0, 0],
)
.unwrap_err();
assert!(matches!(
err,
GraphValidationError::NonMonotonicOffsets { .. }
));
}
#[test]
fn validate_passes_canonical_small_graph() {
let ok = validate_program_graph(
ProgramGraphShape::new(3, 2),
&[0, 0, 0],
&[0, 1, 2, 2],
&[1, 2],
&[1, 1],
&[0, 0, 0],
);
assert_eq!(ok, Ok(()));
}
}