import typing
import importlib
import sys
import subprocess
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
tomli: typing.Any = importlib.import_module("tomllib")
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
R_MODEL_PATH = SCRIPT_DIR / 'gam_model_fit.rds'
RUST_MODEL_CONFIG_PATH = PROJECT_ROOT / 'model.toml'
N_POINTS_PLOT = 400
def print_array_summary(name: typing.Any, arr: typing.Any) -> None:
if arr.ndim == 1:
print(f" [DIAGNOSTIC] {name} | Shape: {arr.shape} | Min: {np.min(arr):.4f} | Max: {np.max(arr):.4f} | Mean: {np.mean(arr):.4f} | Std: {np.std(arr):.4f}")
print(f" -> First 5 elements: {arr[:5]}")
else:
print(f" [DIAGNOSTIC] {name} | Shape: {arr.shape} | Min: {np.min(arr):.4f} | Max: {np.max(arr):.4f} | Mean: {np.mean(arr):.4f}")
col_stds = np.std(arr, axis=0)
print(f" -> Stds of first 5 columns: {col_stds[:5]}")
print(f" -> First 2x5 slice:\n{arr[:2, :5]}")
def evaluate_bspline_basis(x: typing.Any, knots: typing.Any, degree: typing.Any) -> typing.Any:
num_knots = len(knots)
num_bases = num_knots - degree - 1
basis_matrix = np.zeros((len(x), num_bases))
x_clamped = np.clip(x, knots[degree], knots[num_bases])
for i, val in enumerate(x_clamped):
if val >= knots[num_bases]:
mu = num_bases - 1
else:
mu = np.searchsorted(knots, val, side='right') - 1
mu = max(degree, mu)
b = np.zeros(degree + 1)
b[0] = 1.0
left = np.zeros(degree + 1)
right = np.zeros(degree + 1)
for d in range(1, degree + 1):
left[d] = val - knots[mu + 1 - d]
right[d] = knots[mu + d] - val
saved = 0.0
for r in range(d):
den = right[r + 1] + left[d - r]
temp = 0.0
if abs(den) > 1e-12:
temp = b[r] / den
b[r] = saved + right[r + 1] * temp
saved = left[d - r] * temp
b[d] = saved
start = mu - degree
end = start + degree + 1
if start < 0:
b = b[-start:]
start = 0
if end > num_bases:
b = b[: num_bases - start]
end = num_bases
basis_matrix[i, start:end] = b
return basis_matrix
def get_mgcv_basis_data() -> typing.Any:
print("\n" + "="*80)
print("--- STAGE 1: Extracting CONSTRAINED basis from R/mgcv model ---")
print("="*80)
x_axis_file, basis_file, coeffs_file = [SCRIPT_DIR / f for f in ["t_x.csv", "t_b.csv", "t_c.csv"]]
r_script = f"""
suppressPackageStartupMessages(library(mgcv))
model <- readRDS('{R_MODEL_PATH}')
var_range <- range(model$model$variable_one)
x_seq <- seq(var_range[1], var_range[2], length.out = {N_POINTS_PLOT})
newdata <- data.frame(variable_one = x_seq, variable_two = 0)
lp_matrix <- predict(model, newdata = newdata, type = "lpmatrix")
smooth_info <- model$smooth[[1]]
write.csv(data.frame(x=x_seq), '{x_axis_file.name}', row.names=FALSE)
write.csv(constrained_basis_functions, '{basis_file.name}', row.names=FALSE)
write.csv(data.frame(coeffs=basis_coeffs), '{coeffs_file.name}', row.names=FALSE)
cat("R: Extracted main effect basis for '", smooth_info$label, "'.\\n", sep="")
"""
try:
result = subprocess.run(["Rscript", "-e", r_script], check=True, text=True, cwd=SCRIPT_DIR, capture_output=True)
print(f" [INFO] R stdout: {result.stdout.strip()}")
if result.stderr:
print(f" [INFO] R stderr: {result.stderr.strip()}")
x_axis = pd.read_csv(x_axis_file)['x'].values
basis_matrix = pd.read_csv(basis_file).values
coeffs = pd.read_csv(coeffs_file)['coeffs'].values
print(" [PRINT] mgcv: Loaded x-axis vector.")
print_array_summary("mgcv x_axis", x_axis)
print(" [PRINT] mgcv: Loaded constrained basis matrix.")
print_array_summary("mgcv basis_matrix", basis_matrix)
print(" [PRINT] mgcv: Loaded coefficients.")
print_array_summary("mgcv coeffs", coeffs)
return {"x_axis": x_axis, "basis_matrix": basis_matrix, "coeffs": coeffs}
except subprocess.CalledProcessError as e:
print(f"\n--- FATAL ERROR: R script execution failed. ---\n{e.stderr}")
sys.exit(1)
finally:
for f in [x_axis_file, basis_file, coeffs_file]:
if f.exists():
f.unlink()
def get_gnomon_basis_data() -> typing.Any:
print("\n" + "="*80)
print("--- STAGE 2: Reconstructing CONSTRAINED basis from Rust/gnomon model ---")
print("="*80)
with open(RUST_MODEL_CONFIG_PATH, "rb") as f:
toml_data = tomli.load(f)
def _find_key(obj: typing.Any, key: typing.Any) -> typing.Any:
if isinstance(obj, dict):
if key in obj:
return obj[key]
for value in obj.values():
found = _find_key(value, key)
if found is not None:
return found
elif isinstance(obj, list):
for value in obj:
found = _find_key(value, key)
if found is not None:
return found
return None
knots_data = _find_key(toml_data, "knot_vector") or _find_key(toml_data, "knots")
degree = _find_key(toml_data, "degree")
coeffs_data = _find_key(toml_data, "coefficients")
constraint_info = _find_key(toml_data, "transform")
x_range = _find_key(toml_data, "x_range")
if isinstance(knots_data, dict):
knots_data = knots_data.get("data", knots_data)
if isinstance(coeffs_data, dict):
coeffs_data = coeffs_data.get("data", coeffs_data)
if isinstance(constraint_info, dict) and "transform" in constraint_info:
constraint_info = constraint_info["transform"]
if knots_data is None or degree is None or coeffs_data is None or constraint_info is None:
print("\n--- FATAL ERROR: Could not locate required spline fields in model.toml. ---")
print("Required: knots, degree, coefficients, and constraint transform.")
sys.exit(1)
knots = np.asarray(knots_data, dtype=float)
degree = int(degree)
coeffs = np.asarray(coeffs_data, dtype=float)
if x_range is None:
x_range = [knots[degree], knots[-degree - 1]]
print_array_summary("Knot Vector", knots)
print(f" [PRINT] gnomon: Loaded spline degree {degree}.")
print_array_summary("Coefficients", coeffs)
num_raw_bases = len(knots) - degree - 1
print(f" [INFO] gnomon: Derived k={num_raw_bases} total raw B-spline bases from knot vector.")
z_dims, z_data = constraint_info['dim'], constraint_info['data']
z_transform = np.array(z_data).reshape(z_dims)
print_array_summary("z_transform", z_transform)
x_axis = np.linspace(x_range[0], x_range[1], N_POINTS_PLOT)
raw_basis_matrix = evaluate_bspline_basis(x_axis, knots, degree)
print(" [PRINT] gnomon: Reconstructed FULL raw basis matrix.")
print_array_summary("raw_basis_matrix", raw_basis_matrix)
raw_main_basis_functions = raw_basis_matrix[:, 1:]
print(" [INFO] gnomon: Sliced raw basis to get the non-constant bases for constraining.")
print_array_summary("raw_main_basis_functions", raw_main_basis_functions)
if raw_main_basis_functions.shape[1] != z_transform.shape[0]:
print("\n--- FATAL ERROR: Dimension mismatch for constraint! ---")
print(f"Raw main basis columns: {raw_main_basis_functions.shape[1]}, Z-transform rows: {z_transform.shape[0]}")
sys.exit(1)
constrained_basis_matrix = raw_main_basis_functions @ z_transform
print(" [PRINT] gnomon: Created FINAL constrained basis matrix.")
print_array_summary("constrained_basis_matrix", constrained_basis_matrix)
if constrained_basis_matrix.shape[1] != len(coeffs):
print("\n--- FATAL ERROR: Final dimension mismatch! ---")
print(f"Final basis columns: {constrained_basis_matrix.shape[1]}, Coefficients length: {len(coeffs)}")
sys.exit(1)
print(" [INFO] All dimension checks passed successfully.")
return {"x_axis": x_axis, "basis_matrix": constrained_basis_matrix, "coeffs": coeffs}
def create_comparison_plot(mgcv_data: typing.Any, gnomon_data: typing.Any) -> None:
print("\n" + "="*80)
print("--- STAGE 3: Generating the SINGLE 3x2 Comparison Plot ---")
print("="*80)
gnomon_basis = gnomon_data['basis_matrix']
gnomon_coeffs = gnomon_data['coeffs']
gnomon_weighted = gnomon_basis * gnomon_coeffs
gnomon_final_curve = gnomon_weighted.sum(axis=1)
mgcv_basis = mgcv_data['basis_matrix']
mgcv_coeffs = mgcv_data['coeffs']
mgcv_weighted_uncentered = mgcv_basis * mgcv_coeffs
mgcv_final_curve_uncentered = mgcv_weighted_uncentered.sum(axis=1)
mean_offset = np.mean(mgcv_final_curve_uncentered)
mgcv_final_curve_centered = mgcv_final_curve_uncentered - mean_offset
mgcv_weighted_centered = mgcv_weighted_uncentered - np.mean(mgcv_weighted_uncentered, axis=0)
print(" [PRINT] mgcv: Calculated components.")
print_array_summary("mgcv_weighted_centered", mgcv_weighted_centered)
print_array_summary("mgcv_final_curve_centered", mgcv_final_curve_centered)
print(" [PRINT] gnomon: Calculated components.")
print_array_summary("gnomon_weighted", gnomon_weighted)
print_array_summary("gnomon_final_curve", gnomon_final_curve)
fig, axes = plt.subplots(3, 2, figsize=(15, 18), sharex=True, constrained_layout=True)
fig.suptitle("Internal Component Comparison: mgcv vs gnomon", fontsize=20)
axes[0, 0].set_title("mgcv Model (Computational Basis)", fontsize=16)
axes[0, 1].set_title("gnomon Model (Interpretable Basis)", fontsize=16)
axes[0, 0].plot(mgcv_data['x_axis'], mgcv_basis, alpha=0.7)
axes[0, 0].set_ylabel("Basis Value", fontsize=12)
axes[0, 1].plot(gnomon_data['x_axis'], gnomon_basis, alpha=0.7)
axes[1, 0].plot(mgcv_data['x_axis'], mgcv_weighted_centered, alpha=0.7)
axes[1, 0].axhline(0, color='black', linestyle='--', linewidth=1)
axes[1, 0].set_ylabel("Centered Weighted Basis Value", fontsize=12)
axes[1, 1].plot(gnomon_data['x_axis'], gnomon_weighted, alpha=0.7)
axes[1, 1].axhline(0, color='black', linestyle='--', linewidth=1)
axes[2, 0].plot(mgcv_data['x_axis'], mgcv_final_curve_centered, color='crimson', linewidth=3)
axes[2, 0].set_xlabel(mgcv_data.get('var_name', 'variable_one'), fontsize=12)
axes[2, 0].set_ylabel("Centered Smooth Contribution", fontsize=12)
axes[2, 0].grid(True, linestyle=':', alpha=0.7)
axes[2, 1].plot(mgcv_data['x_axis'], mgcv_final_curve_centered, label='mgcv (Centered)', color='blue', linewidth=6, alpha=0.6)
axes[2, 1].plot(gnomon_data['x_axis'], gnomon_final_curve, label='gnomon', color='red', linewidth=2.5, linestyle='--')
axes[2, 1].legend(title="Model")
axes[2, 1].set_title("Verification: Overlay of Final Curves", fontsize=14)
axes[2, 1].grid(True, linestyle=':', alpha=0.7)
plt.show()
def main() -> None:
for f in [R_MODEL_PATH, RUST_MODEL_CONFIG_PATH]:
if not f.is_file():
print(f"--- FATAL ERROR: Required file not found: '{f}' ---")
sys.exit(1)
mgcv_data = get_mgcv_basis_data()
gnomon_data = get_gnomon_basis_data()
create_comparison_plot(mgcv_data, gnomon_data)
print("\n--- Script finished successfully. ---")
if __name__ == "__main__":
main()