import argparse
import hashlib
import re
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DOCS_API = PROJECT_ROOT / "docs" / "api"
DOCS_REF = PROJECT_ROOT / "docs" / "reference"
DOCS_SPECS = PROJECT_ROOT / "docs" / "specs"
def find_rust_files(root: Path) -> list[Path]:
search_dirs = [root / "src", root / "examples"]
files = []
for d in search_dirs:
if d.exists():
files.extend(sorted(d.rglob("*.rs")))
return sorted(set(files))
def extract_module_doc(content: str) -> str | None:
lines = content.split("\n")
doc_lines = []
for line in lines:
stripped = line.strip()
if stripped.startswith("//!"):
doc_lines.append(stripped[3:].strip())
elif not stripped.startswith("//!") and doc_lines:
break
return "\n".join(doc_lines) if doc_lines else None
def extract_item_docs(content: str) -> list[dict]:
items = []
current_doc = []
for line in content.split("\n"):
stripped = line.strip()
if stripped.startswith("///"):
current_doc.append(stripped[3:].strip())
elif current_doc and stripped and not stripped.startswith("//"):
items.append({"doc": "\n".join(current_doc), "code": stripped})
current_doc = []
else:
current_doc = []
return items
def generate_api_docs() -> dict[str, str]:
outputs = {}
rust_files = find_rust_files(PROJECT_ROOT)
if not rust_files:
return outputs
modules = []
for filepath in rust_files:
content = filepath.read_text(encoding="utf-8")
module_doc = extract_module_doc(content)
if module_doc:
module_name = filepath.relative_to(PROJECT_ROOT).with_suffix("").as_posix()
modules.append({"path": module_name, "name": filepath.stem, "doc": module_doc})
if modules:
gen_path = DOCS_API / "modules.gen.md"
lines = [
"# 模块索引",
"",
"> 自动从 `//!` 模块级注释生成。编辑源码注释后运行 `just gen-api-docs` 刷新。",
"",
]
for mod in modules:
first_line = mod["doc"].split("\n")[0] if mod["doc"] else ""
lines.append(f"- **`{mod['path']}`**: {first_line}")
lines.append("")
outputs[str(gen_path)] = "\n".join(lines)
for filepath in rust_files:
content = filepath.read_text(encoding="utf-8")
items = extract_item_docs(content)
if not items:
continue
module_name = filepath.relative_to(PROJECT_ROOT).with_suffix("").as_posix()
gen_path = DOCS_API / f"{filepath.stem}.gen.md"
lines = [
f"# API: `{module_name}`",
"",
"> 自动从 `///` 注释生成。编辑源码注释后运行 `just gen-api-docs` 刷新。",
"",
]
for item in items:
lines.append(item["doc"])
lines.append("")
if item["code"]:
lines.append("```rust")
lines.append(item["code"])
lines.append("```")
lines.append("")
outputs[str(gen_path)] = "\n".join(lines)
return outputs
def generate_config_docs() -> dict[str, str]:
outputs = {}
config_path = PROJECT_ROOT / "src" / "config.rs"
if not config_path.exists():
return outputs
content = config_path.read_text(encoding="utf-8")
struct_pattern = re.compile(
r'(///[^\n]*\n)*pub struct (\w+)[^{]*\{([^}]+)\}', re.MULTILINE | re.DOTALL
)
field_pattern = re.compile(
r'(///[^\n]*\n)*\s*pub (\w+)\s*:\s*([^,}]+)', re.MULTILINE
)
lines = [
"# 配置参考",
"",
"> 自动从 `src/config.rs` 生成。编辑源码后运行 `just gen-config-docs` 刷新。",
"",
]
for m in struct_pattern.finditer(content):
doc = (m.group(1) or "").replace("///", "").strip()
struct_name = m.group(2)
fields = m.group(3)
lines.append(f"## `{struct_name}`")
lines.append("")
if doc:
lines.extend(doc.split("\n"))
lines.append("")
for f in field_pattern.finditer(fields):
f_doc = (f.group(1) or "").replace("///", "").strip()
f_name = f.group(2)
f_type = f.group(3).strip()
doc_text = f" — {f_doc}" if f_doc else ""
lines.append(f"- **`{f_name}`**: `{f_type}`{doc_text}")
lines.append("")
if lines[3:]:
outputs[str(DOCS_REF / "config.gen.md")] = "\n".join(lines)
return outputs
def generate_error_docs() -> dict[str, str]:
outputs = {}
error_path = PROJECT_ROOT / "src" / "error.rs"
if not error_path.exists():
return outputs
content = error_path.read_text(encoding="utf-8")
lines_out = [
"# 错误类型参考",
"",
"> 自动从 `src/error.rs` 生成。编辑源码后运行 `just gen-error-docs` 刷新。",
"",
]
enum_start = re.compile(r"^pub enum (\w+)\s*\{")
variant_pat = re.compile(
r"^\s*(\w+)(\([^)]*(?:\([^)]*\))?[^)]*\))?\s*,?\s*$"
)
src_lines = content.split("\n")
i = 0
while i < len(src_lines):
m = enum_start.search(src_lines[i])
if m:
enum_name = m.group(1)
lines_out.append(f"## `{enum_name}`")
lines_out.append("")
depth = 1
enum_body = []
i += 1
while i < len(src_lines) and depth > 0:
for ch in src_lines[i]:
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth > 0:
enum_body.append(src_lines[i])
i += 1
for line in enum_body:
stripped = line.strip()
if not stripped or stripped.startswith("#[") or stripped.startswith("//"):
continue
v = variant_pat.search(stripped)
if v:
v_name = v.group(1)
v_params = v.group(2) or ""
lines_out.append(f"- **`{v_name}{v_params}`**")
lines_out.append("")
else:
i += 1
if lines_out[3:]:
outputs[str(DOCS_REF / "error-types.gen.md")] = "\n".join(lines_out)
return outputs
def content_hash(content: str) -> str:
return hashlib.md5(content.encode()).hexdigest()
def write_outputs(outputs: dict[str, str], check: bool = False) -> bool:
all_match = True
for path_str, content in sorted(outputs.items()):
path = Path(path_str)
path.parent.mkdir(parents=True, exist_ok=True)
rel = path.relative_to(PROJECT_ROOT)
cleaned = content.rstrip("\n") + "\n"
if check:
if path.exists():
existing = path.read_text(encoding="utf-8")
if existing == cleaned:
print(f" OK {rel}")
else:
print(f" DRIFT {rel} — content differs")
all_match = False
else:
print(f" MISSING {rel} — will be generated on write")
all_match = False
else:
path.write_text(cleaned, encoding="utf-8")
print(f" WRITTEN {rel}")
return all_match
def main():
parser = argparse.ArgumentParser(description="dapz 文档生成 (SSOT Framework)")
parser.add_argument("--check", action="store_true", help="检查模式:只验证不写入")
args = parser.parse_args()
mode = "CHECK" if args.check else "GENERATE"
print(f"[gen-docs] Mode: {mode}")
generators = [
("API docs", generate_api_docs),
("Config docs", generate_config_docs),
("Error docs", generate_error_docs),
]
all_outputs = {}
for name, gen_fn in generators:
try:
print(f"\n[{name}]")
result = gen_fn()
all_outputs.update(result)
if not result:
print(f" (no source code found yet)")
except Exception as e:
print(f" ERROR: {e}")
if not args.check:
raise
if not all_outputs:
print("\nNo source code found. This is expected before MVP implementation.")
return
print(
f"\nTotal: {len(all_outputs)} files to {'check' if args.check else 'generate'}"
)
ok = write_outputs(all_outputs, check=args.check)
if args.check and not ok:
print(
"\nERROR: Documentation drift detected! Run `just gen-docs` to regenerate."
)
sys.exit(1)
elif not args.check:
print("\nDone. Run `just gen-check` for CI drift detection.")
else:
print("\nAll documentation is up to date.")
if __name__ == "__main__":
main()