megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
#!/usr/bin/env python3
# This file is part of MegBrain.
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import argparse
import os
import re
import subprocess
import tempfile
from functools import partial
from multiprocessing import Manager

from tqdm.contrib.concurrent import process_map

# change workspace to MegBrain root dir
os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

failed_files = Manager().list()


def process_file(file, clang_format, write):
    source = open(file, "r").read()
    source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source)
    source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source)

    result = subprocess.check_output(
        [
            clang_format,
            "-style=file",
            "-verbose",
            "-assume-filename={}".format(file),
            # file,
        ],
        input=bytes(source.encode("utf-8")),
    )

    result = result.decode("utf-8")
    if count:
        result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2       \\", result)
    result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)

    if write:
        with tempfile.NamedTemporaryFile(
            dir=os.path.dirname(file), delete=False
        ) as tmp_file:
            tmp_file.write(result.encode("utf-8"))
        os.rename(tmp_file.name, file)
    else:
        ret_code = subprocess.run(
            ["diff", "--color=always", file, "-"], input=bytes(result.encode("utf-8")),
        ).returncode

        # man diff: 0 for same, 1 for different, 2 if trouble.
        if ret_code == 2:
            raise RuntimeError("format process (without overwrite) failed")
        if ret_code != 0:
            print(file)
            global failed_files
            failed_files.append(file)


def main():
    parser = argparse.ArgumentParser(
        description="Format source files using clang-format, eg: `./tools/format.py src -w`. \
        Require clang-format version == 12.0",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "path", nargs="+", help="file name or path based on MegBrain root dir."
    )
    parser.add_argument(
        "-w",
        "--write",
        action="store_true",
        help="use formatted file to replace original file.",
    )
    parser.add_argument(
        "--clang-format",
        default=os.getenv("CLANG_FORMAT", "clang-format"),
        help="clang-format executable name; it can also be "
        "modified via the CLANG_FORMAT environment var",
    )
    args = parser.parse_args()

    format_type = [".cpp", ".c", ".h", ".cu", ".cuh", ".inl"]

    def getfiles(path):
        rst = []
        for p in os.listdir(path):
            p = os.path.join(path, p)
            if os.path.isdir(p):
                rst += getfiles(p)
            elif (
                os.path.isfile(p)
                and not os.path.islink(p)
                and os.path.splitext(p)[1] in format_type
            ):
                rst.append(p)
        return rst

    files = []
    for path in args.path:
        if os.path.isdir(path):
            files += getfiles(path)
        elif os.path.isfile(path):
            files.append(path)
        else:
            raise ValueError("Invalid path {}".format(path))

    # check version, we only support 12.0.1 now
    version = subprocess.check_output(
        [
            args.clang_format,
            "--version",
        ],
    )
    version = version.decode("utf-8")

    need_version = '12.0.1'
    if version.find(need_version) < 0:
        print('We only support {} now, please install {} version, find version: {}'
                .format(need_version, need_version, version))
        raise RuntimeError('clang-format version not equal {}'.format(need_version))

    process_map(
        partial(process_file, clang_format=args.clang_format, write=args.write,),
        files,
        chunksize=10,
    )

    if failed_files:
        raise RuntimeError("above files are not properly formatted!")


if __name__ == "__main__":
    main()