from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
from typing import Any
SRC = Path(__file__).resolve().parent.parent / "src"
def parse_pyo3_signature(sig_str: str) -> list[dict[str, Any]]:
sig_str = sig_str.strip()
if not sig_str or sig_str == "()":
return []
args: list[dict[str, Any]] = []
depth = 0
start = 0
for i, c in enumerate(sig_str):
if c in "([{":
depth += 1
elif c in ")]}":
depth -= 1
elif c == "," and depth == 0:
part = sig_str[start:i].strip()
if part:
args.append(_parse_sig_arg(part))
start = i + 1
if start < len(sig_str):
part = sig_str[start:].strip()
if part:
args.append(_parse_sig_arg(part))
return args
def _parse_sig_arg(part: str) -> dict[str, Any]:
if "=" in part:
name, default = part.split("=", 1)
name = name.strip()
default = default.strip()
if default in ("None", "None)"):
default = "None"
elif default in ("True", "true"):
default = "True"
elif default in ("False", "false"):
default = "False"
return {"name": name, "default": default, "kind": "POSITIONAL_OR_KEYWORD"}
return {"name": part.strip(), "default": None, "kind": "POSITIONAL_OR_KEYWORD"}
def parse_rust_fn_params(fn_sig: str) -> list[dict[str, Any]]:
m = re.search(r"fn\s+\w+\s*\(\s*(.*?)\s*\)", fn_sig, re.DOTALL)
if not m:
return []
params_str = m.group(1).strip()
if not params_str:
return []
args: list[dict[str, Any]] = []
depth = 0
in_lt = 0
start = 0
for i, c in enumerate(params_str):
if c == "(":
depth += 1
elif c == ")":
depth -= 1
elif c == "<":
in_lt += 1
elif c == ">":
in_lt -= 1
elif c == "," and depth == 0 and in_lt == 0:
part = params_str[start:i].strip()
if part and not part.startswith("&self") and part != "self":
name = part.split(":")[0].strip().lstrip("&")
if name.startswith("mut "):
name = name[4:]
if name != "self":
args.append(
{"name": name, "default": None, "kind": "POSITIONAL_OR_KEYWORD"}
)
start = i + 1
if start < len(params_str):
part = params_str[start:].strip()
if part and not part.startswith("&self") and part != "self":
name = part.split(":")[0].strip().lstrip("&")
if name.startswith("mut "):
name = name[4:]
if name != "self":
args.append(
{"name": name, "default": None, "kind": "POSITIONAL_OR_KEYWORD"}
)
return args
def extract_pyo3_signature_and_fn(
text: str, match_start: int, match_end: int
) -> tuple[list[dict[str, Any]], str | None]:
sig_re = re.compile(
r"#\[\s*pyo3\s*\(\s*signature\s*=\s*\((.*?)\)\s*\)\s*\]",
re.DOTALL,
)
search_start = max(0, match_start - 250)
chunk = text[search_start:match_end]
sig_m = sig_re.search(chunk)
args: list[dict[str, Any]] = []
fn_offset_in_chunk = match_start - search_start
if sig_m and (fn_offset_in_chunk - sig_m.end()) <= 100:
inner = sig_m.group(1).strip()
inner = re.sub(r"\s+", " ", inner)
args = parse_pyo3_signature(inner)
fn_re = re.compile(r"fn\s+(\w+)\s*\([^)]*\)")
fn_m = fn_re.search(text, match_start)
fn_name = fn_m.group(1) if fn_m else None
if not args and fn_m:
full_sig = fn_m.group(0)
args = parse_rust_fn_params(full_sig)
return (args, fn_name)
def extract_module_functions(mod_rs: Path) -> list[dict[str, Any]]:
text = mod_rs.read_text(encoding="utf-8", errors="replace")
func_sigs: dict[str, list[dict[str, Any]]] = {}
for m in re.finditer(r"(?:#\[pyo3[^\]]*\]\s*\n\s*)*fn\s+(py_\w+)\s*\(", text):
fn_name = m.group(1)
args, _ = extract_pyo3_signature_and_fn(text, m.start(), m.end())
func_sigs[fn_name] = args
for m in re.finditer(r"fn\s+(py_\w+)\s*\(", text):
fn_name = m.group(1)
if fn_name not in func_sigs:
args, _ = extract_pyo3_signature_and_fn(text, m.start(), m.end())
func_sigs[fn_name] = args
add_re = re.compile(
r'm\.add\s*\(\s*["\']([^"\']+)["\']\s*,\s*wrap_pyfunction!\s*\(\s*(\w+)\s*,'
)
result: list[dict[str, Any]] = []
seen: set[str] = set()
for m in add_re.finditer(text):
name = m.group(1)
if name.startswith("_"):
continue
py_func = m.group(2)
args = func_sigs.get(py_func, [])
key = name
if key not in seen:
seen.add(key)
result.append({"name": name, "args": args, "kind": "function"})
result.sort(key=lambda x: x["name"])
return result
def extract_impl_methods(file_path: Path, struct_name: str) -> list[dict[str, Any]]:
text = file_path.read_text(encoding="utf-8", errors="replace")
methods: list[dict[str, Any]] = []
impl_re = re.compile(
rf"(?:#\[pymethods\]\s*\n\s*)?impl\s+{re.escape(struct_name)}\s*\{{"
)
for impl_match in impl_re.finditer(text):
start = impl_match.end()
depth = 1
i = start
while i < len(text) and depth > 0:
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
i += 1
block = text[start : i - 1]
fn_block_re = re.compile(
r'(?:#\[pyo3\s*\(\s*name\s*=\s*["\']([^"\']+)["\']\s*\)\]\s*\n\s*)?'
r"(?:#\[pyo3\s*\(\s*signature\s*=\s*\((.*?)\)\s*\)\s*\]\s*\n\s*)?"
r"fn\s+(\w+)\s*\(",
re.DOTALL,
)
for m in fn_block_re.finditer(block):
py_name = m.group(1) or m.group(3) if py_name.startswith("__") and py_name.endswith("__"):
continue
sig_inner = m.group(2)
if sig_inner is not None:
sig_inner = re.sub(r"\s+", " ", sig_inner.strip())
args = parse_pyo3_signature(sig_inner)
else:
fn_sig = m.group(0) + "..." end = block.find(")", m.end()) + 1
if end > m.end():
fn_sig = block[m.start() : end]
args = parse_rust_fn_params(fn_sig)
methods.append(
{"name": py_name, "args": args, "class": struct_name.replace("Py", "")}
)
return methods
def extract_class_methods_from_file(
file_path: Path, class_map: list[tuple[str, str]]
) -> dict[str, list[dict[str, Any]]]:
result: dict[str, list[dict[str, Any]]] = {}
for rust_name, out_name in class_map:
methods = extract_impl_methods(file_path, rust_name)
result[out_name] = [
{"name": m["name"], "args": m.get("args", [])} for m in methods
]
return result
def main() -> int:
parser = argparse.ArgumentParser(
description="Extract robin-sparkless API from Rust source"
)
parser.add_argument(
"--output",
"-o",
type=Path,
default=Path("docs/robin_api_from_source.json"),
help="Output JSON path",
)
args = parser.parse_args()
if not (SRC / "python" / "mod.rs").exists():
print("Source not found. Run from repo root.", file=sys.stderr)
return 1
result: dict[str, Any] = {
"source": "robin_source",
"functions": [],
"classes": {},
}
mod_rs = SRC / "python" / "mod.rs"
result["functions"] = extract_module_functions(mod_rs)
col_rs = SRC / "python" / "column.rs"
if col_rs.exists():
result["classes"]["Column"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(col_rs, "PyColumn")
]
df_rs = SRC / "python" / "dataframe.rs"
if df_rs.exists():
result["classes"]["DataFrame"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(df_rs, "PyDataFrame")
]
result["classes"]["GroupedData"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(df_rs, "PyGroupedData")
]
result["classes"]["DataFrameStat"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(df_rs, "PyDataFrameStat")
]
result["classes"]["DataFrameNa"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(df_rs, "PyDataFrameNa")
]
session_rs = SRC / "python" / "session.rs"
if session_rs.exists():
result["classes"]["SparkSession"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(session_rs, "PySparkSession")
]
result["classes"]["SparkSessionBuilder"] = [
{"name": m["name"], "args": m.get("args", [])}
for m in extract_impl_methods(session_rs, "PySparkSessionBuilder")
]
out_path = args.output
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump(result, f, indent=2)
n_funcs = len(result["functions"])
n_with_args = sum(1 for f in result["functions"] if f.get("args"))
n_classes = sum(len(v) for v in result["classes"].values())
print(
f"Wrote {out_path} ({n_funcs} functions, {n_with_args} with args; {n_classes} class methods)"
)
return 0
if __name__ == "__main__":
sys.exit(main())