use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::fixpoint::persistent_fixpoint";
#[must_use]
pub fn persistent_fixpoint(
transfer_body: Vec<Node>,
current: &str,
next: &str,
changed: &str,
words: u32,
max_iterations: u32,
) -> Program {
let t = Expr::InvocationId { axis: 0 };
let mut iter_body: Vec<Node> = Vec::new();
iter_body.push(Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![Node::store(changed, Expr::u32(0), Expr::u32(0))],
));
iter_body.push(Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
});
iter_body.extend(transfer_body);
iter_body.push(Node::if_then(
Expr::lt(t.clone(), Expr::u32(words)),
vec![
Node::let_bind("c", Expr::load(current, t.clone())),
Node::let_bind("n", Expr::load(next, t.clone())),
Node::if_then(
Expr::ne(Expr::var("c"), Expr::var("n")),
vec![Node::let_bind(
"_pf_set",
Expr::atomic_or(changed, Expr::u32(0), Expr::u32(1)),
)],
),
Node::store(current, t.clone(), Expr::var("n")),
],
));
iter_body.push(Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
});
iter_body.push(Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
});
let outer = vec![Node::loop_for(
"__pf_iter__",
Expr::u32(0),
Expr::u32(max_iterations),
{
let mut body = iter_body;
body.push(Node::if_then(
Expr::eq(Expr::load(changed, Expr::u32(0)), Expr::u32(0)),
vec![Node::Return],
));
body
},
)];
Program::wrapped(
vec![
BufferDecl::storage(current, 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(words),
BufferDecl::storage(next, 1, BufferAccess::ReadWrite, DataType::U32).with_count(words),
BufferDecl::storage(changed, 2, BufferAccess::ReadWrite, DataType::U32).with_count(1),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(outer),
}],
)
}
#[must_use]
pub fn cpu_ref<F>(seed: &[u32], max_iterations: u32, mut transfer_step: F) -> (Vec<u32>, u32)
where
F: FnMut(&[u32], &mut [u32]),
{
let mut current = Vec::new();
let mut next = Vec::new();
let iters = cpu_ref_into(
seed,
max_iterations,
&mut transfer_step,
&mut current,
&mut next,
);
(current, iters)
}
pub fn cpu_ref_into<F>(
seed: &[u32],
max_iterations: u32,
transfer_step: &mut F,
current: &mut Vec<u32>,
next: &mut Vec<u32>,
) -> u32
where
F: FnMut(&[u32], &mut [u32]),
{
current.clear();
current.extend_from_slice(seed);
next.resize(seed.len(), 0);
for iter in 0..max_iterations {
next.fill(0);
transfer_step(current, next);
if next == current {
return iter;
}
std::mem::swap(current, next);
}
max_iterations
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| {
let body = vec![Node::store("next", Expr::u32(0), Expr::u32(0))];
persistent_fixpoint(body, "current", "next", "changed", 2, 8)
},
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0b0001, 0b0000]),
to_bytes(&[0b0000, 0b0000]),
to_bytes(&[0]),
]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0b0001, 0b0001]),
to_bytes(&[0b0001, 0b0001]),
to_bytes(&[0]),
]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_ref_converges_when_step_is_idempotent() {
let seed = vec![0b1010, 0b0101];
let (out, iters) = cpu_ref(&seed, 100, |cur, next| next.copy_from_slice(cur));
assert_eq!(out, seed);
assert_eq!(iters, 0);
}
#[test]
fn cpu_ref_converges_on_or_to_fixed_point() {
let seed = vec![0u32];
let (out, iters) = cpu_ref(&seed, 100, |cur, next| {
next[0] = cur[0] | 0b1010;
});
assert_eq!(out, vec![0b1010]);
assert!(iters < 5, "OR-with-const converges in 1 step + 1 confirm");
}
#[test]
fn cpu_ref_caps_at_max_iterations() {
let seed = vec![0u32];
let max = 16;
let (_, iters) = cpu_ref(&seed, max, |cur, next| {
next[0] = cur[0].wrapping_add(1);
});
assert_eq!(iters, max);
}
#[test]
fn cpu_ref_into_reuses_ping_pong_buffers() {
let seed = vec![0u32];
let mut current = Vec::with_capacity(16);
let mut next = Vec::with_capacity(16);
let current_ptr = current.as_ptr();
let next_ptr = next.as_ptr();
let mut transfer = |cur: &[u32], out: &mut [u32]| {
out[0] = cur[0] | 0b1010;
};
let iters = cpu_ref_into(&seed, 16, &mut transfer, &mut current, &mut next);
assert!(iters < 5);
assert_eq!(current, vec![0b1010]);
assert!(current.as_ptr() == current_ptr || current.as_ptr() == next_ptr);
assert!(next.as_ptr() == current_ptr || next.as_ptr() == next_ptr);
assert_ne!(current.as_ptr(), next.as_ptr());
}
#[test]
fn program_shape_matches_contract() {
let body = vec![Node::store("next", Expr::u32(0), Expr::u32(0))];
let program = persistent_fixpoint(body, "current", "next", "changed", 16, 64);
assert!(
program.buffers.iter().any(|b| b.name() == "current"),
"current buffer must be declared"
);
assert!(
program.buffers.iter().any(|b| b.name() == "next"),
"next buffer must be declared"
);
assert!(
program.buffers.iter().any(|b| b.name() == "changed"),
"changed buffer must be declared"
);
}
}