megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
#!/usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

"""
purpose: use to test whether a model have good parallelism, if a model have good
parallelism it will get high performance improvement.
"""
import argparse
import logging
import os
import re
import subprocess

# test device
device = {
    "name": "hwmt40p",
    "login_name": "hwmt40p-K9000-maliG78",
    "ip": "box86.br.megvii-inc.com",
    "port": 2200,
    "thread_number": 3,
}

class SshConnector:
    """imp ssh control master connector"""

    ip = None
    port = None
    login_name = None

    def setup(self, login_name, ip, port):
        self.ip = ip
        self.login_name = login_name
        self.port = port

    def copy(self, src_list, dst_dir):
        assert isinstance(src_list, list), "code issue happened!!"
        assert isinstance(dst_dir, str), "code issue happened!!"
        for src in src_list:
            cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
                self.port, src, self.login_name, self.ip, dst_dir
            )
            logging.debug("ssh run cmd: {}".format(cmd))
            subprocess.check_call(cmd, shell=True)

    def cmd(self, cmd):
        output = ""
        assert isinstance(cmd, list), "code issue happened!!"
        for sub_cmd in cmd:
            p_cmd = 'ssh -p {} {}@{} "{}" '.format(
                self.port, self.login_name, self.ip, sub_cmd
            )
            logging.debug("ssh run cmd: {}".format(p_cmd))
            output = output + subprocess.check_output(p_cmd, shell=True).decode("utf-8")
        return output


def get_finally_bench_resulut_from_log(raw_log) -> float:
    # raw_log --> avg_time=23.331ms -->23.331ms
    h = re.findall(r"avg_time=.*ms ", raw_log)[-1][9:]
    # to 23.331
    h = h[: h.find("ms")]
    # to float
    h = float(h)
    return h


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--model_file", help="model file", required=True)
    parser.add_argument(
        "--load_and_run_file", help="path for load_and_run", required=True
    )
    args = parser.parse_args()

    # init device
    ssh = SshConnector()
    ssh.setup(device["login_name"], device["ip"], device["port"])
    # create test dir
    workspace = "model_parallelism_test"
    ssh.cmd(["mkdir -p {}".format(workspace)])
    # copy load_and_run_file
    ssh.copy([args.load_and_run_file], workspace)
    # call test
    model_file = args.model_file
    # copy model file
    ssh.copy([args.model_file], workspace)
    m = model_file.split('\\')[-1]
    # run single thread
    result = []
    thread_number = [1, 2, 4]
    for b in thread_number :
        cmd = []
        cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format(
                workspace, m, b
        )
        cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format(
                workspace, m, b
        )
        cmd.append(cmd1)
        cmd.append(cmd2)
        raw_log = ssh.cmd(cmd)
        # logging.debug(raw_log)
        ret = get_finally_bench_resulut_from_log(raw_log)
        logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret))
        result.append(ret)

    thread_2 = result[0]/result[1]
    thread_4 = result[0]/result[2]
    if thread_2 > 1.6 or thread_4 > 3.0:
        print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
    else:
        print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))


if __name__ == "__main__":
    LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
    DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
    logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)

    main()