#pragma once
#include "megbrain/exception.h"
#include "megbrain/utils/hash.h"
#include "megbrain/utils/thin/function.h"
#include "megbrain/utils/thin/hash_table.h"
#include <algorithm>
#include <string>
#include <type_traits>
#include <unordered_set>
#include <utility>
namespace mgb {
struct Typeinfo {
const char* const name;
template <typename T>
bool is() const {
return T::typeinfo() == this;
}
};
class DynTypeObj {
public:
virtual Typeinfo* dyn_typeinfo() const = 0;
template <class T>
T& cast_final() {
return *static_cast<T*>(this);
}
template <class T>
const T& cast_final() const {
return const_cast<DynTypeObj*>(this)->cast_final<T>();
}
template <class T>
T& cast_final_safe() {
mgb_assert(
T::typeinfo() == dyn_typeinfo(), "can not convert type %s to %s",
dyn_typeinfo()->name, T::typeinfo()->name);
return cast_final<T>();
}
template <class T>
const T& cast_final_safe() const {
return const_cast<DynTypeObj*>(this)->cast_final_safe<T>();
}
template <class T>
T* try_cast_final() {
return T::typeinfo() == dyn_typeinfo() ? static_cast<T*>(this) : nullptr;
}
template <class T>
const T* try_cast_final() const {
return const_cast<DynTypeObj*>(this)->try_cast_final<T>();
}
template <class T>
bool same_type() const {
return dyn_typeinfo() == T::typeinfo();
}
protected:
~DynTypeObj() = default;
};
#define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL
#define MGB_TYPEINFO_OBJ_DECL \
public: \
static inline ::mgb::Typeinfo* typeinfo() { return &sm_typeinfo; } \
\
private: \
static ::mgb::Typeinfo sm_typeinfo
#define MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT \
public: \
static inline ::mgb::Typeinfo* typeinfo() { return &sm_typeinfo; } \
\
private: \
static MGE_WIN_DECLSPEC_DATA ::mgb::Typeinfo sm_typeinfo
#if MGB_VERBOSE_TYPEINFO_NAME
#define _MGB_TYPEINFO_CLASS_NAME(_cls) #_cls
#else
#define _MGB_TYPEINFO_CLASS_NAME(_cls) ""
#endif
#define MGB_TYPEINFO_OBJ_IMPL(_cls) \
_MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL \
::mgb::Typeinfo _cls::sm_typeinfo { _MGB_TYPEINFO_CLASS_NAME(_cls) }
#define MGB_DYN_TYPE_OBJ_FINAL_DECL \
public: \
::mgb::Typeinfo* dyn_typeinfo() const override final; \
MGB_TYPEINFO_OBJ_DECL
#define MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT \
public: \
MGE_WIN_DECLSPEC_FUC ::mgb::Typeinfo* dyn_typeinfo() const override final; \
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT
#define MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) \
_MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL \
::mgb::Typeinfo* _cls::dyn_typeinfo() const { return &sm_typeinfo; } \
MGB_TYPEINFO_OBJ_IMPL(_cls)
class NonCopyableObj {
NonCopyableObj(const NonCopyableObj&) = delete;
NonCopyableObj& operator=(const NonCopyableObj&) = delete;
public:
NonCopyableObj() = default;
};
template <typename T>
class ReverseAdaptor {
T& m_t;
public:
ReverseAdaptor(T& t) : m_t(t) {}
typename T::reverse_iterator begin() { return m_t.rbegin(); }
typename T::reverse_iterator end() { return m_t.rend(); }
};
template <typename T>
class ConstReverseAdaptor {
const T& m_t;
public:
ConstReverseAdaptor(const T& t) : m_t(t) {}
typename T::const_reverse_iterator begin() { return m_t.crbegin(); }
typename T::const_reverse_iterator end() { return m_t.crend(); }
};
template <typename T>
ReverseAdaptor<T> reverse_adaptor(T& t) {
return {t};
}
template <typename T>
ConstReverseAdaptor<T> reverse_adaptor(const T& t) {
return {t};
}
template <
typename Iter,
class Cmp = std::less<typename std::iterator_traits<Iter>::value_type>>
void small_sort(Iter begin, Iter end, const Cmp& cmp = {}) {
if (begin == end)
return;
Iter i = begin;
++i;
for (; !(i == end); ++i) {
auto pivot = std::move(*i);
Iter j = i;
for (;;) {
if (begin == j)
break;
Iter jnext = j;
--j;
if (cmp(pivot, *j)) {
*jnext = std::move(*j);
} else {
j = jnext;
break;
}
}
*j = std::move(pivot);
}
}
template <class Container, class Key>
typename Container::iterator safe_find(Container& container, const Key& key) {
typename Container::iterator iter = container.find(key);
mgb_assert(iter != container.end());
return iter;
}
template <class Container, class Key>
typename Container::const_iterator safe_find(
const Container& container, const Key& key) {
typename Container::const_iterator iter = container.find(key);
mgb_assert(iter != container.end());
return iter;
}
template <class T, class Key>
typename SmallVector<T>::iterator safe_find(SmallVector<T>& container, const Key& key) {
typename SmallVector<T>::iterator iter =
std::find(container.begin(), container.end(), key);
mgb_assert(iter != container.end());
return iter;
}
template <class T, class Key>
typename SmallVector<T>::const_iterator safe_find(
const SmallVector<T>& container, const Key& key) {
typename SmallVector<T>::const_iterator iter =
std::find(container.begin(), container.end(), key);
mgb_assert(iter != container.end());
return iter;
}
template <class Key>
typename std::vector<Key>::iterator find(std::vector<Key>& vec, const Key& key) {
return std::find(vec.begin(), vec.end(), key);
}
template <class Key>
typename std::vector<Key>::const_iterator find(
const std::vector<Key>& vec, const Key& key) {
return std::find(vec.begin(), vec.end(), key);
}
template <class Key>
struct HashTrait<std::vector<Key>> {
static size_t eval(const std::vector<Key>& val) {
size_t rst = hash(val.size());
for (auto&& i : val)
rst = hash_pair_combine(rst, ::mgb::hash(i));
return rst;
}
};
template <class Map>
const typename Map::value_type::second_type& get_map_with_default(
const Map& map, const typename Map::value_type::first_type& key,
const typename Map::value_type::second_type& default_ = {}) {
auto iter = map.find(key);
return iter == map.end() ? default_ : iter->second;
}
template <class Obj, size_t SIZE, size_t ALIGN>
class alignas(ALIGN) IncompleteObjStorage {
uint8_t m_mem[SIZE];
public:
IncompleteObjStorage() {
static_assert(
sizeof(Obj) <= SIZE && !(ALIGN % alignof(Obj)),
"SIZE and ALIGN do not match Obj");
new (m_mem) Obj;
}
IncompleteObjStorage(const IncompleteObjStorage& rhs) {
new (m_mem) Obj(rhs.get());
}
IncompleteObjStorage(IncompleteObjStorage&& rhs) noexcept {
new (m_mem) Obj(std::move(rhs.get()));
}
IncompleteObjStorage& operator=(const IncompleteObjStorage& rhs) {
get() = rhs.get();
return *this;
}
IncompleteObjStorage& operator=(IncompleteObjStorage&& rhs) noexcept {
get() = std::move(rhs.get());
return *this;
}
~IncompleteObjStorage() noexcept { get().~Obj(); }
Obj& get() { return *aliased_ptr<Obj>(m_mem); }
const Obj& get() const { return const_cast<IncompleteObjStorage*>(this)->get(); }
};
template <class Obj, class Mock>
using IncompleteObjStorageMock = IncompleteObjStorage<Obj, sizeof(Mock), alignof(Mock)>;
class UserDataContainer {
public:
class UserData {
public:
virtual ~UserData() = default;
};
MGE_WIN_DECLSPEC_FUC ~UserDataContainer() noexcept;
template <typename T>
T* add_user_data(std::shared_ptr<T> data) {
static_assert(
std::is_base_of<UserData, T>::value, "must be derived from UserData");
auto ptr = data.get();
do_add(T::typeinfo(), std::move(data));
return ptr;
}
template <typename T>
int pop_user_data() {
static_assert(
std::is_base_of<UserData, T>::value, "must be derived from UserData");
return do_pop(T::typeinfo());
}
template <typename T>
std::pair<T* const*, size_t> get_user_data() const {
static_assert(
std::is_base_of<UserData, T>::value, "must be derived from UserData");
auto ret = do_get(T::typeinfo());
return {reinterpret_cast<T* const*>(ret.first), ret.second};
}
template <typename T, typename Maker>
T* get_user_data_or_create(Maker&& maker) {
static_assert(
std::is_base_of<UserData, T>::value, "must be derived from UserData");
auto type = T::typeinfo();
if (!m_storage.count(type)) {
do_add(type, maker());
}
return static_cast<T*>(do_get_one(type));
}
template <typename T>
T* get_user_data_or_create() {
return get_user_data_or_create<T>(std::make_shared<T>);
}
void clear_all_user_data();
void swap(UserDataContainer& other) {
m_refkeeper.swap(other.m_refkeeper);
m_storage.swap(other.m_storage);
}
private:
MGE_WIN_DECLSPEC_FUC void do_add(Typeinfo* type, std::shared_ptr<UserData> ptr);
MGE_WIN_DECLSPEC_FUC std::pair<void* const*, size_t> do_get(Typeinfo* type) const;
MGE_WIN_DECLSPEC_FUC void* do_get_one(Typeinfo* type) const;
MGE_WIN_DECLSPEC_FUC int do_pop(Typeinfo* type);
std::unordered_set<std::shared_ptr<UserData>> m_refkeeper;
ThinHashMap<Typeinfo*, SmallVector<void*, 1>> m_storage;
};
template <typename... Args>
class ContinuationCtx {
public:
using Next = thin_function<void(Args...)>;
using Err = thin_function<void(std::exception&)>;
ContinuationCtx(const Next& next = {}, const Err& err = {})
: m_next{next}, m_err{err} {}
template <class... T>
void next(T&&... args) const {
if (m_next)
m_next(std::forward<T>(args)...);
}
void err(std::exception& exc) const {
if (m_err)
m_err(exc);
}
private:
Next m_next;
Err m_err;
};
class CleanupCallback {
public:
using Callback = thin_function<void()>;
void add(Callback callback);
~CleanupCallback() noexcept(false);
private:
SmallVector<Callback> m_callbacks;
};
}
#define MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \
class _name : public _base, ##__VA_ARGS__ { \
public: \
using Super = _tpl _base; \
\
private:
#define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__)
#define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \
MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__)