from __future__ import annotations
import re
import sys
import textwrap
from pathlib import Path
CFG_TEST_RE = re.compile(r"^[ \t]*#\[cfg\(test\)\][ \t]*$")
MOD_LINE_RE = re.compile(
r"^[ \t]*(?:pub(?:\([^)]*\))?[ \t]+)?mod[ \t]+([A-Za-z_][A-Za-z0-9_]*)[ \t]*\{[ \t]*$"
)
def find_matching_brace(source: str, open_idx: int) -> int | None:
assert source[open_idx] == "{"
depth = 0
i = open_idx
n = len(source)
while i < n:
c = source[i]
if c == "/" and i + 1 < n and source[i + 1] == "*":
depth_block = 1
i += 2
while i < n and depth_block > 0:
if source[i] == "/" and i + 1 < n and source[i + 1] == "*":
depth_block += 1
i += 2
elif source[i] == "*" and i + 1 < n and source[i + 1] == "/":
depth_block -= 1
i += 2
else:
i += 1
continue
if c == "/" and i + 1 < n and source[i + 1] == "/":
nl = source.find("\n", i)
i = n if nl == -1 else nl + 1
continue
if c == "r" and i + 1 < n and source[i + 1] in ('"', "#"):
j = i + 1
hashes = 0
while j < n and source[j] == "#":
hashes += 1
j += 1
if j < n and source[j] == '"':
terminator = '"' + "#" * hashes
end = source.find(terminator, j + 1)
if end == -1:
return None
i = end + len(terminator)
continue
if c == "b" and i + 1 < n and source[i + 1] in ('"', "r"):
i += 1
continue
if c == '"':
j = i + 1
while j < n:
if source[j] == "\\":
j += 2
continue
if source[j] == '"':
break
j += 1
i = j + 1
continue
if c == "'":
if i + 1 < n and (source[i + 1].isalpha() or source[i + 1] == "_"):
j = i + 1
while j < n and (source[j].isalnum() or source[j] == "_"):
j += 1
if j < n and source[j] == "'":
i = j + 1
else:
i = j
continue
j = i + 1
if j < n and source[j] == "\\":
j += 2
while j < n and source[j] != "'":
j += 1
else:
j = i + 2
if j < n and source[j] == "'":
i = j + 1
continue
i += 1
continue
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
return i
i += 1
return None
def extract_file(path: Path) -> tuple[bool, str | None]:
source = path.read_text(encoding="utf-8")
lines = source.splitlines(keepends=True)
byte_offsets: list[int] = []
off = 0
for ln in lines:
byte_offsets.append(off)
off += len(ln)
cfg_indices = [i for i, ln in enumerate(lines) if CFG_TEST_RE.match(ln)]
if not cfg_indices:
return False, None
if len(cfg_indices) > 1:
return False, f"multiple #[cfg(test)] attributes — manual review"
i = cfg_indices[0]
j = i + 1
while j < len(lines):
stripped = lines[j].strip()
if stripped == "" or stripped.startswith("//"):
j += 1
continue
break
if j >= len(lines):
return False, "#[cfg(test)] at EOF"
m = MOD_LINE_RE.match(lines[j])
if not m:
return False, f"unrecognized item after #[cfg(test)]: {lines[j].strip()[:60]!r}"
mod_name = m.group(1)
if mod_name != "tests":
return False, f"#[cfg(test)] mod `{mod_name}` (expected `tests`) — manual review"
mod_line_start = byte_offsets[j]
brace_open = source.find("{", mod_line_start)
if brace_open == -1:
return False, "no `{` on mod tests line"
brace_close = find_matching_brace(source, brace_open)
if brace_close is None:
return False, "no matching `}` for mod tests"
body = source[brace_open + 1 : brace_close]
if body.startswith("\n"):
body = body[1:]
body = body.rstrip() + "\n"
body = textwrap.dedent(body)
sibling = path.with_name(path.stem + "_tests.rs")
if sibling.exists():
return False, f"sibling test file already exists: {sibling}"
rel_src = path.name
header = (
f"//! Inline unit tests for `{rel_src}`, extracted into this\n"
f"//! sibling file so they can be excluded from the public mirror.\n"
f"//! Imported by `{rel_src}` via `#[cfg(test)] #[path = ...] mod tests;`.\n"
f"//!\n"
f"//! This file is gated by `#[cfg(test)]` at the parent module, so it\n"
f"//! is only compiled during `cargo test`.\n"
f"\n"
)
sibling.write_text(header + body, encoding="utf-8")
indent_match = re.match(r"[ \t]*", lines[i])
indent = indent_match.group(0) if indent_match else ""
replacement = (
f"{indent}#[cfg(test)]\n"
f"{indent}#[path = \"{sibling.name}\"]\n"
f"{indent}mod tests;\n"
)
strip_start = byte_offsets[i]
strip_end = brace_close + 1
if strip_end < len(source) and source[strip_end] == "\n":
strip_end += 1
new_source = source[:strip_start] + replacement + source[strip_end:]
path.write_text(new_source, encoding="utf-8")
return True, None
def main() -> int:
if len(sys.argv) != 2:
print(f"usage: {sys.argv[0]} <ROOT_DIR>", file=sys.stderr)
return 1
root = Path(sys.argv[1])
if not root.is_dir():
print(f"not a directory: {root}", file=sys.stderr)
return 1
extracted = 0
skipped = 0
errors: list[str] = []
for rs in sorted(root.rglob("*.rs")):
if "target" in rs.parts:
continue
if rs.stem.endswith("_tests"):
continue
changed, err = extract_file(rs)
if err:
errors.append(f"{rs}: {err}")
continue
if changed:
extracted += 1
print(f" extracted {rs} → {rs.with_name(rs.stem + '_tests.rs').name}")
else:
skipped += 1
print(f"\nExtracted: {extracted}, skipped: {skipped}")
if errors:
print("\nErrors:", file=sys.stderr)
for e in errors:
print(f" {e}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())