gam 0.3.13

Generalized penalized likelihood engine
Documentation
import Mathlib

namespace Gam.Solver.Estimate

section WoodReparameterization

/-!
### Wood's Stable Reparameterization

The PIRLS solver in estimate.rs uses Wood (2011)'s reparameterization to
avoid numerical instability. This section proves the algebraic equivalence.
-/

variable {n p : ℕ} [Fintype (Fin n)] [Fintype (Fin p)]

/-- Quadratic form: βᵀSβ computed as dot product. -/
noncomputable def quadForm (S : Matrix (Fin p) (Fin p) ℝ) (β : Fin p → ℝ) : ℝ :=
  Finset.univ.sum (fun i => β i * (S.mulVec β) i)

/-- Penalized least squares objective: ‖y - Xβ‖² + βᵀSβ -/
noncomputable def penalized_objective
    (X : Matrix (Fin n) (Fin p) ℝ) (y : Fin n → ℝ)
    (S : Matrix (Fin p) (Fin p) ℝ) (β : Fin p → ℝ) : ℝ :=
  ‖y - X.mulVec β‖^2 + quadForm S β

/-- A matrix Q is orthogonal if QQᵀ = I. Uses explicit transpose. -/
def IsOrthogonal (Q : Matrix (Fin p) (Fin p) ℝ) : Prop :=
  Q * Matrix.transpose Q = 1 ∧ Matrix.transpose Q * Q = 1

/-- Transpose-dot identity: (Au) ⬝ v = u ⬝ (Aᵀv).
    This is the key algebraic identity for bilinear form transformations. -/
lemma sum_mulVec_mul_eq_sum_mul_transpose_mulVec
    (A : Matrix (Fin p) (Fin p) ℝ) (u v : Fin p → ℝ) :
    ∑ i, (A.mulVec u) i * v i = ∑ i, u i * ((Matrix.transpose A).mulVec v) i := by
  -- Unfold mulVec and dotProduct to get explicit sums
  simp only [Matrix.mulVec, dotProduct, Matrix.transpose_apply]
  -- LHS: ∑ i, (∑ j, A i j * u j) * v i
  -- RHS: ∑ i, u i * (∑ j, A j i * v j)
  -- Distribute the outer multiplication into the inner sums
  simp only [Finset.sum_mul, Finset.mul_sum]
  -- LHS: ∑ i, ∑ j, A i j * u j * v i
  -- RHS: ∑ i, ∑ j, u i * A j i * v j
  -- Convert to sums over Fin p × Fin p using sum_product'
  simp only [← Finset.sum_product']
  -- Now both sides are sums over univ ×ˢ univ
  -- Use Finset.sum_equiv with Equiv.prodComm to swap indices
  refine Finset.sum_equiv (Equiv.prodComm (Fin p) (Fin p)) ?_ ?_
  · intro _; simp
  · intro ⟨i, j⟩ _
    simp only [Equiv.prodComm_apply, Prod.swap_prod_mk]
    ring

/-- The penalty transforms as a congruence under reparameterization.

    **Proof**: (Qβ')ᵀ S (Qβ') = β'ᵀ Qᵀ S Q β' = β'ᵀ (QᵀSQ) β'
    This is just associativity of matrix-vector multiplication.

    This is a key step in Wood's (2011) stable reparameterization for GAMs,
    as it shows how the penalty matrix S transforms under an orthogonal change
    of basis Q. By choosing Q to be the eigenvectors of S, the transformed
    penalty matrix QᵀSQ becomes diagonal, simplifying the optimization problem. -/
theorem penalty_congruence
    (S : Matrix (Fin p) (Fin p) ℝ) (Q : Matrix (Fin p) (Fin p) ℝ)
    (β' : Fin p → ℝ) (_h_orth : IsOrthogonal Q) :
    quadForm S (Q.mulVec β') = quadForm (Matrix.transpose Q * S * Q) β' := by
  -- quadForm S (Qβ') = Σᵢ (Qβ')ᵢ * (S(Qβ'))ᵢ = (Qβ')ᵀ S (Qβ')
  -- = β'ᵀ Qᵀ S Q β' = β'ᵀ (QᵀSQ) β' = quadForm (QᵀSQ) β'
  unfold quadForm
  -- LHS: Σᵢ (Q.mulVec β') i * (S.mulVec (Q.mulVec β')) i
  -- RHS: Σᵢ β' i * ((QᵀSQ).mulVec β') i

  -- Step 1: Simplify RHS using mulVec_mulVec
  have h_rhs : (Matrix.transpose Q * S * Q).mulVec β' =
               (Matrix.transpose Q).mulVec (S.mulVec (Q.mulVec β')) := by
    simp only [Matrix.mul_assoc, Matrix.mulVec_mulVec]

  rw [h_rhs]
  -- Now need: Σᵢ (Qβ')ᵢ * (S(Qβ'))ᵢ = Σᵢ β'ᵢ * (Qᵀ(S(Qβ')))ᵢ

  -- Step 2: Apply transpose-dot identity
  -- Let w = Q.mulVec β' and u = S.mulVec w
  -- LHS = Σᵢ w i * u i
  -- RHS = Σᵢ β' i * (Qᵀ.mulVec u) i
  -- By sum_mulVec_mul_eq_sum_mul_transpose_mulVec with A = Q:
  --   Σᵢ (Q.mulVec β') i * u i = Σᵢ β' i * (Qᵀ.mulVec u) i
  exact sum_mulVec_mul_eq_sum_mul_transpose_mulVec Q β' (S.mulVec (Q.mulVec β'))

/-- **Reparameterization Equivalence**: Under orthogonal change of variables β = Qβ',
    the penalized objective transforms covariantly.
    This validates `stable_reparameterization` in estimate.rs.

    **Proof Sketch (Isometry)**:
    1. Residual: y - X(Qβ') = y - (XQ)β', so ‖residual‖² depends only on XQ, not Q separately
    2. Penalty: (Qβ')ᵀS(Qβ') = β'ᵀ(QᵀSQ)β' by associativity of matrix multiplication

    This shows minimizing over β = Qβ' is equivalent to minimizing over β' with transformed design/penalty. -/
theorem reparameterization_equivalence
    (X : Matrix (Fin n) (Fin p) ℝ) (y : Fin n → ℝ)
    (S : Matrix (Fin p) (Fin p) ℝ) (Q : Matrix (Fin p) (Fin p) ℝ)
    (β' : Fin p → ℝ) (h_orth : IsOrthogonal Q) :
    penalized_objective X y S (Q.mulVec β') =
    penalized_objective (X * Q) y (Matrix.transpose Q * S * Q) β' := by
  unfold penalized_objective
  -- Step 1: Show the residual norms are equal
  -- X(Qβ') = (XQ)β' by Matrix.mulVec_mulVec
  have h_residual : y - X.mulVec (Q.mulVec β') = y - (X * Q).mulVec β' := by
    rw [Matrix.mulVec_mulVec]
  rw [h_residual]

  -- Step 2: Show the penalty terms are equal
  -- quadForm S (Qβ') = quadForm (QᵀSQ) β'
  have h_penalty : quadForm S (Q.mulVec β') = quadForm (Matrix.transpose Q * S * Q) β' := by
    exact penalty_congruence S Q β' h_orth

  rw [h_penalty]

omit [Fintype (Fin n)] in
/-- The fitted values are invariant under reparameterization. -/
theorem fitted_values_invariant
    (X : Matrix (Fin n) (Fin p) ℝ) (Q : Matrix (Fin p) (Fin p) ℝ)
    (β : Fin p → ℝ) (_h_orth : IsOrthogonal Q)
    (β' : Fin p → ℝ) (h_relation : β = Q.mulVec β') :
    X.mulVec β = (X * Q).mulVec β' := by
  rw [h_relation]
  rw [Matrix.mulVec_mulVec]

/-- Eigenvalue structure is preserved: if S = QΛQᵀ, then QᵀSQ = Λ.
    This is the key insight that makes the reparameterization numerically stable.

    **Proof**: QᵀSQ = Qᵀ(QΛQᵀ)Q = (QᵀQ)Λ(QᵀQ) = IΛI = Λ by orthogonality of Q. -/
theorem eigendecomposition_diagonalizes
    (S : Matrix (Fin p) (Fin p) ℝ) (Q : Matrix (Fin p) (Fin p) ℝ)
    (Λ : Matrix (Fin p) (Fin p) ℝ)
    (h_orth : IsOrthogonal Q)
    (h_decomp : S = Q * Λ * Matrix.transpose Q)
    (_h_diag : ∀ i j : Fin p, i ≠ j → Λ i j = 0) :
    Matrix.transpose Q * S * Q = Λ := by
  rw [h_decomp]
  -- Qᵀ(QΛQᵀ)Q = (QᵀQ)Λ(QᵀQ) = IΛI = Λ
  have h_assoc : Matrix.transpose Q * (Q * Λ * Matrix.transpose Q) * Q
                = Matrix.transpose Q * Q * Λ * (Matrix.transpose Q * Q) := by
    -- Use associativity of matrix multiplication
    simp only [Matrix.mul_assoc]
  rw [h_assoc]
  -- By orthogonality: QᵀQ = I
  rw [h_orth.2]
  simp only [Matrix.one_mul, Matrix.mul_one]

/-- The optimal β under the reparameterized system transforms back correctly. -/
theorem optimal_solution_transforms
    (X : Matrix (Fin n) (Fin p) ℝ) (y : Fin n → ℝ)
    (S : Matrix (Fin p) (Fin p) ℝ) (Q : Matrix (Fin p) (Fin p) ℝ)
    (h_orth : IsOrthogonal Q) (β_opt : Fin p → ℝ) (β'_opt : Fin p → ℝ)
    (h_opt : ∀ β, penalized_objective X y S β_opt ≤ penalized_objective X y S β)
    (h_opt'_unique :
      ∀ β',
        penalized_objective (X * Q) y (Matrix.transpose Q * S * Q) β' ≤
            penalized_objective (X * Q) y (Matrix.transpose Q * S * Q) β'_opt ↔
          β' = β'_opt) :
    X.mulVec β_opt = (X * Q).mulVec β'_opt := by
  -- Let `g` be the reparameterized objective function
  let g := penalized_objective (X * Q) y (Matrix.transpose Q * S * Q)
  -- Let `β'_test` be the transformed original optimal solution
  let β'_test := (Matrix.transpose Q).mulVec β_opt
  -- We show that `β'_test` is a minimizer for `g`. `h_opt` shows `β_opt` minimizes the original objective `f`.
  -- By `reparameterization_equivalence`, `f(Qβ') = g(β')`.
  -- So `g(β'_test) = f(Qβ'_test) = f(β_opt)`. For any other `β'`, `g(β') = f(Qβ')`.
  -- Since `f(β_opt) ≤ f(Qβ')`, we have `g(β'_test) ≤ g(β')`.
  have h_test_is_opt : ∀ β', g β'_test ≤ g β' := by
    intro β'
    let f := penalized_objective X y S
    have h_g_eq_f : ∀ b, g b = f (Q.mulVec b) :=
      fun b => (reparameterization_equivalence X y S Q b h_orth).symm
    rw [h_g_eq_f, h_g_eq_f]
    have h_simplify : Q.mulVec β'_test = β_opt := by
      simp only [β'_test, Matrix.mulVec_mulVec, h_orth.1, Matrix.one_mulVec]
    rw [h_simplify]
    exact h_opt (Q.mulVec β')
  -- From `h_test_is_opt`, `g(β'_test) ≤ g(β'_opt)`. By uniqueness `h_opt'_unique`, this implies `β'_test = β'_opt`.
  have h_beta_eq : β'_test = β'_opt := (h_opt'_unique β'_test).mp (h_test_is_opt β'_opt)
  -- The final goal `X.mulVec β_opt = (X * Q).mulVec β'_opt` follows by substituting this equality.
  rw [← h_beta_eq]
  simp only [β'_test, Matrix.mulVec_mulVec, Matrix.mul_assoc, h_orth.1, Matrix.mul_one]

end WoodReparameterization

end Gam.Solver.Estimate