gam 0.1.16

Generalized penalized likelihood engine
Documentation
"""High-level polygenic score calibration helpers.

This module provides :class:`PgsCalibration`, a one-object wrapper around the
Stage-1 preprocessing pattern used in genotype-score calibration: fitting a
conditional transformation-normal model of a raw polygenic score on a basis
expansion over joint principal-component space, then transforming new samples
to population-calibrated z-scores.

The helper encodes the following default choices, each of which can be
overridden at construction time:

* A Duchon radial basis over the PC columns with ``len(pc_columns) + 1``
  centers, order ``1``, power ``2``, and triple operator regularization.
* A fixed Duchon ``length_scale`` so per-axis anisotropy is identifiable
  separately from the global smoothing scale.
* Per-axis anisotropic scaling (``scale_dimensions=True``).
* A transformation-normal likelihood so the fitted response is interpretable
  as a conditional standard normal z-score for each row.
"""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Sequence

from ._api import fit as fit_model
from ._api import load as load_model
from ._model import Model


__all__ = ["PgsCalibration"]


@dataclass
class PgsCalibration:
    """Fit-and-transform helper for polygenic-score calibration on PC space.

    Parameters
    ----------
    pc_columns:
        Ordered list of principal-component column names (e.g.
        ``["pc1", "pc2", "pc3", "pc4"]``).
    pgs_column:
        Name of the raw polygenic-score column to calibrate.
    duchon_centers:
        Number of Duchon basis centers. Defaults to ``len(pc_columns) + 1``.
    duchon_order:
        Duchon radial-basis order (``m``). Defaults to ``1``.
    duchon_power:
        Duchon radial-basis power (``s``). Defaults to ``2``.
    duchon_length_scale:
        Fixed Duchon radial length scale. Defaults to ``1.0`` so anisotropic
        PC scaling is active but does not duplicate the global smoothing scale.
    scale_dimensions:
        Forwarded to :func:`gam.fit`. When ``True`` (the default), enables
        learned per-axis anisotropic length scales on the Duchon smooth.
    out_column:
        Name of the calibrated column appended by :meth:`transform`. Defaults
        to ``"pgs_ctn_z"``.
    extra_fit_kwargs:
        Additional kwargs forwarded verbatim to :func:`gam.fit` (e.g.
        ``{"firth": True}``).

    Examples
    --------
    >>> calib = PgsCalibration(
    ...     pc_columns=["pc1", "pc2", "pc3", "pc4"],
    ...     pgs_column="PGS",
    ... )
    >>> calib.fit(df_train)
    >>> df_train = calib.transform(df_train)
    >>> df_test = calib.transform(df_test)
    """

    pc_columns: Sequence[str]
    pgs_column: str = "PGS"
    duchon_centers: int | None = None
    duchon_order: int = 1
    duchon_power: int = 2
    duchon_length_scale: float = 1.0
    scale_dimensions: bool | None = True
    out_column: str = "pgs_ctn_z"
    extra_fit_kwargs: dict[str, Any] = field(default_factory=dict)

    _model: Model | None = field(default=None, init=False, repr=False)
    _resolved_centers: int | None = field(default=None, init=False, repr=False)

    def __post_init__(self) -> None:
        if not self.pc_columns:
            raise ValueError("pc_columns must be a non-empty sequence")
        if not self.pgs_column:
            raise ValueError("pgs_column must be provided")
        self._resolved_centers = (
            self.duchon_centers
            if self.duchon_centers is not None
            else len(self.pc_columns) + 1
        )

    @property
    def formula(self) -> str:
        """The Wilkinson-style formula used for the Stage-1 fit."""
        pc_args = ", ".join(self.pc_columns)
        duchon = (
            f"duchon({pc_args}, centers={self._resolved_centers}, "
            f"order={self.duchon_order}, power={self.duchon_power}, "
            f"length_scale={self.duchon_length_scale:g})"
        )
        return f"{self.pgs_column} ~ {duchon}"

    @property
    def model(self) -> Model:
        """The underlying fitted :class:`gam.Model`. Raises if not yet fit."""
        if self._model is None:
            raise RuntimeError(
                "PgsCalibration has not been fit yet; call .fit(data) first"
            )
        return self._model

    def fit(self, data: Any) -> "PgsCalibration":
        """Fit the Stage-1 transformation-normal calibration model."""
        fit_kwargs: dict[str, Any] = {
            "transformation_normal": True,
            "scale_dimensions": self.scale_dimensions,
        }
        fit_kwargs.update(self.extra_fit_kwargs)
        self._model = fit_model(data, self.formula, **fit_kwargs)
        return self

    def transform(self, data: Any) -> Any:
        """Append a calibrated z-score column to ``data``.

        When ``data`` is a pandas DataFrame the returned object is a new
        DataFrame with ``self.out_column`` appended. For other input kinds
        (pyarrow table, dict of columns, list of records) the return type
        mirrors the input.
        """
        z = self.predict(data)
        return _attach_z_column(data, z, self.out_column)

    def fit_transform(self, data: Any) -> Any:
        """Convenience: :meth:`fit` then :meth:`transform` on the same data."""
        self.fit(data)
        return self.transform(data)

    def predict(self, data: Any) -> Any:
        """Return the raw calibrated z-score array without attaching it."""
        return self.model.predict(data)

    def save(self, path: str | Path) -> None:
        """Persist the fitted model and wrapper metadata."""
        model_path = Path(path)
        self.model.save(model_path)
        _manifest_path(model_path).write_text(
            json.dumps(self._manifest(), indent=2, sort_keys=True) + "\n",
            encoding="utf-8",
        )

    @classmethod
    def load(
        cls,
        path: str | Path,
        *,
        pc_columns: Sequence[str] | None = None,
        pgs_column: str = "PGS",
        **kwargs: Any,
    ) -> "PgsCalibration":
        """Load a previously-saved calibration model.

        Wrapper metadata is restored from the sidecar manifest written by
        :meth:`save`. ``pc_columns`` and ``pgs_column`` may still be supplied
        to override the manifest when loading older model-only artifacts.
        """
        model_path = Path(path)
        manifest_path = _manifest_path(model_path)
        if manifest_path.exists():
            manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
            pc_columns = pc_columns or manifest["pc_columns"]
            pgs_column = pgs_column if pgs_column != "PGS" else manifest["pgs_column"]
            kwargs.setdefault("out_column", manifest.get("out_column", "pgs_ctn_z"))
            kwargs.setdefault("duchon_centers", manifest.get("duchon_centers"))
            kwargs.setdefault("duchon_order", manifest.get("duchon_order", 1))
            kwargs.setdefault("duchon_power", manifest.get("duchon_power", 2))
            kwargs.setdefault("duchon_length_scale", manifest.get("duchon_length_scale", 1.0))
            kwargs.setdefault("scale_dimensions", manifest.get("scale_dimensions", True))
        if pc_columns is None:
            raise ValueError(
                "PgsCalibration.load requires pc_columns when the sidecar manifest is missing"
            )
        instance = cls(pc_columns=pc_columns, pgs_column=pgs_column, **kwargs)
        instance._model = load_model(model_path)
        return instance

    def _manifest(self) -> dict[str, Any]:
        return {
            "kind": "PgsCalibration",
            "version": 1,
            "pc_columns": list(self.pc_columns),
            "pgs_column": self.pgs_column,
            "out_column": self.out_column,
            "formula": self.formula,
            "duchon_centers": self.duchon_centers,
            "resolved_duchon_centers": self._resolved_centers,
            "duchon_order": self.duchon_order,
            "duchon_power": self.duchon_power,
            "duchon_length_scale": self.duchon_length_scale,
            "scale_dimensions": self.scale_dimensions,
        }


def _manifest_path(model_path: Path) -> Path:
    return model_path.with_name(f"{model_path.name}.pgs.json")


def _attach_z_column(data: Any, z: Any, out_column: str) -> Any:
    pd: Any | None
    try:
        import pandas as pd
    except ImportError:
        pd = None

    if pd is not None and isinstance(data, pd.DataFrame):
        result = data.copy()
        result[out_column] = _to_1d_list(z)
        return result

    if isinstance(data, dict):
        result = {key: list(value) for key, value in data.items()}
        result[out_column] = _to_1d_list(z)
        return result

    if isinstance(data, list):
        values = _to_1d_list(z)
        if len(values) != len(data):
            raise ValueError(
                f"predicted z-score length {len(values)} does not match record "
                f"count {len(data)}"
            )
        result = []
        for record, value in zip(data, values):
            enriched = dict(record)
            enriched[out_column] = value
            result.append(enriched)
        return result

    pa: Any | None
    try:
        import pyarrow as pa
    except ImportError:
        pa = None

    if pa is not None and isinstance(data, pa.Table):
        return data.append_column(out_column, pa.array(_to_1d_list(z)))

    return {
        "_original": data,
        out_column: _to_1d_list(z),
    }


def _to_1d_list(values: Any) -> list[float]:
    np: Any | None
    try:
        import numpy as np
    except ImportError:
        np = None

    if np is not None and isinstance(values, np.ndarray):
        return [float(v) for v in values.reshape(-1).tolist()]
    if isinstance(values, dict):
        for key in ("z", "z_score", "transformed", "eta", "mean"):
            if key in values:
                return [float(v) for v in values[key]]
        raise KeyError(
            "prediction result dict does not contain a z-score column"
        )
    return [float(v) for v in values]