use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
use crate::math::semiring_gemm::{semiring_gemm, semiring_gemm_cpu_into, Semiring};
pub const OP_ID: &str = "vyre-primitives::math::scallop_join";
pub const PROVENANCE_SELF_CONSUMER: &str = "vyre-libs::self_substrate::scallop_provenance";
#[must_use]
pub fn scallop_join(
state: &str,
next: &str,
join_rules: &str,
changed: &str,
n: u32,
max_iterations: u32,
) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
state,
DataType::U32,
format!("Fix: scallop_join requires n > 0, got {n}."),
);
}
if max_iterations == 0 {
return crate::invalid_output_program(
OP_ID,
state,
DataType::U32,
"Fix: scallop_join requires max_iterations > 0, got 0.".to_string(),
);
}
let transfer = semiring_gemm(state, join_rules, next, n, n, n, Semiring::Lineage);
let transfer_body: Vec<Node> = transfer.entry().to_vec();
let words = n.saturating_mul(n);
let inner = crate::fixpoint::persistent_fixpoint::persistent_fixpoint(
transfer_body,
state,
next,
changed,
words,
max_iterations,
);
let entry: Vec<Node> = vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(inner.entry().to_vec()),
}];
Program::wrapped(
vec![
BufferDecl::storage(state, 0, BufferAccess::ReadWrite, DataType::U32).with_count(words),
BufferDecl::storage(next, 1, BufferAccess::ReadWrite, DataType::U32).with_count(words),
BufferDecl::storage(changed, 2, BufferAccess::ReadWrite, DataType::U32).with_count(1),
BufferDecl::storage(join_rules, 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
],
[256, 1, 1],
entry,
)
}
#[must_use]
pub fn cpu_ref(state: &[u32], join_rules: &[u32], n: u32, max_iterations: u32) -> (Vec<u32>, u32) {
let mut current = Vec::new();
let mut next = Vec::new();
let iters = cpu_ref_into(
state,
join_rules,
n,
max_iterations,
&mut current,
&mut next,
);
(current, iters)
}
pub fn cpu_ref_into(
state: &[u32],
join_rules: &[u32],
n: u32,
max_iterations: u32,
current: &mut Vec<u32>,
next: &mut Vec<u32>,
) -> u32 {
let cells = (n as usize) * (n as usize);
current.clear();
current.resize(cells, 0);
for (dst, &src) in current.iter_mut().zip(state.iter()) {
*dst = src;
}
for iter in 0..max_iterations {
semiring_gemm_cpu_into(current, join_rules, n, n, n, Semiring::Lineage, next);
let mut changed = false;
for (cell, derived) in current.iter_mut().zip(next.iter()) {
let merged = *cell | *derived;
changed |= merged != *cell;
*cell = merged;
}
if !changed {
return iter;
}
}
max_iterations
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| scallop_join("state", "next", "join_rules", "changed", 2, 4),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0, 0b01, 0, 0]),
to_bytes(&[0, 0, 0, 0]),
to_bytes(&[0]),
to_bytes(&[0, 0, 0, 0b10]),
]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0, 0b11, 0, 0]), to_bytes(&[0, 0b11, 0, 0]), to_bytes(&[0]), ]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_ref_one_step_join() {
let state = vec![0u32, 0b01, 0u32, 0u32];
let join_rules = vec![0u32, 0u32, 0u32, 0b10];
let (final_state, iters) = cpu_ref(&state, &join_rules, 2, 16);
assert_eq!(
final_state[1] & 0b10,
0b10,
"Lineage of clause 1 must propagate to state[0,1] after one round"
);
assert_eq!(
final_state[1] & 0b01,
0b01,
"seed clause 0 must persist through the fixpoint"
);
assert!(
iters <= 4,
"small system should converge quickly, got {iters}"
);
}
#[test]
fn cpu_ref_converges_on_idempotent_input() {
let state = vec![0b01, 0u32, 0u32, 0b01];
let join_rules = vec![0u32; 4];
let (final_state, iters) = cpu_ref(&state, &join_rules, 2, 16);
assert_eq!(
final_state, state,
"idempotent system must not change state"
);
assert!(iters <= 2, "idempotent system converges in ≤ 2 iters");
}
#[test]
fn cpu_ref_into_reuses_buffers() {
let state = vec![0u32, 0b01, 0u32, 0u32];
let join_rules = vec![0u32, 0u32, 0u32, 0b10];
let mut current = Vec::with_capacity(128);
let mut next = Vec::with_capacity(128);
let current_ptr = current.as_ptr();
let next_ptr = next.as_ptr();
let iters = cpu_ref_into(&state, &join_rules, 2, 16, &mut current, &mut next);
assert!(iters <= 4);
assert_eq!(current[1] & 0b11, 0b11);
assert_eq!(current.as_ptr(), current_ptr);
assert_eq!(next.as_ptr(), next_ptr);
}
#[test]
fn cpu_ref_transitive_closure() {
let mut state = vec![0u32; 9];
state[0 * 3 + 1] = 0b001; state[1 * 3 + 2] = 0b010; let join_rules = state.clone();
let (final_state, iters) = cpu_ref(&state, &join_rules, 3, 16);
assert_eq!(
final_state[0 * 3 + 2] & 0b011,
0b011,
"transitive 0→2 must collect lineage of both edges; got 0x{:x}",
final_state[0 * 3 + 2]
);
assert!(iters <= 8, "3-node chain should converge fast");
}
#[test]
fn cpu_ref_zero_absorbing_no_phantom_lineage() {
let state = vec![0u32; 4]; let join_rules = vec![0b11u32; 4];
let (final_state, _) = cpu_ref(&state, &join_rules, 2, 16);
assert_eq!(
final_state, state,
"no seed facts → no derivations regardless of rule set; \
zero-absorbing combine prevents phantom lineage"
);
}
#[test]
fn program_declares_four_buffers() {
let p = scallop_join("s", "n", "j", "c", 2, 4);
let bufs = p.buffers();
assert_eq!(bufs.len(), 4, "scallop_join must declare 4 buffers");
let names: Vec<&str> = bufs.iter().map(|b| b.name()).collect();
assert!(names.contains(&"s"));
assert!(names.contains(&"n"));
assert!(names.contains(&"j"));
assert!(names.contains(&"c"));
}
#[test]
fn rejects_zero_n_with_trap() {
let p = scallop_join("s", "n", "j", "c", 0, 4);
assert!(p.stats().trap());
}
#[test]
fn rejects_zero_max_iterations_with_trap() {
let p = scallop_join("s", "n", "j", "c", 2, 0);
assert!(p.stats().trap());
}
}