gam 0.3.15

Generalized penalized likelihood engine
Documentation
import Mathlib

namespace Gam.Solver.Reml

section GradientDescentVerification

open Matrix

variable {n p k : ℕ} [Fintype (Fin n)] [Fintype (Fin p)] [Fintype (Fin k)]

/-!
### Matrix Calculus: Log-Determinant Derivatives

We define `H(rho) = A + exp(rho) * B` and prove that the derivative of `log(det(H(rho)))`
with respect to `rho` is `exp(rho) * trace(H(rho)⁻¹ * B)`. This uses Jacobi's formula
for the derivative of the determinant.
-/

variable {m : Type*} [Fintype m] [DecidableEq m]

/-- Matrix function H(ρ) = A + exp(ρ) * B. -/
noncomputable def H_matrix (A B : Matrix m m ℝ) (rho : ℝ) : Matrix m m ℝ := A + Real.exp rho • B

/-- The log-determinant function f(ρ) = log(det(H(ρ))). -/
noncomputable def log_det_H (A B : Matrix m m ℝ) (rho : ℝ) := Real.log (H_matrix A B rho).det

/-- The derivative of log(det(H(ρ))) = log(det(A + exp(ρ)B)) with respect to ρ
    is exp(ρ) * trace(H(ρ)⁻¹ * B). This is derived using Jacobi's formula. -/
theorem derivative_log_det_H_matrix (A B : Matrix m m ℝ)
    (_hB : B.IsSymm)
    (rho : ℝ) (h_pos : (H_matrix A B rho).PosDef) :
    deriv (log_det_H A B) rho = Real.exp rho * ((H_matrix A B rho)⁻¹ * B).trace := by
  have h_inv : (H_matrix A B rho).det ≠ 0 := h_pos.det_pos.ne'
  have h_det : deriv (fun rho => Real.log (Matrix.det (A + Real.exp rho • B))) rho = Real.exp rho * Matrix.trace ((A + Real.exp rho • B)⁻¹ * B) := by
    have h_det_step1 : deriv (fun rho => Matrix.det (A + Real.exp rho • B)) rho = Matrix.det (A + Real.exp rho • B) * Matrix.trace ((A + Real.exp rho • B)⁻¹ * B) * Real.exp rho := by
      have h_jacobi : deriv (fun rho => Matrix.det (A + Real.exp rho • B)) rho = Matrix.trace (Matrix.adjugate (A + Real.exp rho • B) * deriv (fun rho => A + Real.exp rho • B) rho) := by
        have h_jacobi : ∀ (M : ℝ → Matrix m m ℝ), DifferentiableAt ℝ M rho → deriv (fun rho => Matrix.det (M rho)) rho = Matrix.trace (Matrix.adjugate (M rho) * deriv M rho) := by
          intro M hM_diff
          have h_jacobi : deriv (fun rho => Matrix.det (M rho)) rho = ∑ i, ∑ j, (Matrix.adjugate (M rho)) i j * deriv (fun rho => (M rho) j i) rho := by
            simp +decide [ Matrix.det_apply', Matrix.adjugate_apply, Matrix.mul_apply ]
            have h_jacobi : deriv (fun rho => ∑ σ : Equiv.Perm m, (↑(↑((Equiv.Perm.sign : Equiv.Perm m → ℤˣ) σ) : ℤ) : ℝ) * ∏ i : m, M rho ((σ : m → m) i) i) rho = ∑ σ : Equiv.Perm m, (↑(↑((Equiv.Perm.sign : Equiv.Perm m → ℤˣ) σ) : ℤ) : ℝ) * ∑ i : m, (∏ j ∈ Finset.univ.erase i, M rho ((σ : m → m) j) j) * deriv (fun rho => M rho ((σ : m → m) i) i) rho := by
              have h_jacobi : ∀ σ : Equiv.Perm m, deriv (fun rho => ∏ i : m, M rho ((σ : m → m) i) i) rho = ∑ i : m, (∏ j ∈ Finset.univ.erase i, M rho ((σ : m → m) j) j) * deriv (fun rho => M rho ((σ : m → m) i) i) rho := by
                intro σ
                have h_prod_rule : ∀ (f : m → ℝ → ℝ), (∀ i, DifferentiableAt ℝ (f i) rho) → deriv (fun rho => ∏ i, f i rho) rho = ∑ i, (∏ j ∈ Finset.univ.erase i, f j rho) * deriv (f i) rho := by
                  intro f hf
                  convert deriv_finset_prod (u := Finset.univ) (f := f) (x := rho) (fun i _ => hf i)
                  simp
                apply h_prod_rule
                intro i
                exact DifferentiableAt.comp rho ( differentiableAt_pi.1 ( differentiableAt_pi.1 hM_diff _ ) _ ) differentiableAt_id
              have h_deriv_sum : deriv (fun rho => ∑ σ : Equiv.Perm m, (↑(↑((Equiv.Perm.sign : Equiv.Perm m → ℤˣ) σ) : ℤ) : ℝ) * ∏ i : m, M rho ((σ : m → m) i) i) rho = ∑ σ : Equiv.Perm m, (↑(↑((Equiv.Perm.sign : Equiv.Perm m → ℤˣ) σ) : ℤ) : ℝ) * deriv (fun rho => ∏ i : m, M rho ((σ : m → m) i) i) rho := by
                have h_diff : ∀ σ : Equiv.Perm m, DifferentiableAt ℝ (fun rho => ∏ i : m, M rho ((σ : m → m) i) i) rho := by
                  intro σ
                  have h_diff : ∀ i : m, DifferentiableAt ℝ (fun rho => M rho ((σ : m → m) i) i) rho := by
                    intro i
                    exact DifferentiableAt.comp rho ( differentiableAt_pi.1 ( differentiableAt_pi.1 hM_diff _ ) _ ) differentiableAt_id
                  convert DifferentiableAt.finset_prod (u := Finset.univ) (f := fun i rho => M rho ((σ : m → m) i) i) (x := rho) (fun i _ => h_diff i)
                  simp
                norm_num [ h_diff ]
              simpa only [ h_jacobi ] using h_deriv_sum
            simp +decide only [h_jacobi, Finset.mul_sum _ _ _]
            simp +decide [ Finset.sum_mul _ _ _, Matrix.updateRow_apply ]
            rw [ Finset.sum_comm ]
            refine' Finset.sum_congr rfl fun i hi => _
            rw [ Finset.sum_comm, Finset.sum_congr rfl ] ; intros ; simp +decide [ Finset.prod_ite, Finset.filter_ne', Finset.filter_eq' ] ; ring
            rw [ Finset.sum_eq_single ( ( ‹Equiv.Perm m› : m → m ) i ) ] <;> simp +decide [ Finset.prod_ite, Finset.filter_ne', Finset.filter_eq' ] ; ring
            intro j hj; simp +decide [ Pi.single_apply, hj ]
            rw [ Finset.prod_eq_zero_iff.mpr ] <;> simp +decide [ hj ]
            exact ⟨ ( ‹Equiv.Perm m›.symm j ), by simp +decide, by simpa [ Equiv.symm_apply_eq ] using hj ⟩
          rw [ h_jacobi, Matrix.trace ]
          rw [ deriv_pi ]
          · simp +decide [ Matrix.mul_apply, Finset.mul_sum _ _ _ ]
            refine' Finset.sum_congr rfl fun i _ => Finset.sum_congr rfl fun j _ => _
            rw [ deriv_pi ]
            intro i; exact (by
            exact DifferentiableAt.comp rho ( differentiableAt_pi.1 ( differentiableAt_pi.1 hM_diff j ) i ) differentiableAt_id)
          · exact fun i => DifferentiableAt.comp rho ( differentiableAt_pi.1 hM_diff i ) differentiableAt_id
        apply h_jacobi
        exact differentiableAt_pi.2 fun i => differentiableAt_pi.2 fun j => DifferentiableAt.add ( differentiableAt_const _ ) ( DifferentiableAt.smul ( Real.differentiableAt_exp ) ( differentiableAt_const _ ) )
      simp_all +decide [ Matrix.inv_def, mul_assoc, mul_left_comm, mul_comm, Matrix.trace_mul_comm ( Matrix.adjugate _ ) ]
      rw [ show deriv ( fun rho => A + Real.exp rho • B ) rho = Real.exp rho • B from ?_ ]
      · by_cases h : Matrix.det ( A + Real.exp rho • B ) = 0 <;> simp_all +decide [ Matrix.trace_smul, mul_assoc, mul_comm, mul_left_comm ]
        exact False.elim <| h_inv h
      · rw [ deriv_pi ] <;> norm_num [ Real.differentiableAt_exp, mul_comm ]
        ext i; rw [ deriv_pi ] <;> norm_num [ Real.differentiableAt_exp, mul_comm ]
    by_cases h_det : DifferentiableAt ℝ ( fun rho => Matrix.det ( A + Real.exp rho • B ) ) rho <;> simp_all +decide [ Real.exp_ne_zero, mul_assoc, mul_comm, mul_left_comm ]
    · convert HasDerivAt.deriv ( HasDerivAt.log ( h_det.hasDerivAt ) h_inv ) using 1 ; ring!
      exact eq_div_of_mul_eq ( by aesop ) ( by linear_combination' h_det_step1.symm )
    · contrapose! h_det
      simp +decide [ Matrix.det_apply' ]
      fun_prop (disch := norm_num)
  exact h_det

-- 1. Model Functions
noncomputable def S_lambda_fn (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (rho : Fin k → ℝ) : Matrix (Fin p) (Fin p) ℝ :=
  ∑ i, (Real.exp (rho i) • S_basis i)

noncomputable def L_pen_fn (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ) (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (rho : Fin k → ℝ) (beta : Matrix (Fin p) (Fin 1) ℝ) : ℝ :=
  - (log_lik beta) + 0.5 * (beta.transpose * (S_lambda_fn S_basis rho) * beta).trace

noncomputable def Hessian_fn (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (rho : Fin k → ℝ) (beta : Matrix (Fin p) (Fin 1) ℝ) : Matrix (Fin p) (Fin p) ℝ :=
  X.transpose * (W beta) * X + S_lambda_fn S_basis rho

/-- Algebraic matrix inverse via Cramer's rule. Over `ℝ` this is definitionally
    equal to `M⁻¹`, but it avoids carrying inverse-specific structure in later
    definitions that are easier to normalize as polynomial expressions. -/
noncomputable def matrixInvAlg {α : Type*} [Fintype α] [DecidableEq α] (M : Matrix α α ℝ) : Matrix α α ℝ :=
  (M.det)⁻¹ • M.adjugate

theorem matrixInvAlg_eq_inv {α : Type*} [Fintype α] [DecidableEq α] (M : Matrix α α ℝ) :
    matrixInvAlg M = M⁻¹ := by
  by_cases h_det : M.det = 0
  · simp [matrixInvAlg, Matrix.inv_def, h_det]
  · simp [matrixInvAlg, Matrix.inv_def, h_det]

theorem inv_mul_self_of_det_ne_zero {α : Type*} [Fintype α] [DecidableEq α]
    (M : Matrix α α ℝ) (h_det : M.det ≠ 0) : M⁻¹ * M = 1 := by
  simp [Matrix.inv_def, Matrix.adjugate_mul, h_det]

noncomputable def LAML_explicit (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ) (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (rho : Fin k → ℝ) (beta : Matrix (Fin p) (Fin 1) ℝ) : ℝ :=
  let H := Hessian_fn S_basis X W rho beta
  L_pen_fn log_lik S_basis rho beta + 0.5 * Real.log (H.det) - 0.5 * Real.log ((S_lambda_fn S_basis rho).det)

noncomputable def LAML_fn (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ) (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ) (rho : Fin k → ℝ) : ℝ :=
  LAML_explicit log_lik S_basis X W rho (beta_hat rho)

noncomputable def LAML_fixed_beta_fn (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ) (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (b : Matrix (Fin p) (Fin 1) ℝ) (rho : Fin k → ℝ) : ℝ :=
  LAML_explicit log_lik S_basis X W rho b

-- 2. Rust Code Components
noncomputable def rust_delta_fn (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ) (rho : Fin k → ℝ) (i : Fin k) : Matrix (Fin p) (Fin 1) ℝ :=
  let b := beta_hat rho
  let H := Hessian_fn S_basis X W rho b
  let H_inv := matrixInvAlg H
  let lambda := Real.exp (rho i)
  let dS := lambda • S_basis i
  (-H_inv) * (dS * b)

noncomputable def rust_correction_fn (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ) (grad_op : (Matrix (Fin p) (Fin 1) ℝ → ℝ) → Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin p) (Fin 1) ℝ) (rho : Fin k → ℝ) (i : Fin k) : ℝ :=
  let b := beta_hat rho
  let delta := rust_delta_fn S_basis X W beta_hat rho i
  let dV_dbeta := (fun b_val => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W rho b_val)))
  ((grad_op dV_dbeta b).transpose * delta).trace

noncomputable def rust_direct_gradient_fn (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ) (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ) (rho : Fin k → ℝ) (i : Fin k) : ℝ :=
  let b := beta_hat rho
  let H := Hessian_fn S_basis X W rho b
  let H_inv := matrixInvAlg H
  let S := S_lambda_fn S_basis rho
  let S_inv := matrixInvAlg S
  let lambda := Real.exp (rho i)
  let Si := S_basis i
  0.5 * lambda * (b.transpose * Si * b).trace +
  0.5 * lambda * (H_inv * Si).trace -
  0.5 * lambda * (S_inv * Si).trace

-- 3. Verification Theorem

/-- Gradient definition for matrix-to-real functions. -/
def HasGradientAt (f : Matrix (Fin p) (Fin 1) ℝ → ℝ) (g : Matrix (Fin p) (Fin 1) ℝ) (x : Matrix (Fin p) (Fin 1) ℝ) :=
  ∃ (L : Matrix (Fin p) (Fin 1) ℝ →L[ℝ] ℝ),
    (∀ h, L h = (g.transpose * h).trace) ∧ HasFDerivAt f L x

noncomputable def laml_u (rho : Fin k → ℝ) (i : Fin k) (r : ℝ) := Function.update rho i r

noncomputable def laml_L1 (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ) (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ) (rho : Fin k → ℝ) (i : Fin k) (r : ℝ) : ℝ :=
  L_pen_fn log_lik S_basis (laml_u rho i r) (beta_hat (laml_u rho i r))

noncomputable def laml_L2 (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (X : Matrix (Fin n) (Fin p) ℝ) (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ) (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ) (rho : Fin k → ℝ) (i : Fin k) (r : ℝ) : ℝ :=
  0.5 * Real.log ((Hessian_fn S_basis X W (laml_u rho i r) (beta_hat (laml_u rho i r))).det)

noncomputable def laml_L3 (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ) (rho : Fin k → ℝ) (i : Fin k) (r : ℝ) : ℝ :=
  0.5 * Real.log ((S_lambda_fn S_basis (laml_u rho i r)).det)

/-- Rigorous compositional verification of the LAML gradient assembly.
    This packages the sum/subtraction rule argument once the three scalar
    component derivatives are established. -/
theorem laml_gradient_composition_verification
    (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ)
    (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ)
    (X : Matrix (Fin n) (Fin p) ℝ)
    (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ)
    (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ)
    (grad_op : (Matrix (Fin p) (Fin 1) ℝ → ℝ) → Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin p) (Fin 1) ℝ)
    (rho : Fin k → ℝ) (i : Fin k)
    (h_deriv_L1 : deriv (laml_L1 log_lik S_basis beta_hat rho i) (rho i) =
      0.5 * Real.exp (rho i) * trace ((beta_hat rho).transpose * (S_basis i) * (beta_hat rho)))
    (h_deriv_L2 : deriv (laml_L2 S_basis X W beta_hat rho i) (rho i) =
      0.5 * Real.exp (rho i) * trace ((Hessian_fn S_basis X W rho (beta_hat rho))⁻¹ * (S_basis i)) +
      trace ((grad_op (fun b_val => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W rho b_val))) (beta_hat rho)).transpose * rust_delta_fn S_basis X W beta_hat rho i))
    (h_deriv_L3 : deriv (laml_L3 S_basis rho i) (rho i) =
      0.5 * Real.exp (rho i) * trace ((S_lambda_fn S_basis rho)⁻¹ * (S_basis i)))
    (h_diff_L1 : DifferentiableAt ℝ (laml_L1 log_lik S_basis beta_hat rho i) (rho i))
    (h_diff_L2 : DifferentiableAt ℝ (laml_L2 S_basis X W beta_hat rho i) (rho i))
    (h_diff_L3 : DifferentiableAt ℝ (laml_L3 S_basis rho i) (rho i)) :
    deriv (fun r => LAML_fn log_lik S_basis X W beta_hat (laml_u rho i r)) (rho i) =
      rust_direct_gradient_fn S_basis X W beta_hat log_lik rho i +
      rust_correction_fn S_basis X W beta_hat grad_op rho i := by
  let L1 := laml_L1 log_lik S_basis beta_hat rho i
  let L2 := laml_L2 S_basis X W beta_hat rho i
  let L3 := laml_L3 S_basis rho i

  have h_diff_L1' : DifferentiableAt ℝ L1 (rho i) := h_diff_L1
  have h_diff_L2' : DifferentiableAt ℝ L2 (rho i) := h_diff_L2
  have h_diff_L3' : DifferentiableAt ℝ L3 (rho i) := h_diff_L3

  have h_split : ∀ r, LAML_fn log_lik S_basis X W beta_hat (laml_u rho i r) = L1 r + L2 r - L3 r := by
    intro r
    unfold LAML_fn
    rfl

  rw [show (fun r => LAML_fn log_lik S_basis X W beta_hat (laml_u rho i r)) = fun r => L1 r + L2 r - L3 r by
    funext r
    exact h_split r]
  change deriv ((fun r => L1 r + L2 r) - L3) (rho i) = _

  have h_diff_sum : DifferentiableAt ℝ (fun r => L1 r + L2 r) (rho i) := by
    exact DifferentiableAt.add h_diff_L1' h_diff_L2'
  have h_deriv_sum :
      deriv (fun r => L1 r + L2 r) (rho i) = deriv L1 (rho i) + deriv L2 (rho i) := by
    exact deriv_add h_diff_L1' h_diff_L2'

  rw [deriv_sub h_diff_sum h_diff_L3']
  rw [h_deriv_sum]
  rw [h_deriv_L1, h_deriv_L2, h_deriv_L3]
  unfold rust_direct_gradient_fn rust_correction_fn
  simp [matrixInvAlg_eq_inv]
  ring_nf

/-- Fixed-`β` verification: the explicit derivative of the LAML objective with
    respect to `rho_i` matches the Rust direct-gradient assembly. -/
theorem laml_fixed_beta_gradient_is_exact
    (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ)
    (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ)
    (X : Matrix (Fin n) (Fin p) ℝ)
    (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ)
    (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ)
    (rho : Fin k → ℝ) (i : Fin k)
    (b : Matrix (Fin p) (Fin 1) ℝ)
    (h_b : b = beta_hat rho)
    (h_diff_pen : DifferentiableAt ℝ (fun r => L_pen_fn log_lik S_basis (Function.update rho i r) b) (rho i))
    (h_diff_log_H : DifferentiableAt ℝ (fun r => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b))) (rho i))
    (h_diff_log_S : DifferentiableAt ℝ (fun r => -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) (rho i))
    (h_deriv_pen : deriv (fun r => L_pen_fn log_lik S_basis (Function.update rho i r) b) (rho i) =
      0.5 * Real.exp (rho i) * trace (b.transpose * S_basis i * b))
    (h_deriv_log_H : deriv (fun r => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b))) (rho i) =
      0.5 * Real.exp (rho i) * trace ((Hessian_fn S_basis X W rho b)⁻¹ * S_basis i))
    (h_deriv_log_S : deriv (fun r => -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) (rho i) =
      -0.5 * Real.exp (rho i) * trace ((S_lambda_fn S_basis rho)⁻¹ * S_basis i)) :
  deriv (fun r => LAML_fixed_beta_fn log_lik S_basis X W b (Function.update rho i r)) (rho i) =
  rust_direct_gradient_fn S_basis X W beta_hat log_lik rho i := by
  change deriv (fun r =>
      L_pen_fn log_lik S_basis (Function.update rho i r) b +
      0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b)) -
      0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) (rho i) = _
  dsimp only [rust_direct_gradient_fn]
  have h_add1 : deriv (fun r => L_pen_fn log_lik S_basis (Function.update rho i r) b +
      (0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b)) +
      -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r))))) (rho i) =
    deriv (fun r => L_pen_fn log_lik S_basis (Function.update rho i r) b) (rho i) +
    deriv (fun r => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b)) +
      -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) (rho i) := by
    apply deriv_add h_diff_pen
    exact DifferentiableAt.add h_diff_log_H h_diff_log_S
  have h_add2 : deriv (fun r => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b)) +
      -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) (rho i) =
    deriv (fun r => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b))) (rho i) +
    deriv (fun r => -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) (rho i) := by
    exact deriv_add h_diff_log_H h_diff_log_S
  have h_sub_to_add : (fun r => L_pen_fn log_lik S_basis (Function.update rho i r) b +
      0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b)) -
      0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r)))) =
    (fun r => L_pen_fn log_lik S_basis (Function.update rho i r) b +
      (0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W (Function.update rho i r) b)) +
      -0.5 * Real.log (Matrix.det (S_lambda_fn S_basis (Function.update rho i r))))) := by
    ext r
    ring
  rw [h_sub_to_add, h_add1, h_add2, h_deriv_pen, h_deriv_log_H, h_deriv_log_S]
  simp [matrixInvAlg_eq_inv]
  rw [← h_b]
  ring_nf

/-- Structural verification: `rust_delta_fn` implements the correct implicit derivative formula.

    If `grad(L_pen) = 0`, then differentiation gives `H * dbeta + dS * beta = 0`,
    so `dbeta = -H^-1 * dS * beta`.
    This theorem verifies that `rust_delta_fn` computes exactly this quantity. -/
theorem rust_delta_correctness
    (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ)
    (X : Matrix (Fin n) (Fin p) ℝ)
    (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ)
    (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ)
    (rho : Fin k → ℝ) (i : Fin k) :
    rust_delta_fn S_basis X W beta_hat rho i =
    -(Hessian_fn S_basis X W rho (beta_hat rho))⁻¹ *
    ((Real.exp (rho i) • S_basis i) * beta_hat rho) := by
  unfold rust_delta_fn
  simp [matrixInvAlg_eq_inv, neg_mul, Matrix.smul_mul]

/-- Structural verification: `laml_gradient_validity`

    This theorem proves that the total derivative of `LAML_fn` is correctly assembled
    from its partial derivatives and the implicit derivative of `beta`.

    It relies on structural hypotheses:
    1. Chain rule: d(LAML)/d(rho_i) = ∂(LAML)/∂(rho_i) + <∇_beta(LAML), d(beta)/d(rho_i)>
    2. Partial rho: ∂(LAML)/∂(rho_i) matches `rust_direct_gradient_fn`
    3. Partial beta: ∇_beta(LAML) matches the gradient term in `rust_correction_fn`
    4. Implicit beta: the differentiated optimality condition gives the linear system
       `H * d(beta)/d(rho_i) = -dS * beta`, which is then solved to recover `rust_delta_fn`

    This replaces the previous vacuous verification with a rigorous assembly proof. -/
theorem laml_gradient_validity
    (log_lik : Matrix (Fin p) (Fin 1) ℝ → ℝ)
    (S_basis : Fin k → Matrix (Fin p) (Fin p) ℝ)
    (X : Matrix (Fin n) (Fin p) ℝ)
    (W : Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin n) (Fin n) ℝ)
    (beta_hat : (Fin k → ℝ) → Matrix (Fin p) (Fin 1) ℝ)
    (grad_op : (Matrix (Fin p) (Fin 1) ℝ → ℝ) → Matrix (Fin p) (Fin 1) ℝ → Matrix (Fin p) (Fin 1) ℝ)
    (rho : Fin k → ℝ) (i : Fin k)
    -- 1. Hessian solvability at the evaluation point
    (h_hess_pos : (Hessian_fn S_basis X W rho (beta_hat rho)).PosDef)
    -- 2. Implicit differentiation of the optimality condition, stated without inversion
    (h_implicit : Hessian_fn S_basis X W rho (beta_hat rho) *
                  deriv (fun r => beta_hat (Function.update rho i r)) (rho i) =
                  - (Real.exp (rho i) • S_basis i) * (beta_hat rho))
    -- 2. Partial derivative wrt rho matches rust_direct_gradient_fn
    (h_partial_rho : deriv (fun r => LAML_fn log_lik S_basis X W (fun _ => beta_hat rho) (Function.update rho i r)) (rho i) =
                     rust_direct_gradient_fn S_basis X W beta_hat log_lik rho i)
    -- 3. Gradient wrt beta matches the term used in rust_correction_fn
    --    Note: rust_correction_fn uses `grad_op dV_dbeta`.
    --    Optimality of beta implies grad(L_pen) = 0, so grad(LAML) = grad(0.5 log det H).
    (h_grad_beta : HasGradientAt (fun b => LAML_fn log_lik S_basis X W (fun _ => b) rho)
                                 (grad_op (fun b_val => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W rho b_val))) (beta_hat rho))
                                 (beta_hat rho))
    -- 4. Chain rule holds for the total derivative
    (h_chain : deriv (fun r => LAML_fn log_lik S_basis X W beta_hat (Function.update rho i r)) (rho i) =
               deriv (fun r => LAML_fn log_lik S_basis X W (fun _ => beta_hat rho) (Function.update rho i r)) (rho i) +
               ( (grad_op (fun b_val => 0.5 * Real.log (Matrix.det (Hessian_fn S_basis X W rho b_val))) (beta_hat rho)).transpose *
                 deriv (fun r => beta_hat (Function.update rho i r)) (rho i) ).trace) :
  deriv (fun r => LAML_fn log_lik S_basis X W beta_hat (Function.update rho i r)) (rho i) =
  rust_direct_gradient_fn S_basis X W beta_hat log_lik rho i +
  rust_correction_fn S_basis X W beta_hat grad_op rho i :=
by
  have h_hess_det : (Hessian_fn S_basis X W rho (beta_hat rho)).det ≠ 0 := h_hess_pos.det_pos.ne'
  have h_deriv_beta : deriv (fun r => beta_hat (Function.update rho i r)) (rho i) =
      rust_delta_fn S_basis X W beta_hat rho i := by
    let H := Hessian_fn S_basis X W rho (beta_hat rho)
    let dbeta := deriv (fun r => beta_hat (Function.update rho i r)) (rho i)
    have h_solved :
        dbeta = -H⁻¹ * ((Real.exp (rho i) • S_basis i) * (beta_hat rho)) := by
      have h_mul := congrArg (fun M => H⁻¹ * M) h_implicit
      have h_left : H⁻¹ * (H * dbeta) = dbeta := by
        rw [← Matrix.mul_assoc, inv_mul_self_of_det_ne_zero H h_hess_det, Matrix.one_mul]
      calc
        dbeta = H⁻¹ * (H * dbeta) := by
          symm
          exact h_left
        _ = H⁻¹ * (- (Real.exp (rho i) • S_basis i) * beta_hat rho) := by
          simpa [H] using h_mul
        _ = -H⁻¹ * ((Real.exp (rho i) • S_basis i) * beta_hat rho) := by
          simp [Matrix.mul_assoc, neg_mul]
    calc
      deriv (fun r => beta_hat (Function.update rho i r)) (rho i)
          = -H⁻¹ * ((Real.exp (rho i) • S_basis i) * beta_hat rho) := h_solved
      _ = rust_delta_fn S_basis X W beta_hat rho i := by
        symm
        simpa [H] using rust_delta_correctness S_basis X W beta_hat rho i
  rw [h_chain, h_partial_rho, h_deriv_beta]
  unfold rust_correction_fn
  rfl

end GradientDescentVerification

end Gam.Solver.Reml