#pragma once
#include "megbrain_build_config.h"
#include <memory>
#include <stdexcept>
#include <string>
#if MGB_ENABLE_EXCEPTION
#define MGB_IF_EXCEPTION(x...) x
#else
#define MGB_IF_EXCEPTION(x...)
#endif
#if (defined(__GNUC__) && !defined(__ANDROID__) && !defined(ANDROID) && \
!defined(__APPLE__))
#include <cxxabi.h>
#define __MGB_HANDLE_FORCED_UNWIND MGB_CATCH(abi::__forced_unwind&, { throw; })
#else
#define __MGB_HANDLE_FORCED_UNWIND
#endif
#define MGB_CATCH_ALL_EXCEPTION(_scope_msg, _ptr) \
MGB_CATCH(std::exception& _exc, { \
mgb_log_error("caught exception in %s; what(): %s", _scope_msg, _exc.what()); \
_ptr = std::current_exception(); \
}) \
__MGB_HANDLE_FORCED_UNWIND \
MGB_CATCH(..., { \
mgb_log_error("caught unknown exception in %s", _scope_msg); \
_ptr = std::current_exception(); \
}) \
do { \
} while (0)
#define MGB_HANDLE_EXCEPTION_DTOR(_scope_msg) \
MGB_CATCH(std::exception& _exc, { \
mgb_log_error( \
"abort due to exception in %s; what(): %s", _scope_msg, _exc.what()); \
abort(); \
}) \
MGB_CATCH(..., { \
mgb_log_error("abort due to unknown exception in %s", _scope_msg); \
}) \
do { \
} while (0)
namespace mgb {
class MegBrainError : public std::exception {
protected:
std::string m_msg;
public:
class ExtraInfo {
public:
virtual ~ExtraInfo() = default;
};
MegBrainError(const std::string& msg) : m_msg(msg) { init(); }
const char* what() const noexcept override { return m_msg.c_str(); }
const ExtraInfo* extra_info() const { return m_extra_info.get(); }
template <typename T>
MegBrainError& extra_info(T&& ptr) {
m_extra_info = ptr;
return *this;
}
~MegBrainError() noexcept = default;
private:
std::shared_ptr<ExtraInfo> m_extra_info;
MGE_WIN_DECLSPEC_FUC void init();
};
class SystemError : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class MemAllocError : public SystemError {
public:
using SystemError::SystemError;
};
class CudaError final : public SystemError {
public:
static std::string get_cuda_extra_info();
CudaError(const std::string& msg);
};
class EnFlameError final : public SystemError {
public:
EnFlameError(const std::string& msg);
};
class AtlasError final : public SystemError {
public:
AtlasError(const std::string& msg);
};
class ROCmError final : public SystemError {
public:
static std::string get_rocm_extra_info();
ROCmError(const std::string& msg);
};
class CnrtError final : public SystemError {
public:
static std::string get_cnrt_extra_info();
CnrtError(const std::string& msg);
};
class CndevError final : public SystemError {
public:
CndevError(const std::string& msg);
};
class CnmlError final : public SystemError {
public:
CnmlError(const std::string& msg);
};
class AssertionError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class ConversionError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class TensorCopyOverlapError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class TensorReshapeError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class SerializationError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class MegDNNError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class InternalError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
class TimeoutError final : public MegBrainError {
public:
using MegBrainError::MegBrainError;
};
}
namespace mgb {
bool has_uncaught_exception();
}