use super::hessian_paths::{
BernoulliMarginalSlopeRowExactContext, BlockSlices, PrimarySlices, RowCellMomentsBundle,
};
use super::*;
#[inline]
pub(super) fn log_exact_work(n: usize) -> bool {
n >= EXACT_WORK_LOG_MIN_ROWS
}
pub(super) fn runtime_available_memory_bytes() -> u64 {
static SYSTEM: OnceLock<Mutex<sysinfo::System>> = OnceLock::new();
let lock = SYSTEM.get_or_init(|| {
let refresh =
sysinfo::RefreshKind::new().with_memory(sysinfo::MemoryRefreshKind::everything());
Mutex::new(sysinfo::System::new_with_specifics(refresh))
});
let mut system = lock.lock().expect("sysinfo system mutex poisoned");
system.refresh_memory_specifics(sysinfo::MemoryRefreshKind::everything());
system.available_memory()
}
pub(super) fn bms_row_primary_hessian_pinned_bytes() -> &'static AtomicU64 {
static PINNED: OnceLock<AtomicU64> = OnceLock::new();
PINNED.get_or_init(|| AtomicU64::new(0))
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) enum RowPrimaryHessianCacheReason {
ReuseTooLow,
SingleCacheExceedsRamFraction,
GlobalPinExceedsRamFraction,
ReuseAmortizesBuild,
}
impl RowPrimaryHessianCacheReason {
pub(super) const fn as_str(self) -> &'static str {
match self {
Self::ReuseTooLow => "reuse_too_low",
Self::SingleCacheExceedsRamFraction => "single_cache_exceeds_ram_fraction",
Self::GlobalPinExceedsRamFraction => "global_pin_exceeds_ram_fraction",
Self::ReuseAmortizesBuild => "reuse_amortizes_build",
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) struct RowPrimaryHessianCachePlan {
pub(super) materialize: bool,
pub(super) bytes: u64,
pub(super) runtime_available_bytes: u64,
pub(super) workspace_pinned_bytes: u64,
pub(super) single_cache_budget_bytes: u64,
pub(super) global_pin_budget_bytes: u64,
pub(super) expected_reuse_passes: usize,
pub(super) materialized_row_hessian_evals: usize,
pub(super) streamed_row_hessian_evals: usize,
pub(super) reason: RowPrimaryHessianCacheReason,
}
pub(super) fn decide_row_primary_hessian_cache(
n: usize,
r: usize,
expected_reuse_passes: usize,
runtime_available_bytes: u64,
workspace_pinned_bytes: u64,
) -> RowPrimaryHessianCachePlan {
let floats_per_row = (r as u64)
.saturating_mul(r as u64)
.saturating_add(r as u64)
.saturating_add(1);
let bytes = (n as u64)
.saturating_mul(floats_per_row)
.saturating_mul(std::mem::size_of::<f64>() as u64);
let single_cache_budget_bytes = runtime_available_bytes
.saturating_mul(BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_NUM)
/ BMS_ROW_PRIMARY_HESSIAN_SINGLE_FRACTION_DEN.max(1);
let global_pin_budget_bytes = runtime_available_bytes
.saturating_mul(BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_NUM)
/ BMS_ROW_PRIMARY_HESSIAN_GLOBAL_FRACTION_DEN.max(1);
let streamed_row_hessian_evals = n.saturating_mul(expected_reuse_passes);
let materialized_row_hessian_evals = n;
let reason = if expected_reuse_passes < BMS_ROW_PRIMARY_HESSIAN_MIN_REUSE_PASSES {
RowPrimaryHessianCacheReason::ReuseTooLow
} else if bytes >= single_cache_budget_bytes {
RowPrimaryHessianCacheReason::SingleCacheExceedsRamFraction
} else if workspace_pinned_bytes.saturating_add(bytes) > global_pin_budget_bytes {
RowPrimaryHessianCacheReason::GlobalPinExceedsRamFraction
} else {
RowPrimaryHessianCacheReason::ReuseAmortizesBuild
};
RowPrimaryHessianCachePlan {
materialize: matches!(reason, RowPrimaryHessianCacheReason::ReuseAmortizesBuild),
bytes,
runtime_available_bytes,
workspace_pinned_bytes,
single_cache_budget_bytes,
global_pin_budget_bytes,
expected_reuse_passes,
materialized_row_hessian_evals,
streamed_row_hessian_evals,
reason,
}
}
pub struct RowPrimaryEvalPin {
pub(super) neglog: Array1<f64>,
pub(super) grad: Array2<f64>,
pub(super) hess: Array2<f64>,
pub(super) bytes: u64,
}
impl RowPrimaryEvalPin {
pub(super) fn new(
neglog: Array1<f64>,
grad: Array2<f64>,
hess: Array2<f64>,
bytes: u64,
) -> Self {
bms_row_primary_hessian_pinned_bytes().fetch_add(bytes, Ordering::AcqRel);
Self {
neglog,
grad,
hess,
bytes,
}
}
pub(super) fn neglog(&self) -> &Array1<f64> {
&self.neglog
}
pub(super) fn grad(&self) -> &Array2<f64> {
&self.grad
}
pub(super) fn hess(&self) -> &Array2<f64> {
&self.hess
}
}
impl Drop for RowPrimaryEvalPin {
fn drop(&mut self) {
bms_row_primary_hessian_pinned_bytes().fetch_sub(self.bytes, Ordering::AcqRel);
}
}
pub enum RowPrimaryEvalCache {
Empty,
Host(RowPrimaryEvalPin),
#[cfg(target_os = "linux")]
Device(crate::gpu::bms_flex_row::DeviceResidentRowHess),
}
impl RowPrimaryEvalCache {
#[inline]
pub(crate) fn is_some(&self) -> bool {
!matches!(self, Self::Empty)
}
#[inline]
pub(crate) fn host_pin(&self) -> Option<&RowPrimaryEvalPin> {
match self {
Self::Host(pin) => Some(pin),
Self::Empty => None,
#[cfg(target_os = "linux")]
Self::Device(_) => None,
}
}
#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn device(&self) -> Option<&crate::gpu::bms_flex_row::DeviceResidentRowHess> {
match self {
Self::Device(hess) => Some(hess),
_ => None,
}
}
}
pub(super) struct BernoulliMarginalSlopeExactEvalCache {
pub(super) slices: BlockSlices,
pub(super) primary: PrimarySlices,
pub(super) row_contexts: Vec<BernoulliMarginalSlopeRowExactContext>,
pub(super) row_cell_moments: Option<RowCellMomentsBundle>,
pub(super) row_cell_moments_d15:
crate::resource::RayonSafeOnce<Result<Option<RowCellMomentsBundle>, String>>,
pub(super) row_cell_moments_d21:
crate::resource::RayonSafeOnce<Result<Option<RowCellMomentsBundle>, String>>,
pub(super) row_primary_hessians: RowPrimaryEvalCache,
pub(super) rigid_third_full:
crate::resource::RayonSafeOnce<Result<Vec<[[[f64; 2]; 2]; 2]>, String>>,
pub(super) rigid_fourth_full:
crate::resource::RayonSafeOnce<Result<Vec<[[[[f64; 2]; 2]; 2]; 2]>, String>>,
}