pub(crate) const HESSIAN_UNAVAILABLE_PREFIX: &str = "outer Hessian unavailable:";
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_DIM_THRESHOLD: usize = 512;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_LARGE_N_THRESHOLD: usize = 50_000;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_DIM_AT_LARGE_N: usize = 32;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_NP_THRESHOLD: usize = 4_000_000;
pub(crate) const MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD: usize = 32;
pub(crate) const CALLBACK_OUTER_HESSIAN_ROW_PAIR_WORK_THRESHOLD: usize = 25_000_000;
pub(crate) const STOCHASTIC_TRACE_DIM_THRESHOLD: usize = 500;
pub(crate) const REML_TRACE_SLOW_LOG_MS: f64 = 100.0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct OuterHessianRoutePlan {
pub(crate) use_operator: bool,
pub(crate) reason: &'static str,
pub(crate) scale_prefers_operator: bool,
pub(crate) dense_workspace_bytes: usize,
}
impl OuterHessianRoutePlan {
pub(crate) fn choice(self) -> &'static str {
if self.use_operator {
"operator"
} else {
"dense"
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct OuterHessianScaleDecision {
pub(crate) prefers_operator: bool,
pub(crate) reason: &'static str,
}
pub(crate) fn saturating_f64_matrix_bytes(rows: usize, cols: usize) -> usize {
rows.saturating_mul(cols)
.saturating_mul(std::mem::size_of::<f64>())
}
pub(crate) fn outer_hessian_dense_workspace_bytes(p: usize, k: usize) -> usize {
let drift_count = k.saturating_mul(2).saturating_add(3).max(1);
saturating_f64_matrix_bytes(p, p).saturating_mul(drift_count)
}
pub(crate) fn outer_hessian_dense_workspace_budget_bytes() -> usize {
crate::resource::ResourcePolicy::default_library().max_single_materialization_bytes
}
pub(crate) fn dense_outer_hessian_workspace_fits(p: usize, k: usize) -> bool {
outer_hessian_dense_workspace_bytes(p, k) <= outer_hessian_dense_workspace_budget_bytes()
}
pub(crate) fn generic_outer_hessian_scale_decision(
n: usize,
p: usize,
k: usize,
) -> OuterHessianScaleDecision {
if !dense_outer_hessian_workspace_fits(p, k) {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "dense_memory_budget",
};
}
if k >= MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "large_k",
};
}
if p >= MATRIX_FREE_OUTER_HESSIAN_DIM_THRESHOLD {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "large_p",
};
}
if n >= MATRIX_FREE_OUTER_HESSIAN_LARGE_N_THRESHOLD
&& p >= MATRIX_FREE_OUTER_HESSIAN_DIM_AT_LARGE_N
{
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "large_n_moderate_p",
};
}
if n.saturating_mul(p) >= MATRIX_FREE_OUTER_HESSIAN_NP_THRESHOLD {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "large_linear_work",
};
}
OuterHessianScaleDecision {
prefers_operator: false,
reason: "below_crossover",
}
}
pub(crate) fn callback_outer_hessian_scale_decision(
n: usize,
p: usize,
k: usize,
) -> OuterHessianScaleDecision {
if !dense_outer_hessian_workspace_fits(p, k) {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "dense_memory_budget",
};
}
if k >= MATRIX_FREE_OUTER_HESSIAN_K_THRESHOLD {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "large_k",
};
}
if p >= MATRIX_FREE_OUTER_HESSIAN_DIM_THRESHOLD {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "large_p",
};
}
if n.saturating_mul(k).saturating_mul(k) >= CALLBACK_OUTER_HESSIAN_ROW_PAIR_WORK_THRESHOLD {
return OuterHessianScaleDecision {
prefers_operator: true,
reason: "callback_row_pair_work",
};
}
OuterHessianScaleDecision {
prefers_operator: false,
reason: "below_crossover",
}
}
pub(crate) fn outer_hessian_route_plan(
n: usize,
p: usize,
k: usize,
kernel_available: bool,
callback_kernel: bool,
subspace_trace: bool,
) -> OuterHessianRoutePlan {
let dense_workspace_bytes = outer_hessian_dense_workspace_bytes(p, k);
if !kernel_available {
return OuterHessianRoutePlan {
use_operator: false,
reason: "kernel_absent",
scale_prefers_operator: false,
dense_workspace_bytes,
};
}
let scale = if callback_kernel {
callback_outer_hessian_scale_decision(n, p, k)
} else {
generic_outer_hessian_scale_decision(n, p, k)
};
let reason = if subspace_trace && scale.prefers_operator {
"subspace_projected_operator"
} else {
scale.reason
};
OuterHessianRoutePlan {
use_operator: scale.prefers_operator,
reason,
scale_prefers_operator: scale.prefers_operator,
dense_workspace_bytes,
}
}
pub(crate) fn prefer_outer_hessian_operator(n: usize, p: usize, k: usize) -> bool {
generic_outer_hessian_scale_decision(n, p, k).prefers_operator
}
pub(crate) fn is_hessian_unavailable(error: &str) -> bool {
error.starts_with(HESSIAN_UNAVAILABLE_PREFIX)
}