#pragma once
#include <deque>
#include <future>
#include <list>
#include <stack>
#include <thread>
#include <unordered_set>
#include <variant>
#include "megbrain/comp_node.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/utils/mempool.h"
#include "./commands.h"
#include "./option_manager.h"
#include "./stack_manager.h"
#include "./tensor_info.h"
#include "../profiler/events.h"
namespace mgb::imperative::interpreter::intl {
using Handle = Interpreter::Handle;
struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override;
};
struct ChannelImpl : Interpreter::Channel {
ChannelImpl();
~ChannelImpl() override;
Handle put(const HostTensorND& value, bool no_cache) override;
Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
void del(Handle) override;
void drop(Handle) override;
SmallVector<Handle> apply_op(
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) override;
HostTensorND get_value(Handle) override;
TensorShape get_shape(Handle) override;
DType get_dtype(Handle) override;
CompNode get_device(Handle) override;
DeviceTensorND get_dev_tensor(Handle) override;
bool check_available() override;
void sync() override;
void close() override;
size_t get_option(std::string name) override;
void set_option(std::string name, size_t value) override;
void clear_candidates() override;
void start_profile() override;
void stop_profile() override;
void push_scope(std::string) override;
void pop_scope(std::string) override;
private:
struct WorkQueue;
struct State;
TensorInfo* alloc();
void init(TensorInfo*, LogicalTensorDesc desc);
void free(TensorInfo*);
void real_free(TensorInfo*);
void recursive_free(TensorInfo*);
void do_drop(TensorInfo*, bool);
void detach_users(TensorInfo*);
TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
TensorInfo* put_impl(const DeviceTensorND& value, const HostTensorND& hvalue);
void del_impl(Handle);
void sync_impl();
SmallVector<Handle> apply_op_impl(
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs);
TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
void notify_tensor_unsafe(TensorInfo* info);
void process_one_task(Command&);
void check_worker_exc_unsafe();
void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void release_tensor(TensorInfo* dest);
void regenerate(TensorInfo* dest);
void flush_apply_stack();
void do_apply_op(const ApplyOp& cmd, std::string reason);
void dispatch_default_cpu(
std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
void dispatch_kernel(
std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
void push_scope(std::string, State&);
void pop_scope(std::string, State&);
void assert_in_channel();
void assert_in_worker();
std::thread::id get_worker_tid();
void sample_on_device(CompNode device, bool force);
std::unordered_set<TensorInfo*> collect_valid_tensors();
std::mutex m_mutex;
Spinlock m_spin;
std::condition_variable m_cv;
MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle;
TensorInfo* m_waitee = nullptr;
Spinlock m_pool_spin;
Spinlock m_info_spin;
uint64_t m_waitee_id = 0;
std::exception_ptr m_worker_exc;
std::function<void(std::string, std::string)> m_profile_dump_callback;
size_t m_storage_id = 0;
std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
bool m_applying = false;
bool m_closed = false;
struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
WorkQueue(ChannelImpl* owner)
: AsyncQueueSC<Command, WorkQueue>(0, 10000), m_owner(owner) {
sys::set_thread_name("interpreter");
if (const char* env_val = MGB_GETENV("MEGENGINE_ASYNC_QUEUE_SIZE")) {
int len = strlen(env_val);
for (int i = 0; i < len; i++) {
mgb_assert(
env_val[i] >= '0' && env_val[i] <= '9',
"async queue size should be an integer");
}
size_t val;
sscanf(env_val, "%zu", &val);
update_max_items(val);
}
}
void process_one_task(Command& icmd) { m_owner->process_one_task(icmd); }
void on_async_queue_worker_thread_start() override;
private:
ChannelImpl* m_owner;
} m_worker;
int m_async_level = 2;
struct State {
std::thread::id tid;
OptionManager options;
};
struct ChannelState : State {
StackManager stack_manager;
};
struct WorkerState : State {};
ChannelState m_channel_state;
WorkerState m_worker_state;
struct DynamicSublinear {
TensorInfo* find_best_tensor(bool);
double estimate_neighbor_cost(TensorInfo* ptr);
void update_used_time(TensorInfo* ptr);
void merge(std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y);
std::shared_ptr<DsuNode> find_father(std::shared_ptr<DsuNode>& x);
void update_dsu_after_recompute(TensorInfo* ptr);
void update_dsu_after_evict(TensorInfo* ptr);
void pin(const SmallVector<TensorInfo*>& vec);
void unpin(const SmallVector<TensorInfo*>& vec, WorkerState& state);
void insert_candidate(TensorInfo* ptr);
void erase_candidate(TensorInfo* ptr);
double estimate_timestamp = 0;
CompNode comp_node;
SmallVector<TensorInfo*> candidates;
bool is_bad_op(std::string op_name) {
return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) !=
op_blacklist.end();
}
std::vector<std::string> op_blacklist = {
"CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat",
"GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG",
"PoissonRNG", "BetaRNG"};
} m_dtr;
bool auto_evict(size_t);
void alloc_tensor_with_evict(Blob*);
ChannelState& get_channel_state();
WorkerState& get_worker_state();
};
}