use crate::ir::{Expr, Node, Program};
use crate::optimizer::RefusalReason;
use vyre_spec::op_contract::SideEffectClass;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum AtomicOrdering {
Acquire,
Release,
AcqRel,
SeqCst,
}
impl AtomicOrdering {
#[must_use]
pub fn join(self, other: Self) -> Self {
use AtomicOrdering::{AcqRel, Acquire, Release, SeqCst};
match (self, other) {
(SeqCst, _) | (_, SeqCst) => SeqCst,
(AcqRel, _) | (_, AcqRel) => AcqRel,
(Acquire, Release) | (Release, Acquire) => AcqRel,
(Acquire, Acquire) => Acquire,
(Release, Release) => Release,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum SyncScope {
Subgroup,
Workgroup,
Grid,
}
impl SyncScope {
#[must_use]
pub fn join(self, other: Self) -> Self {
self.max(other)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum EffectLevel {
Pure,
ReadAtomic,
ReadWriteAtomic(AtomicOrdering),
Synchronized(SyncScope),
Diverging,
}
impl EffectLevel {
#[must_use]
pub fn from_class(class: SideEffectClass) -> Self {
match class {
SideEffectClass::Pure => Self::Pure,
SideEffectClass::ReadsMemory => Self::ReadAtomic,
SideEffectClass::WritesMemory => Self::ReadWriteAtomic(AtomicOrdering::SeqCst),
SideEffectClass::Synchronizing => Self::Synchronized(SyncScope::Workgroup),
SideEffectClass::Atomic => Self::ReadWriteAtomic(AtomicOrdering::SeqCst),
_ => Self::Synchronized(SyncScope::Workgroup),
}
}
#[must_use]
pub fn kind(&self) -> &'static str {
match self {
Self::Pure => "pure",
Self::ReadAtomic => "read_atomic",
Self::ReadWriteAtomic(_) => "read_write_atomic",
Self::Synchronized(_) => "synchronized",
Self::Diverging => "diverging",
}
}
}
pub fn compose(producer: EffectLevel, consumer: EffectLevel) -> Result<EffectLevel, RefusalReason> {
use EffectLevel::{Diverging, Pure, ReadAtomic, ReadWriteAtomic, Synchronized};
if matches!(producer, Diverging) {
if matches!(consumer, Synchronized(SyncScope::Grid)) {
return Ok(Synchronized(SyncScope::Grid));
}
return Err(RefusalReason::EffectLatticeViolation {
producer: "diverging",
consumer: consumer.kind(),
suggested_fix:
"insert Node::Barrier { ordering: MemoryOrdering::GridSync } between the \
divergent producer and the consumer; without it, threads in non-zero blocks \
race on the producer's stores and read stale memory",
});
}
if matches!(consumer, Diverging) {
if matches!(producer, Synchronized(SyncScope::Grid)) {
return Ok(Diverging);
}
return Err(RefusalReason::EffectLatticeViolation {
producer: producer.kind(),
consumer: "diverging",
suggested_fix:
"insert Node::Barrier { ordering: MemoryOrdering::GridSync } between the \
producer and the divergent consumer so all blocks observe the producer's \
writes before the divergent arm reads them",
});
}
let combined = match (producer, consumer) {
(Pure, c) => c,
(p, Pure) => p,
(ReadAtomic, ReadAtomic) => ReadAtomic,
(ReadAtomic, ReadWriteAtomic(o)) | (ReadWriteAtomic(o), ReadAtomic) => ReadWriteAtomic(o),
(ReadWriteAtomic(o1), ReadWriteAtomic(o2)) => ReadWriteAtomic(o1.join(o2)),
(Synchronized(s), ReadAtomic) | (ReadAtomic, Synchronized(s)) => Synchronized(s),
(Synchronized(s), ReadWriteAtomic(_)) | (ReadWriteAtomic(_), Synchronized(s)) => {
Synchronized(s)
}
(Synchronized(s1), Synchronized(s2)) => Synchronized(s1.join(s2)),
(Diverging, _) | (_, Diverging) => unreachable!("Diverging handled above"),
};
Ok(combined)
}
#[must_use]
pub fn program_effect_level(program: &Program) -> EffectLevel {
let mut acc = EffectLevel::Pure;
for node in program.entry().iter() {
let node_effect = node_effect_level(node);
acc = lattice_join(acc, node_effect);
}
acc
}
#[must_use]
pub fn lattice_join(a: EffectLevel, b: EffectLevel) -> EffectLevel {
use EffectLevel::*;
match (a, b) {
(Diverging, _) | (_, Diverging) => Diverging,
(Synchronized(s1), Synchronized(s2)) => Synchronized(s1.join(s2)),
(Synchronized(s), _) | (_, Synchronized(s)) => Synchronized(s),
(ReadWriteAtomic(o1), ReadWriteAtomic(o2)) => ReadWriteAtomic(o1.join(o2)),
(ReadWriteAtomic(o), _) | (_, ReadWriteAtomic(o)) => ReadWriteAtomic(o),
(ReadAtomic, _) | (_, ReadAtomic) => ReadAtomic,
(Pure, Pure) => Pure,
}
}
#[must_use]
pub fn node_effect_level(node: &Node) -> EffectLevel {
match node {
Node::Store { .. } => EffectLevel::ReadWriteAtomic(AtomicOrdering::SeqCst),
Node::Let { value, .. } | Node::Assign { value, .. } => expr_effect_level(value),
Node::AsyncLoad { .. } => EffectLevel::ReadAtomic,
Node::AsyncStore { .. } => EffectLevel::ReadWriteAtomic(AtomicOrdering::Release),
Node::AsyncWait { .. } => EffectLevel::Synchronized(SyncScope::Workgroup),
Node::Barrier { ordering } => barrier_effect(*ordering),
Node::If {
cond,
then,
otherwise,
} => {
if is_invocation_id_eq_constant(cond) {
return EffectLevel::Diverging;
}
join_arms(then.iter().chain(otherwise.iter()))
}
Node::Loop { body, .. } => join_arms(body.iter()),
Node::Block(body) => join_arms(body.iter()),
Node::Region { body, .. } => join_arms(body.iter()),
Node::IndirectDispatch { .. } => EffectLevel::Synchronized(SyncScope::Grid),
Node::Trap { .. } | Node::Resume { .. } | Node::Return | Node::Opaque(_) => {
EffectLevel::Pure
}
}
}
fn join_arms<'a>(nodes: impl IntoIterator<Item = &'a Node>) -> EffectLevel {
let mut acc = EffectLevel::Pure;
for child in nodes {
acc = lattice_join(acc, node_effect_level(child));
}
acc
}
fn expr_effect_level(expr: &Expr) -> EffectLevel {
match expr {
Expr::Atomic { .. } => EffectLevel::ReadWriteAtomic(AtomicOrdering::SeqCst),
Expr::Load { .. } => EffectLevel::ReadAtomic,
_ => EffectLevel::Pure,
}
}
#[allow(unreachable_patterns)]
fn barrier_effect(ordering: crate::memory_model::MemoryOrdering) -> EffectLevel {
use crate::memory_model::MemoryOrdering;
match ordering {
MemoryOrdering::GridSync => EffectLevel::Synchronized(SyncScope::Grid),
MemoryOrdering::Acquire => EffectLevel::Synchronized(SyncScope::Workgroup),
MemoryOrdering::Release => EffectLevel::Synchronized(SyncScope::Workgroup),
MemoryOrdering::AcqRel => EffectLevel::Synchronized(SyncScope::Workgroup),
MemoryOrdering::SeqCst => EffectLevel::Synchronized(SyncScope::Workgroup),
MemoryOrdering::Relaxed => EffectLevel::ReadWriteAtomic(AtomicOrdering::Acquire),
_ => EffectLevel::Synchronized(SyncScope::Workgroup),
}
}
fn is_invocation_id_eq_constant(cond: &Expr) -> bool {
use crate::ir::BinOp;
match cond {
Expr::BinOp {
op: BinOp::Eq | BinOp::Ne,
left,
right,
} => {
is_invocation_id_expr(left) && matches!(**right, Expr::LitU32(_))
|| is_invocation_id_expr(right) && matches!(**left, Expr::LitU32(_))
}
Expr::BinOp { .. } => false,
_ => false,
}
}
fn is_invocation_id_expr(expr: &Expr) -> bool {
matches!(
expr,
Expr::InvocationId { .. } | Expr::LocalId { .. } | Expr::SubgroupLocalId
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::memory_model::MemoryOrdering;
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn divergent_store_program() -> Program {
Program::wrapped(
vec![buf()],
[256, 1, 1],
vec![Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(0)),
},
vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
)],
)
}
#[test]
fn pure_composes_with_pure() {
let r = compose(EffectLevel::Pure, EffectLevel::Pure);
assert_eq!(r, Ok(EffectLevel::Pure));
}
#[test]
fn pure_then_diverging_refuses_with_grid_sync_fix() {
let r = compose(EffectLevel::Pure, EffectLevel::Diverging);
match r {
Err(RefusalReason::EffectLatticeViolation {
producer,
consumer,
suggested_fix,
}) => {
assert_eq!(producer, "pure");
assert_eq!(consumer, "diverging");
assert!(
suggested_fix.contains("GridSync"),
"fix string must name MemoryOrdering::GridSync; got: {suggested_fix}"
);
}
other => panic!(
"expected EffectLatticeViolation refusing to fuse Pure with Diverging consumer; \
got {other:?}"
),
}
}
#[test]
fn diverging_then_pure_refuses_with_grid_sync_fix() {
let r = compose(EffectLevel::Diverging, EffectLevel::Pure);
match r {
Err(RefusalReason::EffectLatticeViolation {
producer,
consumer,
suggested_fix,
}) => {
assert_eq!(producer, "diverging");
assert_eq!(consumer, "pure");
assert!(
suggested_fix.contains("GridSync"),
"fix string must name MemoryOrdering::GridSync; got: {suggested_fix}"
);
}
other => panic!(
"expected EffectLatticeViolation refusing to fuse Diverging producer with Pure; \
got {other:?}"
),
}
}
#[test]
fn diverging_followed_by_grid_sync_composes() {
let r = compose(
EffectLevel::Diverging,
EffectLevel::Synchronized(SyncScope::Grid),
);
assert_eq!(r, Ok(EffectLevel::Synchronized(SyncScope::Grid)));
}
#[test]
fn read_write_atomic_compose_joins_orderings() {
let r = compose(
EffectLevel::ReadWriteAtomic(AtomicOrdering::Acquire),
EffectLevel::ReadWriteAtomic(AtomicOrdering::Release),
);
assert_eq!(
r,
Ok(EffectLevel::ReadWriteAtomic(AtomicOrdering::AcqRel)),
"Acquire ∘ Release must synthesize AcqRel — without this the lattice would lose \
the Release-side guarantee"
);
}
#[test]
fn synchronized_compose_joins_to_larger_scope() {
let r = compose(
EffectLevel::Synchronized(SyncScope::Workgroup),
EffectLevel::Synchronized(SyncScope::Grid),
);
assert_eq!(
r,
Ok(EffectLevel::Synchronized(SyncScope::Grid)),
"Workgroup ∘ Grid must escalate to Grid; the smaller scope is absorbed"
);
}
#[test]
fn from_class_lifts_every_existing_side_effect_class() {
assert_eq!(
EffectLevel::from_class(SideEffectClass::Pure),
EffectLevel::Pure
);
assert_eq!(
EffectLevel::from_class(SideEffectClass::ReadsMemory),
EffectLevel::ReadAtomic
);
assert!(matches!(
EffectLevel::from_class(SideEffectClass::WritesMemory),
EffectLevel::ReadWriteAtomic(_)
));
assert!(matches!(
EffectLevel::from_class(SideEffectClass::Synchronizing),
EffectLevel::Synchronized(_)
));
assert!(matches!(
EffectLevel::from_class(SideEffectClass::Atomic),
EffectLevel::ReadWriteAtomic(_)
));
}
#[test]
fn program_effect_level_detects_divergent_store_pattern() {
let p = divergent_store_program();
let level = program_effect_level(&p);
assert_eq!(
level,
EffectLevel::Diverging,
"a program containing `if invocation_id == K {{ store ... }}` must surface as \
Diverging at the program level — without this the fusion-refusal pass cannot \
catch the recall-zero-past-block-zero shape"
);
}
#[test]
fn program_effect_level_pure_program_stays_pure() {
let p = Program::wrapped(vec![buf()], [1, 1, 1], vec![Node::Return]);
let level = program_effect_level(&p);
assert_eq!(
level,
EffectLevel::Pure,
"a pure program (just Return) must stay Pure — without this every pass would \
see a stronger effect than the program actually has"
);
}
#[test]
fn barrier_grid_sync_node_lifts_to_synchronized_grid() {
let node = Node::Barrier {
ordering: MemoryOrdering::GridSync,
};
assert_eq!(
node_effect_level(&node),
EffectLevel::Synchronized(SyncScope::Grid)
);
}
#[test]
fn store_node_lifts_to_read_write_atomic() {
let node = Node::store("buf", Expr::u32(0), Expr::u32(7));
assert!(matches!(
node_effect_level(&node),
EffectLevel::ReadWriteAtomic(_)
));
}
#[test]
fn divergent_program_paired_with_grid_sync_composes_cleanly() {
let div = program_effect_level(&divergent_store_program());
let sync = EffectLevel::Synchronized(SyncScope::Grid);
let composed = compose(div, sync);
assert_eq!(composed, Ok(EffectLevel::Synchronized(SyncScope::Grid)));
}
#[test]
fn read_atomic_then_pure_stays_read_atomic() {
let r = compose(EffectLevel::ReadAtomic, EffectLevel::Pure);
assert_eq!(r, Ok(EffectLevel::ReadAtomic));
}
#[test]
fn synchronized_then_synchronized_subgroup_does_not_swallow_grid() {
let r = compose(
EffectLevel::Synchronized(SyncScope::Grid),
EffectLevel::Synchronized(SyncScope::Subgroup),
);
assert_eq!(
r,
Ok(EffectLevel::Synchronized(SyncScope::Grid)),
"Grid scope MUST dominate Subgroup scope on join — losing the larger scope would \
silently downgrade the program's synchronization guarantee"
);
}
}