use super::*;
pub(crate) const SMALL_OUTER_BFGS_MAX_PARAMS: usize = 8;
pub(crate) const SECOND_ORDER_GEOMETRY_PROBE_MAX_PARAMS: usize = 64;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct OuterThetaLayout {
pub n_params: usize,
pub psi_dim: usize,
}
impl OuterThetaLayout {
pub const fn new(n_params: usize, psi_dim: usize) -> Self {
Self { n_params, psi_dim }
}
pub const fn rho_dim(&self) -> usize {
self.n_params.saturating_sub(self.psi_dim)
}
fn validate_capability(&self, context: &str) -> Result<(), EstimationError> {
if self.psi_dim > self.n_params {
return Err(EstimationError::RemlOptimizationFailed(format!(
"{context}: invalid outer theta layout (psi_dim={} exceeds n_params={})",
self.psi_dim, self.n_params
)));
}
Ok::<(), _>(())
}
pub(crate) fn validate_point_len(
&self,
theta: &Array1<f64>,
context: &str,
) -> Result<(), ObjectiveEvalError> {
if theta.len() != self.n_params {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer theta length mismatch: got {}, expected {} (rho_dim={}, psi_dim={})",
theta.len(),
self.n_params,
self.rho_dim(),
self.psi_dim
)));
}
Ok::<(), _>(())
}
pub(crate) fn validate_gradient_len(
&self,
gradient: &Array1<f64>,
context: &str,
) -> Result<(), ObjectiveEvalError> {
if gradient.len() != self.n_params {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer gradient length mismatch: got {}, expected {} (rho_dim={}, psi_dim={})",
gradient.len(),
self.n_params,
self.rho_dim(),
self.psi_dim
)));
}
Ok::<(), _>(())
}
pub(crate) fn validate_hessian_shape(
&self,
hessian: &Array2<f64>,
context: &str,
) -> Result<(), ObjectiveEvalError> {
if hessian.nrows() != self.n_params || hessian.ncols() != self.n_params {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer Hessian shape mismatch: got {}x{}, expected {}x{} (rho_dim={}, psi_dim={})",
hessian.nrows(),
hessian.ncols(),
self.n_params,
self.n_params,
self.rho_dim(),
self.psi_dim
)));
}
Ok::<(), _>(())
}
pub(crate) fn validate_efs_eval(
&self,
eval: &EfsEval,
context: &str,
) -> Result<(), ObjectiveEvalError> {
if eval.steps.len() != self.n_params {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer EFS step length mismatch: got {}, expected {} (rho_dim={}, psi_dim={})",
eval.steps.len(),
self.n_params,
self.rho_dim(),
self.psi_dim
)));
}
if let Some(ref psi_gradient) = eval.psi_gradient
&& psi_gradient.len() != self.psi_dim
{
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer EFS psi-gradient length mismatch: got {}, expected {}",
psi_gradient.len(),
self.psi_dim
)));
}
if let Some(ref psi_indices) = eval.psi_indices {
if psi_indices.len() != self.psi_dim {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer EFS psi-index count mismatch: got {}, expected {}",
psi_indices.len(),
self.psi_dim
)));
}
if psi_indices.iter().any(|&idx| idx >= self.n_params) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer EFS psi index out of range for n_params={}",
self.n_params
)));
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct OuterCapability {
pub gradient: Derivative,
pub hessian: DeclaredHessianForm,
pub n_params: usize,
pub psi_dim: usize,
pub fixed_point_available: bool,
pub barrier_config: Option<BarrierConfig>,
pub prefer_gradient_only: bool,
pub disable_fixed_point: bool,
}
impl OuterCapability {
pub const fn theta_layout(&self) -> OuterThetaLayout {
OuterThetaLayout::new(self.n_params, self.psi_dim)
}
pub fn validate_layout(&self, context: &str) -> Result<(), EstimationError> {
self.theta_layout().validate_capability(context)
}
pub const fn all_penalty_like(&self) -> bool {
self.psi_dim == 0
}
pub const fn has_psi_coords(&self) -> bool {
self.psi_dim > 0
}
fn efs_plan_eligible(&self) -> bool {
self.fixed_point_available
&& !self.disable_fixed_point
&& self.all_penalty_like()
&& self.n_params > SMALL_OUTER_BFGS_MAX_PARAMS
}
fn hybrid_efs_plan_eligible(&self) -> bool {
self.fixed_point_available
&& !self.disable_fixed_point
&& self.has_psi_coords()
&& self.n_params > SMALL_OUTER_BFGS_MAX_PARAMS
}
fn declared_hessian_for_planning(&self) -> Derivative {
if self.hessian.is_analytic() {
Derivative::Analytic
} else {
Derivative::Unavailable
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Solver {
Arc,
Bfgs,
Efs,
HybridEfs,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HessianSource {
Analytic,
BfgsApprox,
EfsFixedPoint,
HybridEfsFixedPoint,
}
pub use gam_model_api::OuterEvalOrder;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct OuterPlan {
pub solver: Solver,
pub hessian_source: HessianSource,
}
pub(crate) const EFS_FIRST_ORDER_FALLBACK_MARKER: &str = "[outer-efs-first-order-fallback]";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FallbackPolicy {
Automatic,
Disabled,
}
impl std::fmt::Display for OuterPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"solver={:?}, hessian_source={:?}",
self.solver, self.hessian_source
)
}
}
impl OuterPlan {
pub fn routing_log_line(&self) -> String {
let matrix_free = false;
format!(
"solver={:?};hessian={:?};matrix-free={}",
self.solver, self.hessian_source, matrix_free
)
}
}
pub fn plan(cap: &OuterCapability) -> OuterPlan {
use Derivative::*;
use HessianSource as H;
use Solver as S;
match (cap.gradient, cap.declared_hessian_for_planning()) {
(Analytic, Analytic) => OuterPlan {
solver: S::Arc,
hessian_source: H::Analytic,
},
(Analytic, Unavailable) if cap.efs_plan_eligible() => OuterPlan {
solver: S::Efs,
hessian_source: H::EfsFixedPoint,
},
(Unavailable, Unavailable) if cap.efs_plan_eligible() => OuterPlan {
solver: S::Efs,
hessian_source: H::EfsFixedPoint,
},
(Analytic, Unavailable) if cap.hybrid_efs_plan_eligible() => OuterPlan {
solver: S::HybridEfs,
hessian_source: H::HybridEfsFixedPoint,
},
(Unavailable, Unavailable) if cap.hybrid_efs_plan_eligible() => OuterPlan {
solver: S::HybridEfs,
hessian_source: H::HybridEfsFixedPoint,
},
(Analytic, Unavailable) => OuterPlan {
solver: S::Bfgs,
hessian_source: H::BfgsApprox,
},
(Unavailable, _) => OuterPlan {
solver: S::Bfgs,
hessian_source: H::BfgsApprox,
},
}
}
pub fn log_plan(context: &str, cap: &OuterCapability, the_plan: &OuterPlan) {
let hess_warning = match the_plan.hessian_source {
HessianSource::BfgsApprox if cap.n_params > 0 => {
" [no Hessian: BFGS approximation]".to_string()
}
_ => String::new(),
};
let barrier_note = if cap.barrier_config.is_some() && cap.efs_plan_eligible() {
" [EFS with runtime barrier-curvature guard]"
} else {
""
};
let hybrid_note = if the_plan.solver == Solver::HybridEfs {
" [hybrid EFS(ρ) + preconditioned-gradient(ψ)]"
} else {
""
};
log::info!(
"[OUTER] {context}: n_params={}, gradient={:?}, hessian={:?} -> {} [{}]{hess_warning}{barrier_note}{hybrid_note}",
cap.n_params,
cap.gradient,
cap.hessian,
the_plan,
the_plan.routing_log_line(),
);
}
pub(crate) fn requests_immediate_first_order_fallback(message: &str) -> bool {
message.contains(EFS_FIRST_ORDER_FALLBACK_MARKER)
}
pub(crate) fn disable_fixed_point(cap: &OuterCapability) -> Option<OuterCapability> {
(!cap.disable_fixed_point && (cap.efs_plan_eligible() || cap.hybrid_efs_plan_eligible())).then(
|| {
let mut degraded = cap.clone();
degraded.disable_fixed_point = true;
degraded
},
)
}
pub(crate) fn automatic_fallback_attempts(cap: &OuterCapability) -> Vec<OuterCapability> {
let mut attempts = Vec::new();
if cap.gradient == Derivative::Analytic
&& matches!(plan(cap).solver, Solver::Efs | Solver::HybridEfs)
&& let Some(no_fp_cap) = disable_fixed_point(cap)
{
attempts.push(no_fp_cap.clone());
return attempts;
}
if matches!(plan(cap).solver, Solver::Arc) {
return attempts;
}
attempts
}
pub(crate) fn disabled_fallback_hybrid_efs_has_standalone_bfgs_primary(
cap: &OuterCapability,
config: &OuterConfig,
) -> bool {
config.fallback_policy == FallbackPolicy::Disabled
&& cap.gradient == Derivative::Analytic
&& matches!(plan(cap).solver, Solver::HybridEfs)
}
pub(crate) fn primary_capability_for_config(
mut cap: OuterCapability,
config: &OuterConfig,
context: &str,
) -> OuterCapability {
if disabled_fallback_hybrid_efs_has_standalone_bfgs_primary(&cap, config) {
log::info!(
"[OUTER] {context}: HybridEFS requires the automatic first-order \
escape path for ψ coordinates; fallback is disabled, so routing the \
primary attempt to analytic-gradient BFGS"
);
cap.disable_fixed_point = true;
}
cap
}