use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType};
pub const BINDING_NODES: u32 = 0;
pub const BINDING_EDGE_OFFSETS: u32 = 1;
pub const BINDING_EDGE_TARGETS: u32 = 2;
pub const BINDING_EDGE_KIND_MASK: u32 = 3;
pub const BINDING_NODE_TAGS: u32 = 4;
pub const BINDING_PRIMITIVE_START: u32 = 5;
pub const NAME_NODES: &str = "pg_nodes";
pub const NAME_EDGE_OFFSETS: &str = "pg_edge_offsets";
pub const NAME_EDGE_TARGETS: &str = "pg_edge_targets";
pub const NAME_EDGE_KIND_MASK: &str = "pg_edge_kind_mask";
pub const NAME_NODE_TAGS: &str = "pg_node_tags";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ProgramGraphShape {
pub node_count: u32,
pub edge_count: u32,
}
impl ProgramGraphShape {
#[must_use]
pub fn new(node_count: u32, edge_count: u32) -> Self {
Self {
node_count,
edge_count,
}
}
#[must_use]
pub fn read_only_buffers(&self) -> Vec<BufferDecl> {
vec![
BufferDecl::storage(
NAME_NODES,
BINDING_NODES,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(self.node_count),
BufferDecl::storage(
NAME_EDGE_OFFSETS,
BINDING_EDGE_OFFSETS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(self.node_count.saturating_add(1)),
BufferDecl::storage(
NAME_EDGE_TARGETS,
BINDING_EDGE_TARGETS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(self.edge_count.max(1)),
BufferDecl::storage(
NAME_EDGE_KIND_MASK,
BINDING_EDGE_KIND_MASK,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(self.edge_count.max(1)),
BufferDecl::storage(
NAME_NODE_TAGS,
BINDING_NODE_TAGS,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(self.node_count),
]
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GraphValidationError {
EdgeOffsetsLen {
expected: usize,
got: usize,
},
EdgeTargetsLen {
expected: usize,
got: usize,
},
EdgeKindMaskLen {
expected: usize,
got: usize,
},
NodeTagsLen {
expected: usize,
got: usize,
},
NodesLen {
expected: usize,
got: usize,
},
EdgeOutOfRange {
index: usize,
target: u32,
node_count: u32,
},
NonMonotonicOffsets {
index: usize,
},
EdgeCountMismatch {
expected: usize,
got: usize,
},
}
pub fn validate_program_graph(
shape: ProgramGraphShape,
nodes: &[u32],
edge_offsets: &[u32],
edge_targets: &[u32],
edge_kind_mask: &[u32],
node_tags: &[u32],
) -> Result<(), GraphValidationError> {
let n = shape.node_count as usize;
let e = shape.edge_count as usize;
if nodes.len() != n {
return Err(GraphValidationError::NodesLen {
expected: n,
got: nodes.len(),
});
}
if edge_offsets.len() != n + 1 {
return Err(GraphValidationError::EdgeOffsetsLen {
expected: n + 1,
got: edge_offsets.len(),
});
}
let expected_edge_len = e.max(1);
if edge_targets.len() != expected_edge_len {
return Err(GraphValidationError::EdgeTargetsLen {
expected: expected_edge_len,
got: edge_targets.len(),
});
}
if edge_kind_mask.len() != expected_edge_len {
return Err(GraphValidationError::EdgeKindMaskLen {
expected: expected_edge_len,
got: edge_kind_mask.len(),
});
}
if node_tags.len() != n {
return Err(GraphValidationError::NodeTagsLen {
expected: n,
got: node_tags.len(),
});
}
if let Some(&first) = edge_offsets.first() {
if first != 0 {
return Err(GraphValidationError::NonMonotonicOffsets { index: 0 });
}
}
for window in edge_offsets.windows(2).enumerate() {
let (index, pair) = window;
if pair[1] < pair[0] {
return Err(GraphValidationError::NonMonotonicOffsets { index });
}
}
let final_offset = edge_offsets.last().copied().unwrap_or_default() as usize;
if final_offset != e {
return Err(GraphValidationError::EdgeCountMismatch {
expected: e,
got: final_offset,
});
}
for (index, &target) in edge_targets.iter().take(e).enumerate() {
if target >= shape.node_count {
return Err(GraphValidationError::EdgeOutOfRange {
index,
target,
node_count: shape.node_count,
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_only_buffers_has_canonical_layout() {
let bufs = ProgramGraphShape::new(4, 6).read_only_buffers();
assert_eq!(bufs.len(), 5);
assert_eq!(bufs[0].name(), NAME_NODES);
assert_eq!(bufs[1].name(), NAME_EDGE_OFFSETS);
assert_eq!(bufs[2].name(), NAME_EDGE_TARGETS);
assert_eq!(bufs[3].name(), NAME_EDGE_KIND_MASK);
assert_eq!(bufs[4].name(), NAME_NODE_TAGS);
assert_eq!(bufs[1].count(), 5); assert_eq!(bufs[2].count(), 6); }
#[test]
fn validate_rejects_oob_edge_target() {
let err = validate_program_graph(
ProgramGraphShape::new(3, 2),
&[0, 0, 0],
&[0, 1, 2, 2],
&[1, 5],
&[0, 0],
&[0, 0, 0],
)
.unwrap_err();
assert!(matches!(
err,
GraphValidationError::EdgeOutOfRange { target: 5, .. }
));
}
#[test]
fn validate_rejects_non_monotonic_offsets() {
let err = validate_program_graph(
ProgramGraphShape::new(2, 1),
&[0, 0],
&[2, 1, 1], &[0],
&[0],
&[0, 0],
)
.unwrap_err();
assert!(matches!(
err,
GraphValidationError::NonMonotonicOffsets { .. }
));
}
#[test]
fn validate_passes_canonical_small_graph() {
let ok = validate_program_graph(
ProgramGraphShape::new(3, 2),
&[0, 0, 0],
&[0, 1, 2, 2],
&[1, 2],
&[1, 1],
&[0, 0, 0],
);
assert!(ok.is_ok());
}
}