use super::hessian_paths::{
BernoulliMarginalSlopeRowExactContext, BlockSlices, PrimarySlices, RowCellMomentsBundle,
};
use super::*;
#[inline]
pub(super) fn log_exact_work(n: usize) -> bool {
n >= EXACT_WORK_LOG_MIN_ROWS
}
pub(super) fn runtime_available_memory_bytes() -> u64 {
static SYSTEM: OnceLock<Mutex<sysinfo::System>> = OnceLock::new();
let lock = SYSTEM.get_or_init(|| {
let refresh =
sysinfo::RefreshKind::new().with_memory(sysinfo::MemoryRefreshKind::everything());
Mutex::new(sysinfo::System::new_with_specifics(refresh))
});
let mut system = lock.lock().expect("sysinfo system mutex poisoned");
system.refresh_memory_specifics(sysinfo::MemoryRefreshKind::everything());
system.available_memory()
}
pub(super) fn bms_row_primary_hessian_pinned_bytes() -> &'static AtomicU64 {
static PINNED: OnceLock<AtomicU64> = OnceLock::new();
PINNED.get_or_init(|| AtomicU64::new(0))
}
pub(super) fn bms_row_primary_hessian_capacity_floor() -> &'static AtomicU64 {
static FLOOR: OnceLock<AtomicU64> = OnceLock::new();
FLOOR.get_or_init(|| AtomicU64::new(0))
}
pub(super) fn observe_capacity_floor(runtime_available_bytes: u64) -> u64 {
bms_row_primary_hessian_capacity_floor()
.fetch_max(runtime_available_bytes, Ordering::AcqRel)
.max(runtime_available_bytes)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) enum RowPrimaryHessianCacheReason {
ReuseTooLow,
SingleCacheExceedsRamFraction,
GlobalPinExceedsRamFraction,
ReuseAmortizesBuild,
}
impl RowPrimaryHessianCacheReason {
pub(super) const fn as_str(self) -> &'static str {
match self {
Self::ReuseTooLow => "reuse_too_low",
Self::SingleCacheExceedsRamFraction => "single_cache_exceeds_ram_fraction",
Self::GlobalPinExceedsRamFraction => "global_pin_exceeds_ram_fraction",
Self::ReuseAmortizesBuild => "reuse_amortizes_build",
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) struct RowPrimaryHessianCachePlan {
pub(super) materialize: bool,
pub(super) bytes: u64,
pub(super) stable_capacity_bytes: u64,
pub(super) runtime_available_bytes: u64,
pub(super) workspace_pinned_bytes: u64,
pub(super) single_cache_budget_bytes: u64,
pub(super) global_pin_budget_bytes: u64,
pub(super) expected_reuse_passes: usize,
pub(super) materialized_row_hessian_evals: usize,
pub(super) streamed_row_hessian_evals: usize,
pub(super) reason: RowPrimaryHessianCacheReason,
}
pub(super) fn decide_row_primary_hessian_cache(
n: usize,
r: usize,
expected_reuse_passes: usize,
stable_capacity_bytes: u64,
runtime_available_bytes: u64,
workspace_pinned_bytes: u64,
) -> RowPrimaryHessianCachePlan {
let floats_per_row = (r as u64)
.saturating_mul(r as u64)
.saturating_add(r as u64)
.saturating_add(1);
let bytes = (n as u64)
.saturating_mul(floats_per_row)
.saturating_mul(std::mem::size_of::<f64>() as u64);
let single_cache_budget_bytes = stable_capacity_bytes
.saturating_mul(BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_NUM)
/ BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_DEN.max(1);
let global_pin_budget_bytes = runtime_available_bytes
.saturating_mul(BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_NUM)
/ BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_DEN.max(1);
let streamed_row_hessian_evals = n.saturating_mul(expected_reuse_passes);
let materialized_row_hessian_evals = n;
let reason = if expected_reuse_passes < BMS_ROW_PRIMARY_HESSIAN_MIN_REUSE_PASSES {
RowPrimaryHessianCacheReason::ReuseTooLow
} else if bytes >= single_cache_budget_bytes {
RowPrimaryHessianCacheReason::SingleCacheExceedsRamFraction
} else if workspace_pinned_bytes.saturating_add(bytes) > global_pin_budget_bytes {
RowPrimaryHessianCacheReason::GlobalPinExceedsRamFraction
} else {
RowPrimaryHessianCacheReason::ReuseAmortizesBuild
};
RowPrimaryHessianCachePlan {
materialize: matches!(reason, RowPrimaryHessianCacheReason::ReuseAmortizesBuild),
bytes,
stable_capacity_bytes,
runtime_available_bytes,
workspace_pinned_bytes,
single_cache_budget_bytes,
global_pin_budget_bytes,
expected_reuse_passes,
materialized_row_hessian_evals,
streamed_row_hessian_evals,
reason,
}
}
pub struct RowPrimaryEvalPin {
pub(super) neglog: Array1<f64>,
pub(super) grad: Array2<f64>,
pub(super) hess: Array2<f64>,
pub(super) bytes: u64,
}
pub(super) struct RowPrimaryEvalTile {
pub(super) row_start: usize,
pub(super) rows: RowPrimaryEvalPin,
}
pub(crate) struct RowPrimaryEvalTiles {
pub(super) n_rows: usize,
pub(super) r: usize,
pub(super) tile_rows: usize,
pub(super) tiles: Vec<RowPrimaryEvalTile>,
}
impl RowPrimaryEvalTiles {
pub(super) fn new(
n_rows: usize,
r: usize,
tile_rows: usize,
tiles: Vec<RowPrimaryEvalTile>,
) -> Self {
Self {
n_rows,
r,
tile_rows,
tiles,
}
}
#[inline]
pub(super) fn is_empty(&self) -> bool {
self.tiles.is_empty()
}
#[inline]
pub(super) fn tile_for_row(&self, row: usize) -> Option<(&RowPrimaryEvalTile, usize)> {
if self.tile_rows > 0 {
let guess = row / self.tile_rows;
if let Some(tile) = self.tiles.get(guess) {
let len = tile.rows.neglog().len();
if row >= tile.row_start && row < tile.row_start + len {
return Some((tile, row - tile.row_start));
}
}
}
for tile in &self.tiles {
let len = tile.rows.neglog().len();
if row >= tile.row_start && row < tile.row_start + len {
return Some((tile, row - tile.row_start));
}
}
None
}
#[inline]
pub(super) fn total_bytes(&self) -> u64 {
self.tiles.iter().map(|tile| tile.rows.bytes).sum()
}
}
impl RowPrimaryEvalPin {
pub(super) fn new(
neglog: Array1<f64>,
grad: Array2<f64>,
hess: Array2<f64>,
bytes: u64,
) -> Self {
bms_row_primary_hessian_pinned_bytes().fetch_add(bytes, Ordering::AcqRel);
Self {
neglog,
grad,
hess,
bytes,
}
}
pub(super) fn neglog(&self) -> &Array1<f64> {
&self.neglog
}
pub(super) fn grad(&self) -> &Array2<f64> {
&self.grad
}
pub(super) fn hess(&self) -> &Array2<f64> {
&self.hess
}
}
impl Drop for RowPrimaryEvalPin {
fn drop(&mut self) {
bms_row_primary_hessian_pinned_bytes().fetch_sub(self.bytes, Ordering::AcqRel);
}
}
pub enum RowPrimaryEvalCache {
Empty,
Host(RowPrimaryEvalPin),
Tiled(RowPrimaryEvalTiles),
#[cfg(target_os = "linux")]
Device(crate::gpu::bms_flex_row::DeviceResidentRowHess),
}
impl RowPrimaryEvalCache {
#[inline]
pub(crate) fn is_some(&self) -> bool {
!matches!(self, Self::Empty)
}
#[inline]
pub(crate) fn is_tiled(&self) -> bool {
matches!(self, Self::Tiled(_))
}
#[inline]
pub(crate) fn tiles(&self) -> Option<&RowPrimaryEvalTiles> {
match self {
Self::Tiled(tiles) => Some(tiles),
_ => None,
}
}
#[inline]
pub(crate) fn host_pin(&self) -> Option<&RowPrimaryEvalPin> {
match self {
Self::Host(pin) => Some(pin),
Self::Tiled(_) => None,
Self::Empty => None,
#[cfg(target_os = "linux")]
Self::Device(_) => None,
}
}
#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn device(&self) -> Option<&crate::gpu::bms_flex_row::DeviceResidentRowHess> {
match self {
Self::Device(hess) => Some(hess),
_ => None,
}
}
}
pub(super) struct FlexAxisThirdRowTensors {
pub(super) third: [Array2<f64>; 2],
}
pub(super) struct FlexAxisFourthRowTensors {
pub(super) qq: Array2<f64>,
pub(super) qg: Array2<f64>,
pub(super) gg: Array2<f64>,
}
pub(super) struct BernoulliMarginalSlopeExactEvalCache {
pub(super) slices: BlockSlices,
pub(super) primary: PrimarySlices,
pub(super) row_contexts: Vec<BernoulliMarginalSlopeRowExactContext>,
pub(super) row_cell_moments: Option<RowCellMomentsBundle>,
pub(super) row_cell_moments_d15:
crate::resource::RayonSafeOnce<Result<Option<RowCellMomentsBundle>, String>>,
pub(super) row_cell_moments_d21:
crate::resource::RayonSafeOnce<Result<Option<RowCellMomentsBundle>, String>>,
pub(super) row_primary_hessians: RowPrimaryEvalCache,
pub(super) rigid_third_full:
crate::resource::RayonSafeOnce<Result<Vec<[[[f64; 2]; 2]; 2]>, String>>,
pub(super) rigid_fourth_full:
crate::resource::RayonSafeOnce<Result<Vec<[[[[f64; 2]; 2]; 2]; 2]>, String>>,
pub(super) flex_axis_third_tensors: crate::resource::RayonSafeOnce<
Vec<crate::resource::RayonSafeOnce<Result<FlexAxisThirdRowTensors, String>>>,
>,
pub(super) flex_axis_fourth_tensors: crate::resource::RayonSafeOnce<
Vec<crate::resource::RayonSafeOnce<Result<FlexAxisFourthRowTensors, String>>>,
>,
}