use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::graph::csr_forward_traverse::bitset_words;
pub const OP_ID: &str = "vyre-primitives::graph::scc_decompose";
#[must_use]
pub fn scc_decompose(
node_count: u32,
forward_bitset: &str,
backward_bitset: &str,
component_out: &str,
pivot: u32,
) -> Program {
let t = Expr::InvocationId { axis: 0 };
let words = bitset_words(node_count);
let body = vec![
Node::let_bind("word_idx", Expr::shr(t.clone(), Expr::u32(5))),
Node::let_bind(
"bit",
Expr::shl(Expr::u32(1), Expr::bitand(t.clone(), Expr::u32(31))),
),
Node::let_bind(
"fwd_word",
Expr::load(forward_bitset, Expr::var("word_idx")),
),
Node::let_bind(
"bwd_word",
Expr::load(backward_bitset, Expr::var("word_idx")),
),
Node::let_bind(
"fwd_set",
Expr::ne(
Expr::bitand(Expr::var("fwd_word"), Expr::var("bit")),
Expr::u32(0),
),
),
Node::let_bind(
"bwd_set",
Expr::ne(
Expr::bitand(Expr::var("bwd_word"), Expr::var("bit")),
Expr::u32(0),
),
),
Node::if_then(
Expr::and(Expr::var("fwd_set"), Expr::var("bwd_set")),
vec![
Node::let_bind("prior", Expr::load(component_out, t.clone())),
Node::if_then(
Expr::eq(Expr::var("prior"), Expr::u32(u32::MAX)),
vec![Node::store(component_out, t.clone(), Expr::u32(pivot))],
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(forward_bitset, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
BufferDecl::storage(backward_bitset, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
BufferDecl::storage(component_out, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(node_count),
],
[1, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(node_count)),
body,
)]),
}],
)
}
#[must_use]
pub fn cpu_ref(
node_count: u32,
forward: &[u32],
backward: &[u32],
component_in: &[u32],
pivot: u32,
) -> Vec<u32> {
let mut out = Vec::new();
cpu_ref_into(node_count, forward, backward, component_in, pivot, &mut out);
out
}
pub fn cpu_ref_into(
node_count: u32,
forward: &[u32],
backward: &[u32],
component_in: &[u32],
pivot: u32,
out: &mut Vec<u32>,
) {
out.clear();
out.extend_from_slice(component_in);
for v in 0..node_count {
let word = (v / 32) as usize;
let bit = 1u32 << (v % 32);
let fwd = forward.get(word).copied().unwrap_or(0) & bit != 0;
let bwd = backward.get(word).copied().unwrap_or(0) & bit != 0;
if fwd && bwd && (v as usize) < out.len() && out[v as usize] == u32::MAX {
out[v as usize] = pivot;
}
}
}
#[cfg(test)]
mod regression_tests {
use super::*;
#[test]
fn cpu_ref_first_pivot_wins_when_two_pivots_share_a_node() {
let component_in = vec![u32::MAX; 4];
let forward = vec![0b1111];
let backward = vec![0b1111];
let after_first = cpu_ref(4, &forward, &backward, &component_in, 5);
assert_eq!(after_first, vec![5, 5, 5, 5]);
let after_second = cpu_ref(4, &forward, &backward, &after_first, 9);
assert_eq!(
after_second,
vec![5, 5, 5, 5],
"second pivot must NOT overwrite first pivot's assignments"
);
}
#[test]
fn cpu_ref_unassigned_node_picks_up_second_pivot() {
let component_in = vec![u32::MAX; 4];
let after_first = cpu_ref(4, &[0b0001], &[0b0001], &component_in, 5);
assert_eq!(after_first[0], 5);
assert_eq!(after_first[2], u32::MAX);
let after_second = cpu_ref(4, &[0b0100], &[0b0100], &after_first, 9);
assert_eq!(after_second[0], 5, "first pivot survives");
assert_eq!(after_second[2], 9, "second pivot stamps unassigned node");
}
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| scc_decompose(4, "fwd", "bwd", "comp", 0),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0b0111]), to_bytes(&[0b1101]), to_bytes(&[u32::MAX, u32::MAX, u32::MAX, u32::MAX]),
]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[0, u32::MAX, 0, u32::MAX])]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn intersection_stamps_pivot() {
let out = cpu_ref(4, &[0b0011], &[0b0011], &[u32::MAX; 4], 0);
assert_eq!(&out[0..2], &[0, 0]);
assert_eq!(&out[2..4], &[u32::MAX, u32::MAX]);
}
#[test]
fn disjoint_forward_backward_yields_no_change() {
let comp_in = vec![u32::MAX; 4];
let out = cpu_ref(4, &[0b0001], &[0b1000], &comp_in, 0);
assert_eq!(out, comp_in);
}
}