use ndarray::{Array1, Array2, ArrayView2};
use crate::gpu::policy::{PirlsLoopCurvatureKind, PirlsLoopFamilyKind, RemlOuterAdmission};
use crate::solver::estimate::EstimationError;
use crate::solver::gpu::reml_gpu::{RemlGpuEvidence, RemlGpuInput, evidence_derivatives_gpu};
#[derive(Clone, Debug)]
pub struct RemlOuterGpuInput {
pub seed_rho: Array1<f64>,
pub bounds: (Array1<f64>, Array1<f64>),
pub gradient_tolerance: f64,
pub max_iterations: usize,
pub axis_step_caps: Option<Array1<f64>>,
pub admission: RemlOuterAdmission,
pub seed_penalised_hessian: Array2<f64>,
pub seed_derivative_hessians: Vec<Array2<f64>>,
pub seed_objective: f64,
}
#[derive(Clone, Debug)]
pub struct RemlOuterGpuOutcome {
pub rho: Array1<f64>,
pub objective: f64,
pub iterations: usize,
pub final_grad_norm: Option<f64>,
pub final_gradient: Option<Array1<f64>>,
pub converged: bool,
}
#[derive(Clone, Debug)]
pub struct RemlOuterDeviceEval {
pub objective: f64,
pub gradient: Array1<f64>,
}
pub fn evaluate_outer_on_device(
penalised_hessian: ArrayView2<'_, f64>,
derivative_hessians: &[ArrayView2<'_, f64>],
penalised_log_likelihood: f64,
penalty_logdet: f64,
) -> Result<RemlOuterDeviceEval, EstimationError> {
let input = RemlGpuInput {
penalized_hessian: penalised_hessian,
derivative_hessians: derivative_hessians.to_vec(),
};
let RemlGpuEvidence {
logdet_hessian,
gradient_rho,
} = evidence_derivatives_gpu(input).map_err(|err| {
EstimationError::RemlOptimizationFailed(format!(
"device-resident REML evidence failed: {err}"
))
})?;
let objective = -penalised_log_likelihood + 0.5 * logdet_hessian - 0.5 * penalty_logdet;
Ok(RemlOuterDeviceEval {
objective,
gradient: gradient_rho,
})
}
pub fn initial_inverse_hessian(num_rho: usize, seed_grad_inf_norm: f64) -> Array2<f64> {
let scale = if seed_grad_inf_norm > 0.0 && seed_grad_inf_norm.is_finite() {
1.0 / seed_grad_inf_norm.max(1.0)
} else {
1.0
};
let mut h_inv = Array2::<f64>::zeros((num_rho, num_rho));
for i in 0..num_rho {
h_inv[[i, i]] = scale;
}
h_inv
}
pub fn cap_axiswise(direction: &mut Array1<f64>, caps: Option<&Array1<f64>>) {
let Some(caps) = caps else {
return;
};
for (d, c) in direction.iter_mut().zip(caps.iter()) {
if !c.is_finite() || *c <= 0.0 {
continue;
}
if d.abs() > *c {
*d = d.signum() * *c;
}
}
}
pub fn project_onto_bounds(rho: &mut Array1<f64>, bounds: &(Array1<f64>, Array1<f64>)) {
let (lo, hi) = bounds;
for i in 0..rho.len() {
let lower = lo[i];
let upper = hi[i];
if rho[i] < lower {
rho[i] = lower;
} else if rho[i] > upper {
rho[i] = upper;
}
}
}
pub fn run_reml_outer_on_device<E>(
input: RemlOuterGpuInput,
mut evaluator: E,
) -> Result<RemlOuterGpuOutcome, EstimationError>
where
E: FnMut(&Array1<f64>) -> Result<RemlOuterDeviceEval, EstimationError>,
{
if !matches!(input.admission.family, Some(_)) {
return Err(EstimationError::RemlOptimizationFailed(
"device-resident REML outer driver requires a JIT-cached PIRLS family".to_string(),
));
}
if !input.admission.gpu_available {
return Err(EstimationError::RemlOptimizationFailed(
"device-resident REML outer driver dispatched without GPU runtime".to_string(),
));
}
let num_rho = input.seed_rho.len();
if num_rho == 0 {
return Ok(RemlOuterGpuOutcome {
rho: Array1::<f64>::zeros(0),
objective: input.seed_objective,
iterations: 0,
final_grad_norm: Some(0.0),
final_gradient: Some(Array1::<f64>::zeros(0)),
converged: true,
});
}
if input.bounds.0.len() != num_rho || input.bounds.1.len() != num_rho {
return Err(EstimationError::RemlOptimizationFailed(format!(
"device-resident REML outer driver: bounds shape mismatch (num_rho={num_rho}, \
lower={}, upper={})",
input.bounds.0.len(),
input.bounds.1.len(),
)));
}
let bounds = input.bounds.clone();
let axis_caps = input.axis_step_caps.clone();
let grad_tol = input.gradient_tolerance.max(0.0);
let max_iter = input.max_iterations;
let mut rho = input.seed_rho.clone();
project_onto_bounds(&mut rho, &bounds);
let eval = evaluator(&rho)?;
let mut objective = eval.objective;
let mut gradient = eval.gradient;
let mut grad_inf = inf_norm(&gradient);
let mut h_inv = initial_inverse_hessian(num_rho, grad_inf);
let mut converged = grad_inf <= grad_tol;
let mut iterations = 0_usize;
while !converged && iterations < max_iter {
let mut direction = matvec_neg(h_inv.view(), &gradient);
cap_axiswise(&mut direction, axis_caps.as_ref());
const ARMIJO_C1: f64 = 1.0e-4;
const MIN_ALPHA: f64 = 1.0e-12;
let g_dot_d = dot(&gradient, &direction);
if !g_dot_d.is_finite() || g_dot_d >= 0.0 {
break;
}
let mut alpha = 1.0_f64;
let mut trial_rho;
let mut trial_eval;
loop {
trial_rho = scaled_add(&rho, alpha, &direction);
project_onto_bounds(&mut trial_rho, &bounds);
match evaluator(&trial_rho) {
Ok(e) => {
trial_eval = e;
if trial_eval.objective.is_finite()
&& trial_eval.objective <= objective + ARMIJO_C1 * alpha * g_dot_d
{
break;
}
}
Err(_) => {
}
}
alpha *= 0.5;
if alpha < MIN_ALPHA {
return Ok(RemlOuterGpuOutcome {
rho,
objective,
iterations,
final_grad_norm: Some(grad_inf),
final_gradient: Some(gradient),
converged: false,
});
}
}
let s = sub(&trial_rho, &rho);
let y = sub(&trial_eval.gradient, &gradient);
let sy = dot(&s, &y);
if sy.is_finite() && sy > 1.0e-16 {
bfgs_inverse_hessian_update(&mut h_inv, &s, &y, sy);
}
rho = trial_rho;
objective = trial_eval.objective;
gradient = trial_eval.gradient;
grad_inf = inf_norm(&gradient);
iterations += 1;
converged = grad_inf <= grad_tol;
}
Ok(RemlOuterGpuOutcome {
rho,
objective,
iterations,
final_grad_norm: Some(grad_inf),
final_gradient: Some(gradient),
converged,
})
}
fn inf_norm(v: &Array1<f64>) -> f64 {
let mut m = 0.0_f64;
for x in v.iter() {
let a = x.abs();
if a > m {
m = a;
}
}
m
}
fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut acc = 0.0_f64;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
fn sub(a: &Array1<f64>, b: &Array1<f64>) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(a.len());
for i in 0..a.len() {
out[i] = a[i] - b[i];
}
out
}
fn scaled_add(base: &Array1<f64>, alpha: f64, dir: &Array1<f64>) -> Array1<f64> {
let mut out = base.clone();
for i in 0..out.len() {
out[i] += alpha * dir[i];
}
out
}
fn matvec_neg(h_inv: ArrayView2<'_, f64>, g: &Array1<f64>) -> Array1<f64> {
let n = g.len();
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0_f64;
for j in 0..n {
acc += h_inv[[i, j]] * g[j];
}
out[i] = -acc;
}
out
}
fn bfgs_inverse_hessian_update(h_inv: &mut Array2<f64>, s: &Array1<f64>, y: &Array1<f64>, sy: f64) {
let n = s.len();
let rho = 1.0 / sy;
let mut hy = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0_f64;
for j in 0..n {
acc += h_inv[[i, j]] * y[j];
}
hy[i] = acc;
}
let mut yhy = 0.0_f64;
for i in 0..n {
yhy += y[i] * hy[i];
}
let coeff = (sy + yhy) * rho * rho;
for i in 0..n {
for j in 0..n {
h_inv[[i, j]] += coeff * s[i] * s[j] - rho * (hy[i] * s[j] + s[i] * hy[j]);
}
}
}
pub fn device_pirls_kernel_kind(
family: PirlsLoopFamilyKind,
curvature: PirlsLoopCurvatureKind,
) -> (PirlsLoopFamilyKind, PirlsLoopCurvatureKind) {
(family, curvature)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu::policy::{PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
fn dummy_admission(num_rho: usize) -> RemlOuterAdmission {
RemlOuterAdmission {
n: 200_000,
p: 64,
num_rho,
family: Some(PirlsLoopFamilyKind::BernoulliLogit),
curvature: PirlsLoopCurvatureKind::Fisher,
gpu_available: true,
}
}
#[test]
fn empty_rho_returns_seed_objective() {
let input = RemlOuterGpuInput {
seed_rho: Array1::<f64>::zeros(0),
bounds: (Array1::<f64>::zeros(0), Array1::<f64>::zeros(0)),
gradient_tolerance: 1.0e-6,
max_iterations: 10,
axis_step_caps: None,
admission: dummy_admission(0),
seed_penalised_hessian: Array2::<f64>::zeros((0, 0)),
seed_derivative_hessians: Vec::new(),
seed_objective: 42.0,
};
let evaluator = |_: &Array1<f64>| -> Result<RemlOuterDeviceEval, EstimationError> {
Ok(RemlOuterDeviceEval {
objective: 0.0,
gradient: Array1::<f64>::zeros(0),
})
};
let out = run_reml_outer_on_device(input, evaluator).expect("empty path");
assert_eq!(out.iterations, 0);
assert!(out.converged);
assert_eq!(out.objective, 42.0);
}
#[test]
fn converges_on_quadratic() {
let target = Array1::from(vec![0.5_f64, -0.25, 1.0, -0.75]);
let target_owned = target.clone();
let input = RemlOuterGpuInput {
seed_rho: Array1::from(vec![2.0, 2.0, 2.0, 2.0]),
bounds: (Array1::from_elem(4, -10.0), Array1::from_elem(4, 10.0)),
gradient_tolerance: 1.0e-8,
max_iterations: 100,
axis_step_caps: None,
admission: dummy_admission(4),
seed_penalised_hessian: Array2::<f64>::eye(4),
seed_derivative_hessians: Vec::new(),
seed_objective: 0.0,
};
let evaluator = move |rho: &Array1<f64>| -> Result<RemlOuterDeviceEval, EstimationError> {
let diff: Array1<f64> = rho - &target_owned;
let value = 0.5 * diff.iter().map(|v| v * v).sum::<f64>();
Ok(RemlOuterDeviceEval {
objective: value,
gradient: diff,
})
};
let out = run_reml_outer_on_device(input, evaluator).expect("quadratic path");
assert!(out.converged, "BFGS should converge on a quadratic");
for (got, want) in out.rho.iter().zip(target.iter()) {
assert!((*got - *want).abs() < 1.0e-4_f64, "got {got} want {want}");
}
}
#[test]
fn axis_caps_clamp_search_direction() {
let mut direction = Array1::from(vec![3.0, -4.0, 0.5]);
let caps = Array1::from(vec![1.0, 2.0, 10.0]);
cap_axiswise(&mut direction, Some(&caps));
assert_eq!(direction[0], 1.0);
assert_eq!(direction[1], -2.0);
assert_eq!(direction[2], 0.5);
}
#[test]
fn projects_onto_bounds() {
let mut rho = Array1::from(vec![-5.0, 7.0, 0.5]);
let bounds = (
Array1::from(vec![-1.0, -1.0, -1.0]),
Array1::from(vec![1.0, 1.0, 1.0]),
);
project_onto_bounds(&mut rho, &bounds);
assert_eq!(rho[0], -1.0);
assert_eq!(rho[1], 1.0);
assert_eq!(rho[2], 0.5);
}
}