#pragma once
#include "./impl_common.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/utils/mempool.h"
namespace mgb {
namespace cg {
#if MGB_ENABLE_GRAD
class GradManager {
public:
struct VarVirtualReceiverDesc {
VarNodeArray inputs, outputs;
VarVirtualReceiverGrad grad;
};
GradManager(ComputingGraphImpl* graph);
~GradManager() noexcept;
VarNode* grad(VarNode* target, VarNode* wrt);
VarNode* current_grad_target() const {
return m_target_stack.empty() ? nullptr : m_target_stack.back();
}
void add_grad_transformer(VarNode* var, const GradTransformer& cb) {
m_grad_transformers[var].emplace_back(cb);
}
void add_extra_dep_for_grad(VarNode* inp, VarNode* out) {
m_extra_deps_inv_lookup[out].push_back(inp);
}
void add_var_virtual_receiver(const std::shared_ptr<VarVirtualReceiverDesc>& desc);
void clean_cache() {
for (auto&& i : m_target_context) {
i.second.cache.clear();
i.second.holistic_input_grads.clear();
}
}
private:
using VarMap = ThinHashMap<VarNode*, VarNode*>;
class StreamStrongPropInfer;
class ContextForTargetVar {
size_t m_virtual_receiver_version = 0;
ThinHashSet<OperatorNodeBase*> m_dep_oprs;
public:
VarMap cache;
ThinHashMap<OperatorNodeBase*, VarNodeArray> holistic_input_grads;
bool has_dep_opr(OperatorNodeBase* opr) const { return m_dep_oprs.count(opr); }
void init(GradManager* manager, VarNode* target);
};
using VarVirtualReceiverArray =
std::vector<std::shared_ptr<VarVirtualReceiverDesc>>;
struct VarReceiver {
OperatorNodeBase* const opr = nullptr;
VarVirtualReceiverDesc* const vrt = nullptr;
VarReceiver() = default;
VarReceiver(OperatorNodeBase* o) : opr{o} {}
VarReceiver(VarVirtualReceiverDesc* v) : vrt{v} {}
};
using VarReceiverArray = std::vector<VarReceiver>;
ComputingGraphImpl* const m_owner_graph;
std::unique_ptr<StreamStrongPropInfer> m_stream_strong_prop_infer;
ThinHashMap<VarNode*, ContextForTargetVar> m_target_context;
std::unordered_set<std::pair<VarNode*, VarNode*>, pairhash> m_in_stack;
ThinHashMap<VarNode*, std::vector<GradTransformer>> m_grad_transformers;
ThinHashMap<VarNode*, VarNodeArray> m_extra_deps_inv_lookup;
size_t m_virtual_receiver_version = 0;
ThinHashMap<VarNode*, VarVirtualReceiverArray> m_var2virtual_receiver;
ThinHashMap<VarNode*, std::vector<VarVirtualReceiverDesc*>>
m_var2virtual_receiver_inv;
std::vector<VarNode*> m_target_stack;
using DepSeq = std::vector<std::pair<VarNode*, VarReceiverArray>>;
struct GetDepSeqStackFrame;
DepSeq get_dep_seq(VarNode* start_var, const ContextForTargetVar& tgt_context);
VarNode* do_grad_with_cache(VarNode* target, VarNode* wrt);
VarNode* compute_grad_of_single_var(
VarNode* target, VarNode* wrt, ContextForTargetVar& context,
const VarReceiverArray& wrt_recv, VarNodeArray* tmp_var_arrs);
};
#else
class GradManager {
public:
GradManager(ComputingGraphImpl*) {}
};
#endif
} }