from __future__ import annotations
import argparse
import difflib
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Set, Tuple
IMPL_QUANTITY_RE = re.compile(
r"impl_quantity!\(\s*(?P<quantity>\w+)\s*,\s*(?P<unit>\w+)\s*,",
re.MULTILINE,
)
IMPL_CONST_RE = re.compile(r"impl_const!\(\s*(?P<quantity>\w+)\s*,\s*(?P<name>\w+)\s*,")
MAKE_ALIAS_RE = re.compile(
r"make_alias!\(\s*(?P<base_quantity>\w+)\s*,\s*(?P<base_unit>\w+)\s*,\s*"
r"(?P<alias_quantity>\w+)\s*,\s*(?P<alias_unit>\w+)\s*\)"
)
IMPL_MUL_REL_SELF_RE = re.compile(
r"impl_mul_relation_with_self!\(\s*(?P<lhs>\w+)\s*,\s*(?P<res>\w+)\s*\)"
)
IMPL_MUL_REL_OTHER_RE = re.compile(
r"impl_mul_relation_with_other!\(\s*(?P<lhs>\w+)\s*,\s*(?P<rhs>\w+)\s*,\s*(?P<res>\w+)\s*\)"
)
IMPL_SQRT_RE = re.compile(r"impl_sqrt!\(\s*(?P<lhs>\w+)\s*,\s*(?P<res>\w+)\s*\)")
IMPL_MUL_RE = re.compile(
r"impl_mul!\(\s*(?P<lhs>\w+)\s*,\s*(?P<rhs>\w+)\s*,\s*(?P<res>\w+)\s*\)"
)
IMPL_DIV_RE = re.compile(
r"impl_div!\(\s*(?P<lhs>\w+)\s*,\s*(?P<rhs>\w+)\s*,\s*(?P<res>\w+)\s*\)"
)
IMPL_MUL_WITH_SELF_RE = re.compile(
r"impl_mul_with_self!\(\s*(?P<lhs>\w+)\s*,\s*(?P<res>\w+)\s*\)"
)
IMPL_DIV_SELF_TO_F64_RE = re.compile(r"impl_div_with_self_to_f64!\(\s*(?P<lhs>\w+)\s*\)")
RTRUE_DIV_RE = re.compile(
r"fn\s+__rtruediv__\s*\(\s*rhs:\s*PyRef<Self>\s*,\s*lhs:\s*f64\s*\)\s*->\s*PyResult<(?P<res>\w+)>"
)
@dataclass
class QuantityInfo:
name: str
unit: str
unit_variants: List[str]
constants: List[str]
rtruediv_return: Optional[str]
@dataclass
class OperatorMaps:
mul: Dict[str, Dict[str, str]]
div: Dict[str, Dict[str, str]]
def _strip_line_comment(line: str) -> str:
return line.split("//", 1)[0].strip()
def _extract_braced_block(source: str, start_index: int) -> str:
brace_start = source.find("{", start_index)
if brace_start == -1:
raise ValueError("opening brace not found")
depth = 0
for index in range(brace_start, len(source)):
char = source[index]
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
return source[brace_start + 1 : index]
raise ValueError("closing brace not found")
def parse_enum_variants(source: str, enum_name: str) -> List[str]:
enum_match = re.search(rf"pub enum\s+{re.escape(enum_name)}\s*\{{", source)
if not enum_match:
raise ValueError("enum not found")
body = _extract_braced_block(source, enum_match.start())
variants: List[str] = []
for raw_line in body.splitlines():
line = _strip_line_comment(raw_line)
if not line or line.startswith("#"):
continue
for segment in line.split(","):
token = segment.strip()
if not token:
continue
if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", token):
variants.append(token)
return variants
def parse_quantity_file(path: Path) -> Optional[QuantityInfo]:
source = path.read_text(encoding="utf-8")
quantity_match = IMPL_QUANTITY_RE.search(source)
if quantity_match is None:
return None
quantity_name = quantity_match.group("quantity")
unit_name = quantity_match.group("unit")
unit_variants = parse_enum_variants(source, unit_name)
constants: List[str] = []
for const_match in IMPL_CONST_RE.finditer(source):
if const_match.group("quantity") == quantity_name:
constants.append(const_match.group("name"))
rtruediv_match = RTRUE_DIV_RE.search(source)
rtruediv_return = rtruediv_match.group("res") if rtruediv_match else None
return QuantityInfo(
name=quantity_name,
unit=unit_name,
unit_variants=unit_variants,
constants=constants,
rtruediv_return=rtruediv_return,
)
def parse_quantities(quantities_dir: Path) -> List[QuantityInfo]:
quantities: List[QuantityInfo] = []
for path in sorted(quantities_dir.glob("*.rs")):
parsed = parse_quantity_file(path)
if parsed is not None:
quantities.append(parsed)
quantities.sort(key=lambda item: item.name)
return quantities
def parse_aliases(quantities_dir: Path) -> List[Tuple[str, str, str, str]]:
aliases: List[Tuple[str, str, str, str]] = []
seen = set()
for path in sorted(quantities_dir.glob("*.rs")):
source = path.read_text(encoding="utf-8")
for match in MAKE_ALIAS_RE.finditer(source):
alias = (
match.group("base_quantity"),
match.group("base_unit"),
match.group("alias_quantity"),
match.group("alias_unit"),
)
if alias not in seen:
seen.add(alias)
aliases.append(alias)
aliases.sort(key=lambda alias: (alias[2], alias[3]))
return aliases
def parse_sqrt_map(src_root: Path) -> Dict[str, str]:
sqrt_map: Dict[str, str] = {}
for path in sorted(src_root.rglob("*.rs")):
source = path.read_text(encoding="utf-8")
for match in IMPL_MUL_REL_SELF_RE.finditer(source):
lhs = match.group("lhs")
res = match.group("res")
sqrt_map[res] = lhs
for match in IMPL_SQRT_RE.finditer(source):
lhs = match.group("lhs")
res = match.group("res")
sqrt_map[lhs] = res
return sqrt_map
def _add_operator(
table: Dict[str, Dict[str, str]],
lhs: str,
rhs: str,
result: str,
) -> None:
if lhs not in table:
table[lhs] = {}
prev = table[lhs].get(rhs)
if prev is not None and prev != result:
raise ValueError(
f"Conflicting operator mapping for {lhs} and {rhs}: {prev} vs {result}"
)
table[lhs][rhs] = result
def parse_operator_maps(src_root: Path) -> OperatorMaps:
mul: Dict[str, Dict[str, str]] = {}
div: Dict[str, Dict[str, str]] = {}
files = [src_root / "relations.rs"] + sorted((src_root / "quantities").glob("*.rs"))
for path in files:
source = path.read_text(encoding="utf-8")
for match in IMPL_MUL_REL_OTHER_RE.finditer(source):
lhs = match.group("lhs")
rhs = match.group("rhs")
res = match.group("res")
_add_operator(mul, lhs, rhs, res)
_add_operator(mul, rhs, lhs, res)
_add_operator(div, res, lhs, rhs)
_add_operator(div, res, rhs, lhs)
for match in IMPL_MUL_REL_SELF_RE.finditer(source):
lhs = match.group("lhs")
res = match.group("res")
_add_operator(mul, lhs, lhs, res)
_add_operator(div, res, lhs, lhs)
for match in IMPL_MUL_RE.finditer(source):
lhs = match.group("lhs")
rhs = match.group("rhs")
res = match.group("res")
_add_operator(mul, lhs, rhs, res)
_add_operator(mul, rhs, lhs, res)
for match in IMPL_DIV_RE.finditer(source):
lhs = match.group("lhs")
rhs = match.group("rhs")
res = match.group("res")
_add_operator(div, lhs, rhs, res)
for match in IMPL_MUL_WITH_SELF_RE.finditer(source):
lhs = match.group("lhs")
res = match.group("res")
_add_operator(mul, lhs, lhs, res)
for match in IMPL_DIV_SELF_TO_F64_RE.finditer(source):
lhs = match.group("lhs")
_add_operator(div, lhs, lhs, "f64")
return OperatorMaps(mul=mul, div=div)
def _format_result_annotation(result: str, quantity_names: Set[str]) -> Optional[str]:
if result == "f64":
return "float"
if result in quantity_names:
return f'"{result}"'
return None
def render_vector3_stub() -> List[str]:
return [
"class Vector3:",
" def __init__(self, x: Any, y: Any, z: Any, unit: Optional[Any] = ...) -> None: ...",
" @staticmethod",
' def zero() -> "Vector3": ...',
" @staticmethod",
' def x() -> "Vector3": ...',
" @staticmethod",
' def y() -> "Vector3": ...',
" @staticmethod",
' def z() -> "Vector3": ...',
" def __getitem__(self, index: int) -> Any: ...",
" def __setitem__(self, index: int, value: Any) -> None: ...",
' def __neg__(self) -> "Vector3": ...',
' def __add__(self, other: "Vector3") -> "Vector3": ...',
' def __sub__(self, other: "Vector3") -> "Vector3": ...',
" def __repr__(self) -> str: ...",
" def __str__(self) -> str: ...",
" def to_list(self) -> List[Any]: ...",
" @staticmethod",
' def from_list(lst: Sequence[Any], unit: Optional[Any] = ...) -> "Vector3": ...',
" @staticmethod",
' def from_array(array: Any, unit: Optional[Any] = ...) -> "Vector3": ...',
" def to_array(self, unit: Optional[Any] = ...) -> Any: ...",
" def norm(self) -> Any: ...",
' def to_unit_vector(self) -> "Vector3": ...',
' def dot_vec(self, other: "Vector3") -> Any: ...',
' def cross(self, other: "Vector3") -> "Vector3": ...',
" def __mul__(self, rhs: Any) -> Vector3: ...",
" def __truediv__(self, rhs: Any) -> Vector3: ...",
]
def render_matrix3_stub() -> List[str]:
return [
"class Matrix3:",
" def __init__(self, data: Optional[Any] = ..., unit: Optional[Any] = ...) -> None: ...",
" @staticmethod",
' def zero() -> "Matrix3": ...',
" @staticmethod",
' def identity() -> "Matrix3": ...',
" def __getitem__(self, indices: Tuple[int, int]) -> Any: ...",
" def __setitem__(self, indices: Tuple[int, int], value: Any) -> None: ...",
' def __neg__(self) -> "Matrix3": ...',
" def __repr__(self) -> str: ...",
" def __str__(self) -> str: ...",
" @staticmethod",
' def from_list(lst: Sequence[Sequence[Any]], unit: Optional[Any] = ...) -> "Matrix3": ...',
" def to_list(self, unit: Optional[Any] = ...) -> List[List[Any]]: ...",
" @staticmethod",
' def from_array(array: Any, unit: Optional[Any] = ...) -> "Matrix3": ...',
" def to_array(self, unit: Optional[Any] = ...) -> Any: ...",
" @staticmethod",
' def from_rows(rows: Sequence[Vector3]) -> "Matrix3": ...',
" @staticmethod",
' def from_columns(columns: Sequence[Vector3]) -> "Matrix3": ...',
' def transpose(self) -> "Matrix3": ...',
" def get_column(self, index: int) -> Vector3: ...",
" def set_column(self, index: int, value: Vector3) -> None: ...",
" def get_row(self, index: int) -> Vector3: ...",
" def set_row(self, index: int, value: Vector3) -> None: ...",
]
def render_unit_stub(quantity: QuantityInfo) -> List[str]:
lines = [f"class {quantity.unit}:"]
if quantity.unit_variants:
for variant in quantity.unit_variants:
lines.append(f' {variant}: ClassVar["{quantity.unit}"]')
else:
lines.append(" pass")
return lines
def render_quantity_stub(
quantity: QuantityInfo,
sqrt_map: Dict[str, str],
operator_maps: OperatorMaps,
quantity_names: Set[str],
) -> List[str]:
q = quantity.name
u = quantity.unit
lines = [
f"class {q}:",
f' def __init__(self, value: float, unit: "{u}") -> None: ...',
" @classmethod",
f' def zero(cls: Type["{q}"]) -> "{q}": ...',
f' def to(self, unit: "{u}") -> float: ...',
f' def close_abs(self, other: "{q}", threshold: "{q}") -> bool: ...',
f' def close_rel(self, other: "{q}", threshold: float) -> bool: ...',
f' def __neg__(self) -> "{q}": ...',
f' def __abs__(self) -> "{q}": ...',
f' def __add__(self, rhs: "{q}") -> "{q}": ...',
f' def __sub__(self, rhs: "{q}") -> "{q}": ...',
" @overload",
f' def __mul__(self, rhs: float) -> "{q}": ...',
]
mul_ops = operator_maps.mul.get(q, {})
for rhs in sorted(mul_ops):
if rhs not in quantity_names:
continue
res = _format_result_annotation(mul_ops[rhs], quantity_names)
if res is None:
continue
lines.append(" @overload")
lines.append(f' def __mul__(self, rhs: "{rhs}") -> {res}: ...')
lines.append(" @overload")
lines.append(" def __mul__(self, rhs: Any) -> Any: ...")
lines.append(" @overload")
lines.append(f' def __truediv__(self, rhs: float) -> "{q}": ...')
div_ops = operator_maps.div.get(q, {})
for rhs in sorted(div_ops):
if rhs not in quantity_names:
continue
res = _format_result_annotation(div_ops[rhs], quantity_names)
if res is None:
continue
lines.append(" @overload")
lines.append(f' def __truediv__(self, rhs: "{rhs}") -> {res}: ...')
lines.append(" @overload")
lines.append(" def __truediv__(self, rhs: Any) -> Any: ...")
lines.append(f' def __rmul__(self, rhs: float) -> "{q}": ...')
lines.append(" def __repr__(self) -> str: ...")
lines.append(" def __str__(self) -> str: ...")
if quantity.rtruediv_return:
lines.append(f' def __rtruediv__(self, lhs: float) -> "{quantity.rtruediv_return}": ...')
for const_name in quantity.constants:
lines.extend(
[
" @staticmethod",
f' def {const_name}() -> "{q}": ...',
]
)
sqrt_return = sqrt_map.get(q)
if sqrt_return:
lines.append(f' def sqrt(self) -> "{sqrt_return}": ...')
return lines
def render_stub(
quantities: Sequence[QuantityInfo],
aliases: Sequence[Tuple[str, str, str, str]],
sqrt_map: Dict[str, str],
operator_maps: OperatorMaps,
) -> str:
quantity_names = {quantity.name for quantity in quantities}
lines: List[str] = [
"# This file is generated by scripts/generate_python_stubs.py.",
"# Do not edit manually.",
"",
"from typing import Any, ClassVar, List, Optional, Sequence, Tuple, Type, overload",
"",
]
lines.extend(render_vector3_stub())
lines.append("")
lines.extend(render_matrix3_stub())
lines.append("")
for quantity in quantities:
lines.extend(render_unit_stub(quantity))
lines.append("")
lines.extend(
render_quantity_stub(
quantity=quantity,
sqrt_map=sqrt_map,
operator_maps=operator_maps,
quantity_names=quantity_names,
)
)
lines.append("")
for base_quantity, base_unit, alias_quantity, alias_unit in aliases:
lines.append(f"{alias_quantity} = {base_quantity}")
lines.append(f"{alias_unit} = {base_unit}")
if aliases:
lines.append("")
return "\n".join(lines)
def _print_diff(path: Path, expected: str, current: str) -> None:
diff = difflib.unified_diff(
current.splitlines(),
expected.splitlines(),
fromfile=f"{path} (current)",
tofile=f"{path} (expected)",
lineterm="",
)
for line in diff:
print(line)
def sync_file(path: Path, content: str, check: bool) -> bool:
if path.exists():
current = path.read_text(encoding="utf-8")
else:
current = None
if current == content:
return False
if check:
print(f"File is out of date: {path}")
if current is not None:
_print_diff(path, content, current)
return True
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8")
return False
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--check",
action="store_true",
help="Fail if generated files are missing or outdated.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
repo_root = Path(__file__).resolve().parents[1]
quantities_dir = repo_root / "src" / "quantities"
src_root = repo_root / "src"
out_pkg = repo_root / "python" / "unitforge"
quantities = parse_quantities(quantities_dir)
aliases = parse_aliases(quantities_dir)
sqrt_map = parse_sqrt_map(src_root)
operator_maps = parse_operator_maps(src_root)
stub_content = render_stub(quantities, aliases, sqrt_map, operator_maps)
out_of_date = False
out_of_date |= sync_file(out_pkg / "__init__.pyi", stub_content, args.check)
out_of_date |= sync_file(out_pkg / "py.typed", "", args.check)
if out_of_date and args.check:
return 1
return 0
if __name__ == "__main__":
sys.exit(main())