#include "shared_library.h"
#include "filesystem/api.h"
#include "mutex"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif
#ifdef _WIN32
#define NOMINMAX
#include <Windows.h>
#else
#include <dlfcn.h>
#endif
namespace triton { namespace core {
static std::mutex mu_;
Status
SharedLibrary::Acquire(std::unique_ptr<SharedLibrary>* slib)
{
mu_.lock();
slib->reset(new SharedLibrary());
return Status::Success;
}
SharedLibrary::~SharedLibrary()
{
mu_.unlock();
}
Status
SharedLibrary::SetLibraryDirectory(const std::string& path)
{
#ifdef _WIN32
LOG_VERBOSE(1) << "SetLibraryDirectory: path = " << path;
if (!SetDllDirectory(path.c_str())) {
LPSTR err_buffer = nullptr;
size_t size = FormatMessageA(
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPSTR)&err_buffer, 0, NULL);
std::string errstr(err_buffer, size);
LocalFree(err_buffer);
return Status(
Status::Code::NOT_FOUND,
"unable to set dll path " + path + ": " + errstr);
}
#endif
return Status::Success;
}
Status
SharedLibrary::ResetLibraryDirectory()
{
#ifdef _WIN32
LOG_VERBOSE(1) << "ResetLibraryDirectory";
if (!SetDllDirectory(NULL)) {
LPSTR err_buffer = nullptr;
size_t size = FormatMessageA(
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPSTR)&err_buffer, 0, NULL);
std::string errstr(err_buffer, size);
LocalFree(err_buffer);
return Status(
Status::Code::NOT_FOUND, "unable to reset dll path: " + errstr);
}
#endif
return Status::Success;
}
Status
SharedLibrary::OpenLibraryHandle(const std::string& path, void** handle)
{
LOG_VERBOSE(1) << "OpenLibraryHandle: " << path;
#ifdef TRITON_ENABLE_GPU
int device_count;
cudaGetDeviceCount(&device_count);
#endif
#ifdef _WIN32
const std::string library_dir = DirName(path);
RETURN_IF_ERROR(SetLibraryDirectory(library_dir));
LOG_VERBOSE(1) << "OpenLibraryHandle: path = " << path;
*handle = LoadLibrary(path.c_str());
RETURN_IF_ERROR(ResetLibraryDirectory());
if (*handle == nullptr) {
LPSTR err_buffer = nullptr;
size_t size = FormatMessageA(
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPSTR)&err_buffer, 0, NULL);
std::string errstr(err_buffer, size);
LocalFree(err_buffer);
return Status(
Status::Code::NOT_FOUND, "unable to load shared library: " + errstr);
}
#else
*handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
if (*handle == nullptr) {
return Status(
Status::Code::NOT_FOUND,
"unable to load shared library: " + std::string(dlerror()));
}
#endif
return Status::Success;
}
Status
SharedLibrary::CloseLibraryHandle(void* handle)
{
if (handle != nullptr) {
#ifdef _WIN32
if (FreeLibrary((HMODULE)handle) == 0) {
LPSTR err_buffer = nullptr;
size_t size = FormatMessageA(
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPSTR)&err_buffer, 0, NULL);
std::string errstr(err_buffer, size);
LocalFree(err_buffer);
return Status(
Status::Code::INTERNAL, "unable to unload shared library: " + errstr);
}
#else
if (dlclose(handle) != 0) {
return Status(
Status::Code::INTERNAL,
"unable to unload shared library: " + std::string(dlerror()));
}
#endif
}
return Status::Success;
}
Status
SharedLibrary::GetEntrypoint(
void* handle, const std::string& name, const bool optional, void** befn)
{
*befn = nullptr;
#ifdef _WIN32
void* fn = GetProcAddress((HMODULE)handle, name.c_str());
if ((fn == nullptr) && !optional) {
LPSTR err_buffer = nullptr;
size_t size = FormatMessageA(
FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPSTR)&err_buffer, 0, NULL);
std::string errstr(err_buffer, size);
LocalFree(err_buffer);
return Status(
Status::Code::NOT_FOUND,
"unable to find '" + name +
"' entrypoint in custom library: " + errstr);
}
#else
dlerror();
void* fn = dlsym(handle, name.c_str());
const char* dlsym_error = dlerror();
if (dlsym_error != nullptr) {
if (optional) {
return Status::Success;
}
std::string errstr(dlsym_error); return Status(
Status::Code::NOT_FOUND, "unable to find required entrypoint '" + name +
"' in shared library: " + errstr);
}
if (fn == nullptr) {
if (optional) {
return Status::Success;
}
return Status(
Status::Code::NOT_FOUND,
"unable to find required entrypoint '" + name + "' in shared library");
}
#endif
*befn = fn;
return Status::Success;
}
}}