use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
use crate::math::scallop_persistent::{
ceil_div_u32, single_word_lineage_body, single_word_lineage_grid_sync_body,
};
#[cfg(any(test, feature = "cpu-parity"))]
use crate::math::semiring_gemm::{semiring_gemm_cpu_into, Semiring};
pub const OP_ID: &str = "vyre-primitives::math::scallop_join";
pub const SCALLOP_JOIN_WORKGROUP_SIZE: [u32; 3] = [256, 1, 1];
#[must_use]
pub const fn scallop_join_dispatch_grid(_n: u32) -> [u32; 3] {
let cells = _n.saturating_mul(_n);
let blocks = ceil_div_u32(cells, SCALLOP_JOIN_WORKGROUP_SIZE[0]);
[if blocks == 0 { 1 } else { blocks }, 1, 1]
}
pub const PROVENANCE_SELF_CONSUMER: &str = "vyre-libs::self_substrate::scallop_provenance";
#[must_use]
pub fn scallop_join(
state: &str,
next: &str,
join_rules: &str,
changed: &str,
n: u32,
max_iterations: u32,
) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
state,
DataType::U32,
format!("Fix: scallop_join requires n > 0, got {n}."),
);
}
if max_iterations == 0 {
return crate::invalid_output_program(
OP_ID,
state,
DataType::U32,
"Fix: scallop_join requires max_iterations > 0, got 0.".to_string(),
);
}
let words = n.checked_mul(n).unwrap_or_else(|| {
panic!(
"scallop_join n={n} overflows relation matrix word count. Fix: shard the relation matrix before GPU dispatch."
)
});
let body = if words <= SCALLOP_JOIN_WORKGROUP_SIZE[0] {
single_word_lineage_body(
state,
next,
join_rules,
changed,
n,
words,
max_iterations,
SCALLOP_JOIN_WORKGROUP_SIZE[0],
)
} else {
single_word_lineage_grid_sync_body(
state,
next,
join_rules,
changed,
n,
words,
max_iterations,
)
};
let entry: Vec<Node> = vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}];
Program::wrapped(
vec![
BufferDecl::storage(state, 0, BufferAccess::ReadWrite, DataType::U32).with_count(words),
BufferDecl::storage(next, 1, BufferAccess::ReadWrite, DataType::U32).with_count(words),
BufferDecl::storage(changed, 2, BufferAccess::ReadWrite, DataType::U32).with_count(1),
BufferDecl::storage(join_rules, 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
],
SCALLOP_JOIN_WORKGROUP_SIZE,
entry,
)
}
#[cfg(any(test, feature = "cpu-parity"))]
#[must_use]
pub fn cpu_ref(state: &[u32], join_rules: &[u32], n: u32, max_iterations: u32) -> (Vec<u32>, u32) {
let mut current = Vec::new();
let mut next = Vec::new();
let iters = cpu_ref_into(
state,
join_rules,
n,
max_iterations,
&mut current,
&mut next,
);
(current, iters)
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_ref_into(
state: &[u32],
join_rules: &[u32],
n: u32,
max_iterations: u32,
current: &mut Vec<u32>,
next: &mut Vec<u32>,
) -> u32 {
let cells = usize::try_from(n)
.ok()
.and_then(|n| n.checked_mul(n))
.unwrap_or_else(|| {
panic!(
"scallop_join CPU oracle n={n} overflows relation matrix word count. Fix: shard the relation matrix before parity comparison."
)
});
assert_eq!(
state.len(),
cells,
"scallop_join CPU oracle received state_len={} for n={n}. Fix: pass a complete n*n state matrix before parity comparison.",
state.len()
);
assert_eq!(
join_rules.len(),
cells,
"scallop_join CPU oracle received join_rules_len={} for n={n}. Fix: pass a complete n*n rule matrix before parity comparison.",
join_rules.len()
);
current.clear();
current.extend_from_slice(state);
for iter in 0..max_iterations {
semiring_gemm_cpu_into(current, join_rules, n, n, n, Semiring::Lineage, next);
let mut changed = false;
for (cell, derived) in current.iter_mut().zip(next.iter()) {
let merged = *cell | *derived;
changed |= merged != *cell;
*cell = merged;
}
if !changed {
return iter;
}
}
max_iterations
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| scallop_join("state", "next", "join_rules", "changed", 2, 4),
Some(|| {
let to_bytes = |w: &[u32]| crate::wire::pack_u32_slice(w);
vec![vec![
to_bytes(&[0, 0b01, 0, 0]),
to_bytes(&[0, 0, 0, 0]),
to_bytes(&[0]),
to_bytes(&[0, 0, 0, 0b10]),
]]
}),
Some(|| {
let to_bytes = |w: &[u32]| crate::wire::pack_u32_slice(w);
vec![vec![
to_bytes(&[0, 0b11, 0, 0]), to_bytes(&[0, 0b11, 0, 0]), to_bytes(&[0]), ]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::Node;
use vyre_foundation::MemoryOrdering;
#[test]
fn cpu_ref_one_step_join() {
let state = vec![0u32, 0b01, 0u32, 0u32];
let join_rules = vec![0u32, 0u32, 0u32, 0b10];
let (final_state, iters) = cpu_ref(&state, &join_rules, 2, 16);
assert_eq!(
final_state[1] & 0b10,
0b10,
"Lineage of clause 1 must propagate to state[0,1] after one round"
);
assert_eq!(
final_state[1] & 0b01,
0b01,
"seed clause 0 must persist through the fixpoint"
);
assert!(
iters <= 4,
"small system should converge quickly, got {iters}"
);
}
#[test]
fn cpu_ref_converges_on_idempotent_input() {
let state = vec![0b01, 0u32, 0u32, 0b01];
let join_rules = vec![0u32; 4];
let (final_state, iters) = cpu_ref(&state, &join_rules, 2, 16);
assert_eq!(
final_state, state,
"idempotent system must not change state"
);
assert!(iters <= 2, "idempotent system converges in ≤ 2 iters");
}
#[test]
fn cpu_ref_into_reuses_buffers() {
let state = vec![0u32, 0b01, 0u32, 0u32];
let join_rules = vec![0u32, 0u32, 0u32, 0b10];
let mut current = Vec::with_capacity(128);
let mut next = Vec::with_capacity(128);
let current_ptr = current.as_ptr();
let next_ptr = next.as_ptr();
let iters = cpu_ref_into(&state, &join_rules, 2, 16, &mut current, &mut next);
assert!(iters <= 4);
assert_eq!(current[1] & 0b11, 0b11);
assert_eq!(current.as_ptr(), current_ptr);
assert_eq!(next.as_ptr(), next_ptr);
}
#[test]
fn cpu_ref_transitive_closure() {
let mut state = vec![0u32; 9];
state[0 * 3 + 1] = 0b001; state[1 * 3 + 2] = 0b010; let join_rules = state.clone();
let (final_state, iters) = cpu_ref(&state, &join_rules, 3, 16);
assert_eq!(
final_state[0 * 3 + 2] & 0b011,
0b011,
"transitive 0→2 must collect lineage of both edges; got 0x{:x}",
final_state[0 * 3 + 2]
);
assert!(iters <= 8, "3-node chain should converge fast");
}
#[test]
fn cpu_ref_zero_absorbing_no_phantom_lineage() {
let state = vec![0u32; 4]; let join_rules = vec![0b11u32; 4];
let (final_state, _) = cpu_ref(&state, &join_rules, 2, 16);
assert_eq!(
final_state, state,
"no seed facts → no derivations regardless of rule set; \
zero-absorbing combine prevents phantom lineage"
);
}
#[test]
fn program_declares_four_buffers() {
let p = scallop_join("s", "n", "j", "c", 2, 4);
let bufs = p.buffers();
assert_eq!(bufs.len(), 4, "scallop_join must declare 4 buffers");
assert_eq!(p.workgroup_size(), SCALLOP_JOIN_WORKGROUP_SIZE);
let names: Vec<&str> = bufs.iter().map(|b| b.name()).collect();
assert!(names.contains(&"s"));
assert!(names.contains(&"n"));
assert!(names.contains(&"j"));
assert!(names.contains(&"c"));
}
#[test]
fn dispatch_grid_scales_large_relations_into_blocks() {
assert_eq!(scallop_join_dispatch_grid(0), [1, 1, 1]);
assert_eq!(scallop_join_dispatch_grid(1), [1, 1, 1]);
assert_eq!(scallop_join_dispatch_grid(16), [1, 1, 1]);
assert_eq!(scallop_join_dispatch_grid(17), [2, 1, 1]);
assert_eq!(scallop_join_dispatch_grid(33), [5, 1, 1]);
}
#[test]
fn large_program_uses_split_visible_grid_sync() {
let p = scallop_join("s", "n", "j", "c", 17, 4);
assert_eq!(count_grid_sync(p.entry()), 7);
}
#[test]
fn rejects_zero_n_with_trap() {
let p = scallop_join("s", "n", "j", "c", 0, 4);
assert!(p.stats().trap());
}
#[test]
fn rejects_zero_max_iterations_with_trap() {
let p = scallop_join("s", "n", "j", "c", 2, 0);
assert!(p.stats().trap());
}
fn count_grid_sync(nodes: &[Node]) -> usize {
nodes
.iter()
.map(|node| match node {
Node::Barrier {
ordering: MemoryOrdering::GridSync,
} => 1,
Node::If {
then, otherwise, ..
} => count_grid_sync(then) + count_grid_sync(otherwise),
Node::Loop { body, .. } | Node::Block(body) => count_grid_sync(body),
Node::Region { body, .. } => count_grid_sync(body),
_ => 0,
})
.sum()
}
}