use super::*;
pub use crate::solver_contract::EfsEval;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SeedOutcome {
Installed,
NoSlot,
Incompatible,
}
pub trait OuterObjective {
fn capability(&self) -> OuterCapability;
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError>;
fn eval_screening_proxy(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.eval_cost(rho)
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError>;
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
match order {
OuterEvalOrder::Value => {
let cost = self.eval_cost(rho)?;
Ok(OuterEval::value_only(cost, rho.len(), None))
}
OuterEvalOrder::ValueAndGradient | OuterEvalOrder::ValueGradientHessian => {
self.eval(rho)
}
}
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
Err(EstimationError::RemlOptimizationFailed(format!(
"EFS evaluation not implemented for this objective at rho_dim={}",
rho.len()
)))
}
fn reset(&mut self);
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError>;
fn allow_continuation_prewarm(&self) -> bool {
false
}
fn outer_device_admission(&self) -> Option<crate::gpu::policy::RemlOuterAdmission> {
None
}
fn requires_continuation_path_entry(&self) -> bool {
false
}
fn curvature_homotopy_entry(
&mut self,
rho: &Array1<f64>,
) -> Option<Result<bool, EstimationError>> {
if let Some(idx) = rho.iter().position(|v| !v.is_finite()) {
return Some(Err(EstimationError::RemlOptimizationFailed(format!(
"curvature-homotopy entry received non-finite rho[{idx}]"
))));
}
None
}
fn accept_seed_without_outer_iterations(
&mut self,
rho: &Array1<f64>,
) -> Result<Option<f64>, EstimationError> {
if rho.is_empty() {
return Ok(None);
}
Ok(None)
}
fn finalize_outer_result(
&mut self,
rho: &Array1<f64>,
plan: &OuterPlan,
) -> Result<(), EstimationError> {
log::debug!(
"[OUTER] finalize: re-installing best rho into the objective (solver {:?})",
plan.solver
);
match plan.solver {
Solver::Efs | Solver::HybridEfs => self.eval_efs(rho).map(|_| ()),
Solver::Bfgs => self
.eval_with_order(rho, OuterEvalOrder::ValueAndGradient)
.map(|_| ()),
Solver::Arc => self
.eval_with_order(rho, OuterEvalOrder::ValueGradientHessian)
.map(|_| ()),
}
}
}
#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct IteratePayload {
schema: u32,
pub(crate) rho: Vec<f64>,
#[serde(default)]
pub(crate) beta: Vec<f64>,
#[serde(default)]
pub(crate) hessian: Vec<f64>,
#[serde(default)]
pub(crate) hessian_dim: usize,
pub(crate) cost: f64,
eval_id: u64,
}
pub(crate) const ITERATE_PAYLOAD_SCHEMA: u32 = 2;
pub(crate) fn encode_iterate(
rho: &Array1<f64>,
beta: Option<&Array1<f64>>,
hessian: Option<&Array2<f64>>,
cost: f64,
eval_id: u64,
) -> Option<Vec<u8>> {
let (hessian_flat, hessian_dim) = match hessian {
Some(h) if h.nrows() == h.ncols() && h.iter().all(|v| v.is_finite()) => {
(h.iter().copied().collect::<Vec<f64>>(), h.nrows())
}
_ => (Vec::new(), 0),
};
let p = IteratePayload {
schema: ITERATE_PAYLOAD_SCHEMA,
rho: rho.to_vec(),
beta: beta.map(|b| b.to_vec()).unwrap_or_default(),
hessian: hessian_flat,
hessian_dim,
cost,
eval_id,
};
serde_json::to_vec(&p).ok()
}
pub(crate) fn decode_iterate(bytes: &[u8], expected_rho_dim: usize) -> Option<IteratePayload> {
let mut p: IteratePayload = serde_json::from_slice(bytes).ok()?;
if p.schema != ITERATE_PAYLOAD_SCHEMA {
return None;
}
if p.rho.len() != expected_rho_dim {
return None;
}
if !p.rho.iter().all(|x| x.is_finite()) || !p.cost.is_finite() {
return None;
}
if !p.beta.iter().all(|x| x.is_finite()) {
return None;
}
if p.hessian_dim.saturating_mul(p.hessian_dim) != p.hessian.len()
|| !p.hessian.iter().all(|x| x.is_finite())
{
p.hessian = Vec::new();
p.hessian_dim = 0;
}
Some(p)
}
#[derive(Debug)]
pub(crate) enum CacheSeedDecision {
ExactFinal {
rho: Array1<f64>,
beta: Vec<f64>,
final_value: f64,
iterations: usize,
prior_obj_display: f64,
},
Seed {
rho: Array1<f64>,
beta: Vec<f64>,
hessian: Option<(usize, Vec<f64>)>,
prior_obj_display: f64,
iteration: u64,
},
Discard {
reason: &'static str,
prior_obj_display: f64,
all_rho_finite: Option<bool>,
},
}
pub(crate) fn classify_cache_entry_for_outer(
loaded: &crate::warm_start::LoadedEntry,
expected_rho_dim: usize,
) -> CacheSeedDecision {
let entry = &loaded.entry;
let Some(payload) = decode_iterate(&entry.payload, expected_rho_dim) else {
return CacheSeedDecision::Discard {
reason: "payload-shape-mismatch",
prior_obj_display: entry.objective.unwrap_or(f64::NAN),
all_rho_finite: None,
};
};
let cached_rho = Array1::from_vec(payload.rho);
let prior_obj_display = entry.objective.unwrap_or(f64::NAN);
if matches!(entry.objective, Some(v) if !v.is_finite()) {
return CacheSeedDecision::Discard {
reason: "non-finite-payload",
prior_obj_display,
all_rho_finite: Some(cached_rho.iter().all(|v| v.is_finite())),
};
}
if !cached_rho.iter().all(|v| v.is_finite()) {
return CacheSeedDecision::Discard {
reason: "non-finite-payload",
prior_obj_display,
all_rho_finite: Some(false),
};
}
if loaded.source == LoadSource::Exact && entry.kind == crate::warm_start::EntryKind::Final {
return CacheSeedDecision::ExactFinal {
rho: cached_rho,
beta: payload.beta,
final_value: entry.objective.unwrap_or(payload.cost),
iterations: entry
.iteration
.unwrap_or(payload.eval_id)
.min(usize::MAX as u64) as usize,
prior_obj_display,
};
}
let hessian = if payload.hessian_dim > 0
&& payload.hessian.len() == payload.hessian_dim * payload.hessian_dim
{
Some((payload.hessian_dim, payload.hessian))
} else {
None
};
CacheSeedDecision::Seed {
rho: cached_rho,
beta: payload.beta,
hessian,
prior_obj_display,
iteration: entry.iteration.unwrap_or(payload.eval_id),
}
}
pub(crate) fn cache_entry_would_help_outer(
loaded: &crate::warm_start::LoadedEntry,
expected_rho_dim: usize,
) -> bool {
matches!(
classify_cache_entry_for_outer(loaded, expected_rho_dim),
CacheSeedDecision::ExactFinal { .. } | CacheSeedDecision::Seed { .. }
)
}
pub(crate) struct CheckpointingObjective<'a> {
inner: &'a mut dyn OuterObjective,
session: Arc<CacheSession>,
mirror_sessions: Vec<Arc<CacheSession>>,
eval_counter: AtomicU64,
last_inner_beta: std::sync::Mutex<Option<Array1<f64>>>,
}
impl<'a> CheckpointingObjective<'a> {
pub(crate) fn new(
inner: &'a mut dyn OuterObjective,
session: Arc<CacheSession>,
mirror_sessions: Vec<Arc<CacheSession>>,
) -> Self {
Self {
inner,
session,
mirror_sessions,
eval_counter: AtomicU64::new(0),
last_inner_beta: std::sync::Mutex::new(None),
}
}
pub(crate) fn last_inner_beta(&self) -> Option<Array1<f64>> {
self.last_inner_beta.lock().ok().and_then(|g| g.clone())
}
fn note(&self, rho: &Array1<f64>, beta: Option<&Array1<f64>>, cost: f64) {
if !cost.is_finite() {
return;
}
if let Some(b) = beta {
if !b.iter().all(|v| v.is_finite()) {
return;
}
if let Ok(mut guard) = self.last_inner_beta.lock() {
*guard = Some(b.clone());
}
}
let i = self.eval_counter.fetch_add(1, Ordering::Relaxed);
if let Some(bytes) = encode_iterate(rho, beta, None, cost, i) {
self.session.checkpoint(&bytes, Some(cost), Some(i));
for mirror in &self.mirror_sessions {
mirror.checkpoint(&bytes, Some(cost), Some(i));
}
}
}
}
impl<'a> OuterObjective for CheckpointingObjective<'a> {
fn capability(&self) -> OuterCapability {
self.inner.capability()
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
let v = self.inner.eval_cost(rho)?;
self.note(rho, None, v);
Ok(v)
}
fn eval_screening_proxy(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
self.inner.eval_screening_proxy(rho)
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
let r = self.inner.eval(rho)?;
self.note(rho, r.inner_beta_hint.as_ref(), r.cost);
Ok(r)
}
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
let r = self.inner.eval_with_order(rho, order)?;
self.note(rho, r.inner_beta_hint.as_ref(), r.cost);
Ok(r)
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
let r = self.inner.eval_efs(rho)?;
self.note(rho, None, r.cost);
Ok(r)
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
let result = self.inner.seed_inner_state(beta);
if matches!(result, Ok(SeedOutcome::Installed))
&& beta.iter().all(|v| v.is_finite())
&& let Ok(mut guard) = self.last_inner_beta.lock()
{
*guard = Some(beta.clone());
}
result
}
fn allow_continuation_prewarm(&self) -> bool {
self.inner.allow_continuation_prewarm()
}
fn requires_continuation_path_entry(&self) -> bool {
self.inner.requires_continuation_path_entry()
}
fn reset(&mut self) {
self.inner.reset();
}
}
pub struct ClosureObjective<
S,
Fc,
Fe,
Fr = fn(&mut S),
Fefs = fn(&mut S, &Array1<f64>) -> Result<EfsEval, EstimationError>,
Feo = fn(&mut S, &Array1<f64>, OuterEvalOrder) -> Result<OuterEval, EstimationError>,
Fsp = fn(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fseed = fn(&mut S, &Array1<f64>) -> Result<SeedOutcome, EstimationError>,
> {
pub(crate) state: S,
pub(crate) cap: OuterCapability,
pub(crate) cost_fn: Fc,
pub(crate) eval_fn: Fe,
pub(crate) eval_order_fn: Option<Feo>,
pub(crate) reset_fn: Option<Fr>,
pub(crate) efs_fn: Option<Fefs>,
pub(crate) screening_proxy_fn: Option<Fsp>,
pub(crate) seed_fn: Option<Fseed>,
pub(crate) continuation_prewarm: bool,
}
impl<S, Fc, Fe, Fr, Fefs, Feo, Fsp, Fseed> OuterObjective
for ClosureObjective<S, Fc, Fe, Fr, Fefs, Feo, Fsp, Fseed>
where
Fc: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fe: FnMut(&mut S, &Array1<f64>) -> Result<OuterEval, EstimationError>,
Fr: FnMut(&mut S),
Fefs: FnMut(&mut S, &Array1<f64>) -> Result<EfsEval, EstimationError>,
Feo: FnMut(&mut S, &Array1<f64>, OuterEvalOrder) -> Result<OuterEval, EstimationError>,
Fsp: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fseed: FnMut(&mut S, &Array1<f64>) -> Result<SeedOutcome, EstimationError>,
{
fn capability(&self) -> OuterCapability {
self.cap.clone()
}
fn eval_cost(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
crate::solver::estimate::reml::outer_eval::record_current_outer_theta_for_ift(rho);
(self.cost_fn)(&mut self.state, rho)
}
fn eval_screening_proxy(&mut self, rho: &Array1<f64>) -> Result<f64, EstimationError> {
crate::solver::estimate::reml::outer_eval::record_current_outer_theta_for_ift(rho);
match self.screening_proxy_fn.as_mut() {
Some(f) => f(&mut self.state, rho),
None => (self.cost_fn)(&mut self.state, rho),
}
}
fn eval(&mut self, rho: &Array1<f64>) -> Result<OuterEval, EstimationError> {
crate::solver::estimate::reml::outer_eval::record_current_outer_theta_for_ift(rho);
(self.eval_fn)(&mut self.state, rho)
}
fn eval_with_order(
&mut self,
rho: &Array1<f64>,
order: OuterEvalOrder,
) -> Result<OuterEval, EstimationError> {
crate::solver::estimate::reml::outer_eval::record_current_outer_theta_for_ift(rho);
match self.eval_order_fn.as_mut() {
Some(f) => f(&mut self.state, rho, order),
None => (self.eval_fn)(&mut self.state, rho),
}
}
fn eval_efs(&mut self, rho: &Array1<f64>) -> Result<EfsEval, EstimationError> {
crate::solver::estimate::reml::outer_eval::record_current_outer_theta_for_ift(rho);
match self.efs_fn.as_mut() {
Some(f) => f(&mut self.state, rho),
None => Err(EstimationError::RemlOptimizationFailed(
"EFS evaluation not implemented for this objective".to_string(),
)),
}
}
fn seed_inner_state(&mut self, beta: &Array1<f64>) -> Result<SeedOutcome, EstimationError> {
if beta.is_empty() {
return Ok(SeedOutcome::Installed);
}
match self.seed_fn.as_mut() {
Some(f) => f(&mut self.state, beta),
None => Ok(SeedOutcome::NoSlot),
}
}
fn allow_continuation_prewarm(&self) -> bool {
self.continuation_prewarm && self.seed_fn.is_some()
}
fn reset(&mut self) {
if let Some(f) = self.reset_fn.as_mut() {
f(&mut self.state);
}
}
}
impl<S, Fc, Fe, Fr, Fefs, Feo, Fsp> ClosureObjective<S, Fc, Fe, Fr, Fefs, Feo, Fsp>
where
Fc: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
Fe: FnMut(&mut S, &Array1<f64>) -> Result<OuterEval, EstimationError>,
Fr: FnMut(&mut S),
Fefs: FnMut(&mut S, &Array1<f64>) -> Result<EfsEval, EstimationError>,
Feo: FnMut(&mut S, &Array1<f64>, OuterEvalOrder) -> Result<OuterEval, EstimationError>,
Fsp: FnMut(&mut S, &Array1<f64>) -> Result<f64, EstimationError>,
{
pub fn with_seed_inner_state<Fseed>(
self,
seed_fn: Fseed,
) -> ClosureObjective<S, Fc, Fe, Fr, Fefs, Feo, Fsp, Fseed>
where
Fseed: FnMut(&mut S, &Array1<f64>) -> Result<SeedOutcome, EstimationError>,
{
ClosureObjective {
state: self.state,
cap: self.cap,
cost_fn: self.cost_fn,
eval_fn: self.eval_fn,
eval_order_fn: self.eval_order_fn,
reset_fn: self.reset_fn,
efs_fn: self.efs_fn,
screening_proxy_fn: self.screening_proxy_fn,
seed_fn: Some(seed_fn),
continuation_prewarm: self.continuation_prewarm,
}
}
}
pub(crate) fn into_objective_error(context: &str, err: EstimationError) -> ObjectiveEvalError {
ObjectiveEvalError::recoverable(format!("{context}: {err}"))
}
pub(crate) fn finite_cost_or_error(context: &str, cost: f64) -> Result<f64, ObjectiveEvalError> {
if cost.is_finite() {
Ok(cost)
} else {
Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite cost"
)))
}
}
fn validate_outer_first_order(
context: &str,
layout: OuterThetaLayout,
eval: &OuterEval,
) -> Result<(), ObjectiveEvalError> {
layout.validate_gradient_len(&eval.gradient, context)?;
if !eval.cost.is_finite() {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite cost"
)));
}
if !eval.gradient.iter().all(|v| v.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite gradient"
)));
}
Ok(())
}
pub(crate) fn finite_outer_eval_or_error(
context: &str,
layout: OuterThetaLayout,
eval: OuterEval,
) -> Result<OuterEval, ObjectiveEvalError> {
validate_outer_first_order(context, layout, &eval)?;
match &eval.hessian {
HessianResult::Analytic(hessian) => {
layout.validate_hessian_shape(hessian, context)?;
if !hessian.iter().all(|v| v.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: objective returned a non-finite Hessian"
)));
}
}
HessianResult::Operator(op) => {
if op.dim() != layout.n_params {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: outer Hessian operator dimension mismatch: got {}, expected {} (rho_dim={}, psi_dim={})",
op.dim(),
layout.n_params,
layout.rho_dim(),
layout.psi_dim
)));
}
}
HessianResult::Unavailable => {}
}
Ok(eval)
}
pub(crate) fn finite_outer_first_order_eval_or_error(
context: &str,
layout: OuterThetaLayout,
eval: OuterEval,
) -> Result<OuterEval, ObjectiveEvalError> {
validate_outer_first_order(context, layout, &eval)?;
Ok(eval)
}
pub(crate) fn validate_second_order_seed_hessian(
context: &str,
layout: OuterThetaLayout,
eval: &OuterEval,
) -> Result<(), ObjectiveEvalError> {
if layout.n_params > SECOND_ORDER_GEOMETRY_PROBE_MAX_PARAMS || !eval.hessian.is_analytic() {
return Ok(());
}
if matches!(
&eval.hessian,
HessianResult::Operator(op) if !op.materialization_capability().is_available()
) {
return Ok(());
}
let Some(hessian) = eval.hessian.materialize_dense().map_err(|message| {
ObjectiveEvalError::recoverable(format!(
"{context}: analytic outer Hessian materialization failed during second-order seed validation: {message}"
))
})?
else {
return Ok(());
};
layout.validate_hessian_shape(&hessian, context)?;
if !hessian.iter().all(|value| value.is_finite()) {
return Err(ObjectiveEvalError::recoverable(format!(
"{context}: analytic outer Hessian probe encountered non-finite entries"
)));
}
Ok(())
}