from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Sequence
from ._api import fit as fit_model
from ._api import load as load_model
from ._model import Model
__all__ = ["PgsCalibration"]
@dataclass
class PgsCalibration:
pc_columns: Sequence[str]
pgs_column: str = "PGS"
duchon_centers: int | None = None
duchon_order: int = 1
duchon_power: int = 2
duchon_length_scale: float = 1.0
scale_dimensions: bool | None = True
out_column: str = "pgs_ctn_z"
extra_fit_kwargs: dict[str, Any] = field(default_factory=dict)
_model: Model | None = field(default=None, init=False, repr=False)
_resolved_centers: int | None = field(default=None, init=False, repr=False)
def __post_init__(self) -> None:
if not self.pc_columns:
raise ValueError("pc_columns must be a non-empty sequence")
if not self.pgs_column:
raise ValueError("pgs_column must be provided")
self._resolved_centers = (
self.duchon_centers
if self.duchon_centers is not None
else len(self.pc_columns) + 1
)
@property
def formula(self) -> str:
pc_args = ", ".join(self.pc_columns)
duchon = (
f"duchon({pc_args}, centers={self._resolved_centers}, "
f"order={self.duchon_order}, power={self.duchon_power}, "
f"length_scale={self.duchon_length_scale:g})"
)
return f"{self.pgs_column} ~ {duchon}"
@property
def model(self) -> Model:
if self._model is None:
raise RuntimeError(
"PgsCalibration has not been fit yet; call .fit(data) first"
)
return self._model
def fit(self, data: Any) -> "PgsCalibration":
fit_kwargs: dict[str, Any] = {
"transformation_normal": True,
"scale_dimensions": self.scale_dimensions,
}
fit_kwargs.update(self.extra_fit_kwargs)
self._model = fit_model(data, self.formula, **fit_kwargs)
return self
def transform(self, data: Any) -> Any:
z = self.predict(data)
return _attach_z_column(data, z, self.out_column)
def fit_transform(self, data: Any) -> Any:
self.fit(data)
return self.transform(data)
def predict(self, data: Any) -> Any:
return self.model.predict(data)
def save(self, path: str | Path) -> None:
model_path = Path(path)
self.model.save(model_path)
_manifest_path(model_path).write_text(
json.dumps(self._manifest(), indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
@classmethod
def load(
cls,
path: str | Path,
*,
pc_columns: Sequence[str] | None = None,
pgs_column: str = "PGS",
**kwargs: Any,
) -> "PgsCalibration":
model_path = Path(path)
manifest_path = _manifest_path(model_path)
if manifest_path.exists():
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
pc_columns = pc_columns or manifest["pc_columns"]
pgs_column = pgs_column if pgs_column != "PGS" else manifest["pgs_column"]
kwargs.setdefault("out_column", manifest.get("out_column", "pgs_ctn_z"))
kwargs.setdefault("duchon_centers", manifest.get("duchon_centers"))
kwargs.setdefault("duchon_order", manifest.get("duchon_order", 1))
kwargs.setdefault("duchon_power", manifest.get("duchon_power", 2))
kwargs.setdefault("duchon_length_scale", manifest.get("duchon_length_scale", 1.0))
kwargs.setdefault("scale_dimensions", manifest.get("scale_dimensions", True))
if pc_columns is None:
raise ValueError(
"PgsCalibration.load requires pc_columns when the sidecar manifest is missing"
)
instance = cls(pc_columns=pc_columns, pgs_column=pgs_column, **kwargs)
instance._model = load_model(model_path)
return instance
def _manifest(self) -> dict[str, Any]:
return {
"kind": "PgsCalibration",
"version": 1,
"pc_columns": list(self.pc_columns),
"pgs_column": self.pgs_column,
"out_column": self.out_column,
"formula": self.formula,
"duchon_centers": self.duchon_centers,
"resolved_duchon_centers": self._resolved_centers,
"duchon_order": self.duchon_order,
"duchon_power": self.duchon_power,
"duchon_length_scale": self.duchon_length_scale,
"scale_dimensions": self.scale_dimensions,
}
def _manifest_path(model_path: Path) -> Path:
return model_path.with_name(f"{model_path.name}.pgs.json")
def _attach_z_column(data: Any, z: Any, out_column: str) -> Any:
pd: Any | None
try:
import pandas as pd
except ImportError:
pd = None
if pd is not None and isinstance(data, pd.DataFrame):
result = data.copy()
result[out_column] = _to_1d_list(z)
return result
if isinstance(data, dict):
result = {key: list(value) for key, value in data.items()}
result[out_column] = _to_1d_list(z)
return result
if isinstance(data, list):
values = _to_1d_list(z)
if len(values) != len(data):
raise ValueError(
f"predicted z-score length {len(values)} does not match record "
f"count {len(data)}"
)
result = []
for record, value in zip(data, values):
enriched = dict(record)
enriched[out_column] = value
result.append(enriched)
return result
pa: Any | None
try:
import pyarrow as pa
except ImportError:
pa = None
if pa is not None and isinstance(data, pa.Table):
return data.append_column(out_column, pa.array(_to_1d_list(z)))
return {
"_original": data,
out_column: _to_1d_list(z),
}
def _to_1d_list(values: Any) -> list[float]:
np: Any | None
try:
import numpy as np
except ImportError:
np = None
if np is not None and isinstance(values, np.ndarray):
return [float(v) for v in values.reshape(-1).tolist()]
if isinstance(values, dict):
for key in ("z", "z_score", "transformed", "eta", "mean"):
if key in values:
return [float(v) for v in values[key]]
raise KeyError(
"prediction result dict does not contain a z-score column"
)
return [float(v) for v in values]