use vyre_primitives::graph::persistent_bfs::PERSISTENT_BFS_WORKGROUP_SIZE;
use vyre_self_substrate::optimizer::dispatcher::{DispatchError, OptimizerDispatcher};
use vyre_self_substrate::persistent_bfs::{
bfs_expand_resident_graph_batch_with_scratch_into, upload_resident_bfs_graph,
PersistentBfsResidentScratch,
};
use super::{BorrowFacts, Conflict, ConflictKind, LoanKind};
const ALLOW_ALL: u32 = 0xFFFF_FFFF;
fn bitset_words(n: u32) -> usize {
(n as usize).div_ceil(32)
}
fn set_bit(words: &mut [u32], bit: u32) {
words[(bit / 32) as usize] |= 1u32 << (bit % 32);
}
fn test_bit(words: &[u32], bit: u32) -> bool {
(words[(bit / 32) as usize] >> (bit % 32)) & 1 == 1
}
fn slice_test_bit(flat: &[u32], words: usize, q: usize, bit: u32) -> bool {
test_bit(&flat[q * words..(q + 1) * words], bit)
}
fn build_csr(n: u32, edges: &[(u32, u32)], reverse: bool) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let n = n as usize;
let mut adjacency: Vec<Vec<u32>> = vec![Vec::new(); n];
for &(a, b) in edges {
let (from, to) = if reverse { (b, a) } else { (a, b) };
if (from as usize) < n && (to as usize) < n {
adjacency[from as usize].push(to);
}
}
let mut offsets = Vec::with_capacity(n + 1);
let mut targets = Vec::new();
offsets.push(0u32);
for successors in &adjacency {
targets.extend_from_slice(successors);
offsets.push(targets.len() as u32);
}
let masks = vec![1u32; targets.len()];
(offsets, targets, masks)
}
pub fn analyze_crate_batched(
dispatcher: &dyn OptimizerDispatcher,
functions: &[BorrowFacts],
) -> Result<Vec<Vec<Conflict>>, DispatchError> {
analyze_crate_batched_with_shard_cap(dispatcher, functions, DEFAULT_SHARD_WORDS)
}
pub const DEFAULT_SHARD_WORDS: usize = 8 * 1024 * 1024;
const MAX_GRID_SYNC_FREE_SHARD_POINTS: u32 = PERSISTENT_BFS_WORKGROUP_SIZE[0];
pub fn analyze_crate_batched_with_shard_cap(
dispatcher: &dyn OptimizerDispatcher,
functions: &[BorrowFacts],
max_shard_words: usize,
) -> Result<Vec<Vec<Conflict>>, DispatchError> {
let mut results: Vec<Vec<Conflict>> = Vec::with_capacity(functions.len());
let mut start = 0;
while start < functions.len() {
let mut end = start;
let mut shard_points: u64 = 0;
let mut shard_loans: u64 = 0;
while end < functions.len() {
let f = &functions[end];
let next_points = shard_points + u64::from(f.point_count);
let next_loans = shard_loans + f.loan_count() as u64;
let next_words = next_points.div_ceil(32);
let next_buffer = next_loans.saturating_mul(next_words);
if end > start
&& (next_buffer > max_shard_words as u64
|| next_points > u64::from(MAX_GRID_SYNC_FREE_SHARD_POINTS))
{
break;
}
shard_points = next_points;
shard_loans = next_loans;
end += 1;
}
results.extend(batch_shard(dispatcher, &functions[start..end])?);
start = end;
}
Ok(results)
}
fn batch_shard(
dispatcher: &dyn OptimizerDispatcher,
functions: &[BorrowFacts],
) -> Result<Vec<Vec<Conflict>>, DispatchError> {
let empty = || functions.iter().map(|_| Vec::new()).collect::<Vec<_>>();
let mut point_base = Vec::with_capacity(functions.len());
let mut combined_points: u32 = 0;
for f in functions {
point_base.push(combined_points);
combined_points = combined_points.saturating_add(f.point_count);
}
let mut loan_base = Vec::with_capacity(functions.len());
let mut total_loans: usize = 0;
for f in functions {
loan_base.push(total_loans);
total_loans += f.loan_count();
}
if combined_points == 0 || total_loans == 0 {
return Ok(empty());
}
let words = bitset_words(combined_points);
let max_iters = functions
.iter()
.map(|f| f.point_count)
.max()
.unwrap_or(1)
.max(1);
let mut edges: Vec<(u32, u32)> = Vec::new();
for (fi, f) in functions.iter().enumerate() {
let base = point_base[fi];
for &(a, b) in &f.cfg_edges {
edges.push((a + base, b + base));
}
}
let (fwd_off, fwd_tgt, fwd_msk) = build_csr(combined_points, &edges, false);
let (rev_off, rev_tgt, rev_msk) = build_csr(combined_points, &edges, true);
let fwd_graph =
upload_resident_bfs_graph(dispatcher, combined_points, &fwd_off, &fwd_tgt, &fwd_msk)?;
let rev_graph =
upload_resident_bfs_graph(dispatcher, combined_points, &rev_off, &rev_tgt, &rev_msk)?;
let mut issue_seeds = vec![0u32; total_loans * words];
let mut use_seeds = vec![0u32; total_loans * words];
for (fi, f) in functions.iter().enumerate() {
let pbase = point_base[fi];
let lbase = loan_base[fi];
for a in 0..f.loan_count() {
let g = lbase + a;
set_bit(
&mut issue_seeds[g * words..(g + 1) * words],
pbase + f.loan_issued_at[a],
);
}
for &(loan, point) in &f.loan_used_at {
let a = loan as usize;
if a < f.loan_count() {
let g = lbase + a;
set_bit(&mut use_seeds[g * words..(g + 1) * words], pbase + point);
}
}
}
let mut scratch = PersistentBfsResidentScratch::default();
let mut forward = Vec::new();
let mut forward_changed = Vec::new();
bfs_expand_resident_graph_batch_with_scratch_into(
dispatcher,
&fwd_graph,
&issue_seeds,
total_loans,
ALLOW_ALL,
max_iters,
&mut scratch,
&mut forward,
&mut forward_changed,
)?;
let mut backward = Vec::new();
let mut backward_changed = Vec::new();
bfs_expand_resident_graph_batch_with_scratch_into(
dispatcher,
&rev_graph,
&use_seeds,
total_loans,
ALLOW_ALL,
max_iters,
&mut scratch,
&mut backward,
&mut backward_changed,
)?;
let mut per_function = Vec::with_capacity(functions.len());
for (fi, f) in functions.iter().enumerate() {
let pbase = point_base[fi];
let lbase = loan_base[fi];
let loans = f.loan_count();
let mut conflicts = Vec::new();
for a in 0..loans {
for b in (a + 1)..loans {
if f.loan_place[a] != f.loan_place[b] {
continue;
}
let a_mut = f.loan_kind[a] == LoanKind::Mut;
let b_mut = f.loan_kind[b] == LoanKind::Mut;
if !(a_mut || b_mut) {
continue;
}
let issue_a = pbase + f.loan_issued_at[a];
let issue_b = pbase + f.loan_issued_at[b];
let ga = lbase + a;
let gb = lbase + b;
let a_live_at_b = slice_test_bit(&forward, words, ga, issue_b)
&& slice_test_bit(&backward, words, ga, issue_b);
let b_live_at_a = slice_test_bit(&forward, words, gb, issue_a)
&& slice_test_bit(&backward, words, gb, issue_a);
if a_live_at_b || b_live_at_a {
let (first, second) = if f.loan_issued_at[a] <= f.loan_issued_at[b] {
(a, b)
} else {
(b, a)
};
conflicts.push(Conflict {
first: first as u32,
second: second as u32,
offset: f.loan_offset[second],
kind: if a_mut && b_mut {
ConflictKind::TwoMutable
} else {
ConflictKind::MutableAndShared
},
});
}
}
}
per_function.push(conflicts);
}
Ok(per_function)
}
pub fn analyze_batched(
dispatcher: &dyn OptimizerDispatcher,
facts: &BorrowFacts,
) -> Result<Vec<Conflict>, DispatchError> {
let mut per_function = analyze_crate_batched(dispatcher, std::slice::from_ref(facts))?;
Ok(per_function.pop().unwrap_or_default())
}