gam 0.3.17

Generalized penalized likelihood engine
Documentation
import Mathlib

namespace Gam.Linalg.Matrix

section WeightedOrthogonality

/-!
### Weighted Orthogonality Constraints

The calibration code applies sum-to-zero and polynomial orthogonality constraints
via nullspace projection. These theorems formalize that the projection is correct.
-/

set_option linter.unusedSectionVars false

variable {n m k : ℕ} [Fintype (Fin n)] [Fintype (Fin m)] [Fintype (Fin k)]
variable [DecidableEq (Fin n)] [DecidableEq (Fin m)] [DecidableEq (Fin k)]

/-- A diagonal weight matrix constructed from a weight vector. -/
def diagonalWeight (w : Fin n → ℝ) : Matrix (Fin n) (Fin n) ℝ :=
  Matrix.diagonal w

/-- Two column spaces are weighted-orthogonal if their weighted inner product is zero.
    Uses explicit transpose to avoid parsing issues. -/
def IsWeightedOrthogonal (A : Matrix (Fin n) (Fin m) ℝ)
    (B : Matrix (Fin n) (Fin k) ℝ) (W : Matrix (Fin n) (Fin n) ℝ) : Prop :=
  Matrix.transpose A * W * B = 0

/-- A matrix Z spans the nullspace of M if MZ = 0 and Z has maximal rank. -/
def SpansNullspace (Z : Matrix (Fin m) (Fin (m - k)) ℝ)
    (M : Matrix (Fin k) (Fin m) ℝ) : Prop :=
  M * Z = 0 ∧ Matrix.rank Z = m - k

/-- **Constraint Projection Correctness**: If Z spans the nullspace of BᵀWC,
    then B' = BZ is weighted-orthogonal to C.
    This validates `apply_weighted_orthogonality_constraint` in basis.rs.

    **Proof**:
    (BZ)ᵀ W C = Zᵀ (Bᵀ W C) = 0 because Z is in the nullspace of (Bᵀ W C)ᵀ.

    More precisely:
    - SpansNullspace Z M means M * Z = 0
    - Here M = (Bᵀ W C)ᵀ = Cᵀ Wᵀ B = Cᵀ W B (if W is symmetric, which diagonal matrices are)
    - We want: (BZ)ᵀ W C = Zᵀ Bᵀ W C
    - By associativity: Zᵀ Bᵀ W C = (Bᵀ W C)ᵀ · Z = M · Z = 0 (by h_spans.1)

    Wait, transpose swap: (Zᵀ (Bᵀ W C))ᵀ = (Bᵀ W C)ᵀ Z
    Actually: Zᵀ · (Bᵀ W C) has shape (m-k) × k, while M · Z = 0 where M = (Bᵀ W C)ᵀ

    The key relation is: Zᵀ · A = (Aᵀ · Z)ᵀ, so if Aᵀ · Z = 0, then Zᵀ · A = 0. -/
theorem constraint_projection_correctness
    (B : Matrix (Fin n) (Fin m) ℝ)
    (C : Matrix (Fin n) (Fin k) ℝ)
    (W : Matrix (Fin n) (Fin n) ℝ)
    (Z : Matrix (Fin m) (Fin (m - k)) ℝ)
    (h_spans : SpansNullspace Z (Matrix.transpose (Matrix.transpose B * W * C))) :
    IsWeightedOrthogonal (B * Z) C W := by
  unfold IsWeightedOrthogonal
  -- Goal: Matrix.transpose (B * Z) * W * C = 0
  -- Expand: (BZ)ᵀ W C = Zᵀ Bᵀ W C
  have h1 : Matrix.transpose (B * Z) = Matrix.transpose Z * Matrix.transpose B := by
    exact Matrix.transpose_mul B Z
  rw [h1]
  -- Now: Zᵀ Bᵀ W C
  -- We need to show: Zᵀ * (Bᵀ W C) = 0
  -- From h_spans: (Bᵀ W C)ᵀ * Z = 0
  -- Taking transpose: Zᵀ * (Bᵀ W C) = ((Bᵀ W C)ᵀ * Z)ᵀ
  -- If (Bᵀ W C)ᵀ * Z = 0, then Zᵀ * (Bᵀ W C) = 0ᵀ = 0
  have h2 : Matrix.transpose Z * Matrix.transpose B * W * C =
            Matrix.transpose Z * (Matrix.transpose B * W * C) := by
    simp only [Matrix.mul_assoc]
  rw [h2]
  -- Now use the nullspace condition
  have h3 : Matrix.transpose (Matrix.transpose B * W * C) * Z = 0 := h_spans.1
  -- Taking transpose of both sides: Zᵀ * (Bᵀ W C) = 0
  have h4 : Matrix.transpose Z * (Matrix.transpose B * W * C) =
            Matrix.transpose (Matrix.transpose (Matrix.transpose B * W * C) * Z) := by
    rw [Matrix.transpose_mul]
    simp only [Matrix.transpose_transpose]
  rw [h4, h3]
  simp only [Matrix.transpose_zero]

/-- The constrained basis preserves the column space spanned by valid coefficients. -/
theorem constrained_basis_spans_subspace
    (B : Matrix (Fin n) (Fin m) ℝ)
    (Z : Matrix (Fin m) (Fin (m - k)) ℝ)
    (β : Fin (m - k) → ℝ) :
    ∃ (β' : Fin m → ℝ), (B * Z).mulVec β = B.mulVec β' := by
  use Z.mulVec β
  rw [Matrix.mulVec_mulVec]

/-- Sum-to-zero constraint: the constraint matrix C is a column of ones. -/
def sumToZeroConstraint (n : ℕ) : Matrix (Fin n) (Fin 1) ℝ :=
  fun _ _ => 1

/-- After applying sum-to-zero constraint, basis evaluations sum to zero at data points.
    Note: This theorem uses a specialized constraint for k=1. -/
theorem sum_to_zero_after_projection
    (B : Matrix (Fin n) (Fin m) ℝ)
    (W : Matrix (Fin n) (Fin n) ℝ) (hW_diag : W = Matrix.diagonal (fun i => W i i))
    (Z : Matrix (Fin m) (Fin (m - 1)) ℝ)
    (h_constraint : SpansNullspace Z (Matrix.transpose (Matrix.transpose B * W * sumToZeroConstraint n)))
    (β : Fin (m - 1) → ℝ) :
    Finset.univ.sum (fun i : Fin n => ((B * Z).mulVec β) i * W i i) = 0 := by
  -- Use constraint_projection_correctness to get weighted orthogonality
  have h_orth : IsWeightedOrthogonal (B * Z) (sumToZeroConstraint n) W :=
    constraint_projection_correctness B (sumToZeroConstraint n) W Z h_constraint
  -- IsWeightedOrthogonal (B * Z) C W means: (BZ)ᵀ * W * C = 0
  -- For C = sumToZeroConstraint n (all ones), the (i,0) entry of (BZ)ᵀ * W * C is:
  --   Σⱼ ((BZ)ᵀ * W)_{i,j} * C_{j,0} = Σⱼ ((BZ)ᵀ * W)_{i,j} * 1 = Σⱼ ((BZ)ᵀ * W)_{i,j}
  -- When we sum over the "first column" being all zeros, we get the constraint.
  -- More directly: the (0,0) entry of Cᵀ * (BZ)ᵀ * W * C = 0
  -- which expands to: Σᵢ Σⱼ C_{i,0} * ((BZ)ᵀ * W)_{j,i} * C_{j,0}
  --                 = Σᵢ Σⱼ 1 * ((BZ)ᵀ * W)_{j,i} * 1
  -- For a diagonal W, ((BZ)ᵀ * W)_{j,i} = (BZ)_{i,j} * W_{i,i}
  --
  -- Actually the goal is: Σᵢ (BZ · β)ᵢ * Wᵢᵢ = 0
  -- This is related to the weighted orthogonality by:
  --   (sumToZeroConstraint n)ᵀ * diag(W) * (BZ · β)
  -- where we interpret W as having diagonal form.
  --
  -- The proof uses that (BZ)ᵀ * W * C = 0 implies the weighted inner product
  -- of any column of BZ with the ones vector is zero.
  unfold IsWeightedOrthogonal at h_orth
  -- h_orth : Matrix.transpose (B * Z) * W * sumToZeroConstraint n = 0
  -- For any column j of (BZ), we have: Σᵢ (BZ)ᵢⱼ * (W * 1)ᵢ = 0
  -- where 1 is the all-ones vector.
  -- The goal is: Σᵢ (Σⱼ (BZ)ᵢⱼ * βⱼ) * Wᵢᵢ = 0
  --            = Σⱼ βⱼ * (Σᵢ (BZ)ᵢⱼ * Wᵢᵢ)
  -- Each inner sum Σᵢ (BZ)ᵢⱼ * Wᵢᵢ corresponds to a column of (BZ)ᵀ * W * 1
  -- Since (BZ)ᵀ * W * C = 0 where C is all ones, each entry is 0.
  -- Therefore the entire sum is 0.

  -- Step 1: Expand mulVec and rewrite the goal as a double sum
  simp only [Matrix.mulVec, dotProduct]
  -- Goal: Σᵢ (Σⱼ (B*Z)ᵢⱼ * βⱼ) * Wᵢᵢ = 0

  -- Step 2: Use diagonal form of W to simplify
  rw [hW_diag]
  simp

  -- Step 3: Swap the order of summation
  -- Σᵢ (Σⱼ aᵢⱼ * βⱼ) * wᵢ = Σⱼ βⱼ * (Σᵢ aᵢⱼ * wᵢ)
  classical
  have h_swap :
      ∑ x, (∑ x_1, (B * Z) x x_1 * β x_1) * W x x
        = ∑ x, ∑ x_1, (B * Z) x x_1 * β x_1 * W x x := by
    refine Finset.sum_congr rfl ?_
    intro x _
    calc
      (∑ x_1, (B * Z) x x_1 * β x_1) * W x x
          = W x x * ∑ x_1, (B * Z) x x_1 * β x_1 := by ring
      _ = ∑ x_1, W x x * ((B * Z) x x_1 * β x_1) := by
          simpa [Finset.mul_sum]
      _ = ∑ x_1, (B * Z) x x_1 * β x_1 * W x x := by
          refine Finset.sum_congr rfl ?_
          intro x_1 _
          ring
  rw [h_swap]
  rw [Finset.sum_comm]

  -- After swap: Σⱼ Σᵢ (B*Z)ᵢⱼ * βⱼ * Wᵢᵢ = Σⱼ βⱼ * (Σᵢ (B*Z)ᵢⱼ * Wᵢᵢ)
  have h_factor :
      ∀ y, ∑ x, (B * Z) x y * β y * W x x = β y * ∑ x, (B * Z) x y * W x x := by
    intro y
    calc
      ∑ x, (B * Z) x y * β y * W x x
          = ∑ x, β y * ((B * Z) x y * W x x) := by
              refine Finset.sum_congr rfl ?_
              intro x _
              ring
      _ = β y * ∑ x, (B * Z) x y * W x x := by
              simpa [Finset.mul_sum]
  simp [h_factor]
  -- Now: Σⱼ βⱼ * (Σᵢ (B*Z)ᵢⱼ * Wᵢᵢ)

  -- Step 4: Show each inner sum Σᵢ (B*Z)ᵢⱼ * Wᵢᵢ = 0 using h_orth
  -- The (j, 0) entry of (BZ)ᵀ * W * C is: Σᵢ (BZ)ᵀⱼᵢ * (W * C)ᵢ₀
  --                                      = Σᵢ (BZ)ᵢⱼ * (Σₖ Wᵢₖ * Cₖ₀)
  -- For diagonal W and C = all ones:    = Σᵢ (BZ)ᵢⱼ * Wᵢᵢ * 1
  --                                      = Σᵢ (BZ)ᵢⱼ * Wᵢᵢ
  -- Since h_orth says the whole matrix is 0, entry (j, 0) = 0.

  apply Finset.sum_eq_zero
  intro j _
  -- Show βⱼ * (Σᵢ (B*Z)ᵢⱼ * Wᵢᵢ) = 0
  -- Suffices to show Σᵢ (B*Z)ᵢⱼ * Wᵢᵢ = 0
  suffices h_inner : Finset.univ.sum (fun i => (B * Z) i j * W i i) = 0 by
    simp [h_inner]

  -- Extract from h_orth: the (j, 0) entry of (BZ)ᵀ * W * C = 0
  have h_entry : (Matrix.transpose (B * Z) * W * sumToZeroConstraint n) j 0 = 0 := by
    rw [h_orth]
    rfl

  -- Expand this entry
  simp only [Matrix.mul_apply, Matrix.transpose_apply, sumToZeroConstraint] at h_entry
  -- (BZ)ᵀ * W * C at (j, 0) = Σₖ ((BZ)ᵀ * W)ⱼₖ * Cₖ₀ = Σₖ ((BZ)ᵀ * W)ⱼₖ * 1
  -- = Σₖ (Σᵢ (BZ)ᵀⱼᵢ * Wᵢₖ) = Σₖ (Σᵢ (BZ)ᵢⱼ * Wᵢₖ)

  -- For diagonal W, Wᵢₖ = 0 unless i = k, so:
  -- = Σᵢ (BZ)ᵢⱼ * Wᵢᵢ (the i=k diagonal terms)

  -- The entry expansion gives us what we need
  convert h_entry using 1
  -- Need to show the sum forms are equal

  -- Expand both sides more carefully
  simp only [Matrix.mul_apply]
  -- LHS: Σᵢ (B*Z)ᵢⱼ * Wᵢᵢ
  -- RHS: Σₖ (Σᵢ (B*Z)ᵢⱼ * Wᵢₖ) * 1

  -- Use diagonal structure: Wᵢₖ = W i i if i = k, else 0
  rw [hW_diag]
  simp [Matrix.diagonal_apply]

  -- Inner sum: Σᵢ (B*Z)ᵢⱼ * (if i = k then W i i else 0)
  -- = (B*Z)ₖⱼ * W k k (only i=k term survives)

end WeightedOrthogonality

end Gam.Linalg.Matrix