dapz 0.0.1

AI-friendly DAP compression proxy — token-efficient Debug Adapter Protocol proxy
Documentation
#!/usr/bin/env python3
"""
dapz 文档生成脚本 — SSOT Framework

从代码注释(/// / //!)生成 .gen.md 文档。
支持 --check 模式用于 CI 漂移检测。

SSOT 原则: 代码是唯一的真相源,文档都是生成物。

使用方法:
    python scripts/gen-docs.py           # 生成所有文档
    python scripts/gen-docs.py --check   # 检查文档是否过时
"""

import argparse
import hashlib
import re
import sys
from pathlib import Path


PROJECT_ROOT = Path(__file__).resolve().parent.parent
DOCS_API = PROJECT_ROOT / "docs" / "api"
DOCS_REF = PROJECT_ROOT / "docs" / "reference"
DOCS_SPECS = PROJECT_ROOT / "docs" / "specs"


def find_rust_files(root: Path) -> list[Path]:
    """Find all .rs files in the project."""
    search_dirs = [root / "src", root / "examples"]
    files = []
    for d in search_dirs:
        if d.exists():
            files.extend(sorted(d.rglob("*.rs")))
    return sorted(set(files))


def extract_module_doc(content: str) -> str | None:
    """Extract //! module-level documentation."""
    lines = content.split("\n")
    doc_lines = []
    for line in lines:
        stripped = line.strip()
        if stripped.startswith("//!"):
            doc_lines.append(stripped[3:].strip())
        elif not stripped.startswith("//!") and doc_lines:
            break
    return "\n".join(doc_lines) if doc_lines else None


def extract_item_docs(content: str) -> list[dict]:
    """Extract /// item-level documentation + the item definition."""
    items = []
    current_doc = []
    for line in content.split("\n"):
        stripped = line.strip()
        if stripped.startswith("///"):
            current_doc.append(stripped[3:].strip())
        elif current_doc and stripped and not stripped.startswith("//"):
            items.append({"doc": "\n".join(current_doc), "code": stripped})
            current_doc = []
        else:
            current_doc = []
    return items


def generate_api_docs() -> dict[str, str]:
    """Generate API documentation from source comments."""
    outputs = {}
    rust_files = find_rust_files(PROJECT_ROOT)
    if not rust_files:
        return outputs

    modules = []
    for filepath in rust_files:
        content = filepath.read_text(encoding="utf-8")
        module_doc = extract_module_doc(content)
        if module_doc:
            module_name = filepath.relative_to(PROJECT_ROOT).with_suffix("").as_posix()
            modules.append({"path": module_name, "name": filepath.stem, "doc": module_doc})

    # Module index
    if modules:
        gen_path = DOCS_API / "modules.gen.md"
        lines = [
            "# 模块索引",
            "",
            "> 自动从 `//!` 模块级注释生成。编辑源码注释后运行 `just gen-api-docs` 刷新。",
            "",
        ]
        for mod in modules:
            first_line = mod["doc"].split("\n")[0] if mod["doc"] else ""
            lines.append(f"- **`{mod['path']}`**: {first_line}")
        lines.append("")
        outputs[str(gen_path)] = "\n".join(lines)

    # Per-module API docs
    for filepath in rust_files:
        content = filepath.read_text(encoding="utf-8")
        items = extract_item_docs(content)
        if not items:
            continue
        module_name = filepath.relative_to(PROJECT_ROOT).with_suffix("").as_posix()
        gen_path = DOCS_API / f"{filepath.stem}.gen.md"
        lines = [
            f"# API: `{module_name}`",
            "",
            "> 自动从 `///` 注释生成。编辑源码注释后运行 `just gen-api-docs` 刷新。",
            "",
        ]
        for item in items:
            lines.append(item["doc"])
            lines.append("")
            if item["code"]:
                lines.append("```rust")
                lines.append(item["code"])
                lines.append("```")
                lines.append("")
        outputs[str(gen_path)] = "\n".join(lines)

    return outputs


def generate_config_docs() -> dict[str, str]:
    """Generate config reference from Config struct."""
    outputs = {}
    config_path = PROJECT_ROOT / "src" / "config.rs"
    if not config_path.exists():
        return outputs

    content = config_path.read_text(encoding="utf-8")
    struct_pattern = re.compile(
        r'(///[^\n]*\n)*pub struct (\w+)[^{]*\{([^}]+)\}', re.MULTILINE | re.DOTALL
    )
    field_pattern = re.compile(
        r'(///[^\n]*\n)*\s*pub (\w+)\s*:\s*([^,}]+)', re.MULTILINE
    )

    lines = [
        "# 配置参考",
        "",
        "> 自动从 `src/config.rs` 生成。编辑源码后运行 `just gen-config-docs` 刷新。",
        "",
    ]
    for m in struct_pattern.finditer(content):
        doc = (m.group(1) or "").replace("///", "").strip()
        struct_name = m.group(2)
        fields = m.group(3)
        lines.append(f"## `{struct_name}`")
        lines.append("")
        if doc:
            lines.extend(doc.split("\n"))
            lines.append("")
        for f in field_pattern.finditer(fields):
            f_doc = (f.group(1) or "").replace("///", "").strip()
            f_name = f.group(2)
            f_type = f.group(3).strip()
            doc_text = f"{f_doc}" if f_doc else ""
            lines.append(f"- **`{f_name}`**: `{f_type}`{doc_text}")
        lines.append("")

    if lines[3:]:
        outputs[str(DOCS_REF / "config.gen.md")] = "\n".join(lines)
    return outputs


def generate_error_docs() -> dict[str, str]:
    """Generate error type reference from DapzError enum."""
    outputs = {}
    error_path = PROJECT_ROOT / "src" / "error.rs"
    if not error_path.exists():
        return outputs

    content = error_path.read_text(encoding="utf-8")

    lines_out = [
        "# 错误类型参考",
        "",
        "> 自动从 `src/error.rs` 生成。编辑源码后运行 `just gen-error-docs` 刷新。",
        "",
    ]

    enum_start = re.compile(r"^pub enum (\w+)\s*\{")
    variant_pat = re.compile(
        r"^\s*(\w+)(\([^)]*(?:\([^)]*\))?[^)]*\))?\s*,?\s*$"
    )

    src_lines = content.split("\n")
    i = 0
    while i < len(src_lines):
        m = enum_start.search(src_lines[i])
        if m:
            enum_name = m.group(1)
            lines_out.append(f"## `{enum_name}`")
            lines_out.append("")
            depth = 1
            enum_body = []
            i += 1
            while i < len(src_lines) and depth > 0:
                for ch in src_lines[i]:
                    if ch == "{":
                        depth += 1
                    elif ch == "}":
                        depth -= 1
                if depth > 0:
                    enum_body.append(src_lines[i])
                i += 1

            for line in enum_body:
                stripped = line.strip()
                if not stripped or stripped.startswith("#[") or stripped.startswith("//"):
                    continue
                v = variant_pat.search(stripped)
                if v:
                    v_name = v.group(1)
                    v_params = v.group(2) or ""
                    lines_out.append(f"- **`{v_name}{v_params}`**")
            lines_out.append("")
        else:
            i += 1

    if lines_out[3:]:
        outputs[str(DOCS_REF / "error-types.gen.md")] = "\n".join(lines_out)
    return outputs


def content_hash(content: str) -> str:
    return hashlib.md5(content.encode()).hexdigest()


def write_outputs(outputs: dict[str, str], check: bool = False) -> bool:
    """Write generated files or check if they match."""
    all_match = True
    for path_str, content in sorted(outputs.items()):
        path = Path(path_str)
        path.parent.mkdir(parents=True, exist_ok=True)
        rel = path.relative_to(PROJECT_ROOT)

        cleaned = content.rstrip("\n") + "\n"
        if check:
            if path.exists():
                existing = path.read_text(encoding="utf-8")
                if existing == cleaned:
                    print(f"  OK {rel}")
                else:
                    print(f"  DRIFT {rel} — content differs")
                    all_match = False
            else:
                print(f"  MISSING {rel} — will be generated on write")
                all_match = False
        else:
            path.write_text(cleaned, encoding="utf-8")
            print(f"  WRITTEN {rel}")

    return all_match


def main():
    parser = argparse.ArgumentParser(description="dapz 文档生成 (SSOT Framework)")
    parser.add_argument("--check", action="store_true", help="检查模式:只验证不写入")
    args = parser.parse_args()

    mode = "CHECK" if args.check else "GENERATE"
    print(f"[gen-docs] Mode: {mode}")

    generators = [
        ("API docs", generate_api_docs),
        ("Config docs", generate_config_docs),
        ("Error docs", generate_error_docs),
    ]

    all_outputs = {}
    for name, gen_fn in generators:
        try:
            print(f"\n[{name}]")
            result = gen_fn()
            all_outputs.update(result)
            if not result:
                print(f"  (no source code found yet)")
        except Exception as e:
            print(f"  ERROR: {e}")
            if not args.check:
                raise

    if not all_outputs:
        print("\nNo source code found. This is expected before MVP implementation.")
        return

    print(
        f"\nTotal: {len(all_outputs)} files to {'check' if args.check else 'generate'}"
    )

    ok = write_outputs(all_outputs, check=args.check)

    if args.check and not ok:
        print(
            "\nERROR: Documentation drift detected! Run `just gen-docs` to regenerate."
        )
        sys.exit(1)
    elif not args.check:
        print("\nDone. Run `just gen-check` for CI drift detection.")
    else:
        print("\nAll documentation is up to date.")


if __name__ == "__main__":
    main()