mod dispatch;
#[cfg(any(test, feature = "cpu-parity"))]
mod reference;
#[cfg(test)]
mod tests;
pub use dispatch::{
build_ifds_csr_via, build_ifds_csr_via_into, build_ifds_csr_via_with_scratch_into,
};
#[cfg(any(test, feature = "cpu-parity"))]
pub use reference::{
reference_build_ifds_csr, reference_canonicalize_csr_within_rows, try_reference_build_ifds_csr,
};
use vyre_primitives::graph::exploded::{
dense_to_encoded, encoded_to_dense, ifds_node_count_saturating, IfdsCsrProgramCacheKey,
IfdsCsrRuleColumns, IfdsCsrRuleInputFingerprint, IfdsCsrStaticInputKey,
};
use crate::graph::dispatch_bridge::{CachedProgram, ProgramCache};
#[derive(Debug, Default)]
pub struct IfdsCsrGpuScratch {
rule_columns: IfdsCsrRuleColumns,
rule_fingerprint: Option<IfdsCsrRuleInputFingerprint>,
inputs: Vec<Vec<u8>>,
static_input_key: Option<IfdsCsrStaticInputKey>,
row_cursor: Vec<u32>,
col_len_words: Vec<u32>,
program_cache: ProgramCache<IfdsCsrProgramCacheKey, CachedIfdsCsrProgram>,
}
type CachedIfdsCsrProgram = CachedProgram;
impl IfdsCsrGpuScratch {
#[cfg(test)]
fn program_builds(&self) -> usize {
self.program_cache.builds()
}
}
#[must_use]
pub fn ifds_node_count(num_procs: u32, blocks_per_proc: u32, facts_per_proc: u32) -> u32 {
ifds_node_count_saturating(num_procs, blocks_per_proc, facts_per_proc)
}
#[must_use]
pub fn round_trip_dense(dense: u32, blocks_per_proc: u32, facts_per_proc: u32) -> Option<u32> {
let encoded = dense_to_encoded(dense, blocks_per_proc, facts_per_proc)?;
encoded_to_dense(encoded, blocks_per_proc, facts_per_proc)
}