use super::*;
pub(crate) fn outer_strategy_contract_panic(message: impl Into<String>) -> ! {
std::panic::panic_any(message.into())
}
pub struct OuterEval {
pub cost: f64,
pub gradient: Array1<f64>,
pub hessian: HessianResult,
pub inner_beta_hint: Option<Array1<f64>>,
}
impl OuterEval {
pub fn infeasible(n_params: usize) -> Self {
Self {
cost: f64::INFINITY,
gradient: Array1::zeros(n_params),
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
}
}
}
pub enum HessianResult {
Analytic(Array2<f64>),
Operator(Arc<dyn OuterHessianOperator>),
Unavailable,
}
impl Clone for OuterEval {
fn clone(&self) -> Self {
Self {
cost: self.cost,
gradient: self.gradient.clone(),
hessian: self.hessian.clone(),
inner_beta_hint: self.inner_beta_hint.clone(),
}
}
}
impl std::fmt::Debug for OuterEval {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OuterEval")
.field("cost", &self.cost)
.field("gradient", &self.gradient)
.field("hessian", &self.hessian)
.finish()
}
}
impl Clone for HessianResult {
fn clone(&self) -> Self {
match self {
Self::Analytic(h) => Self::Analytic(h.clone()),
Self::Operator(op) => Self::Operator(Arc::clone(op)),
Self::Unavailable => Self::Unavailable,
}
}
}
impl std::fmt::Debug for HessianResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Analytic(h) => f
.debug_tuple("Analytic")
.field(&format!("{}x{}", h.nrows(), h.ncols()))
.finish(),
Self::Operator(op) => f
.debug_tuple("Operator")
.field(&format!("dim={}", op.dim()))
.finish(),
Self::Unavailable => f.write_str("Unavailable"),
}
}
}
impl HessianResult {
pub fn unwrap_analytic(self) -> Array2<f64> {
match self {
HessianResult::Analytic(h) => h,
HessianResult::Operator(_) => {
outer_strategy_contract_panic(
"expected dense analytic Hessian but got HessianResult::Operator",
)
}
HessianResult::Unavailable => {
outer_strategy_contract_panic(
"expected analytic Hessian but got HessianResult::Unavailable",
)
}
}
}
pub fn is_analytic(&self) -> bool {
matches!(
self,
HessianResult::Analytic(_) | HessianResult::Operator(_)
)
}
pub fn into_option(self) -> Option<Array2<f64>> {
match self {
HessianResult::Analytic(h) => Some(h),
HessianResult::Operator(_) => None,
HessianResult::Unavailable => None,
}
}
pub fn dim(&self) -> Option<usize> {
match self {
HessianResult::Analytic(h) => Some(h.nrows()),
HessianResult::Operator(op) => Some(op.dim()),
HessianResult::Unavailable => None,
}
}
pub fn materialize_dense(&self) -> Result<Option<Array2<f64>>, String> {
match self {
HessianResult::Analytic(h) => Ok(Some(h.clone())),
HessianResult::Operator(op) => op.materialize_dense().map(Some),
HessianResult::Unavailable => Ok(None),
}
}
pub fn add_rho_block_dense(&mut self, rho_block: &Array2<f64>) -> Result<(), String> {
if rho_block.nrows() != rho_block.ncols() {
return Err(OuterStrategyError::RhoBlockShape {
reason: format!(
"rho-block Hessian update must be square, got {}x{}",
rho_block.nrows(),
rho_block.ncols()
),
}
.into());
}
match self {
HessianResult::Analytic(h) => {
if rho_block.nrows() > h.nrows() || rho_block.ncols() > h.ncols() {
return Err(OuterStrategyError::RhoBlockShape {
reason: format!(
"rho-block Hessian update shape mismatch: got {}x{}, outer Hessian is {}x{}",
rho_block.nrows(),
rho_block.ncols(),
h.nrows(),
h.ncols()
),
}
.into());
}
let k = rho_block.nrows();
let mut sl = h.slice_mut(ndarray::s![..k, ..k]);
sl += rho_block;
Ok(())
}
HessianResult::Operator(op) => {
let base = Arc::clone(op);
let dim = base.dim();
if rho_block.nrows() > dim {
return Err(OuterStrategyError::RhoBlockShape {
reason: format!(
"rho-block Hessian update dimension mismatch: got {}x{}, operator dim is {}",
rho_block.nrows(),
rho_block.ncols(),
dim
),
}
.into());
}
*self = HessianResult::Operator(Arc::new(RhoBlockAdditiveOuterHessian {
base,
rho_block: rho_block.clone(),
dim,
}));
Ok(())
}
HessianResult::Unavailable => Ok(()),
}
}
}
#[derive(Clone, Debug)]
pub struct EfsEval {
pub cost: f64,
pub steps: Vec<f64>,
pub beta: Option<Array1<f64>>,
pub psi_gradient: Option<Array1<f64>>,
pub psi_indices: Option<Vec<usize>>,
pub inner_hessian_scale: Option<f64>,
pub logdet_enclosure_gap: Option<f64>,
}
impl EfsEval {
#[must_use]
pub fn with_logdet_enclosure_gap(mut self, gap: Option<f64>) -> Self {
self.logdet_enclosure_gap = gap;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SeedOutcome {
Installed,
NoSlot,
}
pub trait OuterObjective {
fn capability(&self) -> OuterCapability;
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError>;
fn eval_screening_proxy(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.eval_cost(rho)
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError>;
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
match order {
OuterEvalOrder::Value => {
let cost = self.eval_cost(rho)?;
Ok(OuterEval {
cost,
gradient: Array1::zeros(rho.len()),
hessian: HessianResult::Unavailable,
inner_beta_hint: None,
})
}
OuterEvalOrder::ValueAndGradient | OuterEvalOrder::ValueGradientHessian => {
self.eval(rho)
}
}
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
Err(EstimationError::RemlOptimizationFailed(format!(
"EFS evaluation not implemented for this objective at rho_dim={}",
rho.len()
)))
}
fn reset(&mut self);
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError>;
fn allow_continuation_prewarm(&self) -> bool {
false
}
fn outer_device_admission(&self) -> Option<crate::gpu::policy::RemlOuterAdmission> {
None
}
fn requires_continuation_path_entry(&self) -> bool {
false
}
fn curvature_homotopy_entry(
&mut self,
rho: &Array1<f64>,
) -> Option<Result<bool, EstimationError>> {
if let Some(idx) = rho.iter().position(|v| !v.is_finite()) {
return Some(Err(EstimationError::RemlOptimizationFailed(format!(
"curvature-homotopy entry received non-finite rho[{idx}]"
))));
}
None
}
fn accept_seed_without_outer_iterations(
&mut self,
rho: &Array1<f64>,
) -> Result<Option<f64>, EstimationError> {
if rho.is_empty() {
return Ok(None);
}
Ok(None)
}
fn finalize_outer_result(
&mut self,
rho: &Array1<f64>,
plan: &OuterPlan,
) -> Result<(), EstimationError> {
log::debug!(
"[OUTER] finalize: re-installing best rho into the objective (solver {:?})",
plan.solver
);
self.eval_cost(rho).map(|_| ())
}
}
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct IteratePayload {
schema: u32,
pub(crate) rho: Vec<f64>,
#[serde(default)]
pub(crate) beta: Vec<f64>,
pub(crate) cost: f64,
eval_id: u64,
}
pub(crate) const ITERATE_PAYLOAD_SCHEMA: u32 = 2;
pub(crate) fn encode_iterate(
rho: &Array1<f64>,
beta: Option<&Array1<f64>>,
cost: f64,
eval_id: u64,
) -> Option<Vec<u8>> {
let p = IteratePayload {
schema: ITERATE_PAYLOAD_SCHEMA,
rho: rho.to_vec(),
beta: beta.map(|b| b.to_vec()).unwrap_or_default(),
cost,
eval_id,
};
serde_json::to_vec(&p).ok()
}
pub(crate) fn decode_iterate(bytes: &[u8], expected_rho_dim: usize) -> Option<IteratePayload> {
let p: IteratePayload = serde_json::from_slice(bytes).ok()?;
if p.schema != ITERATE_PAYLOAD_SCHEMA {
return None;
}
if p.rho.len() != expected_rho_dim {
return None;
}
if !p.rho.iter().all(|x| x.is_finite()) || !p.cost.is_finite() {
return None;
}
if !p.beta.iter().all(|x| x.is_finite()) {
return None;
}
Some(p)
}
#[derive(Debug)]
pub(crate) enum CacheSeedDecision {
ExactFinal {
rho: Array1<f64>,
beta: Vec<f64>,
final_value: f64,
iterations: usize,
prior_obj_display: f64,
},
Seed {
rho: Array1<f64>,
beta: Vec<f64>,
prior_obj_display: f64,
iteration: u64,
},
Discard {
reason: &'static str,
prior_obj_display: f64,
all_rho_finite: Option<bool>,
},
}
pub(crate) fn classify_cache_entry_for_outer(
loaded: &crate::cache::LoadedEntry,
expected_rho_dim: usize,
) -> CacheSeedDecision {
let entry = &loaded.entry;
let Some(payload) = decode_iterate(&entry.payload, expected_rho_dim) else {
return CacheSeedDecision::Discard {
reason: "payload-shape-mismatch",
prior_obj_display: entry.objective.unwrap_or(f64::NAN),
all_rho_finite: None,
};
};
let cached_rho = Array1::from_vec(payload.rho);
let prior_obj_display = entry.objective.unwrap_or(f64::NAN);
if matches!(entry.objective, Some(v) if !v.is_finite()) {
return CacheSeedDecision::Discard {
reason: "non-finite-payload",
prior_obj_display,
all_rho_finite: Some(cached_rho.iter().all(|v| v.is_finite())),
};
}
if !cached_rho.iter().all(|v| v.is_finite()) {
return CacheSeedDecision::Discard {
reason: "non-finite-payload",
prior_obj_display,
all_rho_finite: Some(false),
};
}
if loaded.source == LoadSource::Exact && entry.kind == crate::cache::EntryKind::Final {
return CacheSeedDecision::ExactFinal {
rho: cached_rho,
beta: payload.beta,
final_value: entry.objective.unwrap_or(payload.cost),
iterations: entry
.iteration
.unwrap_or(payload.eval_id)
.min(usize::MAX as u64) as usize,
prior_obj_display,
};
}
CacheSeedDecision::Seed {
rho: cached_rho,
beta: payload.beta,
prior_obj_display,
iteration: entry.iteration.unwrap_or(payload.eval_id),
}
}
pub(crate) fn cache_entry_would_help_outer(
loaded: &crate::cache::LoadedEntry,
expected_rho_dim: usize,
) -> bool {
matches!(
classify_cache_entry_for_outer(loaded, expected_rho_dim),
CacheSeedDecision::ExactFinal { .. } | CacheSeedDecision::Seed { .. }
)
}
pub(crate) struct CheckpointingObjective<'a> {
inner: &'a mut dyn OuterObjective,
session: Arc<CacheSession>,
mirror_sessions: Vec<Arc<CacheSession>>,
eval_counter: AtomicU64,
last_inner_beta: std::sync::Mutex<Option<Array1<f64>>>,
}
impl<'a> CheckpointingObjective<'a> {
pub(crate) fn new(
inner: &'a mut dyn OuterObjective,
session: Arc<CacheSession>,
mirror_sessions: Vec<Arc<CacheSession>>,
) -> Self {
Self {
inner,
session,
mirror_sessions,
eval_counter: AtomicU64::new(0),
last_inner_beta: std::sync::Mutex::new(None),
}
}
pub(crate) fn last_inner_beta(&self) -> Option<Array1<f64>> {
self.last_inner_beta.lock().ok().and_then(|g| g.clone())
}
fn note(&self, rho: &Array1<f64>, beta: Option<&Array1<f64>>, cost: f64) {
if !cost.is_finite() {
return;
}
if let Some(b) = beta {
if !b.iter().all(|v| v.is_finite()) {
return;
}
if let Ok(mut guard) = self.last_inner_beta.lock() {
*guard = Some(b.clone());
}
}
let i = self.eval_counter.fetch_add(1, Ordering::Relaxed);
if let Some(bytes) = encode_iterate(rho, beta, cost, i) {
self.session.checkpoint(&bytes, Some(cost), Some(i));
for mirror in &self.mirror_sessions {
mirror.checkpoint(&bytes, Some(cost), Some(i));
}
}
}
}
impl<'a> OuterObjective for CheckpointingObjective<'a> {
fn capability(&self) -> OuterCapability {
self.inner.capability()
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
let v = self.inner.eval_cost(rho)?;
self.note(rho, None, v);
Ok(v)
}
fn eval_screening_proxy(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.inner.eval_screening_proxy(rho)
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
let r = self.inner.eval(rho)?;
self.note(rho, r.inner_beta_hint.as_ref(), r.cost);
Ok(r)
}
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
let r = self.inner.eval_with_order(rho, order)?;
self.note(rho, r.inner_beta_hint.as_ref(), r.cost);
Ok(r)
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
let r = self.inner.eval_efs(rho)?;
self.note(rho, None, r.cost);
Ok(r)
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
let result = self.inner.seed_inner_state(beta);
if matches!(result, Ok(SeedOutcome::Installed))
&& beta.iter().all(|v| v.is_finite())
&& let Ok(mut guard) = self.last_inner_beta.lock()
{
*guard = Some(beta.clone());
}
result
}
fn allow_continuation_prewarm(&self) -> bool {
self.inner.allow_continuation_prewarm()
}
fn requires_continuation_path_entry(&self) -> bool {
self.inner.requires_continuation_path_entry()
}
fn reset(&mut self) {
self.inner.reset();
}
}
pub struct ClosureObjective<
S,
Fc,
Fe,
Fr = fn(&mut S),
Fefs = fn(&mut S, &Array1<f64>) -> Result<EfsEval, EstimationError>,
Feo = fn(&mut S, &Array1<f64>, OuterEvalOrder) -> Result<OuterEval, EstimationError>,
Fsp = fn(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fseed = fn(&mut S, &Array1<f64>) -> Result<(), EstimationError>,
> {
pub(crate) state: S,
pub(crate) cap: OuterCapability,
pub(crate) cost_fn: Fc,
pub(crate) eval_fn: Fe,
pub(crate) eval_order_fn: Option<Feo>,
pub(crate) reset_fn: Option<Fr>,
pub(crate) efs_fn: Option<Fefs>,
pub(crate) screening_proxy_fn: Option<Fsp>,
pub(crate) seed_fn: Option<Fseed>,
pub(crate) continuation_prewarm: bool,
}
impl<S, Fc, Fe, Fr, Fefs, Feo, Fsp, Fseed> OuterObjective
for ClosureObjective<S, Fc, Fe, Fr, Fefs, Feo, Fsp, Fseed>
where
Fc: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fe: FnMut(&mut S, &Array1<f64>) -> Result<OuterEval, EstimationError>,
Fr: FnMut(&mut S),
Fefs: FnMut(&mut S, &Array1<f64>) -> Result<EfsEval, EstimationError>,
Feo: FnMut(&mut S, &Array1<f64>, OuterEvalOrder) -> Result<OuterEval, EstimationError>,
Fsp: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fseed: FnMut(&mut S, &Array1<f64>) -> Result<(), EstimationError>,
{
fn capability(&self) -> OuterCapability {
self.cap.clone()
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
crate::solver::estimate::reml::runtime::record_current_outer_theta_for_ift(rho);
(self.cost_fn)(&mut self.state, rho)
}
fn eval_screening_proxy(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
crate::solver::estimate::reml::runtime::record_current_outer_theta_for_ift(rho);
match self.screening_proxy_fn.as_mut() {
Some(f) => f(&mut self.state, rho),
None => (self.cost_fn)(&mut self.state, rho),
}
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
crate::solver::estimate::reml::runtime::record_current_outer_theta_for_ift(rho);
(self.eval_fn)(&mut self.state, rho)
}
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
crate::solver::estimate::reml::runtime::record_current_outer_theta_for_ift(rho);
match self.eval_order_fn.as_mut() {
Some(f) => f(&mut self.state, rho, order),
None => (self.eval_fn)(&mut self.state, rho),
}
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
crate::solver::estimate::reml::runtime::record_current_outer_theta_for_ift(rho);
match self.efs_fn.as_mut() {
Some(f) => f(&mut self.state, rho),
None => Err(EstimationError::RemlOptimizationFailed(
"EFS evaluation not implemented for this objective".to_string(),
)),
}
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
if beta.is_empty() {
return Ok(SeedOutcome::Installed);
}
match self.seed_fn.as_mut() {
Some(f) => f(&mut self.state, beta).map(|()| SeedOutcome::Installed),
None => Ok(SeedOutcome::NoSlot),
}
}
fn allow_continuation_prewarm(&self) -> bool {
self.continuation_prewarm && self.seed_fn.is_some()
}
fn reset(&mut self) {
if let Some(f) = self.reset_fn.as_mut() {
f(&mut self.state);
}
}
}
impl<S, Fc, Fe, Fr, Fefs, Feo, Fsp> ClosureObjective<S, Fc, Fe, Fr, Fefs, Feo, Fsp>
where
Fc: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fe: FnMut(&mut S, &Array1<f64>) -> Result<OuterEval, EstimationError>,
Fr: FnMut(&mut S),
Fefs: FnMut(&mut S, &Array1<f64>) -> Result<EfsEval, EstimationError>,
Feo: FnMut(&mut S, &Array1<f64>, OuterEvalOrder) -> Result<OuterEval, EstimationError>,
Fsp: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
{
pub fn with_seed_inner_state<Fseed>(
self,
seed_fn: Fseed,
) -> ClosureObjective<S, Fc, Fe, Fr, Fefs, Feo, Fsp, Fseed>
where
Fseed: FnMut(&mut S, &Array1<f64>) -> Result<(), EstimationError>,
{
ClosureObjective {
state: self.state,
cap: self.cap,
cost_fn: self.cost_fn,
eval_fn: self.eval_fn,
eval_order_fn: self.eval_order_fn,
reset_fn: self.reset_fn,
efs_fn: self.efs_fn,
screening_proxy_fn: self.screening_proxy_fn,
seed_fn: Some(seed_fn),
continuation_prewarm: self.continuation_prewarm,
}
}
}
pub(crate) fn into_objective_error(context: &str, err: EstimationError) -> ObjectiveEvalError {
ObjectiveEvalError::recoverable(format!("{context}: {err}"))
}
pub(crate) fn finite_cost_or_error(context: &str, cost: f64) -> Result<f64, ObjectiveEvalError> {
if cost.is_finite() {
Ok(cost)
} else {
Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite cost"
)))
}
}
pub(crate) fn finite_outer_eval_or_error(
context: &str,
layout: OuterThetaLayout,
eval: OuterEval,
) -> Result<OuterEval, ObjectiveEvalError> {
layout.validate_gradient_len(&eval.gradient, context)?;
if !eval.cost.is_finite() {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite cost"
)));
}
if !eval.gradient.iter().all(|v| v.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite gradient"
)));
}
match &eval.hessian {
HessianResult::Analytic(hessian) => {
layout.validate_hessian_shape(hessian, context)?;
if !hessian.iter().all(|v| v.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite Hessian"
)));
}
}
HessianResult::Operator(op) => {
if op.dim() != layout.n_params {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer Hessian operator dimension mismatch: got {}, expected {} (rho_dim={}, psi_dim={})",
op.dim(),
layout.n_params,
layout.rho_dim(),
layout.psi_dim
)));
}
}
HessianResult::Unavailable => {}
}
Ok(eval)
}
pub(crate) fn finite_outer_first_order_eval_or_error(
context: &str,
layout: OuterThetaLayout,
eval: OuterEval,
) -> Result<OuterEval, ObjectiveEvalError> {
layout.validate_gradient_len(&eval.gradient, context)?;
if !eval.cost.is_finite() {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite cost"
)));
}
if !eval.gradient.iter().all(|v| v.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite gradient"
)));
}
Ok(eval)
}
pub(crate) fn validate_second_order_seed_hessian(
context: &str,
layout: OuterThetaLayout,
eval: &OuterEval,
) -> Result<(), ObjectiveEvalError> {
if layout.n_params > SECOND_ORDER_GEOMETRY_PROBE_MAX_PARAMS || !eval.hessian.is_analytic() {
return Ok(());
}
if matches!(
&eval.hessian,
HessianResult::Operator(op) if !op.materialization_capability().is_available()
) {
return Ok(());
}
let Some(hessian) = eval.hessian.materialize_dense().map_err(|message| {
ObjectiveEvalError::recoverable(format!(
"{context}: analytic outer Hessian materialization failed during second-order seed validation: {message}"
))
})?
else {
return Ok(());
};
layout.validate_hessian_shape(&hessian, context)?;
if !hessian.iter().all(|value| value.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: analytic outer Hessian probe encountered non-finite entries"
)));
}
Ok(())
}