#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative::interpreter::intl {
class StackSnapshot;
class StackManager : public NonCopyableObj {
public:
class Node;
class Guard;
struct Frame;
class Trace;
private:
std::unique_ptr<Node> m_root = nullptr;
Node* m_current = nullptr;
SmallVector<uint64_t> m_trace_id_stack;
uint64_t m_last_trace_id = 0;
public:
StackManager();
std::pair<Node*, uint64_t> enter(std::string name);
void exit(std::string name);
Trace dump();
Node* current();
};
class StackManager::Node : public NonCopyableObj {
private:
std::string m_name;
std::unordered_map<std::string, std::unique_ptr<Node>> m_children;
std::unordered_map<std::string, size_t> m_id_table;
Node* m_parent = nullptr;
int64_t m_depth = -1;
uint64_t m_version = 0;
explicit Node(std::string name, Node* parent) : m_name{name}, m_parent{parent} {
if (parent) {
m_depth = parent->m_depth + 1;
}
}
public:
const std::string& name() const { return m_name; }
Node* operator[](const std::string& name) {
auto& child = m_children[name];
if (child == nullptr) {
child.reset(new Node(name, this));
}
return child.get();
}
Node* parent() { return m_parent; }
bool is_root() { return m_parent == nullptr; }
uint64_t version() const { return m_version; }
void update_version() {
++m_version;
for (auto&& [key, child] : m_children) {
child->reset_version();
}
m_id_table.clear();
}
void reset_version() {
m_version = 0;
m_id_table.clear();
}
int64_t depth() const { return m_depth; }
uint64_t next_id(std::string key) { return m_id_table[key]++; }
static std::unique_ptr<Node> make() {
return std::unique_ptr<Node>(new Node("", nullptr));
}
};
class StackManager::Guard {
private:
std::string m_name;
StackManager* m_manager;
public:
Guard(std::string name, StackManager* manager) : m_name{name}, m_manager{manager} {
if (m_manager) {
m_manager->enter(name);
}
}
~Guard() { release(); }
void release() {
if (m_manager) {
m_manager->exit(m_name);
m_manager = nullptr;
}
}
};
struct StackManager::Frame {
StackManager::Node* node;
uint64_t version;
};
class StackManager::Trace {
private:
SmallVector<StackManager::Frame> m_frames;
uint64_t m_id = 0;
public:
explicit Trace(StackManager::Node* top, uint64_t id) : m_id{id} {
int64_t nr_frames = top->depth() + 1;
m_frames = SmallVector<StackManager::Frame>(nr_frames);
StackManager::Node* node = top;
for (int64_t i = 0; i < nr_frames; ++i) {
m_frames[m_frames.size() - 1 - i] = {node, node->version()};
node = node->parent();
}
mgb_assert(node->is_root());
}
Trace() = default;
std::string to_string() const {
std::string buffer;
for (auto&& [node, version] : m_frames) {
if (!buffer.empty()) {
buffer.append(".");
}
buffer.append(node->name());
if (version != 0) {
buffer.append(ssprintf("[%zu]", version));
}
}
return buffer;
}
const SmallVector<StackManager::Frame>& frames() const { return m_frames; }
uint64_t id() const { return m_id; }
};
inline StackManager::StackManager() {
m_root = Node::make();
m_current = m_root.get();
}
inline std::pair<StackManager::Node*, uint64_t> StackManager::enter(std::string name) {
m_current = (*m_current)[name];
m_trace_id_stack.push_back(++m_last_trace_id);
return {m_current, m_current->version()};
}
inline void StackManager::exit(std::string name) {
mgb_assert(m_current->name() == name, "scope name mismatch");
m_current = m_current->parent();
m_trace_id_stack.pop_back();
m_current->update_version();
}
inline StackManager::Trace StackManager::dump() {
return Trace(m_current, m_trace_id_stack.empty() ? 0 : m_trace_id_stack.back());
}
inline StackManager::Node* StackManager::current() {
return m_current;
}
}