import os
import re
from typing import Optional
from .core._imperative_rt.common import CompNode, DeviceType
from .core._imperative_rt.common import (
get_cuda_compute_capability as _get_cuda_compute_capability,
)
from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config
from .core._imperative_rt.common import what_is_xpu as _what_is_xpu
from .core._imperative_rt.utils import _try_coalesce_all_free_memory
__all__ = [
"is_cuda_available",
"get_device_count",
"get_default_device",
"set_default_device",
"get_mem_status_bytes",
"get_cuda_compute_capability",
"get_allocated_memory",
"get_reserved_memory",
"get_max_reserved_memory",
"get_max_allocated_memory",
"reset_max_memory_stats",
"set_prealloc_config",
"coalesce_free_memory",
"DeviceType",
]
class _stream_helper:
def __init__(self):
self.stream = 1
def get_next(self):
out = self.stream
self.stream = self.stream + 1
return out
_sh = _stream_helper()
def _valid_device(inp):
if isinstance(inp, str) and re.match(
"^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp
):
return True
return False
def _str2device_type(type_str: str, allow_unspec: bool = True):
type_str = type_str.upper()
if type_str == "CPU":
return DeviceType.CPU
elif type_str == "GPU" or type_str == "CUDA":
return DeviceType.CUDA
elif type_str == "CAMBRICON":
return DeviceType.CAMBRICON
elif type_str == "ATLAS":
return DeviceType.ATLAS
elif type_str == "ROCM" or type_str == "AMDGPU":
return DeviceType.ROCM
else:
assert (
allow_unspec and type_str == "XPU"
), "device type can only be cpu, gpu or xpu"
return DeviceType.UNSPEC
_device_type_set = {"cpu", "gpu", "xpu", "rocm"}
def get_device_count(device_type: str) -> int:
assert device_type in _device_type_set, "device must be one of {}".format(
_device_type_set
)
device_type = _str2device_type(device_type)
return CompNode._get_device_count(device_type, False)
def is_cuda_available() -> bool:
t = _str2device_type("gpu")
return CompNode._get_device_count(t, False) > 0
def is_cambricon_available() -> bool:
t = _str2device_type("cambricon")
return CompNode._get_device_count(t, False) > 0
def is_atlas_available() -> bool:
t = _str2device_type("atlas")
return CompNode._get_device_count(t, False) > 0
def is_rocm_available() -> bool:
t = _str2device_type("rocm")
return CompNode._get_device_count(t, False) > 0
def set_default_device(device: str = "xpux"):
assert _valid_device(device), "Invalid device name {}".format(device)
CompNode._set_default_device(device)
def get_default_device() -> str:
return CompNode._get_default_device()
def get_mem_status_bytes(device: Optional[str] = None):
if device is None:
device = get_default_device()
tot, free = CompNode(device).get_mem_status_bytes
return tot, free
def get_cuda_compute_capability(device: int, device_type=DeviceType.CUDA) -> int:
return _get_cuda_compute_capability(device, device_type)
def get_allocated_memory(device: Optional[str] = None):
if device is None:
device = get_default_device()
return CompNode(device).get_used_memory
def get_reserved_memory(device: Optional[str] = None):
if device is None:
device = get_default_device()
return CompNode(device).get_reserved_memory
def get_max_reserved_memory(device: Optional[str] = None):
if device is None:
device = get_default_device()
return CompNode(device).get_max_reserved_memory
def get_max_allocated_memory(device: Optional[str] = None):
if device is None:
device = get_default_device()
return CompNode(device).get_max_used_memory
def reset_max_memory_stats(device: Optional[str] = None):
if device is None:
device = get_default_device()
CompNode.reset_max_memory_stats(device)
set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux"))
def set_prealloc_config(
alignment: int = 1,
min_req: int = 32 * 1024 * 1024,
max_overhead: int = 0,
growth_factor=2.0,
device_type=DeviceType.CUDA,
):
assert alignment > 0
assert min_req > 0
assert max_overhead >= 0
assert growth_factor >= 1
_set_prealloc_config(alignment, min_req, max_overhead, growth_factor, device_type)
def what_is_xpu():
return _what_is_xpu().name.lower()
def coalesce_free_memory():
return _try_coalesce_all_free_memory()