from __future__ import annotations
import re
import sys
from pathlib import Path
REPO = Path(__file__).resolve().parent.parent
SRC = REPO / "src"
TOP_LEVEL_MODULES = {
"adapters",
"backends",
"cli",
"codegen",
"core",
"docs",
"e2e",
"extract",
"publish",
"readme",
"scaffold",
"snippets",
}
_PUB_ITEM = re.compile(
r"^pub(\s*\(\s*crate\s*\))?\s+"
r"(?:fn|struct|enum|trait|type|const|static|use\s+[A-Za-z_:][\w:]*\s+as\s+)"
r"\s*([A-Za-z_][A-Za-z0-9_]*)"
, re.MULTILINE
)
_PUB_USE = re.compile(
r"^pub\s+use\s+[A-Za-z_:][\w:]*::([A-Za-z_][A-Za-z0-9_]*)(?:\s*;|\s+as\b)"
, re.MULTILINE
)
_PUB_USE_GLOB = re.compile(
r"^pub\s+use\s+(?:self::|crate::)?([A-Za-z_][A-Za-z0-9_]*)(?:::[^;]*)?\s*;",
re.MULTILINE,
)
def child_modules(module_dir: Path) -> set[str]:
names: set[str] = set()
for p in module_dir.iterdir():
if p.name in {"mod.rs", "lib.rs"}:
continue
if p.name == "templates":
continue
if p.is_file() and p.suffix == ".rs":
names.add(p.stem)
elif p.is_dir() and (p / "mod.rs").exists():
names.add(p.name)
mod_rs = module_dir / "mod.rs"
if mod_rs.exists():
text = mod_rs.read_text()
for match in _PUB_ITEM.finditer(text):
names.add(match.group(2))
for match in _PUB_USE.finditer(text):
names.add(match.group(1))
return names
def rewrite_module(module_dir: Path, module_name: str) -> int:
children = child_modules(module_dir)
safe_children = children - TOP_LEVEL_MODULES
if not safe_children:
return 0
pattern_pieces = "|".join(
re.escape(c) for c in sorted(safe_children, key=len, reverse=True)
)
pattern_path = re.compile(rf"\bcrate::({pattern_pieces})::")
pattern_end = re.compile(rf"\bcrate::({pattern_pieces})\b(?!::)")
use_brace_block = re.compile(
r"^(?P<indent>[ \t]*)use\s+crate::\{(?P<body>[^{}]*)\};",
re.MULTILINE,
)
def rewrite_brace_block(match: re.Match) -> str:
body = match.group("body")
items = [item.strip() for item in body.split(",") if item.strip()]
out: list[str] = []
for item in items:
head = item.split("::", 1)[0].split(" ", 1)[0] if head in safe_children:
out.append(f"{module_name}::{item}")
else:
out.append(item)
new_body = ", ".join(out)
return f"{match.group('indent')}use crate::{{{new_body}}};"
count = 0
for rs in module_dir.rglob("*.rs"):
text = rs.read_text()
new = use_brace_block.sub(rewrite_brace_block, text)
new = pattern_path.sub(rf"crate::{module_name}::\1::", new)
new = pattern_end.sub(rf"crate::{module_name}::\1", new)
if new != text:
rs.write_text(new)
count += 1
return count
def main() -> None:
total = 0
for entry in sorted(SRC.iterdir()):
if not entry.is_dir():
continue
if entry.name == "backends":
for backend in sorted(entry.iterdir()):
if backend.is_dir():
n = rewrite_module(backend, f"backends::{backend.name}")
if n:
print(f"backends/{backend.name}: rewrote {n} files")
total += n
elif entry.name in TOP_LEVEL_MODULES:
n = rewrite_module(entry, entry.name)
if n:
print(f"{entry.name}: rewrote {n} files")
total += n
print(f"\nTOTAL: rewrote {total} files")
if __name__ == "__main__":
main()