use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::graph::matroid_exchange_bfs_step";
#[must_use]
pub fn matroid_exchange_bfs_step(
frontier_in: &str,
exchange_adj: &str,
visited: &str,
frontier_out: &str,
any_change: &str,
n: u32,
) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
frontier_out,
DataType::U32,
format!("Fix: matroid_exchange_bfs_step requires n > 0, got {n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(n)),
vec![
Node::let_bind("reached", Expr::u32(0)),
Node::if_then(
Expr::eq(Expr::load(visited, t.clone()), Expr::u32(0)),
vec![Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(n),
vec![Node::if_then(
Expr::and(
Expr::ne(Expr::load(frontier_in, Expr::var("k")), Expr::u32(0)),
Expr::ne(
Expr::load(
exchange_adj,
Expr::add(Expr::mul(Expr::var("k"), Expr::u32(n)), t.clone()),
),
Expr::u32(0),
),
),
vec![Node::assign("reached", Expr::u32(1))],
)],
)],
),
Node::store(frontier_out, t.clone(), Expr::var("reached")),
Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![
Node::let_bind("changed", Expr::u32(0)),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(n),
vec![Node::if_then(
Expr::ne(Expr::load(frontier_out, Expr::var("j")), Expr::u32(0)),
vec![Node::assign("changed", Expr::u32(1))],
)],
),
Node::store(any_change, Expr::u32(0), Expr::var("changed")),
],
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(frontier_in, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n),
BufferDecl::storage(exchange_adj, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(n * n),
BufferDecl::storage(visited, 2, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(frontier_out, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(n),
BufferDecl::storage(any_change, 4, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn matroid_exchange_bfs_step_cpu(
frontier_in: &[u32],
exchange_adj: &[u32],
visited: &[u32],
n: usize,
) -> (Vec<u32>, bool) {
let mut out = vec![0u32; n];
let mut any = false;
for j in 0..n {
if visited.get(j).copied().unwrap_or(0) != 0 {
continue;
}
for k in 0..n {
let frontier = frontier_in.get(k).copied().unwrap_or(0);
let exchange = exchange_adj.get(k * n + j).copied().unwrap_or(0);
if frontier != 0 && exchange != 0 {
out[j] = 1;
any = true;
break;
}
}
}
(out, any)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_one_step_advances() {
let f = vec![1, 0, 0];
let adj = vec![
0, 1, 0, 0, 0, 0, 0, 0, 0,
];
let v = vec![0, 0, 0];
let (out, any) = matroid_exchange_bfs_step_cpu(&f, &adj, &v, 3);
assert_eq!(out, vec![0, 1, 0]);
assert!(any);
}
#[test]
fn cpu_visited_blocks_re_advance() {
let f = vec![1, 0, 0];
let adj = vec![0, 1, 0, 0, 0, 0, 0, 0, 0];
let v = vec![0, 1, 0]; let (out, any) = matroid_exchange_bfs_step_cpu(&f, &adj, &v, 3);
assert_eq!(out, vec![0, 0, 0]);
assert!(!any);
}
#[test]
fn cpu_empty_frontier_no_change() {
let f = vec![0; 3];
let adj = vec![1; 9];
let v = vec![0; 3];
let (out, any) = matroid_exchange_bfs_step_cpu(&f, &adj, &v, 3);
assert_eq!(out, vec![0; 3]);
assert!(!any);
}
#[test]
fn cpu_malformed_inputs_treat_missing_entries_as_zero() {
let (out, any) = matroid_exchange_bfs_step_cpu(&[1], &[], &[], 2);
assert_eq!(out, vec![0, 0]);
assert!(!any);
}
#[test]
fn cpu_multiple_sources_advance_all_targets() {
let f = vec![1, 1, 0, 0];
let adj = vec![
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
];
let v = vec![0; 4];
let (out, _) = matroid_exchange_bfs_step_cpu(&f, &adj, &v, 4);
assert_eq!(out, vec![0, 0, 1, 1]);
}
#[test]
fn ir_program_buffer_layout() {
let p = matroid_exchange_bfs_step("fi", "adj", "v", "fo", "ch", 4);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["fi", "adj", "v", "fo", "ch"]);
assert_eq!(p.buffers[0].count(), 4);
assert_eq!(p.buffers[1].count(), 16);
assert_eq!(p.buffers[2].count(), 4);
assert_eq!(p.buffers[3].count(), 4);
assert_eq!(p.buffers[4].count(), 1);
}
#[test]
fn zero_n_traps() {
let p = matroid_exchange_bfs_step("fi", "adj", "v", "fo", "ch", 0);
assert!(p.stats().trap());
}
}