#pragma once
#include "./halide_header.h"
#if MGB_JIT_HALIDE
#include "megbrain/graph.h"
#include "megbrain/utils/metahelper.h"
namespace mgb {
namespace jit {
namespace ast_hl {
struct AstNode;
using AstNodePtr = std::shared_ptr<AstNode>;
using AstNodeArray = mgb::SmallVector<AstNodePtr>;
struct AstNode : public DynTypeObj {
AstNodeArray m_inputs;
Halide::Func m_func;
megdnn::TensorLayout m_layout;
virtual ~AstNode() = default;
virtual void init(cg::OperatorNodeBase* opr) = 0;
};
#define AST_NODE_DECL(_cls, _mem...) \
struct _cls final : public AstNode { \
MGB_DYN_TYPE_OBJ_FINAL_DECL; \
\
public: \
void init(cg::OperatorNodeBase* opr) override; \
_mem; \
}
AST_NODE_DECL(InputHostValueShapeOp);
AST_NODE_DECL(InputDevValueOp, Halide::Buffer<> m_buffer);
AST_NODE_DECL(ElemwiseOp);
AST_NODE_DECL(TypeCvtOp);
AST_NODE_DECL(ReduceOp, Halide::Func m_comp);
AST_NODE_DECL(
ScalarImmOp,
union Val {
int32_t iv;
float fv;
};
Val m_val);
AST_NODE_DECL(BroadcastOp);
template <class Op>
inline Op* try_cast_as_op(AstNode* node) {
if (node->same_type<Op>())
return &node->cast_final<Op>();
return nullptr;
}
AstNodePtr make_from_opr(cg::OperatorNodeBase* opr);
} } }
#endif