use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::fixpoint::bitset_fixpoint";
pub const NAME_CHANGED_FLAG: &str = "fp_changed";
#[must_use]
pub fn bitset_fixpoint(current: &str, next: &str, changed: &str, words: u32) -> Program {
let t = Expr::InvocationId { axis: 0 };
let body = vec![
Node::let_bind("c", Expr::load(current, t.clone())),
Node::let_bind("n", Expr::load(next, t.clone())),
Node::if_then(
Expr::ne(Expr::var("c"), Expr::var("n")),
vec![Node::let_bind(
"_",
Expr::atomic_or(changed, Expr::u32(0), Expr::u32(1)),
)],
),
];
Program::wrapped(
vec![
BufferDecl::storage(current, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
BufferDecl::storage(next, 1, BufferAccess::ReadOnly, DataType::U32).with_count(words),
BufferDecl::storage(changed, 2, BufferAccess::ReadWrite, DataType::U32).with_count(1),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(words)),
body,
)]),
}],
)
}
#[must_use]
pub fn reference_eval(current: &[u32], next: &[u32]) -> u32 {
if current == next {
0
} else {
1
}
}
pub const NAME_WARM_SEED: &str = "fp_warm_seed";
#[must_use]
pub fn bitset_fixpoint_warm_start(
current: &str,
next: &str,
changed: &str,
seed: &str,
words: u32,
) -> Program {
let t = Expr::InvocationId { axis: 0 };
let body = vec![
Node::let_bind("s", Expr::load(seed, t.clone())),
Node::let_bind("c0", Expr::load(current, t.clone())),
Node::let_bind("c1", Expr::bitor(Expr::var("c0"), Expr::var("s"))),
Node::store(current, t.clone(), Expr::var("c1")),
Node::let_bind("n", Expr::load(next, t.clone())),
Node::if_then(
Expr::ne(Expr::var("c0"), Expr::var("n")),
vec![Node::let_bind(
"_",
Expr::atomic_or(changed, Expr::u32(0), Expr::u32(1)),
)],
),
];
Program::wrapped(
vec![
BufferDecl::storage(current, 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(words),
BufferDecl::storage(next, 1, BufferAccess::ReadOnly, DataType::U32).with_count(words),
BufferDecl::storage(changed, 2, BufferAccess::ReadWrite, DataType::U32).with_count(1),
BufferDecl::storage(seed, 3, BufferAccess::ReadOnly, DataType::U32).with_count(words),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID_WARM_START),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(words)),
body,
)]),
}],
)
}
pub const OP_ID_WARM_START: &str = "vyre-primitives::fixpoint::bitset_fixpoint_warm_start";
#[must_use]
pub fn reference_eval_warm_start(current: &[u32], next: &[u32], seed: &[u32]) -> (Vec<u32>, u32) {
debug_assert_eq!(current.len(), seed.len());
debug_assert_eq!(current.len(), next.len());
let updated: Vec<u32> = current
.iter()
.zip(seed.iter())
.map(|(c, s)| c | s)
.collect();
let flag = if current == next { 0 } else { 1 };
(updated, flag)
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| bitset_fixpoint("current", "next", NAME_CHANGED_FLAG, 1),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[0b0001]), to_bytes(&[0b0011]), to_bytes(&[0])]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[1])]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flag_clears_when_bitsets_equal() {
assert_eq!(reference_eval(&[0b0011], &[0b0011]), 0);
}
#[test]
fn cold_transfer_step_signals_change_then_converges() {
let current = vec![0b0001];
let next_after_transfer = vec![0b0011];
assert_eq!(
reference_eval(¤t, &next_after_transfer),
1,
"transfer added bits → flag must set"
);
let current2 = next_after_transfer.clone();
let next2 = next_after_transfer;
assert_eq!(
reference_eval(¤t2, &next2),
0,
"no bits added on iteration 2 → converged"
);
}
#[test]
fn warm_start_short_circuits_when_seed_anticipates_transfer() {
let (updated, flag) = reference_eval_warm_start(&[0b0001], &[0b0011], &[0b0010]);
assert_eq!(updated, vec![0b0011]);
assert_eq!(flag, 1);
}
#[test]
fn flag_sets_when_bitsets_diverge() {
assert_eq!(reference_eval(&[0b0001], &[0b0011]), 1);
}
#[test]
fn warm_start_ors_seed_into_current() {
let (updated, flag) = reference_eval_warm_start(&[0b0001], &[0b0011], &[0b0010]);
assert_eq!(updated, vec![0b0011], "seed OR still rewrites current");
assert_eq!(
flag, 1,
"c0 (0b0001) != next (0b0011) → transfer added bits → flag set",
);
}
#[test]
fn warm_start_flags_when_transfer_added_bits() {
let (updated, flag) = reference_eval_warm_start(&[0b0001], &[0b0011], &[0b0000]);
assert_eq!(updated, vec![0b0001]);
assert_eq!(flag, 1);
}
#[test]
fn warm_start_with_zero_seed_matches_cold_semantics() {
let (updated, flag) = reference_eval_warm_start(&[0b0001], &[0b0001], &[0b0000]);
assert_eq!(updated, vec![0b0001]);
assert_eq!(flag, reference_eval(&[0b0001], &[0b0001]));
}
}