use crate::families::custom_family::family_trait::{CustomFamily, OuterEvalContext};
use crate::families::custom_family::psi_design::{
CustomFamilyBlockPsiDerivative, ExactNewtonJointHessianWorkspace,
};
use gam_linalg::RidgePolicy;
use gam_problem::{ParameterBlockSpec, ParameterBlockState};
use ndarray::Array1;
use std::ops::Range;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
pub use gam_problem::{ExactNewtonOuterObjective, ExactOuterDerivativeOrder};
pub use gam_problem::{
CUSTOM_FAMILY_RIDGE_FLOOR, CUSTOM_FAMILY_WEIGHT_FLOOR, validate_blockspec_consistency,
};
pub(crate) fn assert_valid_blockspecs(specs: &[ParameterBlockSpec], context: &str) {
assert!(
validate_blockspec_consistency(specs).is_ok(),
"{context}: inconsistent parameter block specs"
);
}
pub(crate) fn assert_valid_options(options: &BlockwiseFitOptions, context: &str) {
assert!(
options.inner_tol.is_finite() && options.inner_tol >= 0.0,
"{context}: inner_tol must be finite and non-negative"
);
assert!(
options.outer_tol.is_finite() && options.outer_tol >= 0.0,
"{context}: outer_tol must be finite and non-negative"
);
assert!(
options.minweight.is_finite() && options.minweight >= 0.0,
"{context}: minweight must be finite and non-negative"
);
assert!(
options.ridge_floor.is_finite() && options.ridge_floor >= 0.0,
"{context}: ridge_floor must be finite and non-negative"
);
if let Some(threshold) = options.early_exit_threshold {
assert!(
threshold.is_finite(),
"{context}: early_exit_threshold must be finite"
);
}
}
pub(crate) fn assert_states_match_specs(
states: &[ParameterBlockState],
specs: &[ParameterBlockSpec],
context: &str,
) {
assert_eq!(
states.len(),
specs.len(),
"{context}: state/spec block count mismatch"
);
for (block, (state, spec)) in states.iter().zip(specs).enumerate() {
assert_eq!(
state.beta.len(),
spec.design.ncols(),
"{context}: beta length mismatch in block {block}"
);
assert_eq!(
state.eta.len(),
spec.solver_design().nrows(),
"{context}: eta length mismatch in block {block}"
);
}
}
pub(crate) fn assert_derivative_blocks_match_specs(
derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
specs: &[ParameterBlockSpec],
context: &str,
) {
assert_eq!(
derivative_blocks.len(),
specs.len(),
"{context}: derivative/spec block count mismatch"
);
}
pub(crate) fn assert_rho_matches_specs(
rho: &Array1<f64>,
specs: &[ParameterBlockSpec],
context: &str,
) {
let expected = specs.iter().map(|spec| spec.penalties.len()).sum::<usize>();
assert_eq!(
rho.len(),
expected,
"{context}: rho length does not match penalty count"
);
}
pub(crate) fn validate_hessian_workspace_ready(
hessian_workspace: &Option<Arc<dyn ExactNewtonJointHessianWorkspace>>,
context: &str,
) -> Result<(), String> {
if let Some(workspace) = hessian_workspace.as_ref() {
workspace
.warm_up_outer_caches()
.map_err(|err| format!("{context}: failed to warm Hessian workspace caches: {err}"))?;
}
Ok(())
}
pub fn exact_outer_order_from_capability(
specs: &[ParameterBlockSpec],
coefficient_cost: u64,
) -> ExactOuterDerivativeOrder {
assert_valid_blockspecs(specs, "exact outer derivative order");
match coefficient_cost {
0 => ExactOuterDerivativeOrder::Second,
_ => ExactOuterDerivativeOrder::Second,
}
}
pub fn exact_outer_order_with_outer_hvp(
specs: &[ParameterBlockSpec],
coefficient_cost: u64,
outer_hyper_hessian_hvp_available: bool,
) -> ExactOuterDerivativeOrder {
if outer_hyper_hessian_hvp_available {
assert_valid_blockspecs(specs, "exact outer derivative order with HVP");
match coefficient_cost {
0 => ExactOuterDerivativeOrder::Second,
_ => ExactOuterDerivativeOrder::Second,
}
} else {
exact_outer_order_from_capability(specs, coefficient_cost)
}
}
#[derive(Clone, Copy, Debug)]
pub struct OuterDerivativePolicy {
pub capability: ExactOuterDerivativeOrder,
pub predicted_hessian_work: u128,
pub predicted_gradient_work: u128,
pub subsample_capable: bool,
}
impl OuterDerivativePolicy {
pub const OUTER_GRADIENT_WORK_BUDGET: u128 = 50_000_000_000;
pub const STAGED_KAPPA_TRIGGER_N: usize = 30_000;
pub fn order_for_evaluation(&self, requested: crate::OuterEvalOrder) -> crate::OuterEvalOrder {
use crate::OuterEvalOrder;
match requested {
OuterEvalOrder::Value => OuterEvalOrder::Value,
OuterEvalOrder::ValueAndGradient => OuterEvalOrder::ValueAndGradient,
OuterEvalOrder::ValueGradientHessian => {
if matches!(
self.declared_hessian_form(),
gam_problem::DeclaredHessianForm::Unavailable
) {
OuterEvalOrder::ValueAndGradient
} else {
OuterEvalOrder::ValueGradientHessian
}
}
}
}
pub fn declared_hessian_form(&self) -> gam_problem::DeclaredHessianForm {
use gam_problem::DeclaredHessianForm;
if !self.capability.has_hessian() {
return DeclaredHessianForm::Unavailable;
}
DeclaredHessianForm::Either
}
pub fn should_use_staged_kappa(&self, n: usize) -> bool {
if !self.subsample_capable {
return false;
}
n >= Self::STAGED_KAPPA_TRIGGER_N
|| self.predicted_gradient_work > Self::OUTER_GRADIENT_WORK_BUDGET
}
}
#[inline]
pub(crate) fn outer_coord_dim_for_policy(specs: &[ParameterBlockSpec], psi_dim: usize) -> u128 {
let rho_total: u128 = specs
.iter()
.map(|s| s.penalties.len() as u128)
.fold(0u128, |acc, k| acc.saturating_add(k));
rho_total.saturating_add(psi_dim as u128)
}
pub fn default_outer_derivative_policy_costs(
specs: &[ParameterBlockSpec],
psi_dim: usize,
grad_cost: u64,
hess_cost: u64,
) -> (u128, u128) {
let k = outer_coord_dim_for_policy(specs, psi_dim);
let grad = (grad_cost as u128).saturating_mul(k.max(1));
let hess = (hess_cost as u128).saturating_mul(k.max(1));
(grad, hess)
}
pub fn default_coefficient_hessian_cost(specs: &[ParameterBlockSpec]) -> u64 {
specs
.iter()
.map(|s| {
let n = s.design.nrows() as u64;
let p = s.design.ncols() as u64;
n.saturating_mul(p.saturating_mul(p))
})
.fold(0u64, |acc, c| acc.saturating_add(c))
}
pub fn joint_coupled_coefficient_hessian_cost(n: u64, specs: &[ParameterBlockSpec]) -> u64 {
let p_total: u64 = specs
.iter()
.map(|s| s.design.ncols() as u64)
.fold(0u64, |acc, p| acc.saturating_add(p));
n.saturating_mul(p_total.saturating_mul(p_total))
}
pub fn default_coefficient_gradient_cost(specs: &[ParameterBlockSpec]) -> u64 {
default_coefficient_hessian_cost(specs) / 2
}
pub fn block_offsets_from_specs(specs: &[ParameterBlockSpec]) -> Arc<[Range<usize>]> {
let mut ranges: Vec<Range<usize>> = Vec::with_capacity(specs.len());
let mut cursor = 0usize;
for spec in specs {
let p = spec.design.ncols();
ranges.push(cursor..cursor + p);
cursor += p;
}
Arc::from(ranges.into_boxed_slice())
}
pub fn cost_gated_first_order_max_iter(
requested: usize,
coefficient_gradient_cost: u64,
has_outer_hessian: bool,
) -> usize {
const FIRST_ORDER_OUTER_WORK_BUDGET: u64 = 80_000_000_000;
const MIN_FIRST_ORDER_ITERS: usize = 4;
if has_outer_hessian || requested <= 1 || coefficient_gradient_cost == 0 {
return requested;
}
let affordable = (FIRST_ORDER_OUTER_WORK_BUDGET / coefficient_gradient_cost) as usize;
requested.min(affordable.max(MIN_FIRST_ORDER_ITERS))
}
pub const fn first_order_bfgs_loglambda_step_cap(has_outer_hessian: bool) -> Option<f64> {
if has_outer_hessian { None } else { Some(5.0) }
}
pub fn exact_newton_outer_geometry_supports_second_order_solver<F: CustomFamily + ?Sized>(
family: &F,
) -> bool {
family.exact_newton_outerobjective() == ExactNewtonOuterObjective::StrictPseudoLaplace
}
#[derive(Clone)]
pub struct BlockwiseFitOptions {
pub inner_max_cycles: usize,
pub inner_tol: f64,
pub outer_max_iter: usize,
pub outer_tol: f64,
pub outer_rel_cost_tol: Option<f64>,
pub rho_lower_bound: f64,
pub minweight: f64,
pub ridge_floor: f64,
pub ridge_policy: RidgePolicy,
pub use_remlobjective: bool,
pub use_outer_hessian: bool,
pub compute_covariance: bool,
pub screening_max_inner_iterations: Option<Arc<AtomicUsize>>,
pub outer_inner_max_iterations: Option<Arc<AtomicUsize>>,
pub early_exit_threshold: Option<f64>,
pub outer_score_subsample: Option<Arc<crate::OuterScoreSubsample>>,
pub auto_outer_subsample: bool,
pub outer_eval_context: Option<OuterEvalContext>,
pub cache_session: Option<Arc<gam_runtime::warm_start::Session>>,
pub cache_mirror_sessions: Vec<Arc<gam_runtime::warm_start::Session>>,
pub joint_penalties: Option<Arc<crate::JointPenaltyBundle>>,
pub screen_initial_rho: bool,
pub seed_screening: bool,
}
pub const DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES: usize = 1200;
impl Default for BlockwiseFitOptions {
fn default() -> Self {
Self {
inner_max_cycles: DEFAULT_CUSTOM_FAMILY_INNER_MAX_CYCLES,
inner_tol: 1e-6,
outer_max_iter: 60,
outer_tol: 1e-5,
outer_rel_cost_tol: None,
rho_lower_bound: -10.0,
minweight: CUSTOM_FAMILY_WEIGHT_FLOOR,
ridge_floor: CUSTOM_FAMILY_RIDGE_FLOOR,
ridge_policy: RidgePolicy::explicit_stabilization_pospart(),
use_remlobjective: true,
use_outer_hessian: true,
compute_covariance: false,
screening_max_inner_iterations: None,
outer_inner_max_iterations: None,
seed_screening: false,
early_exit_threshold: None,
outer_score_subsample: None,
auto_outer_subsample: true,
outer_eval_context: None,
cache_session: None,
cache_mirror_sessions: Vec::new(),
joint_penalties: None,
screen_initial_rho: true,
}
}
}