use std::hash::{Hash, Hasher as _};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use rustc_hash::FxHasher;
use vyre_spec::bin_op::OpIntensity;
use crate::ir::{Expr, Node};
use crate::ir_inner::model::expr::Ident;
use crate::ir_inner::model::types::BufferAccess;
use crate::transform::visit::{walk_nodes_and_exprs, ExprVisitor, NodeVisitor};
use super::Program;
fn mix_wire_fallback_hashable<T: Hash>(hasher: &mut blake3::Hasher, value: &T) {
let mut state = FxHasher::default();
value.hash(&mut state);
hasher.update(&state.finish().to_le_bytes());
}
struct FallbackWireHasher<'a>(&'a mut blake3::Hasher);
impl NodeVisitor for FallbackWireHasher<'_> {
fn visit_node(&mut self, node: &Node) {
let h = &mut *self.0;
match node {
Node::Let { name, .. } => {
h.update(b"n:Let\0");
h.update(name.as_bytes());
}
Node::Assign { name, .. } => {
h.update(b"n:Assign\0");
h.update(name.as_bytes());
}
Node::Store { buffer, .. } => {
h.update(b"n:Store\0");
h.update(buffer.as_bytes());
}
Node::If { .. } => {
h.update(b"n:If\0");
}
Node::Loop { var, .. } => {
h.update(b"n:Loop\0");
h.update(var.as_bytes());
}
Node::IndirectDispatch {
count_buffer,
count_offset,
} => {
h.update(b"n:IndirectDispatch\0");
h.update(count_buffer.as_bytes());
h.update(&count_offset.to_le_bytes());
}
Node::AsyncLoad {
source,
destination,
tag,
..
} => {
h.update(b"n:AsyncLoad\0");
h.update(source.as_bytes());
h.update(destination.as_bytes());
h.update(tag.as_bytes());
}
Node::AsyncStore {
source,
destination,
tag,
..
} => {
h.update(b"n:AsyncStore\0");
h.update(source.as_bytes());
h.update(destination.as_bytes());
h.update(tag.as_bytes());
}
Node::AsyncWait { tag } => {
h.update(b"n:AsyncWait\0");
h.update(tag.as_bytes());
}
Node::Trap { tag, .. } => {
h.update(b"n:Trap\0");
h.update(tag.as_bytes());
}
Node::Resume { tag } => {
h.update(b"n:Resume\0");
h.update(tag.as_bytes());
}
Node::Return => {
h.update(b"n:Return\0");
}
Node::Barrier { ordering } => {
h.update(b"n:Barrier\0");
mix_wire_fallback_hashable(h, ordering);
}
Node::Block(_) => {
h.update(b"n:Block\0");
}
Node::Region {
generator,
source_region,
..
} => {
h.update(b"n:Region\0");
h.update(generator.as_bytes());
if let Some(gen) = source_region {
h.update(gen.name.as_bytes());
}
}
Node::Opaque(ext) => {
h.update(b"n:Opaque\0");
h.update(ext.extension_kind().as_bytes());
}
}
}
}
impl ExprVisitor for FallbackWireHasher<'_> {
fn visit_expr(&mut self, expr: &Expr) {
let h = &mut *self.0;
match expr {
Expr::LitU32(v) => {
h.update(b"e:LitU32\0");
h.update(&v.to_le_bytes());
}
Expr::LitI32(v) => {
h.update(b"e:LitI32\0");
h.update(&v.to_le_bytes());
}
Expr::LitF32(v) => {
h.update(b"e:LitF32\0");
h.update(&v.to_le_bytes());
}
Expr::LitBool(v) => {
h.update(b"e:LitBool\0");
h.update(&[u8::from(*v)]);
}
Expr::Var(name) => {
h.update(b"e:Var\0");
h.update(name.as_bytes());
}
Expr::Load { buffer, .. } => {
h.update(b"e:Load\0");
h.update(buffer.as_bytes());
}
Expr::BufLen { buffer } => {
h.update(b"e:BufLen\0");
h.update(buffer.as_bytes());
}
Expr::InvocationId { axis } => {
h.update(b"e:InvocationId\0");
h.update(&[*axis]);
}
Expr::WorkgroupId { axis } => {
h.update(b"e:WorkgroupId\0");
h.update(&[*axis]);
}
Expr::LocalId { axis } => {
h.update(b"e:LocalId\0");
h.update(&[*axis]);
}
Expr::BinOp { op, .. } => {
h.update(b"e:BinOp\0");
mix_wire_fallback_hashable(h, op);
}
Expr::UnOp { op, .. } => {
h.update(b"e:UnOp\0");
mix_wire_fallback_hashable(h, op);
}
Expr::Call { op_id, .. } => {
h.update(b"e:Call\0");
h.update(op_id.as_bytes());
}
Expr::Select { .. } => {
h.update(b"e:Select\0");
}
Expr::Cast { target, .. } => {
h.update(b"e:Cast\0");
mix_wire_fallback_hashable(h, target);
}
Expr::Fma { .. } => {
h.update(b"e:Fma\0");
}
Expr::Atomic {
op,
buffer,
ordering,
..
} => {
h.update(b"e:Atomic\0");
mix_wire_fallback_hashable(h, op);
h.update(buffer.as_bytes());
mix_wire_fallback_hashable(h, ordering);
}
Expr::SubgroupBallot { .. } => {
h.update(b"e:SubgroupBallot\0");
}
Expr::SubgroupShuffle { .. } => {
h.update(b"e:SubgroupShuffle\0");
}
Expr::SubgroupAdd { .. } => {
h.update(b"e:SubgroupAdd\0");
}
Expr::SubgroupLocalId => {
h.update(b"e:SubgroupLocalId\0");
}
Expr::SubgroupSize => {
h.update(b"e:SubgroupSize\0");
}
Expr::Opaque(ext) => {
h.update(b"e:Opaque\0");
h.update(ext.extension_kind().as_bytes());
}
}
}
}
impl Program {
#[must_use]
pub fn reconcile_runnable_top_level(self) -> Self {
if self.is_top_level_region_wrapped() {
return self;
}
let new_entry = Self::wrap_entry(self.entry().to_vec());
self.with_rewritten_entry(new_entry)
}
#[must_use]
#[inline]
pub fn buffer(&self, name: &str) -> Option<&super::BufferDecl> {
self.buffer_index
.get(name)
.and_then(|&index| self.buffers.get(index))
}
#[must_use]
#[inline]
pub fn buffers(&self) -> &[super::BufferDecl] {
self.buffers.as_ref()
}
#[must_use]
#[inline]
#[cfg(test)]
pub(crate) fn buffers_arc(&self) -> &Arc<[super::BufferDecl]> {
&self.buffers
}
#[must_use]
#[inline]
pub fn structural_eq(&self, other: &Self) -> bool {
if std::ptr::eq(self, other)
|| (Arc::ptr_eq(&self.buffers, &other.buffers)
&& Arc::ptr_eq(&self.entry, &other.entry)
&& self.entry_op_id == other.entry_op_id
&& self.non_composable_with_self == other.non_composable_with_self
&& self.workgroup_size == other.workgroup_size)
{
return true;
}
self.entry_op_id == other.entry_op_id
&& self.non_composable_with_self == other.non_composable_with_self
&& buffers_equal_ignoring_declaration_order(&self.buffers, &other.buffers)
&& self.workgroup_size == other.workgroup_size
&& self.entry == other.entry
}
#[must_use]
#[inline]
pub fn workgroup_size(&self) -> [u32; 3] {
self.workgroup_size
}
#[must_use]
#[inline]
pub fn parallel_region_size(&self) -> [u32; 3] {
self.workgroup_size
}
#[must_use]
#[inline]
pub fn is_non_composable_with_self(&self) -> bool {
self.non_composable_with_self
}
#[must_use]
#[inline]
pub fn with_non_composable_with_self(mut self, flag: bool) -> Self {
self.non_composable_with_self = flag;
self.invalidate_caches();
self
}
#[inline]
pub fn set_workgroup_size(&mut self, workgroup_size: [u32; 3]) {
self.workgroup_size = workgroup_size;
self.invalidate_caches();
}
#[inline]
pub fn set_parallel_region_size(&mut self, parallel_region_size: [u32; 3]) {
self.workgroup_size = parallel_region_size;
self.invalidate_caches();
}
#[must_use]
#[inline]
pub fn entry(&self) -> &[Node] {
self.entry.as_ref().as_slice()
}
#[must_use]
#[inline]
pub fn entry_arc(&self) -> &Arc<Vec<Node>> {
&self.entry
}
#[must_use]
#[inline]
pub fn is_explicit_noop(&self) -> bool {
self.buffers().is_empty()
&& matches!(self.entry(), [Node::Region { body, .. }] if body.is_empty())
}
#[must_use]
#[inline]
pub fn is_top_level_region_wrapped(&self) -> bool {
!self.entry.is_empty()
&& self
.entry()
.iter()
.all(|node| matches!(node, Node::Region { .. }))
}
#[must_use]
pub fn top_level_region_violation(&self) -> Option<String> {
if self.entry().is_empty() {
return Some(
"program entry has no top-level Region. Fix: construct runnable programs with Program::wrapped(...) or wrap the body in Node::Region before validation, interpretation, or dispatch."
.to_string(),
);
}
self.entry()
.iter()
.enumerate()
.find(|(_, node)| !matches!(node, Node::Region { .. }))
.map(|(index, node)| {
format!(
"program entry node {index} is `{}` instead of `Node::Region`. Fix: construct runnable programs with Program::wrapped(...) or wrap the top-level body in Node::Region; raw Program::new is reserved for wire decode and negative tests.",
Self::top_level_node_name(node)
)
})
}
#[must_use]
#[inline]
pub fn entry_mut(&mut self) -> &mut Vec<Node> {
self.invalidate_caches();
Arc::make_mut(&mut self.entry)
}
#[must_use]
#[inline]
pub fn fingerprint(&self) -> [u8; 32] {
*self.fingerprint.get_or_init(|| {
let hash = self.compute_wire_hash();
let _ = self.hash.set(hash);
*hash.as_bytes()
})
}
#[must_use]
pub fn vsa_fingerprint(&self) -> Vec<u32> {
self.fingerprint()
.chunks_exact(core::mem::size_of::<u32>())
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
#[must_use]
#[inline]
pub fn output_buffer_indices(&self) -> &[u32] {
self.output_buffer_index
.get_or_init(|| {
Arc::new(
self.buffers()
.iter()
.enumerate()
.filter_map(|(index, buffer)| {
(buffer.access() == BufferAccess::ReadWrite).then_some(index as u32)
})
.collect(),
)
})
.as_slice()
}
#[must_use]
#[inline]
pub fn has_indirect_dispatch(&self) -> bool {
*self.has_indirect_dispatch.get_or_init(|| {
let mut stack: Vec<&Node> = self.entry().iter().rev().collect();
while let Some(node) = stack.pop() {
match node {
Node::IndirectDispatch { .. } => return true,
Node::If {
then, otherwise, ..
} => {
stack.extend(otherwise.iter().rev());
stack.extend(then.iter().rev());
}
Node::Loop { body, .. } | Node::Block(body) => {
stack.extend(body.iter().rev());
}
Node::Region { body, .. } => {
stack.extend(body.iter().rev());
}
Node::Let { .. }
| Node::Assign { .. }
| Node::Store { .. }
| Node::Return
| Node::Barrier { .. }
| Node::AsyncLoad { .. }
| Node::AsyncStore { .. }
| Node::AsyncWait { .. }
| Node::Trap { .. }
| Node::Resume { .. }
| Node::Opaque(_) => {}
}
}
false
})
}
#[must_use]
#[inline]
pub fn has_buffer(&self, name: &str) -> bool {
self.buffer_index.contains_key(name)
}
#[must_use]
#[inline]
pub fn buffer_count(&self) -> usize {
self.buffers.len()
}
#[inline]
pub(super) fn build_buffer_index(
buffers: &[super::BufferDecl],
) -> rustc_hash::FxHashMap<Arc<str>, usize> {
let mut index = rustc_hash::FxHashMap::default();
index.reserve(buffers.len());
for (buffer_index, buffer) in buffers.iter().enumerate() {
index
.entry(Arc::clone(&buffer.name))
.or_insert(buffer_index);
}
index
}
#[inline]
pub fn mark_structurally_validated(&self) {
self.structural_validated.store(true, Ordering::Release);
}
#[must_use]
#[inline]
pub fn is_structurally_validated(&self) -> bool {
self.structural_validated.load(Ordering::Acquire)
}
#[inline]
pub fn mark_validated_on(&self, backend_id: &str) {
self.validation_set
.insert(Arc::from(self.validation_cache_key(backend_id)));
}
#[must_use]
#[inline]
pub fn is_validated_on(&self, backend_id: &str) -> bool {
self.validation_set
.contains(self.validation_cache_key(backend_id).as_str())
}
#[deprecated(note = "use is_structurally_validated or is_validated_on")]
#[must_use]
#[inline]
pub fn is_validated(&self) -> bool {
self.is_structurally_validated()
}
#[deprecated(note = "use mark_structurally_validated or mark_validated_on")]
#[inline]
pub fn mark_validated(&self) {
self.mark_structurally_validated();
}
pub fn validate(&self) -> crate::error::Result<()> {
if self.is_structurally_validated() {
return Ok(());
}
let errors = crate::validate::validate(self);
if errors.is_empty() {
self.mark_structurally_validated();
return Ok(());
}
let mut message = String::new();
for (index, error) in errors.into_iter().enumerate() {
if index > 0 {
message.push_str("; ");
}
message.push_str(error.message());
}
Err(crate::error::Error::WireFormatValidation { message })
}
#[inline]
#[must_use]
pub fn estimate_peak_vram_bytes(&self) -> u64 {
self.buffers
.iter()
.map(|buffer| {
let element_size = buffer.element.size_bytes().unwrap_or(4);
(buffer.count as u64) * (element_size as u64)
})
.sum()
}
#[must_use]
pub fn peak_intensity(&self) -> OpIntensity {
let mut peak = OpIntensity::Free;
for node in self.entry().iter() {
peak = peak.max(self.node_intensity(node));
}
peak
}
fn node_intensity(&self, node: &crate::ir::Node) -> OpIntensity {
use crate::ir::Node;
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => self.expr_intensity(value),
Node::Store { index, value, .. } => {
self.expr_intensity(index).max(self.expr_intensity(value))
}
Node::If {
cond,
then,
otherwise,
} => {
let mut p = self.expr_intensity(cond);
for n in then {
p = p.max(self.node_intensity(n));
}
for n in otherwise {
p = p.max(self.node_intensity(n));
}
p
}
Node::Loop { from, to, body, .. } => {
let mut p = self.expr_intensity(from).max(self.expr_intensity(to));
for n in body.iter() {
p = p.max(self.node_intensity(n));
}
p
}
Node::Block(nodes) => {
let mut p = OpIntensity::Free;
for n in nodes {
p = p.max(self.node_intensity(n));
}
p
}
Node::Region { body, .. } => {
let mut p = OpIntensity::Free;
for n in body.iter() {
p = p.max(self.node_intensity(n));
}
p
}
_ => OpIntensity::Free,
}
}
#[allow(clippy::only_used_in_recursion)]
fn expr_intensity(&self, expr: &crate::ir::Expr) -> OpIntensity {
use crate::ir::Expr;
match expr {
Expr::BinOp { op, left, right } => op
.intensity()
.max(self.expr_intensity(left))
.max(self.expr_intensity(right)),
Expr::UnOp { operand, .. } => self.expr_intensity(operand),
Expr::Load { index, .. } => self.expr_intensity(index),
Expr::Select {
cond,
true_val,
false_val,
} => self
.expr_intensity(cond)
.max(self.expr_intensity(true_val))
.max(self.expr_intensity(false_val)),
Expr::Cast { value, .. } => self.expr_intensity(value),
Expr::Fma { a, b, c } => self
.expr_intensity(a)
.max(self.expr_intensity(b))
.max(self.expr_intensity(c)),
Expr::Atomic {
index,
value,
expected,
..
} => {
let mut p = self.expr_intensity(index).max(self.expr_intensity(value));
if let Some(e) = expected {
p = p.max(self.expr_intensity(e));
}
p.max(OpIntensity::Heavy)
}
Expr::SubgroupBallot { cond } => self.expr_intensity(cond).max(OpIntensity::Heavy),
Expr::SubgroupShuffle { value, lane } => self
.expr_intensity(value)
.max(self.expr_intensity(lane))
.max(OpIntensity::Heavy),
Expr::SubgroupAdd { value } => self.expr_intensity(value).max(OpIntensity::Heavy),
_ => OpIntensity::Free,
}
}
fn compute_wire_hash(&self) -> blake3::Hash {
match self.canonical_wire_hash() {
Ok(hash) => hash,
Err(error) => {
let structural = self.structural_fingerprint_fallback();
let err_msg = error.to_string();
let mut fallback = Vec::with_capacity(96 + err_msg.len() + structural.len());
fallback.extend_from_slice(b"VYRE-PROGRAM-CANONICAL-WIRE-HASH-ERROR\0");
fallback.extend_from_slice(err_msg.as_bytes());
fallback.push(0);
fallback.extend_from_slice(structural.as_bytes());
blake3::hash(&fallback)
}
}
}
fn structural_fingerprint_fallback(&self) -> String {
let mut hasher = blake3::Hasher::new();
hasher.update(b"VYRE-WIRE-FALLBACK-V4\0");
if let Some(id) = self.entry_op_id.as_deref() {
hasher.update(id.as_bytes());
}
hasher.update(b"\0");
for axis in &self.workgroup_size {
hasher.update(&axis.to_le_bytes());
}
hasher.update(&[u8::from(self.non_composable_with_self)]);
let mut keys: Vec<Vec<u8>> = self
.buffers()
.iter()
.map(buffer_decl_canonical_key)
.collect();
keys.sort_unstable();
for key in keys {
hasher.update(&key);
}
let mut visitor = FallbackWireHasher(&mut hasher);
walk_nodes_and_exprs(self, &mut visitor);
hasher.finalize().to_hex().to_string()
}
fn validation_cache_key(&self, backend_id: &str) -> String {
let fingerprint = self.fingerprint();
let mut key = String::with_capacity(backend_id.len() + 1 + 64);
key.push_str(backend_id);
key.push(':');
for byte in fingerprint {
use std::fmt::Write as _;
let _ = write!(&mut key, "{byte:02x}");
}
key
}
#[inline]
pub(super) fn invalidate_caches(&mut self) {
self.structural_validated.store(false, Ordering::Release);
self.validation_set.clear();
let _ = self.hash.take();
let _ = self.fingerprint.take();
let _ = self.output_buffer_index.take();
let _ = self.has_indirect_dispatch.take();
let _ = self.stats.take();
}
#[inline]
pub(super) fn wrap_entry(entry: Vec<Node>) -> Vec<Node> {
if !Self::entry_needs_root_region(&entry) {
return entry;
}
vec![Node::Region {
generator: Ident::from(Self::ROOT_REGION_GENERATOR),
source_region: None,
body: Arc::new(entry),
}]
}
#[inline]
fn entry_needs_root_region(entry: &[Node]) -> bool {
entry.is_empty()
|| entry
.iter()
.any(|node| !matches!(node, Node::Region { .. }))
}
#[inline]
fn top_level_node_name(node: &Node) -> &'static str {
match node {
Node::Let { .. } => "Let",
Node::Assign { .. } => "Assign",
Node::Store { .. } => "Store",
Node::If { .. } => "If",
Node::Loop { .. } => "Loop",
Node::Return => "Return",
Node::Block(_) => "Block",
Node::Barrier { .. } => "Barrier",
Node::Region { .. } => "Region",
Node::IndirectDispatch { .. } => "IndirectDispatch",
Node::AsyncLoad { .. } => "AsyncLoad",
Node::AsyncStore { .. } => "AsyncStore",
Node::AsyncWait { .. } => "AsyncWait",
Node::Trap { .. } => "Trap",
Node::Resume { .. } => "Resume",
Node::Opaque(_) => "Opaque",
}
}
}
pub(crate) fn buffers_equal_ignoring_declaration_order(
left: &[super::BufferDecl],
right: &[super::BufferDecl],
) -> bool {
if left.len() != right.len() {
return false;
}
if left == right {
return true;
}
let mut left_keys = Vec::with_capacity(left.len());
left_keys.extend(left.iter().map(buffer_decl_canonical_key));
let mut right_keys = Vec::with_capacity(right.len());
right_keys.extend(right.iter().map(buffer_decl_canonical_key));
left_keys.sort_unstable();
right_keys.sort_unstable();
left_keys == right_keys
}
pub(super) fn buffer_decl_canonical_key(buffer: &super::BufferDecl) -> Vec<u8> {
use crate::serial::wire::framing::{put_len_u32, put_u32, put_u8};
use crate::serial::wire::tags::put_data_type;
let mut key = Vec::with_capacity(96);
if let Err(error) = put_len_u32(&mut key, buffer.name.len(), "buffer name length") {
key.extend_from_slice(b"\0name-length-error\0");
key.extend_from_slice(error.as_bytes());
}
key.extend_from_slice(buffer.name.as_bytes());
put_u32(&mut key, buffer.binding);
match crate::serial::wire::tags::access_tag::access_tag(buffer.access.clone()) {
Ok(tag) => put_u8(&mut key, tag),
Err(error) => {
put_u8(&mut key, u8::MAX);
key.extend_from_slice(error.as_bytes());
}
}
put_u8(
&mut key,
match buffer.kind {
super::MemoryKind::Global => 0,
super::MemoryKind::Shared => 1,
super::MemoryKind::Uniform => 2,
super::MemoryKind::Local => 3,
super::MemoryKind::Readonly => 4,
super::MemoryKind::Persistent => 5,
super::MemoryKind::Push => 6,
},
);
if let Err(error) = put_data_type(&mut key, &buffer.element) {
key.extend_from_slice(b"\0dtype-error\0");
key.extend_from_slice(error.as_bytes());
}
put_u32(&mut key, buffer.count);
put_u8(&mut key, u8::from(buffer.is_output));
put_u8(&mut key, u8::from(buffer.pipeline_live_out));
match &buffer.output_byte_range {
Some(range) => {
put_u8(&mut key, 1);
match u32::try_from(range.start) {
Ok(start) => put_u32(&mut key, start),
Err(error) => {
put_u32(&mut key, u32::MAX);
key.extend_from_slice(error.to_string().as_bytes());
}
}
match u32::try_from(range.end) {
Ok(end) => put_u32(&mut key, end),
Err(error) => {
put_u32(&mut key, u32::MAX);
key.extend_from_slice(error.to_string().as_bytes());
}
}
}
None => put_u8(&mut key, 0),
}
match buffer.hints.coalesce_axis {
Some(axis) => {
put_u8(&mut key, 1);
put_u8(&mut key, axis);
}
None => put_u8(&mut key, 0),
}
put_u32(&mut key, buffer.hints.preferred_alignment);
put_u8(
&mut key,
match buffer.hints.cache_locality {
super::CacheLocality::Streaming => 0,
super::CacheLocality::Temporal => 1,
super::CacheLocality::Random => 2,
},
);
put_u8(&mut key, u8::from(buffer.bytes_extraction));
key
}